#!/usr/bin/env python3

import rospy
import std_srvs.srv

import ShapeDetector
import SegmentSign
from cv_bridge import CvBridge, CvBridgeError
import numpy as np
from sensor_msgs.msg import Image
from sensor_msgs.msg import CompressedImage
from TmClassification import TmClassification
import cv2
from copy import copy
from vision_msgs.msg import Detection2D, ObjectHypothesisWithPose
from minibot_msgs.srv import set_url

# *** CONSTANTS ***
visualize = True
camera_frame = "camera_aligned_depth_to_color_frame"
#OLD CONFIG: IMG_RES = (480, 640)
IMG_RES = (rospy.get_param("sign_detector/img_height"), rospy.get_param("sign_detector/img_width")) 
TF_RES = (224, 224)     # tf is cropping the image

# *** GLOBALS ***
sign_classifier = TmClassification()
bridge = CvBridge()
img_rgb_stream = np.zeros((IMG_RES[0], IMG_RES[1], 3), np.uint8)
img_rgb_timestamp = rospy.Time(0, 0)
img_depth_stream = np.zeros((IMG_RES[0], IMG_RES[1], 1), np.uint8)
img_rgb = img_rgb_stream
pub_keypoint = None
pub_result_img = None

# subscribe to RGB img
def image_color_callback(data):
    global bridge, img_rgb_stream, img_rgb_timestamp

    try:
        img_rgb_stream = bridge.imgmsg_to_cv2(data, "bgr8")
        img_rgb_timestamp = rospy.Time.now()
    except CvBridgeError as e:
        print(e)


def image_depth_callback(data):
    global img_depth_stream, bridge

    try:
        img_depth_stream = bridge.imgmsg_to_cv2(data, "16UC1")
    except CvBridgeError as e:
        print(e)


def publish_results(point, radius, depth, label, precision, timestamp):
    global camera_frame

    detection_msg = Detection2D()
    # the time when the image was taken
    detection_msg.header.stamp = timestamp
    detection_msg.header.frame_id = camera_frame

    detection_msg.bbox.size_x = radius*2
    detection_msg.bbox.size_y = radius*2
    detection_msg.bbox.center.x = point[0]
    detection_msg.bbox.center.y = point[1]

    obj_with_pose = ObjectHypothesisWithPose()
    # the id might not be the same in different msgs
    #obj_with_pose.id = i
    # TODO calc x and y in img frame
    obj_with_pose.pose.pose.position.z = depth
    obj_with_pose.score = precision
    obj_with_pose.id = label

    detection_msg.results = [obj_with_pose]

    pub_keypoint.publish(detection_msg)

def detect_sign(img_rgb_stream, image_timestamp):
    global img_depth_stream, pub_result_img

    img_orig = copy(img_rgb_stream)

    # get sign location in img
    keypoints = SegmentSign.do_hough_circle_detection(copy(img_orig), copy(img_depth_stream))
    keypoints += ShapeDetector.do_shape_detection(copy(img_orig), copy(img_depth_stream))
    keypoints = SegmentSign.filter_duplicate_keypoints(keypoints)
    patches = SegmentSign.get_tensor_patches(copy(img_orig), keypoints)

    # cut to multiple images at keypoints
    text = []
    for i in range(len(keypoints)):
        k = keypoints[i]
        p = patches[i]
        d = k["depth"]
        center = [k["center"][1], k["center"][0]]
        r = k["radius"]

        # classify image batches
        label, precision = sign_classifier.predictImage(p)  # returns tupel (label, precision), if no model / error is set up label= -1
        if label >= 0:
            # publish results
            publish_results(center, r, d, label, precision, image_timestamp)
            text.append("c: {} p: {:1.3f} d:{:1.3f}".format(sign_classifier.labelOfClass(label), precision, d))

    if visualize:
        if len(text) > 0:
            SegmentSign.visualize_patches(keypoints, patches, text, img_orig)
        # compress and publish
        cmprsmsg = bridge.cv2_to_compressed_imgmsg(img_orig)
        pub_result_img.publish(cmprsmsg)


def set_model_callback(req):
    sign_classifier.setNewModel(req.url)
    rospy.logwarn("TODO implement url error check")
    return False        # TODO implement url error check


def set_visualize_callback(req):
    global visualize

    visualize = req.data
    return True, ""


if __name__ == "__main__":
    rospy.init_node("sign_detector")

    img_color_topic = "{}camera/color/image_raw".format(rospy.get_namespace())
    img_depth_topic = "{}camera/aligned_depth_to_color/image_raw".format(rospy.get_namespace())

    rospy.Subscriber(img_color_topic, Image, image_color_callback, queue_size=1)
    rospy.Subscriber(img_depth_topic, Image, image_depth_callback, queue_size=1)
    rospy.Service('sign_detector/set_model', set_url, set_model_callback)
    rospy.Service('sign_detector/set_visualize', std_srvs.srv.SetBool, set_visualize_callback)
    pub_keypoint = rospy.Publisher('sign_detector/keypoints', Detection2D, queue_size=10)
    pub_result_img = rospy.Publisher("sign_detector/result_image/compressed", CompressedImage, queue_size=10)

    rate = rospy.Rate(30)       # currently this is impossible, but then the rate is defined by the detect_sign evaluation time
    while not rospy.is_shutdown():
        detect_sign(img_rgb_stream, img_rgb_timestamp)
        rate.sleep()