diff --git a/minibot_vision/config/sign_detector.yaml b/minibot_vision/config/sign_detector.yaml index b90f8a85e23145a8ca3aae624509137805333866..d7545ca367dc23f5d32bab62adf640a74c7912af 100644 --- a/minibot_vision/config/sign_detector.yaml +++ b/minibot_vision/config/sign_detector.yaml @@ -1,3 +1,3 @@ -sign_detector: {canny_param1: 47, canny_param2: 46, img_height: 1080, img_width: 1920, +sign_detector: {canny_param1: 47, canny_param2: 46, img_height: 480, img_width: 640, max_depth: 1.0, max_radius: 128, min_depth: 0.2, min_radius: 15, visualize: false, zoom_threshold: 1.15} diff --git a/minibot_vision/config/sign_detector_default.yaml b/minibot_vision/config/sign_detector_default.yaml index f9188fc21ebca596d2ef3490077b85664dc6ff07..1cd549cc9e028cb30aa1c0c64c9a508eece443ff 100644 --- a/minibot_vision/config/sign_detector_default.yaml +++ b/minibot_vision/config/sign_detector_default.yaml @@ -1,7 +1,7 @@ # default params that got not overriden sign_detector: - img_height: 1080 - img_width: 1920 + img_height: 480 + img_width: 640 canny_param1: 100 canny_param2: 100 # good light: 40; bad light: 100 min_depth: 0.2 diff --git a/minibot_vision/scripts/Capture_Images.py b/minibot_vision/scripts/Capture_Images.py index b20bf45b28c7e3496d50ed81bd56d9a9025e8df9..77388d2e3363c367291942b1dc2415cfc5e5af0b 100755 --- a/minibot_vision/scripts/Capture_Images.py +++ b/minibot_vision/scripts/Capture_Images.py @@ -98,6 +98,8 @@ def enable_callback(req): if __name__ == "__main__": rospy.init_node("capture_images") + rospy.logwarn("({}) This node is deprecated".format(rospy.get_name())) + if rospy.has_param("~remote_node"): REMOTE_NODE = rospy.get_param("~remote_node") if rospy.has_param("~save_dir"): diff --git a/minibot_vision/scripts/SegmentSign.py b/minibot_vision/scripts/SegmentSign.py index 64f7a6436011bb4997f9eb8b9d0691aa32feaaee..9a7c72d2c1b272a82d9ee1bf9aa1cb969d7f85b3 100644 --- a/minibot_vision/scripts/SegmentSign.py +++ b/minibot_vision/scripts/SegmentSign.py @@ -129,31 +129,36 @@ def crop_to_bounds(crop_bounds, max_val): return crop_bounds -def get_tensor_patches(img_rgb, keypoints): +def get_tensor_patches(img_rgb, keypoints, zoom=True): """ Turns a set of keypoints into image patches from the original image. Each patch has the size needed by the tensorflow model so that the patch can be directly fet into the image classifier. Each patch is zoomed such that the detected object fills the entire patch. :param img_rgb: original image :param keypoints: list of detected keypoints + :param zoom: If true the resulting patch will be zoomed in such a way that it fits the radius*2 + some margin :return: A set of image patches. """ global TENSOR_RES, ZOOM_THREASHOLD img_patches = [] for k in keypoints: - img = copy(img_rgb) + img = np.copy(img_rgb) d = k["depth"] center = k["center"] center = [center[1], center[0]] r = k["radius"] # zoom into images based on radius? - zoom_factor = np.array(TENSOR_RES) / ((r*2 * ZOOM_THREASHOLD)) - zoomed_image = cv2.resize(img, dsize=None, fx=zoom_factor[0], fy=zoom_factor[1], interpolation=cv2.INTER_NEAREST) + if zoom: + zoom_factor = np.array(TENSOR_RES) / ((r*2 * ZOOM_THREASHOLD)) + zoomed_image = cv2.resize(img, dsize=None, fx=zoom_factor[0], fy=zoom_factor[1], interpolation=cv2.INTER_NEAREST) + img_center_zoomed = (center * zoom_factor).astype(int) + else: + zoomed_image = img + img_center_zoomed = center # handle border - img_center_zoomed = (center * zoom_factor).astype(int) y = [img_center_zoomed[0] - TENSOR_RES[0] // 2, img_center_zoomed[0] + TENSOR_RES[0] // 2] y = crop_to_bounds(y, np.shape(zoomed_image)[0]) x = [img_center_zoomed[1] - TENSOR_RES[1] // 2, img_center_zoomed[1] + TENSOR_RES[1] // 2] diff --git a/minibot_vision/scripts/SignDetector.py b/minibot_vision/scripts/SignDetector.py index d8434279bf080a7d747e797d2e61b779ea8135ba..4c8cd421f51bf37dbdac66e04ff9e993fdfe7be1 100755 --- a/minibot_vision/scripts/SignDetector.py +++ b/minibot_vision/scripts/SignDetector.py @@ -19,6 +19,7 @@ from minibot_msgs.srv import set_url visualize = True camera_frame = "camera_aligned_depth_to_color_frame" #OLD CONFIG: IMG_RES = (480, 640) +# TODO get this from camera topic IMG_RES = (rospy.get_param("sign_detector/img_height"), rospy.get_param("sign_detector/img_width")) TF_RES = (224, 224) # tf is cropping the image @@ -31,6 +32,11 @@ img_depth_stream = np.zeros((IMG_RES[0], IMG_RES[1], 1), np.uint8) img_rgb = img_rgb_stream pub_keypoint = None pub_result_img = None +enable_sign_detection = True +enable_capture_images = False +toggle_hough_detection = True +pub_raw_img = None +pub_cmpr_img = None # subscribe to RGB img def image_color_callback(data): @@ -77,16 +83,46 @@ def publish_results(point, radius, depth, label, precision, timestamp): pub_keypoint.publish(detection_msg) + +def publish_img_patch(img_patch): + global bridge, pub_raw_img, pub_cmpr_img + + # use same timestamp for synchronisation + timestamp = rospy.Time.now() + + # publish non compressed image for saving + rawmsg = bridge.cv2_to_imgmsg(img_patch, "bgr8") + rawmsg.header.stamp = timestamp + pub_raw_img.publish(rawmsg) + # publish compressed img for website visualization + cmprsmsg = bridge.cv2_to_compressed_imgmsg(img_patch) + cmprsmsg.header.stamp = timestamp + pub_cmpr_img.publish(cmprsmsg) + + def detect_sign(img_rgb_stream, image_timestamp): - global img_depth_stream, pub_result_img + global img_depth_stream, pub_result_img, toggle_hough_detection, enable_sign_detection - img_orig = copy(img_rgb_stream) + img_orig = np.copy(img_rgb_stream) + # TODO the ratio between img_depth_stream and img_rgb_stream might be different! # get sign location in img - keypoints = SegmentSign.do_hough_circle_detection(copy(img_orig), copy(img_depth_stream)) - keypoints += ShapeDetector.do_shape_detection(copy(img_orig), copy(img_depth_stream)) - keypoints = SegmentSign.filter_duplicate_keypoints(keypoints) - patches = SegmentSign.get_tensor_patches(copy(img_orig), keypoints) + if toggle_hough_detection: + keypoints = SegmentSign.do_hough_circle_detection(copy(img_orig), copy(img_depth_stream)) + keypoints += ShapeDetector.do_shape_detection(copy(img_orig), copy(img_depth_stream)) + keypoints = SegmentSign.filter_duplicate_keypoints(keypoints) + # only use first keypoint (this should be the most accurate guess) + if len(keypoints) >= 1: + keypoints = [keypoints[0]] + else: + # use center of image + keypoints = [{"center": (IMG_RES[1]//2, IMG_RES[0]//2), "radius": TF_RES[0] // 2, "depth": -1}] + + patches = SegmentSign.get_tensor_patches(copy(img_orig), keypoints, zoom=toggle_hough_detection) + + # publish patch for capture images + if enable_capture_images: + publish_img_patch(patches[0]) # cut to multiple images at keypoints text = [] @@ -98,11 +134,13 @@ def detect_sign(img_rgb_stream, image_timestamp): r = k["radius"] # classify image batches - label, precision = sign_classifier.predictImage(p) # returns tupel (label, precision), if no model / error is set up label= -1 - if label >= 0: - # publish results - publish_results(center, r, d, label, precision, image_timestamp) - text.append("c: {} p: {:1.3f} d:{:1.3f}".format(sign_classifier.labelOfClass(label), precision, d)) + label, precision = -1, -1 + if enable_sign_detection: + label, precision = sign_classifier.predictImage(p) # returns tupel (label, precision), if no model / error is set up label= -1 + + # publish results + publish_results(center, r, d, label, precision, image_timestamp) + text.append("c: {} p: {:1.3f} d:{:1.3f}".format(sign_classifier.labelOfClass(label), precision, d)) if visualize: if len(text) > 0: @@ -125,6 +163,33 @@ def set_visualize_callback(req): return True, "" +def enable_capture_callback(req): + global enable_capture_images + + enable_capture_images = req.data + rospy.loginfo("({}) set enable_capture_images to {}".format(rospy.get_name(), enable_capture_images)) + + return True, "" + + +def toggle_hough_detection_callback(req): + global toggle_hough_detection + + toggle_hough_detection = req.data + rospy.loginfo("({}) set toggle_hough_detection to {}".format(rospy.get_name(), toggle_hough_detection)) + + return True, "" + + +def enable_sign_detector_callback(req): + global enable_sign_detection + + enable_sign_detection = req.data + rospy.loginfo("({}) set enable_sign_detection to {}".format(rospy.get_name(), enable_sign_detection)) + + return True, "" + + if __name__ == "__main__": rospy.init_node("sign_detector") @@ -135,9 +200,17 @@ if __name__ == "__main__": rospy.Subscriber(img_depth_topic, Image, image_depth_callback, queue_size=1) rospy.Service('sign_detector/set_model', set_url, set_model_callback) rospy.Service('sign_detector/set_visualize', std_srvs.srv.SetBool, set_visualize_callback) + # TODO enable sign detection + rospy.Service('sign_detector/enable', std_srvs.srv.SetBool, enable_sign_detector_callback) pub_keypoint = rospy.Publisher('sign_detector/keypoints', Detection2D, queue_size=10) pub_result_img = rospy.Publisher("sign_detector/result_image/compressed", CompressedImage, queue_size=10) + # capture images stuff + rospy.Service("capture_imgs/enable", std_srvs.srv.SetBool, enable_capture_callback) # TODO migrate service uri to this node! + rospy.Service("capture_imgs/toggle_hough_detection", std_srvs.srv.SetBool, toggle_hough_detection_callback) # TODO migrate service uri to this node! + pub_raw_img = rospy.Publisher("capture_imgs/result_image", Image, queue_size=10) # TODO migrate service uri to this node! + pub_cmpr_img = rospy.Publisher("capture_imgs/result_image/compressed", CompressedImage, queue_size=10) # TODO migrate service uri to this node! + rate = rospy.Rate(30) # currently this is impossible, but then the rate is defined by the detect_sign evaluation time while not rospy.is_shutdown(): detect_sign(img_rgb_stream, img_rgb_timestamp) diff --git a/minibot_workshops/launch/traffic_sign_workshop.launch b/minibot_workshops/launch/traffic_sign_workshop.launch index 4f26916cc079fe8c6c0581e99adf48ca5c48054e..7a2069d458d5781355654ba6c149aa801988c516 100644 --- a/minibot_workshops/launch/traffic_sign_workshop.launch +++ b/minibot_workshops/launch/traffic_sign_workshop.launch @@ -1,8 +1,8 @@ <?xml version="1.0"?> <launch> - <include file="$(find minibot_vision)/launch/capture_imgs.launch" > + <!--include file="$(find minibot_vision)/launch/capture_imgs.launch" > <arg name="remote_node" value="true" /> - </include> + </include--> <group ns="$(env ROBOT)" > <node name="sign_detector" pkg="minibot_vision" type="SignDetector.py" output="screen" />