From 9096f55c5cf2dff8b7bccd38e3f4d906c68a47bc Mon Sep 17 00:00:00 2001
From: minibot-1 <paddy-hofmann@web.de>
Date: Fri, 17 Mar 2023 10:31:08 +0000
Subject: [PATCH] minibot_vision: refactoring and added toggle hough transform
 capabilities

---
 minibot_vision/config/sign_detector.yaml      |  2 +-
 .../config/sign_detector_default.yaml         |  4 +-
 minibot_vision/scripts/Capture_Images.py      |  2 +
 minibot_vision/scripts/SegmentSign.py         | 15 ++-
 minibot_vision/scripts/SignDetector.py        | 95 ++++++++++++++++---
 .../launch/traffic_sign_workshop.launch       |  4 +-
 6 files changed, 101 insertions(+), 21 deletions(-)

diff --git a/minibot_vision/config/sign_detector.yaml b/minibot_vision/config/sign_detector.yaml
index b90f8a8..d7545ca 100644
--- a/minibot_vision/config/sign_detector.yaml
+++ b/minibot_vision/config/sign_detector.yaml
@@ -1,3 +1,3 @@
-sign_detector: {canny_param1: 47, canny_param2: 46, img_height: 1080, img_width: 1920,
+sign_detector: {canny_param1: 47, canny_param2: 46, img_height: 480, img_width: 640,
   max_depth: 1.0, max_radius: 128, min_depth: 0.2, min_radius: 15, visualize: false,
   zoom_threshold: 1.15}
diff --git a/minibot_vision/config/sign_detector_default.yaml b/minibot_vision/config/sign_detector_default.yaml
index f9188fc..1cd549c 100644
--- a/minibot_vision/config/sign_detector_default.yaml
+++ b/minibot_vision/config/sign_detector_default.yaml
@@ -1,7 +1,7 @@
 # default params that got not overriden
 sign_detector:
-  img_height: 1080
-  img_width: 1920
+  img_height: 480
+  img_width: 640
   canny_param1: 100
   canny_param2: 100 # good light: 40; bad light: 100
   min_depth: 0.2
diff --git a/minibot_vision/scripts/Capture_Images.py b/minibot_vision/scripts/Capture_Images.py
index b20bf45..77388d2 100755
--- a/minibot_vision/scripts/Capture_Images.py
+++ b/minibot_vision/scripts/Capture_Images.py
@@ -98,6 +98,8 @@ def enable_callback(req):
 if __name__ == "__main__":
     rospy.init_node("capture_images")
 
+    rospy.logwarn("({}) This node is deprecated".format(rospy.get_name()))
+
     if rospy.has_param("~remote_node"):
         REMOTE_NODE = rospy.get_param("~remote_node")
     if rospy.has_param("~save_dir"):
diff --git a/minibot_vision/scripts/SegmentSign.py b/minibot_vision/scripts/SegmentSign.py
index 64f7a64..9a7c72d 100644
--- a/minibot_vision/scripts/SegmentSign.py
+++ b/minibot_vision/scripts/SegmentSign.py
@@ -129,31 +129,36 @@ def crop_to_bounds(crop_bounds, max_val):
     return crop_bounds
 
 
-def get_tensor_patches(img_rgb, keypoints):
+def get_tensor_patches(img_rgb, keypoints, zoom=True):
     """
     Turns a set of keypoints into image patches from the original image.
     Each patch has the size needed by the tensorflow model so that the patch can be directly fet into the image classifier.
     Each patch is zoomed such that the detected object fills the entire patch.
     :param img_rgb: original image
     :param keypoints: list of detected keypoints
+    :param zoom: If true the resulting patch will be zoomed in such a way that it fits the radius*2 + some margin
     :return: A set of image patches.
     """
     global TENSOR_RES, ZOOM_THREASHOLD
 
     img_patches = []
     for k in keypoints:
-        img = copy(img_rgb)
+        img = np.copy(img_rgb)
         d = k["depth"]
         center = k["center"]
         center = [center[1], center[0]]
         r = k["radius"]
 
         # zoom into images based on radius?
-        zoom_factor = np.array(TENSOR_RES) / ((r*2 * ZOOM_THREASHOLD))
-        zoomed_image = cv2.resize(img, dsize=None, fx=zoom_factor[0], fy=zoom_factor[1], interpolation=cv2.INTER_NEAREST)
+        if zoom:
+            zoom_factor = np.array(TENSOR_RES) / ((r*2 * ZOOM_THREASHOLD))
+            zoomed_image = cv2.resize(img, dsize=None, fx=zoom_factor[0], fy=zoom_factor[1], interpolation=cv2.INTER_NEAREST)
+            img_center_zoomed = (center * zoom_factor).astype(int)
+        else:
+            zoomed_image = img
+            img_center_zoomed = center
 
         # handle border
-        img_center_zoomed = (center * zoom_factor).astype(int)
         y = [img_center_zoomed[0] - TENSOR_RES[0] // 2, img_center_zoomed[0] + TENSOR_RES[0] // 2]
         y = crop_to_bounds(y, np.shape(zoomed_image)[0])
         x = [img_center_zoomed[1] - TENSOR_RES[1] // 2, img_center_zoomed[1] + TENSOR_RES[1] // 2]
diff --git a/minibot_vision/scripts/SignDetector.py b/minibot_vision/scripts/SignDetector.py
index d843427..4c8cd42 100755
--- a/minibot_vision/scripts/SignDetector.py
+++ b/minibot_vision/scripts/SignDetector.py
@@ -19,6 +19,7 @@ from minibot_msgs.srv import set_url
 visualize = True
 camera_frame = "camera_aligned_depth_to_color_frame"
 #OLD CONFIG: IMG_RES = (480, 640)
+# TODO get this from camera topic
 IMG_RES = (rospy.get_param("sign_detector/img_height"), rospy.get_param("sign_detector/img_width")) 
 TF_RES = (224, 224)     # tf is cropping the image
 
@@ -31,6 +32,11 @@ img_depth_stream = np.zeros((IMG_RES[0], IMG_RES[1], 1), np.uint8)
 img_rgb = img_rgb_stream
 pub_keypoint = None
 pub_result_img = None
+enable_sign_detection = True
+enable_capture_images = False
+toggle_hough_detection = True
+pub_raw_img = None
+pub_cmpr_img = None
 
 # subscribe to RGB img
 def image_color_callback(data):
@@ -77,16 +83,46 @@ def publish_results(point, radius, depth, label, precision, timestamp):
 
     pub_keypoint.publish(detection_msg)
 
+
+def publish_img_patch(img_patch):
+    global bridge, pub_raw_img, pub_cmpr_img
+
+    # use same timestamp for synchronisation
+    timestamp = rospy.Time.now()
+
+    # publish non compressed image for saving
+    rawmsg = bridge.cv2_to_imgmsg(img_patch, "bgr8")
+    rawmsg.header.stamp = timestamp
+    pub_raw_img.publish(rawmsg)
+    # publish compressed img for website visualization
+    cmprsmsg = bridge.cv2_to_compressed_imgmsg(img_patch)
+    cmprsmsg.header.stamp = timestamp
+    pub_cmpr_img.publish(cmprsmsg)
+
+
 def detect_sign(img_rgb_stream, image_timestamp):
-    global img_depth_stream, pub_result_img
+    global img_depth_stream, pub_result_img, toggle_hough_detection, enable_sign_detection
 
-    img_orig = copy(img_rgb_stream)
+    img_orig = np.copy(img_rgb_stream)
+    # TODO the ratio between img_depth_stream and img_rgb_stream might be different!
 
     # get sign location in img
-    keypoints = SegmentSign.do_hough_circle_detection(copy(img_orig), copy(img_depth_stream))
-    keypoints += ShapeDetector.do_shape_detection(copy(img_orig), copy(img_depth_stream))
-    keypoints = SegmentSign.filter_duplicate_keypoints(keypoints)
-    patches = SegmentSign.get_tensor_patches(copy(img_orig), keypoints)
+    if toggle_hough_detection:
+        keypoints = SegmentSign.do_hough_circle_detection(copy(img_orig), copy(img_depth_stream))
+        keypoints += ShapeDetector.do_shape_detection(copy(img_orig), copy(img_depth_stream))
+        keypoints = SegmentSign.filter_duplicate_keypoints(keypoints)
+        # only use first keypoint (this should be the most accurate guess)
+        if len(keypoints) >= 1:
+            keypoints = [keypoints[0]]
+    else:
+        # use center of image
+        keypoints = [{"center": (IMG_RES[1]//2, IMG_RES[0]//2), "radius": TF_RES[0] // 2, "depth": -1}]
+
+    patches = SegmentSign.get_tensor_patches(copy(img_orig), keypoints, zoom=toggle_hough_detection)
+
+    # publish patch for capture images
+    if enable_capture_images:
+        publish_img_patch(patches[0])
 
     # cut to multiple images at keypoints
     text = []
@@ -98,11 +134,13 @@ def detect_sign(img_rgb_stream, image_timestamp):
         r = k["radius"]
 
         # classify image batches
-        label, precision = sign_classifier.predictImage(p)  # returns tupel (label, precision), if no model / error is set up label= -1
-        if label >= 0:
-            # publish results
-            publish_results(center, r, d, label, precision, image_timestamp)
-            text.append("c: {} p: {:1.3f} d:{:1.3f}".format(sign_classifier.labelOfClass(label), precision, d))
+        label, precision = -1, -1
+        if enable_sign_detection:
+            label, precision = sign_classifier.predictImage(p)  # returns tupel (label, precision), if no model / error is set up label= -1
+
+        # publish results
+        publish_results(center, r, d, label, precision, image_timestamp)
+        text.append("c: {} p: {:1.3f} d:{:1.3f}".format(sign_classifier.labelOfClass(label), precision, d))
 
     if visualize:
         if len(text) > 0:
@@ -125,6 +163,33 @@ def set_visualize_callback(req):
     return True, ""
 
 
+def enable_capture_callback(req):
+    global enable_capture_images
+
+    enable_capture_images = req.data
+    rospy.loginfo("({}) set enable_capture_images to {}".format(rospy.get_name(), enable_capture_images))
+
+    return True, ""
+
+
+def toggle_hough_detection_callback(req):
+    global toggle_hough_detection
+
+    toggle_hough_detection = req.data
+    rospy.loginfo("({}) set toggle_hough_detection to {}".format(rospy.get_name(), toggle_hough_detection))
+
+    return True, ""
+
+
+def enable_sign_detector_callback(req):
+    global enable_sign_detection
+
+    enable_sign_detection = req.data
+    rospy.loginfo("({}) set enable_sign_detection to {}".format(rospy.get_name(), enable_sign_detection))
+
+    return True, ""
+
+
 if __name__ == "__main__":
     rospy.init_node("sign_detector")
 
@@ -135,9 +200,17 @@ if __name__ == "__main__":
     rospy.Subscriber(img_depth_topic, Image, image_depth_callback, queue_size=1)
     rospy.Service('sign_detector/set_model', set_url, set_model_callback)
     rospy.Service('sign_detector/set_visualize', std_srvs.srv.SetBool, set_visualize_callback)
+    # TODO enable sign detection
+    rospy.Service('sign_detector/enable', std_srvs.srv.SetBool, enable_sign_detector_callback)
     pub_keypoint = rospy.Publisher('sign_detector/keypoints', Detection2D, queue_size=10)
     pub_result_img = rospy.Publisher("sign_detector/result_image/compressed", CompressedImage, queue_size=10)
 
+    # capture images stuff
+    rospy.Service("capture_imgs/enable", std_srvs.srv.SetBool, enable_capture_callback)     # TODO migrate service uri to this node!
+    rospy.Service("capture_imgs/toggle_hough_detection", std_srvs.srv.SetBool, toggle_hough_detection_callback)     # TODO migrate service uri to this node!
+    pub_raw_img = rospy.Publisher("capture_imgs/result_image", Image, queue_size=10)        # TODO migrate service uri to this node!
+    pub_cmpr_img = rospy.Publisher("capture_imgs/result_image/compressed", CompressedImage, queue_size=10)      # TODO migrate service uri to this node!
+
     rate = rospy.Rate(30)       # currently this is impossible, but then the rate is defined by the detect_sign evaluation time
     while not rospy.is_shutdown():
         detect_sign(img_rgb_stream, img_rgb_timestamp)
diff --git a/minibot_workshops/launch/traffic_sign_workshop.launch b/minibot_workshops/launch/traffic_sign_workshop.launch
index 4f26916..7a2069d 100644
--- a/minibot_workshops/launch/traffic_sign_workshop.launch
+++ b/minibot_workshops/launch/traffic_sign_workshop.launch
@@ -1,8 +1,8 @@
 <?xml version="1.0"?>
 <launch>
-    <include file="$(find minibot_vision)/launch/capture_imgs.launch" >
+    <!--include file="$(find minibot_vision)/launch/capture_imgs.launch" >
         <arg name="remote_node" value="true" />
-    </include>
+    </include-->
 
     <group ns="$(env ROBOT)" >
         <node name="sign_detector" pkg="minibot_vision" type="SignDetector.py" output="screen" />
-- 
GitLab