-
Notifications
You must be signed in to change notification settings - Fork 0
/
g_detect_moto.py
120 lines (106 loc) · 4.64 KB
/
g_detect_moto.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A demo which runs object detection on camera frames using GStreamer.
Run default object detection:
python3 detect.py
Choose different camera and input encoding
python3 detect.py --videosrc /dev/video1 --videofmt jpeg
TEST_DATA=../all_models
Run face detection model:
python3 detect.py \
--model ${TEST_DATA}/mobilenet_ssd_v2_face_quant_postprocess_edgetpu.tflite
Run coco model:
python3 detect.py \
--model ${TEST_DATA}/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite \
--labels ${TEST_DATA}/coco_labels.txt
"""
import argparse
import gstreamer
import os
import time
from common import avg_fps_counter, SVG
from pycoral.adapters.common import input_size
from pycoral.adapters.detect import get_objects
from pycoral.utils.dataset import read_label_file
from pycoral.utils.edgetpu import make_interpreter
from pycoral.utils.edgetpu import run_inference
def generate_svg(src_size, inference_box, objs, labels, text_lines):
svg = SVG(src_size)
src_w, src_h = src_size
box_x, box_y, box_w, box_h = inference_box
scale_x, scale_y = src_w / box_w, src_h / box_h
for y, line in enumerate(text_lines, start=1):
svg.add_text(10, y * 20, line, 20)
for obj in objs:
bbox = obj.bbox
if not bbox.valid:
continue
# Absolute coordinates, input tensor space.
x, y = bbox.xmin, bbox.ymin
w, h = bbox.width, bbox.height
# Subtract boxing offset.
x, y = x - box_x, y - box_y
# Scale to source coordinate space.
x, y, w, h = x * scale_x, y * scale_y, w * scale_x, h * scale_y
percent = int(100 * obj.score)
label = '{}% {}'.format(percent, labels.get(obj.id, obj.id))
svg.add_text(x, y - 5, label, 20)
svg.add_rect(x, y, w, h, 'red', 2)
return svg.finish()
def main():
default_model_dir = '../all_models'
default_model = 'mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite'
default_labels = 'coco_labels.txt'
parser = argparse.ArgumentParser()
parser.add_argument('--model', help='.tflite model path',
default=os.path.join(default_model_dir,default_model))
parser.add_argument('--labels', help='label file path',
default=os.path.join(default_model_dir, default_labels))
parser.add_argument('--top_k', type=int, default=3,
help='number of categories with highest score to display')
parser.add_argument('--threshold', type=float, default=0.1,
help='classifier score threshold')
parser.add_argument('--videosrc', help='Which video source to use. ',
default='/dev/video0')
parser.add_argument('--videofmt', help='Input video format.',
default='raw',
choices=['raw', 'h264', 'jpeg'])
args = parser.parse_args()
print('Loading {} with {} labels.'.format(args.model, args.labels))
interpreter = make_interpreter(args.model)
interpreter.allocate_tensors()
labels = read_label_file(args.labels)
inference_size = input_size(interpreter)
# Average fps over last 30 frames.
fps_counter = avg_fps_counter(30)
def user_callback(input_tensor, src_size, inference_box):
nonlocal fps_counter
start_time = time.monotonic()
run_inference(interpreter, input_tensor)
# For larger input image sizes, use the edgetpu.classification.engine for better performance
objs = get_objects(interpreter, args.threshold)[:args.top_k]
end_time = time.monotonic()
text_lines = [
'Inference: {:.2f} ms'.format((end_time - start_time) * 1000),
'FPS: {} fps'.format(round(next(fps_counter))),
]
print(' '.join(text_lines))
return generate_svg(src_size, inference_box, objs, labels, text_lines)
result = gstreamer.run_pipeline(user_callback,
src_size=(640, 480),
appsink_size=inference_size,
videosrc=args.videosrc,
videofmt=args.videofmt)
if __name__ == '__main__':
main()