Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
Minibot Vision
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Institut für Informatik
stair
Minibot Vision
Commits
40bf2048
Commit
40bf2048
authored
2 years ago
by
User expired
Browse files
Options
Downloads
Patches
Plain Diff
model download url check
parent
9d340d1e
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
minibot_vision/scripts/SignDetector.py
+3
-3
3 additions, 3 deletions
minibot_vision/scripts/SignDetector.py
minibot_vision/scripts/TmClassification.py
+36
-15
36 additions, 15 deletions
minibot_vision/scripts/TmClassification.py
with
39 additions
and
18 deletions
minibot_vision/scripts/SignDetector.py
+
3
−
3
View file @
40bf2048
...
...
@@ -164,9 +164,9 @@ def detect_sign(img_rgb_stream, image_timestamp):
def
set_model_callback
(
req
):
sign_classifier
.
setNewModel
(
req
.
url
)
rospy
.
logwarn
(
"
TODO implement url error check
"
)
return
False
# TODO implement url error check
success
=
sign_classifier
.
setNewModel
(
req
.
url
)
return
success
def
set_visualize_callback
(
req
):
...
...
This diff is collapsed.
Click to expand it.
minibot_vision/scripts/TmClassification.py
+
36
−
15
View file @
40bf2048
...
...
@@ -24,6 +24,8 @@ class TmClassification:
self
.
tfjs_dir
=
ros_dir
+
self
.
tfjs_dir
self
.
h5_dir
=
ros_dir
+
self
.
h5_dir
self
.
metadata
=
{
"
labels
"
:
[]}
if
url
is
not
None
:
self
.
setNewModel
(
url
)
...
...
@@ -36,16 +38,21 @@ class TmClassification:
def
setNewModel
(
self
,
url
):
print
(
"
TF: Downloading model from url: {}
"
.
format
(
url
))
self
.
_prepareDirectories
()
self
.
_downloadFiles
(
url
)
if
not
self
.
_downloadFiles
(
url
):
return
False
self
.
_convertFromFfjsFoKeras
()
self
.
loadNewModel
()
return
True
def
loadNewModel
(
self
):
# TODO if there is an existing model, else error
if
not
os
.
path
.
exists
(
f
'
{
self
.
tfjs_dir
}
/
{
self
.
files
[
1
]
}
'
)
or
not
os
.
path
.
exists
(
f
'
{
self
.
h5_dir
}
/
{
self
.
h5_file
}
'
):
rospy
.
logwarn
(
"
({}) There is no existing tensorflow model on your machine. You can set a new model by calling the /set_model service.
"
.
format
(
rospy
.
get_name
()))
self
.
model
=
None
# ensure that object exists in class
if
not
os
.
path
.
exists
(
f
'
{
self
.
tfjs_dir
}
/
{
self
.
files
[
1
]
}
'
)
or
not
os
.
path
.
exists
(
f
'
{
self
.
h5_dir
}
/
{
self
.
h5_file
}
'
):
rospy
.
logwarn
(
"
({}) There is no existing tensorflow model on your machine. You can set a new model by calling the /set_model service.
"
.
format
(
rospy
.
get_name
()))
self
.
model
=
None
# ensure that object exists in class
return
# Load the model
self
.
model
=
load_model
(
f
'
{
self
.
h5_dir
}
/
{
self
.
h5_file
}
'
,
compile
=
False
)
...
...
@@ -63,17 +70,32 @@ class TmClassification:
os
.
mkdir
(
self
.
tfjs_dir
)
def
_downloadFiles
(
self
,
url
):
# check url (only allow teachable machine models)
if
not
url
.
startswith
(
"
https://teachablemachine.withgoogle.com/models/
"
):
rospy
.
logerr
(
"
({}) No teachable machine url
"
.
format
(
rospy
.
get_name
()))
return
False
for
f
in
self
.
files
:
request_url
=
url
+
f
storage_file
=
f
'
{
self
.
tfjs_dir
}
/
{
f
}
'
r
=
requests
.
get
(
request_url
,
allow_redirects
=
True
)
open
(
storage_file
,
'
wb
'
).
write
(
r
.
content
)
try
:
response
=
requests
.
get
(
request_url
,
allow_redirects
=
True
)
if
response
.
status_code
!=
200
:
raise
Exception
(
"
Request Response code is not 200
"
)
except
Exception
as
e
:
rospy
.
logerr
(
"
({}) {}
"
.
format
(
rospy
.
get_name
(),
e
))
return
False
open
(
storage_file
,
'
wb
'
).
write
(
response
.
content
)
return
True
def
_convertFromFfjsFoKeras
(
self
):
os
.
system
(
f
'
tensorflowjs_converter --input_format=tfjs_layers_model --output_format=keras
{
self
.
tfjs_dir
}
/
{
self
.
files
[
0
]
}
{
self
.
h5_dir
}
/
{
self
.
h5_file
}
'
)
os
.
system
(
f
'
tensorflowjs_converter --input_format=tfjs_layers_model --output_format=keras
{
self
.
tfjs_dir
}
/
{
self
.
files
[
0
]
}
{
self
.
h5_dir
}
/
{
self
.
h5_file
}
'
)
def
_loadMetadata
(
self
):
f
=
open
(
self
.
tfjs_dir
+
'
/
'
+
self
.
files
[
1
])
f
=
open
(
self
.
tfjs_dir
+
'
/
'
+
self
.
files
[
1
])
return
json
.
load
(
f
)
def
predictImage
(
self
,
image
):
...
...
@@ -84,11 +106,11 @@ class TmClassification:
# determined by the first position in the shape tuple, in this case 1.
data
=
np
.
ndarray
(
shape
=
(
1
,
224
,
224
,
3
),
dtype
=
np
.
float32
)
# Replace this with the path to your image
#image = Image.open('rosa.jpg')
#resize the image to a 224x224 with the same strategy as in TM2:
#resizing the image to be at least 224x224 and then cropping from the center
#
image = Image.open('rosa.jpg')
#
resize the image to a 224x224 with the same strategy as in TM2:
#
resizing the image to be at least 224x224 and then cropping from the center
#turn the image into a numpy array
#
turn the image into a numpy array
image_array
=
np
.
asarray
(
image
)
# Normalize the image
normalized_image_array
=
(
image_array
.
astype
(
np
.
float32
)
/
127.0
)
-
1
...
...
@@ -97,7 +119,7 @@ class TmClassification:
# run the inference
prediction
=
self
.
model
.
predict
(
data
)
# Generate arg maxes of predictions
class_nr
=
np
.
argmax
(
prediction
,
axis
=
1
)[
0
]
return
class_nr
,
np
.
max
(
prediction
,
axis
=
1
)[
0
]
...
...
@@ -118,4 +140,3 @@ class TmClassification:
if
class_number
<
0
or
class_number
>
len
(
labels
):
return
'
unkown
'
return
labels
[
class_number
]
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment