Skip to content

Commit

Permalink
gets multi object vision working without throwing an error!
Browse files Browse the repository at this point in the history
  • Loading branch information
NishanthJKumar committed Jul 24, 2023
1 parent 17f2244 commit cbb8b81
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 68 deletions.
4 changes: 2 additions & 2 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,9 +1031,9 @@ def _generate_task_goal(self) -> Set[GroundAtom]:
hex_screwdriver = self._obj_name_to_obj("hex_screwdriver")
bag = self._obj_name_to_obj("toolbag")
return {
GroundAtom(self._InBag, [hammer, bag]),
# GroundAtom(self._InBag, [hammer, bag]),
GroundAtom(self._InBag, [brush, bag]),
GroundAtom(self._InBag, [hex_key, bag]),
# GroundAtom(self._InBag, [hex_key, bag]),
GroundAtom(self._InBag, [hex_screwdriver, bag]),
}

Expand Down
79 changes: 44 additions & 35 deletions predicators/spot_utils/perception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
models with the Boston Dynamics Spot robot."""

import io
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple

import bosdyn.client
import bosdyn.client.util
Expand All @@ -14,6 +14,7 @@
from bosdyn.api import image_pb2
from PIL import Image
from scipy import ndimage
import dill as pkl

from predicators.settings import CFG

Expand Down Expand Up @@ -97,7 +98,7 @@ def show_box(box: np.ndarray, ax: matplotlib.axes.Axes) -> None:


def query_detic_sam(image_in: np.ndarray, classes: List[str],
viz: bool) -> Optional[Dict[str, List[np.ndarray]]]:
viz: bool) -> Dict[str, List[np.ndarray]]:
"""Send a query to SAM and return the response.
The response is a dictionary that contains 4 keys: 'boxes',
Expand All @@ -109,12 +110,22 @@ def query_detic_sam(image_in: np.ndarray, classes: List[str],
files={"file": buf},
data={"classes": ",".join(classes)})

d_filtered: Dict[str, List[np.ndarray]] = {
"boxes": [],
"classes": [],
"masks": [],
"scores": []
}
# If the status code is not 200, then fail.
if r.status_code != 200:
return None
d_filtered

with io.BytesIO(r.content) as f:
arr = np.load(f, allow_pickle=True)
try:
arr = np.load(f, allow_pickle=True)
except pkl.UnpicklingError:
return d_filtered

boxes = arr['boxes']
ret_classes = arr['classes']
masks = arr['masks']
Expand All @@ -137,24 +148,20 @@ def query_detic_sam(image_in: np.ndarray, classes: List[str],
#, we only select the most confident one. This structure makes
# it easy for us to select multiple detections if that's ever
# necessary in the future.
import ipdb; ipdb.set_trace()
d_filtered: Dict[str, List[np.ndarray]] = {
"boxes": [],
"classes": [],
"masks": [],
"scores": []
}
for obj_class in classes:
obj_idxs_with_classes = np.where(d['classes'] == obj_class)
if len(obj_idxs_with_classes) == 0:
class_mask = (d['classes'] == obj_class)
if np.all(class_mask == False):
continue
# TODO: continue from here!

selected_idx = np.argmax(d['scores'])
if d['scores'][selected_idx] < CFG.spot_vision_detection_threshold:
return None
for key, value in d.items():
d_filtered[key].append(value[selected_idx])
max_score = np.max(d['scores'][class_mask])
max_score_idx = np.where(d['scores'] == max_score)[0]
if d['scores'][max_score_idx] < CFG.spot_vision_detection_threshold:
continue
for key, value in d.items():
# Sanity check to ensure that we're selecting a value from
# the class we're looking for.
if key == "classes":
assert value[max_score_idx] == obj_class
d_filtered[key].append(value[max_score_idx])

return d_filtered

Expand Down Expand Up @@ -355,15 +362,15 @@ def get_object_locations_with_detic_sam(
res_image: Dict[str, np.ndarray],
res_image_responses: Dict[str, bosdyn.api.image_pb2.ImageResponse],
source_name: str,
plot: bool = False) -> List[Tuple[float, float, float]]:
plot: bool = False) -> Dict[str, Tuple[float, float, float]]:
"""Given a list of string queries (classes), call SAM on these and return
the positions of the centroids of these detections in the world frame.
the positions of the centroids of these detections in the camera frame.
Importantly, note that a number of cameras on the Spot robot are
rotated by various degrees. Since SAM doesn't do so well on rotated
images, we first rotate these images to be upright, pass them to
SAM, then rotate the result back so that we can correctly compute
the 3D position in the world frame.
the 3D position in the camera frame.
"""
# First, rotate the rgb and depth images by the correct angle.
# Importantly, DO NOT reshape the image, because this will
Expand All @@ -384,20 +391,22 @@ def get_object_locations_with_detic_sam(
res_segment = query_detic_sam(image_in=rotated_rgb,
classes=classes,
viz=plot)
if res_segment is None:
return []

import ipdb; ipdb.set_trace()
if len(res_segment['classes']) == 0:
return []

# Detect multiple objects with their masks
obj_num = len(res_segment['masks'])
res_locations = []
for i in range(obj_num):
ret_obj_positions: Dict[str, Tuple[float, float, float]] = {}
for i, obj_class in enumerate(res_segment['classes']):
# Check that this particular class is one of the
# classes we passed in, and that there was only one
# instance of this class that was found.
assert obj_class in classes
assert res_segment['classes'].count(obj_class) == 1
# Compute median value of depth
depth_median = np.median(rotated_depth[res_segment['masks'][i][0]
depth_median = np.median(rotated_depth[res_segment['masks'][i][0].squeeze()
& (rotated_depth > 2)[:, :, 0]])
# Compute geometric center of object bounding box
x1, y1, x2, y2 = res_segment['boxes'][i]
x1, y1, x2, y2 = res_segment['boxes'][i].squeeze()
x_c = (x1 + x2) / 2
y_c = (y1 + y2) / 2
# Create a transformation matrix for the rotation. Be very
Expand Down Expand Up @@ -432,7 +441,7 @@ def get_object_locations_with_detic_sam(
if plot:
inverse_rotation_angle = -ROTATION_ANGLE[source_name]
plt.imshow(res_image['rgb'])
plt.imshow(ndimage.rotate(res_segment['masks'][i][0],
plt.imshow(ndimage.rotate(res_segment['masks'][i][0].squeeze(),
inverse_rotation_angle,
reshape=False),
alpha=0.5,
Expand All @@ -455,6 +464,6 @@ def get_object_locations_with_detic_sam(
depth_value=depth_median,
point_x=x_c_rotated,
point_y=y_c_rotated)
res_locations.append((x0, y0, z0))
ret_obj_positions[obj_class.item()] = (x0, y0, z0)

return res_locations
return ret_obj_positions
57 changes: 26 additions & 31 deletions predicators/spot_utils/spot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def get_memorized_waypoint(obj_name: str) -> Optional[Tuple[str, Array]]:
"hex_screwdriver": "screwdriver",
"toolbag": "work bag",
}
vision_prompt_to_obj_name = {value: key for key, value in obj_name_to_vision_prompt.items()}

OBJECT_CROPS = {
# min_x, max_x, min_y, max_y
Expand Down Expand Up @@ -324,22 +325,19 @@ def get_objects_in_view_by_camera(
viewable_obj_poses = self.get_apriltag_pose_from_camera(
source_name=source_name)
else:
# NOTE: we now hard-code the 'yellow brush' to be a
# stand-in for the cube, which is quite a hack.
# We will remove this and do correct object classing
# in a future PR
# First, get a dictionary mapping vision prompts
# to the corresponding location of that object in the
# scene by camera
sam_pose_results = self.get_sam_object_loc_from_camera(
source_rgb=source_name,
source_depth=RGB_TO_DEPTH_CAMERAS[source_name],
classes=[obj_name_to_vision_prompt['brush'], obj_name_to_vision_prompt['hex_screwdriver']],
classes=list(obj_name_to_vision_prompt.values()),
)

if 'yellow brush' in sam_pose_results:
viewable_obj_poses = {
410: sam_pose_results['yellow brush']
}
else:
viewable_obj_poses = {}
# Next, convert the keys of this dictionary to be april
# tag id's instead.
viewable_obj_poses: Dict[int, Tuple[float, float, float]] = {}
for k, v in sam_pose_results.items():
viewable_obj_poses[obj_name_to_apriltag_id[vision_prompt_to_obj_name[k]]] = v
tag_to_pose[source_name].update(viewable_obj_poses)

apriltag_id_to_obj_name = {
Expand Down Expand Up @@ -498,31 +496,28 @@ def get_sam_object_loc_from_camera(
'depth': depth_img_response[0],
}

res_locations = get_object_locations_with_detic_sam(
res_location_dict = get_object_locations_with_detic_sam(
classes=classes,
res_image=image,
res_image_responses=image_responses,
source_name=source_rgb,
plot=CFG.spot_visualize_vision_model_outputs)

# We only want the most likely sample (for now).
# NOTE: we make the hard assumption here that
# we will only see one instance of a particular object
# type. We can relax this later.
if len(res_locations) > 0:
assert len(res_locations) == 1
camera_tform_body = get_a_tform_b(
image_responses['depth'].shot.transforms_snapshot,
image_responses['depth'].shot.frame_name_image_sensor,
BODY_FRAME_NAME)
object_rt_gn_origin = self.convert_obj_location(
camera_tform_body, *res_locations[0])

# Use the input class name as the identifier for object(s) and
# their positions
return {classes[0]: object_rt_gn_origin}

return {}
transformed_location_dict: Dict[str, Tuple[float, float, float]] = {}
for obj_class in classes:
if obj_class in res_location_dict:
camera_tform_body = get_a_tform_b(
image_responses['depth'].shot.transforms_snapshot,
image_responses['depth'].shot.frame_name_image_sensor,
BODY_FRAME_NAME)
x, y, z = res_location_dict[obj_class]
object_rt_gn_origin = self.convert_obj_location(
camera_tform_body, x, y, z)
transformed_location_dict[obj_class] = object_rt_gn_origin

# Use the input class name as the identifier for object(s) and
# their positions
return transformed_location_dict

def convert_obj_location(
self, camera_tform_body: bosdyn.client.math_helpers.SE3Pose,
Expand Down

0 comments on commit cbb8b81

Please sign in to comment.