#!/usr/bin/env python3

import cv2
import numpy as np
from cv_bridge import CvBridge, CvBridgeError
from sensor_msgs.msg import Image, CompressedImage, CameraInfo
import rospy
from copy import copy
import yaml
import rospkg
import std_srvs.srv
from minibot_msgs.srv import segment_sign_command, segment_sign_commandRequest, segment_sign_commandResponse
import time

# *** hyper params ***
DEPTH_RES = (480, 640)
TENSOR_RES = (224, 224)
# gazebo 70,30; rviz 70, 20; real 100 40
canny = rospy.get_param("sign_detector/canny_param1", 100)  #100
accum_thresh = rospy.get_param("sign_detector/canny_param2", 40) #30
VISUALIZE = rospy.get_param("sign_detector/visualize", True) # Flase
ZOOM_THREASHOLD = rospy.get_param("sign_detector/zoom_threshold", 1.15)     #(1.15) multiplied percentage to the detected radius

MIN_DEPTH = rospy.get_param("sign_detector/min_depth", 0.2)  # 12
MAX_DEPTH = rospy.get_param("sign_detector/max_depth", 1.0)  # 20

MIN_RADIUS = rospy.get_param("sign_detector/min_radius", 15)  # 15
MAX_RADIUS = rospy.get_param("sign_detector/max_radius", 128)  # 128

# *** Globals ***
cv_bridge = CvBridge()
img_rgb_stream = np.zeros((IMG_RES[0], IMG_RES[1], 3), np.uint8)
img_depth_stream = np.zeros((IMG_RES[0], IMG_RES[1], 1), np.uint8)
ros_pack = None
segment_enable = False
segment_rate = None
toggle_patch_visualization = True
bridge = CvBridge()
def update_camera_params():
    global IMG_RES, DEPTH_RES

    cam_color_info = rospy.wait_for_message("camera/color/camera_info", CameraInfo)
    cam_depth_info = rospy.wait_for_message("camera/depth/camera_info", CameraInfo)

    IMG_RES = (cam_color_info.height, cam_color_info.width)
    DEPTH_RES = (cam_depth_info.height, cam_depth_info.width)

def fetch_rosparams():

    canny = rospy.get_param("sign_detector/canny_param1")
    if canny < 1:
        canny = 1
    accum_thresh = rospy.get_param("sign_detector/canny_param2")
    if accum_thresh < 1:
        accum_thresh = 1
    ZOOM_THREASHOLD = rospy.get_param("sign_detector/zoom_threshold")

    MIN_DEPTH = rospy.get_param("sign_detector/min_depth")
    MAX_DEPTH = rospy.get_param("sign_detector/max_depth")

    MIN_RADIUS = rospy.get_param("sign_detector/min_radius")
    MAX_RADIUS = rospy.get_param("sign_detector/max_radius")

def image_color_callback(data):
    global img_rgb_stream, cv_bridge

        img_rgb_stream = cv_bridge.imgmsg_to_cv2(data, "bgr8")
    except CvBridgeError as e:

def image_depth_callback(data):
    global img_depth_stream, cv_bridge

        img_depth_stream = cv_bridge.imgmsg_to_cv2(data, "16UC1")
    except CvBridgeError as e:

def circular_mean(p, r, arr : np.array):
    returns the mean intensity in a circle described by a middle point p and radius r of a grey image.
    #                   x_start         x_end       x_step  y_start         y_end       y_step
    xy = np.mgrid[int(p[0] - r) : int(p[0] + r) : 1, int(p[1] - r) : int(p[1] + r):1].reshape(2,-1).T
    sum_px_values = 0
    count_px = 0
    for x, y in xy:
        if x >= DEPTH_RES[1] or y >= DEPTH_RES[0]:
        if (x - p[0])**2 + (y - p[1])**2 < r**2:
            sum_px_values += arr[y, x]
            count_px += 1

    if count_px == 0:
        return 0

    return sum_px_values / count_px

def do_hough_circle_detection(img_rgb, img_depth, VISUALIZE=False):
    global canny, accum_thresh

    gray = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2GRAY)
    gray = cv2.medianBlur(gray, 5)      # reduce noise
    # TODO try
    #   It also helps to smooth image a bit unless it's already soft. For example,
    #   GaussianBlur() with 7x7 kernel and 1.5x1.5 sigma or similar blurring may help.
    circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, 1, gray.shape[0] / 4,
                               param1=canny,        # First method-specific parameter. In case of HOUGH_GRADIENT , it is the higher threshold of the two passed to the Canny edge detector (the lower one is twice smaller).
                               param2=accum_thresh,       # Second method-specific parameter. In case of HOUGH_GRADIENT , it is the accumulator threshold for the circle centers at the detection stage. The smaller it is, the more false circles may be detected. Circles, corresponding to the larger accumulator values, will be returned first.
                               minRadius=MIN_RADIUS, maxRadius=MAX_RADIUS)
    keypoint = []
    if circles is not None:
        circles = np.uint16(np.around(circles))
        i = circles[0, 0]
        center = (i[0], i[1])
        radius = i[2]

        # get depth in [m] (was radius * 0.4 in real world)
        d = circular_mean(center, radius * 0.2, copy(img_depth)) / 1000     # todo this is a slow implementation. You might want to speedup
        # filter if sign to close (circle detector will struggle) or to far (background)
        # was 0.2 and 1.0
        if d < MIN_DEPTH or d > MAX_DEPTH:
            return []
        keypoint.append({"center": center, "radius": radius, "depth": d})

        # circle center
        if VISUALIZE:
            cv2.putText(img_rgb, "d:{:1.3f} r:{:1.0f}".format(d, radius), (center[0], center[1] - radius - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), thickness=1)
  , center, 1, (0, 100, 100), 3)
            # circle outline
  , center, radius, (255, 0, 255), 3)

    return keypoint

def crop_to_bounds(crop_bounds, max_val):
    if crop_bounds[0] < 0:
        crop_bounds[1] += 0 - crop_bounds[0]
        crop_bounds[0] = 0
    if crop_bounds[1] > max_val:
        crop_bounds[0] -= crop_bounds[1] - max_val
        crop_bounds[1] = max_val

    return crop_bounds

def get_tensor_patches(img_rgb, keypoints, zoom=True):
    Turns a set of keypoints into image patches from the original image.
    Each patch has the size needed by the tensorflow model so that the patch can be directly fet into the image classifier.
    Each patch is zoomed such that the detected object fills the entire patch.
    :param img_rgb: original image
    :param keypoints: list of detected keypoints
    :param zoom: If true the resulting patch will be zoomed in such a way that it fits the radius*2 + some margin
    :return: A set of image patches.

    img_patches = []
    for k in keypoints:
        d = k["depth"]
        center = k["center"]
        center = [center[1], center[0]]
        r = k["radius"]

        # zoom into images based on radius?
        if zoom:
            zoom_factor = np.array(TENSOR_RES) / ((r*2 * ZOOM_THREASHOLD))
            zoomed_image = cv2.resize(img, dsize=None, fx=zoom_factor[0], fy=zoom_factor[1], interpolation=cv2.INTER_NEAREST)
            img_center_zoomed = (center * zoom_factor).astype(int)
            zoomed_image = img
            img_center_zoomed = center

        # handle border
        y = [img_center_zoomed[0] - TENSOR_RES[0] // 2, img_center_zoomed[0] + TENSOR_RES[0] // 2]
        y = crop_to_bounds(y, np.shape(zoomed_image)[0])
        x = [img_center_zoomed[1] - TENSOR_RES[1] // 2, img_center_zoomed[1] + TENSOR_RES[1] // 2]
        x = crop_to_bounds(x, np.shape(zoomed_image)[1])
        img_patches.append(zoomed_image[y[0]:y[1], x[0]:x[1], :])

    return img_patches

def visualize_patches(keypoints, patches, text, img_rgb):
    for i in range(len(keypoints)):
        k = keypoints[i]
        d = k["depth"]
        center = k["center"]
        center = [center[1], center[0]]
        r = k["radius"]
        patch = patches[i]

        # we need the exact idx in the non zoomed image, so we have to reacalc the boarders
        y = [center[0] - TENSOR_RES[0] // 2, center[0] + TENSOR_RES[0] // 2]
        y = crop_to_bounds(y, np.shape(img_rgb)[0])
        x = [center[1] - TENSOR_RES[1] // 2, center[1] + TENSOR_RES[1] // 2]
        x = crop_to_bounds(x, np.shape(img_rgb)[1])
        # replace the patch of the zoomed sign so that the patch that is fed to t flow can be directly seen
        img_rgb[y[0]:y[1], x[0]:x[1]] = patch
        cv2.rectangle(img_rgb, (x[0], y[0]), (x[1], y[1]), (255, 255, 255), thickness=1)
        # draw a text within a rectangle of uniform color
        cv2.rectangle(img_rgb, (x[0] - 10, y[0] - 25), (x[0] + TENSOR_RES[0] + 10, y[0]), (255, 255, 255), -1)
        cv2.putText(img_rgb, text[i], (x[0] - 10, y[0] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), thickness=2)

    return img_rgb

def filter_duplicate_keypoints(keypoints):
    required_keypoints = []
    for point in keypoints:
        result = list(filter(lambda x: abs(point['radius'] - x['radius']) < MIN_RADIUS, required_keypoints))
        if len(result) == 0:
    return required_keypoints

def call_fetch_rosparams():
    service = "sign_detector/update_rosparams"

        call_place = rospy.ServiceProxy(service, std_srvs.srv.Trigger())
        response =
    except rospy.ServiceException as e:
        print("Service call failed: %s" % e)

def save_params():
    During VISUALIZE mode.
    Saves the current configured parameter set permanently and in addition loads them to rosparam.
    global canny, accum_thresh, ros_pack

    # write to file
    path = "{}/config/sign_detector.yaml".format(ros_pack.get_path("minibot_vision"))
    with open(path, "r") as stream:
            data = yaml.safe_load(stream)
        except yaml.YAMLError as exc:

    with open(path, "w") as stream:
            data['sign_detector']['canny_param1'] = canny
            data['sign_detector']['canny_param2'] = accum_thresh
            yaml.safe_dump(data, stream)
        except yaml.YAMLError as exc:

    # save to rosparam
    rospy.set_param("sign_detector/canny_param1", canny)
    rospy.set_param("sign_detector/canny_param2", accum_thresh)


    rospy.loginfo("({}) Saved new params persistent".format(rospy.get_name()))

def load_default_params():
    During VISUALIZE mode.
    Loads the default params to the ui. These params are neither persistent saved nore loaded to rosparam.
    global canny, accum_thresh, ros_pack

    with open("{}/config/sign_detector_default.yaml".format(ros_pack.get_path("minibot_vision"))) as stream:
            data = yaml.safe_load(stream)
        except yaml.YAMLError as exc:

        canny = data['sign_detector']['canny_param1']
        accum_thresh = data['sign_detector']['canny_param2']
        #cv2.setTrackbarPos("Canny", "Parameters", canny)
        #cv2.setTrackbarPos("Accum", "Parameters", accum_thresh)
        rospy.set_param("sign_detector/canny_param1", canny)
        rospy.set_param("sign_detector/canny_param2", accum_thresh)

        rospy.loginfo("({}) Loaded default params".format(rospy.get_name()))

def enable_callback(req):
    global segment_enable, segment_rate

    segment_enable =
    rospy.loginfo("({}) set enable to {}".format(rospy.get_name(), segment_enable))
    # go in low power mode if the node is doing nothing
    if segment_enable:
        segment_rate = rospy.Rate(30)
        segment_rate = rospy.Rate(5)

    return True, ""

def command_callback(req):
    Callback to remote control the visualization of the segment sign hyper param adjustment.
    global toggle_patch_visualization

    if req.command == segment_sign_commandRequest.PERSISTENT_SAVE:
        rospy.loginfo("({}) save params persistent".format(rospy.get_name()))
    elif req.command == segment_sign_commandRequest.TOGGLE_PATCH_VISUALIZATION:
        rospy.loginfo("({}) Toggle patch visualisation is {}".format(rospy.get_name(), toggle_patch_visualization))
        toggle_patch_visualization = not toggle_patch_visualization
    elif req.command == segment_sign_commandRequest.LOAD_DEFAULT:
        rospy.loginfo("({}) load default params".format(rospy.get_name()))
        rospy.logwarn("({}) command {} not known".format(rospy.get_name(), req.command))

    return segment_sign_commandResponse()

def publish_image(img, pub_cmpr_img):
    global bridge

    # use same timestamp for synchronisation
    timestamp =

    # publish compressed img for website visualization
    cmprsmsg = bridge.cv2_to_compressed_imgmsg(img)
    cmprsmsg.header.stamp = timestamp

if __name__=="__main__":
    ros_pack = rospkg.RosPack()
    ns = rospy.get_namespace()
    segment_rate = rospy.Rate(5)
    # TODO docu
    # *** TOPICS
    img_color_topic = "{}camera/color/image_raw".format(ns)
    img_depth_topic = "{}camera/aligned_depth_to_color/image_raw".format(ns)
    rospy.Subscriber(img_color_topic, Image, image_color_callback)
    rospy.Subscriber(img_depth_topic, Image, image_depth_callback)
    pub_cmpr_img = rospy.Publisher("~result_image/compressed", CompressedImage, queue_size=10)

    # *** SERVICES
    rospy.Service("~enable", std_srvs.srv.SetBool, enable_callback)
    rospy.Service("~command", segment_sign_command, command_callback)
    # visualization in soft sleep, until awake service called
    while not rospy.is_shutdown():
        if segment_enable:
            canny = rospy.get_param("sign_detector/canny_param1")  #100
            accum_thresh = rospy.get_param("sign_detector/canny_param2") #30

            img_processed = copy(img_rgb_stream)
            keypoints = do_hough_circle_detection(img_processed, copy(img_depth_stream), VISUALIZE=True)
            if toggle_patch_visualization:
                img_processed = copy(img_rgb_stream)
                patches = get_tensor_patches(copy(img_rgb_stream), keypoints)
                img_processed = visualize_patches(keypoints, patches, ["d:{:1.3f}".format(k["depth"]) for k in keypoints], img_processed)
            publish_image(img_processed, pub_cmpr_img)