Skip to content

Commit

Permalink
Merge pull request #26 from Naveen-Dodda/master
Browse files Browse the repository at this point in the history
Updating Pose_engine.py: tto fix heatmap parsing failure
  • Loading branch information
Naveen-Dodda authored Aug 5, 2021
2 parents 445735b + ab3f990 commit d580888
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pose_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@ def __init__(self, model_path, mirror=False):
def calcStride(h,w,L):
return int((2*h*w)/(math.sqrt(h**2 + 4*h*L*w - 2*h*w + w**2) - h - w))

details = self._interpreter.get_output_details()[4]
details = self._interpreter.get_output_details()[5]
self.heatmap_zero_point = details['quantization_parameters']['zero_points'][0]
self.heatmap_scale = details['quantization_parameters']['scales'][0]
heatmap_size = self._interpreter.tensor(details['index'])().nbytes
self.stride = calcStride(self.image_height, self.image_width, heatmap_size)
self.heatmap_size = (self.image_width // self.stride + 1, self.image_height // self.stride + 1)
details = self._interpreter.get_output_details()[5]
details = self._interpreter.get_output_details()[6]
self.parts_zero_point = details['quantization_parameters']['zero_points'][0]
self.parts_scale = details['quantization_parameters']['scales'][0]

Expand Down Expand Up @@ -235,9 +235,9 @@ def softmax(self, y, axis):

def _parse_heatmaps(self, outputs):
# Heatmaps are really float32.
heatmap = (outputs[4].astype(np.float32) - self.heatmap_zero_point) * self.heatmap_scale
heatmap = (outputs[5].astype(np.float32) - self.heatmap_zero_point) * self.heatmap_scale
heatmap = np.reshape(heatmap, [self.heatmap_size[1], self.heatmap_size[0]])
part_heatmap = (outputs[5].astype(np.float32) - self.parts_zero_point) * self.parts_scale
part_heatmap = (outputs[6].astype(np.float32) - self.parts_zero_point) * self.parts_scale
part_heatmap = np.reshape(part_heatmap, [self.heatmap_size[1], self.heatmap_size[0], -1])
part_heatmap = self.softmax(part_heatmap, axis=2)
return heatmap, part_heatmap
Expand Down

0 comments on commit d580888

Please sign in to comment.