Skip to content
Snippets Groups Projects
Commit 40bf2048 authored by User expired's avatar User expired :spy_tone1:
Browse files

model download url check

parent 9d340d1e
No related branches found
No related tags found
No related merge requests found
......@@ -164,9 +164,9 @@ def detect_sign(img_rgb_stream, image_timestamp):
def set_model_callback(req):
sign_classifier.setNewModel(req.url)
rospy.logwarn("TODO implement url error check")
return False # TODO implement url error check
success = sign_classifier.setNewModel(req.url)
return success
def set_visualize_callback(req):
......
......@@ -24,6 +24,8 @@ class TmClassification:
self.tfjs_dir = ros_dir + self.tfjs_dir
self.h5_dir = ros_dir + self.h5_dir
self.metadata = {"labels": []}
if url is not None:
self.setNewModel(url)
......@@ -36,16 +38,21 @@ class TmClassification:
def setNewModel(self, url):
print("TF: Downloading model from url: {}".format(url))
self._prepareDirectories()
self._downloadFiles(url)
if not self._downloadFiles(url):
return False
self._convertFromFfjsFoKeras()
self.loadNewModel()
return True
def loadNewModel(self):
# TODO if there is an existing model, else error
if not os.path.exists(f'{self.tfjs_dir}/{self.files[1]}') or not os.path.exists(f'{self.h5_dir}/{self.h5_file}'):
rospy.logwarn("({}) There is no existing tensorflow model on your machine. You can set a new model by calling the /set_model service.".format(rospy.get_name()))
self.model = None # ensure that object exists in class
if not os.path.exists(f'{self.tfjs_dir}/{self.files[1]}') or not os.path.exists(
f'{self.h5_dir}/{self.h5_file}'):
rospy.logwarn(
"({}) There is no existing tensorflow model on your machine. You can set a new model by calling the /set_model service.".format(
rospy.get_name()))
self.model = None # ensure that object exists in class
return
# Load the model
self.model = load_model(f'{self.h5_dir}/{self.h5_file}', compile=False)
......@@ -63,17 +70,32 @@ class TmClassification:
os.mkdir(self.tfjs_dir)
def _downloadFiles(self, url):
# check url (only allow teachable machine models)
if not url.startswith("https://teachablemachine.withgoogle.com/models/"):
rospy.logerr("({}) No teachable machine url".format(rospy.get_name()))
return False
for f in self.files:
request_url = url + f
storage_file = f'{self.tfjs_dir}/{f}'
r = requests.get(request_url, allow_redirects=True)
open(storage_file, 'wb').write(r.content)
try:
response = requests.get(request_url, allow_redirects=True)
if response.status_code != 200:
raise Exception("Request Response code is not 200")
except Exception as e:
rospy.logerr("({}) {}".format(rospy.get_name(), e))
return False
open(storage_file, 'wb').write(response.content)
return True
def _convertFromFfjsFoKeras(self):
os.system(f'tensorflowjs_converter --input_format=tfjs_layers_model --output_format=keras {self.tfjs_dir}/{self.files[0]} {self.h5_dir}/{self.h5_file}')
os.system(
f'tensorflowjs_converter --input_format=tfjs_layers_model --output_format=keras {self.tfjs_dir}/{self.files[0]} {self.h5_dir}/{self.h5_file}')
def _loadMetadata(self):
f = open(self.tfjs_dir+'/'+self.files[1])
f = open(self.tfjs_dir + '/' + self.files[1])
return json.load(f)
def predictImage(self, image):
......@@ -84,11 +106,11 @@ class TmClassification:
# determined by the first position in the shape tuple, in this case 1.
data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
# Replace this with the path to your image
#image = Image.open('rosa.jpg')
#resize the image to a 224x224 with the same strategy as in TM2:
#resizing the image to be at least 224x224 and then cropping from the center
# image = Image.open('rosa.jpg')
# resize the image to a 224x224 with the same strategy as in TM2:
# resizing the image to be at least 224x224 and then cropping from the center
#turn the image into a numpy array
# turn the image into a numpy array
image_array = np.asarray(image)
# Normalize the image
normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1
......@@ -97,7 +119,7 @@ class TmClassification:
# run the inference
prediction = self.model.predict(data)
# Generate arg maxes of predictions
class_nr = np.argmax(prediction, axis=1)[0]
return class_nr, np.max(prediction, axis=1)[0]
......@@ -118,4 +140,3 @@ class TmClassification:
if class_number < 0 or class_number > len(labels):
return 'unkown'
return labels[class_number]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment