Skip to content
Snippets Groups Projects
Commit 9096f55c authored by User expired's avatar User expired :spy_tone1:
Browse files

minibot_vision: refactoring and added toggle hough transform capabilities

parent d20fbdb0
No related branches found
No related tags found
No related merge requests found
sign_detector: {canny_param1: 47, canny_param2: 46, img_height: 1080, img_width: 1920,
sign_detector: {canny_param1: 47, canny_param2: 46, img_height: 480, img_width: 640,
max_depth: 1.0, max_radius: 128, min_depth: 0.2, min_radius: 15, visualize: false,
zoom_threshold: 1.15}
# default params that got not overriden
sign_detector:
img_height: 1080
img_width: 1920
img_height: 480
img_width: 640
canny_param1: 100
canny_param2: 100 # good light: 40; bad light: 100
min_depth: 0.2
......
......@@ -98,6 +98,8 @@ def enable_callback(req):
if __name__ == "__main__":
rospy.init_node("capture_images")
rospy.logwarn("({}) This node is deprecated".format(rospy.get_name()))
if rospy.has_param("~remote_node"):
REMOTE_NODE = rospy.get_param("~remote_node")
if rospy.has_param("~save_dir"):
......
......@@ -129,31 +129,36 @@ def crop_to_bounds(crop_bounds, max_val):
return crop_bounds
def get_tensor_patches(img_rgb, keypoints):
def get_tensor_patches(img_rgb, keypoints, zoom=True):
"""
Turns a set of keypoints into image patches from the original image.
Each patch has the size needed by the tensorflow model so that the patch can be directly fet into the image classifier.
Each patch is zoomed such that the detected object fills the entire patch.
:param img_rgb: original image
:param keypoints: list of detected keypoints
:param zoom: If true the resulting patch will be zoomed in such a way that it fits the radius*2 + some margin
:return: A set of image patches.
"""
global TENSOR_RES, ZOOM_THREASHOLD
img_patches = []
for k in keypoints:
img = copy(img_rgb)
img = np.copy(img_rgb)
d = k["depth"]
center = k["center"]
center = [center[1], center[0]]
r = k["radius"]
# zoom into images based on radius?
zoom_factor = np.array(TENSOR_RES) / ((r*2 * ZOOM_THREASHOLD))
zoomed_image = cv2.resize(img, dsize=None, fx=zoom_factor[0], fy=zoom_factor[1], interpolation=cv2.INTER_NEAREST)
if zoom:
zoom_factor = np.array(TENSOR_RES) / ((r*2 * ZOOM_THREASHOLD))
zoomed_image = cv2.resize(img, dsize=None, fx=zoom_factor[0], fy=zoom_factor[1], interpolation=cv2.INTER_NEAREST)
img_center_zoomed = (center * zoom_factor).astype(int)
else:
zoomed_image = img
img_center_zoomed = center
# handle border
img_center_zoomed = (center * zoom_factor).astype(int)
y = [img_center_zoomed[0] - TENSOR_RES[0] // 2, img_center_zoomed[0] + TENSOR_RES[0] // 2]
y = crop_to_bounds(y, np.shape(zoomed_image)[0])
x = [img_center_zoomed[1] - TENSOR_RES[1] // 2, img_center_zoomed[1] + TENSOR_RES[1] // 2]
......
......@@ -19,6 +19,7 @@ from minibot_msgs.srv import set_url
visualize = True
camera_frame = "camera_aligned_depth_to_color_frame"
#OLD CONFIG: IMG_RES = (480, 640)
# TODO get this from camera topic
IMG_RES = (rospy.get_param("sign_detector/img_height"), rospy.get_param("sign_detector/img_width"))
TF_RES = (224, 224) # tf is cropping the image
......@@ -31,6 +32,11 @@ img_depth_stream = np.zeros((IMG_RES[0], IMG_RES[1], 1), np.uint8)
img_rgb = img_rgb_stream
pub_keypoint = None
pub_result_img = None
enable_sign_detection = True
enable_capture_images = False
toggle_hough_detection = True
pub_raw_img = None
pub_cmpr_img = None
# subscribe to RGB img
def image_color_callback(data):
......@@ -77,16 +83,46 @@ def publish_results(point, radius, depth, label, precision, timestamp):
pub_keypoint.publish(detection_msg)
def publish_img_patch(img_patch):
global bridge, pub_raw_img, pub_cmpr_img
# use same timestamp for synchronisation
timestamp = rospy.Time.now()
# publish non compressed image for saving
rawmsg = bridge.cv2_to_imgmsg(img_patch, "bgr8")
rawmsg.header.stamp = timestamp
pub_raw_img.publish(rawmsg)
# publish compressed img for website visualization
cmprsmsg = bridge.cv2_to_compressed_imgmsg(img_patch)
cmprsmsg.header.stamp = timestamp
pub_cmpr_img.publish(cmprsmsg)
def detect_sign(img_rgb_stream, image_timestamp):
global img_depth_stream, pub_result_img
global img_depth_stream, pub_result_img, toggle_hough_detection, enable_sign_detection
img_orig = copy(img_rgb_stream)
img_orig = np.copy(img_rgb_stream)
# TODO the ratio between img_depth_stream and img_rgb_stream might be different!
# get sign location in img
keypoints = SegmentSign.do_hough_circle_detection(copy(img_orig), copy(img_depth_stream))
keypoints += ShapeDetector.do_shape_detection(copy(img_orig), copy(img_depth_stream))
keypoints = SegmentSign.filter_duplicate_keypoints(keypoints)
patches = SegmentSign.get_tensor_patches(copy(img_orig), keypoints)
if toggle_hough_detection:
keypoints = SegmentSign.do_hough_circle_detection(copy(img_orig), copy(img_depth_stream))
keypoints += ShapeDetector.do_shape_detection(copy(img_orig), copy(img_depth_stream))
keypoints = SegmentSign.filter_duplicate_keypoints(keypoints)
# only use first keypoint (this should be the most accurate guess)
if len(keypoints) >= 1:
keypoints = [keypoints[0]]
else:
# use center of image
keypoints = [{"center": (IMG_RES[1]//2, IMG_RES[0]//2), "radius": TF_RES[0] // 2, "depth": -1}]
patches = SegmentSign.get_tensor_patches(copy(img_orig), keypoints, zoom=toggle_hough_detection)
# publish patch for capture images
if enable_capture_images:
publish_img_patch(patches[0])
# cut to multiple images at keypoints
text = []
......@@ -98,11 +134,13 @@ def detect_sign(img_rgb_stream, image_timestamp):
r = k["radius"]
# classify image batches
label, precision = sign_classifier.predictImage(p) # returns tupel (label, precision), if no model / error is set up label= -1
if label >= 0:
# publish results
publish_results(center, r, d, label, precision, image_timestamp)
text.append("c: {} p: {:1.3f} d:{:1.3f}".format(sign_classifier.labelOfClass(label), precision, d))
label, precision = -1, -1
if enable_sign_detection:
label, precision = sign_classifier.predictImage(p) # returns tupel (label, precision), if no model / error is set up label= -1
# publish results
publish_results(center, r, d, label, precision, image_timestamp)
text.append("c: {} p: {:1.3f} d:{:1.3f}".format(sign_classifier.labelOfClass(label), precision, d))
if visualize:
if len(text) > 0:
......@@ -125,6 +163,33 @@ def set_visualize_callback(req):
return True, ""
def enable_capture_callback(req):
global enable_capture_images
enable_capture_images = req.data
rospy.loginfo("({}) set enable_capture_images to {}".format(rospy.get_name(), enable_capture_images))
return True, ""
def toggle_hough_detection_callback(req):
global toggle_hough_detection
toggle_hough_detection = req.data
rospy.loginfo("({}) set toggle_hough_detection to {}".format(rospy.get_name(), toggle_hough_detection))
return True, ""
def enable_sign_detector_callback(req):
global enable_sign_detection
enable_sign_detection = req.data
rospy.loginfo("({}) set enable_sign_detection to {}".format(rospy.get_name(), enable_sign_detection))
return True, ""
if __name__ == "__main__":
rospy.init_node("sign_detector")
......@@ -135,9 +200,17 @@ if __name__ == "__main__":
rospy.Subscriber(img_depth_topic, Image, image_depth_callback, queue_size=1)
rospy.Service('sign_detector/set_model', set_url, set_model_callback)
rospy.Service('sign_detector/set_visualize', std_srvs.srv.SetBool, set_visualize_callback)
# TODO enable sign detection
rospy.Service('sign_detector/enable', std_srvs.srv.SetBool, enable_sign_detector_callback)
pub_keypoint = rospy.Publisher('sign_detector/keypoints', Detection2D, queue_size=10)
pub_result_img = rospy.Publisher("sign_detector/result_image/compressed", CompressedImage, queue_size=10)
# capture images stuff
rospy.Service("capture_imgs/enable", std_srvs.srv.SetBool, enable_capture_callback) # TODO migrate service uri to this node!
rospy.Service("capture_imgs/toggle_hough_detection", std_srvs.srv.SetBool, toggle_hough_detection_callback) # TODO migrate service uri to this node!
pub_raw_img = rospy.Publisher("capture_imgs/result_image", Image, queue_size=10) # TODO migrate service uri to this node!
pub_cmpr_img = rospy.Publisher("capture_imgs/result_image/compressed", CompressedImage, queue_size=10) # TODO migrate service uri to this node!
rate = rospy.Rate(30) # currently this is impossible, but then the rate is defined by the detect_sign evaluation time
while not rospy.is_shutdown():
detect_sign(img_rgb_stream, img_rgb_timestamp)
......
<?xml version="1.0"?>
<launch>
<include file="$(find minibot_vision)/launch/capture_imgs.launch" >
<!--include file="$(find minibot_vision)/launch/capture_imgs.launch" >
<arg name="remote_node" value="true" />
</include>
</include-->
<group ns="$(env ROBOT)" >
<node name="sign_detector" pkg="minibot_vision" type="SignDetector.py" output="screen" />
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment