forked from FIRST-Tech-Challenge/fmltc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tracking.py
211 lines (182 loc) · 8.82 KB
/
tracking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__author__ = "lizlooney@google.com (Liz Looney)"
# Inspired by
# https://github.com/google/ftc-object-detection/tree/46197ce4ecaee954c2164d257d7dc24e85678285/training/tracking.py
# Python Standard Library
from datetime import datetime, timedelta, timezone
import logging
import os
import time
import uuid
# Other Modules
import cv2
import numpy as np
# My Modules
import action
import bbox_writer
import blob_storage
import exceptions
import storage
import util
tracker_fns = {
'CSRT': cv2.TrackerCSRT_create,
'MedianFlow': cv2.TrackerMedianFlow_create,
'MIL': cv2.TrackerMIL_create,
'MOSSE': cv2.TrackerMOSSE_create,
'TLD': cv2.TrackerTLD_create,
'KCF': cv2.TrackerKCF_create,
'Boosting': cv2.TrackerBoosting_create,
}
TWO_MINUTES_IN_MS = 2 * 60 * 1000
def prepare_to_start_tracking(team_uuid, video_uuid, tracker_name, scale, init_frame_number, init_bboxes_text):
tracker_uuid = storage.tracker_starting(team_uuid, video_uuid, tracker_name, scale, init_frame_number, init_bboxes_text)
action_parameters = action.create_action_parameters(action.ACTION_NAME_TRACKING)
action_parameters['video_uuid'] = video_uuid
action_parameters['tracker_uuid'] = tracker_uuid
action.trigger_action_via_blob(action_parameters)
return tracker_uuid
def start_tracking(action_parameters):
video_uuid = action_parameters['video_uuid']
tracker_uuid = action_parameters['tracker_uuid']
tracker_entity = storage.retrieve_tracker_entity(video_uuid, tracker_uuid)
if tracker_entity is None:
util.log('Unexpected: storage.retrieve_tracker_entity returned None')
return
team_uuid = tracker_entity['team_uuid']
tracker_client_entity = storage.retrieve_tracker_client_entity(video_uuid, tracker_uuid)
if tracker_client_entity is None:
util.log('Unexpected: storage.retrieve_tracker_client_entity returned None')
return
if (tracker_client_entity['tracking_stop_requested'] or
datetime.now(timezone.utc) - tracker_client_entity['update_time'] > timedelta(minutes=2)):
storage.tracker_stopping(team_uuid, video_uuid, tracker_uuid)
return
tracker_name = tracker_entity['tracker_name']
scale = tracker_entity['scale']
frame_number = tracker_entity['frame_number']
if tracker_name not in tracker_fns:
message = 'Error: Tracker named %s not found.' % tracker_name
logging.critical(message)
raise exceptions.HttpErrorNotFound(message)
tracker_fn = tracker_fns[tracker_name]
# Write the video out to a temporary file.
video_filename = '/tmp/%s' % str(uuid.uuid4().hex)
os.makedirs(os.path.dirname(video_filename), exist_ok=True)
blob_storage.write_video_to_file(tracker_entity['video_blob_name'], video_filename)
try:
# Open the video file with cv2.
vid = cv2.VideoCapture(video_filename)
if not vid.isOpened():
message = "Error: Unable to open video for video_uuid=%s." % video_uuid
logging.critical(message)
raise exceptions.HttpErrorInternalServerError(message)
try:
if frame_number > 0:
# We are tracking from a frame that is not the beginning of the video. Skip to
# that frame. Setting the CAP_PROP_POS_FRAMES property is not reliable.
# Instead, we skip through frames using vid.grab().
for i in range(frame_number):
vid.grab()
trackers = None
# Read the frame from the video file.
success, frame = vid.read()
if not success:
# We've reached the end of the video.
storage.tracker_stopping(team_uuid, video_uuid, tracker_uuid)
return
# Wait for the bboxes to be approved/adjusted.
while tracker_client_entity['frame_number'] != frame_number:
if __should_stop(team_uuid, video_uuid, tracker_uuid, tracker_client_entity,
action_parameters):
return
time.sleep(0.1)
tracker_client_entity = storage.retrieve_tracker_client_entity(video_uuid, tracker_uuid)
if tracker_client_entity is None:
util.log('Unexpected: storage.retrieve_tracker_client_entity returned None')
return
# Separate bboxes_text into bboxes and classes.
bboxes, classes = bbox_writer.parse_bboxes_text(tracker_client_entity['bboxes_text'], scale)
# Create the trackers, one per bbox.
trackers = __create_trackers(tracker_fn, frame, bboxes, classes)
while True:
# Read the next frame from the video file.
frame_number += 1
success, frame = vid.read()
if not success:
# We've reached the end of the video.
storage.tracker_stopping(team_uuid, video_uuid, tracker_uuid)
return
# Get the updated bboxes from the trackers.
bboxes = []
for i, tracker in enumerate(trackers):
success, tuple = tracker.update(frame)
if success:
bboxes.append(np.array(tuple))
else:
logging.error('Tracking failure for object %d on frame %d' % (i, frame_number))
bboxes.append(None)
# Store the new bboxes.
tracked_bboxes_text = bbox_writer.format_bboxes_text(bboxes, classes, scale)
storage.store_tracked_bboxes(video_uuid, tracker_uuid, frame_number, tracked_bboxes_text)
if __should_stop(team_uuid, video_uuid, tracker_uuid, tracker_client_entity,
action_parameters):
return
# Wait for the bboxes to be approved/adjusted.
tracker_client_entity = storage.retrieve_tracker_client_entity(video_uuid, tracker_uuid)
if tracker_client_entity is None:
util.log('Unexpected: storage.retrieve_tracker_client_entity returned None')
return
while tracker_client_entity['frame_number'] != frame_number:
if __should_stop(team_uuid, video_uuid, tracker_uuid, tracker_client_entity,
action_parameters):
return
time.sleep(0.1)
tracker_client_entity = storage.retrieve_tracker_client_entity(video_uuid, tracker_uuid)
if tracker_client_entity is None:
util.log('Unexpected: storage.retrieve_tracker_client_entity returned None')
return
if tracker_client_entity['bboxes_text'] != tracked_bboxes_text:
# Separate bboxes_text into bboxes and classes.
bboxes, classes = bbox_writer.parse_bboxes_text(tracker_client_entity['bboxes_text'], scale)
# Create new trackers, one per bbox.
trackers = __create_trackers(tracker_fn, frame, bboxes, classes)
finally:
# Release the cv2 video.
vid.release()
finally:
# Delete the temporary file.
os.remove(video_filename)
def __should_stop(team_uuid, video_uuid, tracker_uuid, tracker_client_entity, action_parameters):
if (tracker_client_entity['tracking_stop_requested'] or
datetime.now(timezone.utc) - tracker_client_entity['update_time'] > timedelta(minutes=2)):
storage.tracker_stopping(team_uuid, video_uuid, tracker_uuid)
return True
action.retrigger_if_necessary(action_parameters)
return False
def __create_trackers(tracker_fn, frame, init_bboxes, classes):
trackers = []
for i, bbox in enumerate(init_bboxes):
# For cv2.TrackerBoosting, round the box coordinates to prevent a bus error.
if tracker_fn == cv2.TrackerBoosting_create:
for j in range(len(bbox)):
bbox[j] = round(bbox[j])
tracker = tracker_fn()
success = tracker.init(frame, tuple(bbox))
if not success:
logging.error('Unable to initialize tracker %d, labeled %s' % (f, classes[i]))
continue
else:
trackers.append(tracker)
return trackers