Skip to content

Commit

Permalink
Merge branch 'vlm-pipeline-testing' into change-perceiver-to-have-his…
Browse files Browse the repository at this point in the history
…tory
  • Loading branch information
ashay-bdai authored Sep 16, 2024
2 parents 199698e + 5a461cc commit 9ce43d3
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 37 deletions.
1 change: 0 additions & 1 deletion predicators/approaches/bilevel_planning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]:
import pdb
pdb.set_trace()
# utils.abstract(task.init, preds, self._vlm)

# Run task planning only and then greedily sample and execute in the
# policy.
if self._plan_without_sim:
Expand Down
126 changes: 108 additions & 18 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
from predicators.settings import CFG
from predicators.spot_utils.perception.object_detection import \
AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \
LanguageObjectDetectionID, ObjectDetectionID, detect_objects, \
visualize_all_artifacts
LanguageObjectDetectionID, ObjectDetectionID, _query_detic_sam2, \
detect_objects, visualize_all_artifacts
from predicators.spot_utils.perception.object_specific_grasp_selection import \
brush_prompt, bucket_prompt, football_prompt, train_toy_prompt
from predicators.spot_utils.perception.perception_structs import \
RGBDImageWithContext
from predicators.spot_utils.perception.perception_structs import RGBDImage, \
RGBDImageWithContext, SegmentedBoundingBox
from predicators.spot_utils.perception.spot_cameras import capture_images, \
capture_images_without_context
from predicators.spot_utils.skills.spot_find_objects import \
Expand Down Expand Up @@ -107,6 +107,8 @@ class _TruncatedSpotObservation:
# nonpercept_atoms: Set[GroundAtom]
# nonpercept_predicates: Set[Predicate]
executed_skill: Optional[_Option] = None
# Object detections per camera in self.rgbd_images.
object_detections_per_camera: Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]]


class _PartialPerceptionState(State):
Expand Down Expand Up @@ -2465,8 +2467,18 @@ def _get_dry_task(self, train_or_test: str,

@property
def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]:
"""Get an object from a perception detection ID."""
raise NotImplementedError("No dry task for VLMTestEnv.")

detection_id_to_obj: Dict[ObjectDetectionID, Object] = {}
objects = {
Object("pan", _movable_object_type),
Object("cup", _movable_object_type),
Object("chair", _movable_object_type),
Object("bowl", _movable_object_type),
}
for o in objects:
detection_id = LanguageObjectDetectionID(o.name)
detection_id_to_obj[detection_id] = o
return detection_id_to_obj

def _create_operators(self) -> Iterator[STRIPSOperator]:
# Pick object
Expand Down Expand Up @@ -2539,16 +2551,87 @@ def __init__(self, use_gui: bool = True) -> None:
self._strips_operators = {op_to_name[o] for o in op_names_to_keep}
self._train_tasks = []
self._test_tasks = []


def detect_objects(self, rgbd_images: Dict[str, RGBDImage]) -> Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]]:
object_ids = self._detection_id_to_obj.keys()
object_id_to_img_detections = _query_detic_sam2(object_ids, rgbd_images)
# This ^ is currently a mapping of object_id -> camera_name -> SegmentedBoundingBox.
# We want to do our annotations by camera image, so let's turn this into a
# mapping of camera_name -> object_id -> SegmentedBoundingBox.
detections = {k: [] for k in rgbd_images.keys()}
for object_id, d in object_id_to_img_detections.items():
for camera_name, seg_bb in d.items():
detections[camera_name].append((object_id, seg_bb))
return detections

def _actively_construct_env_task(self) -> EnvironmentTask:
assert self._robot is not None
rgbd_images = capture_images_without_context(self._robot)
gripper_open_percentage = get_robot_gripper_open_percentage(
self._robot)
# import PIL
# imgs = [v.rgb for _, v in rgbd_images.items()]
# rot_imgs = [v.rotated_rgb for _, v in rgbd_images.items()]
# ex1 = PIL.Image.fromarray(imgs[0])
# ex2 = PIL.Image.fromarray(rot_imgs[0])
# import pdb; pdb.set_trace()
gripper_open_percentage = get_robot_gripper_open_percentage(self._robot)
objects_in_view = []
obs = _TruncatedSpotObservation(rgbd_images, set(objects_in_view),
set(), set(), self._spot_object,
gripper_open_percentage, None)

# Perform object detection.
object_detections_per_camera = self.detect_objects(rgbd_images)


# artifacts = {"language": {"rgbds": rgbd_images, "object_id_to_img_detections": ret}}
# detections_outfile = Path(".") / "object_detection_artifacts.png"
# no_detections_outfile = Path(".") / "no_detection_artifacts.png"
# visualize_all_artifacts(artifacts, detections_outfile, no_detections_outfile)

# # Draw object bounding box on images.
# rgbds = artifacts["language"]["rgbds"]
# detections = artifacts["language"]["object_id_to_img_detections"]
# flat_detections: List[Tuple[RGBDImage,
# LanguageObjectDetectionID,
# SegmentedBoundingBox]] = []
# for obj_id, img_detections in detections.items():
# for camera, seg_bb in img_detections.items():
# rgbd = rgbds[camera]
# flat_detections.append((rgbd, obj_id, seg_bb))

# # For now assume we only have 1 image, front-left.
# import pdb; pdb.set_trace()
# import PIL
# from PIL import ImageDraw, ImageFont
# bb_pil_imgs = []
# img = list(rgbd_images.values())[0].rotated_rgb
# pil_img = PIL.Image.fromarray(img)
# draw = ImageDraw.Draw(pil_img)
# for i, (rgbd, obj_id, seg_bb) in enumerate(flat_detections):
# # img = rgbd.rotated_rgb
# # pil_img = PIL.Image.fromarray(img)
# x0, y0, x1, y1 = seg_bb.bounding_box
# draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2)
# text = f"{obj_id.language_id}"
# font = ImageFont.load_default()
# # font = utils.get_scaled_default_font(draw, 4)
# # text_width, text_height = draw.textsize(text, font)
# # text_width = draw.textlength(text, font)
# # text_height = font.getsize("hg")[1]
# text_mask = font.getmask(text)
# text_width, text_height = text_mask.size
# text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)]
# draw.rectangle(text_bbox, fill='green')
# draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font)

# import pdb; pdb.set_trace()

obs = _TruncatedSpotObservation(
rgbd_images,
set(objects_in_view),
set(),
set(),
self._spot_object,
gripper_open_percentage,
object_detections_per_camera
)
goal_description = self._generate_goal_description()
task = EnvironmentTask(obs, goal_description)
return task
Expand Down Expand Up @@ -2606,13 +2689,20 @@ def step(self, action: Action) -> Observation:
logging.warning("WARNING: the following retryable error "
f"was encountered. Trying again.\n{e}")
rgbd_images = capture_images_without_context(self._robot)
gripper_open_percentage = get_robot_gripper_open_percentage(
self._robot)
print(gripper_open_percentage)
gripper_open_percentage = get_robot_gripper_open_percentage(self._robot)
objects_in_view = []
obs = _TruncatedSpotObservation(rgbd_images, set(objects_in_view),
set(), set(), self._spot_object,
gripper_open_percentage, action.get_option())
# Perform object detection.
object_detections_per_camera = self.detect_objects(rgbd_images)

obs = _TruncatedSpotObservation(
rgbd_images,
set(objects_in_view),
set(),
set(),
self._spot_object,
gripper_open_percentage,
object_detections_per_camera
)
return obs


Expand Down
40 changes: 33 additions & 7 deletions predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
_PartialPerceptionState, _SpotObservation, in_general_view_classifier
from predicators.perception.base_perceiver import BasePerceiver
from predicators.settings import CFG
from predicators.spot_utils.perception.object_detection import \
AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \
LanguageObjectDetectionID, ObjectDetectionID, _query_detic_sam2, \
detect_objects, visualize_all_artifacts
from predicators.spot_utils.utils import _container_type, \
_immovable_object_type, _movable_object_type, _robot_type, \
get_allowed_map_regions, load_spot_metadata, object_to_top_down_geom
Expand Down Expand Up @@ -680,13 +684,37 @@ def reset(self, env_task: EnvironmentTask) -> Task:
return Task(state, goal)

def step(self, observation: Observation) -> State:
import pdb; pdb.set_trace()
self._waiting_for_observation = False
self._robot = observation.robot
imgs = observation.rgbd_images
img_names = [v.camera_name for _, v in imgs.items()]
imgs = [v.rgb for _, v in imgs.items()]
import pdb
pdb.set_trace()
img_objects = observation.rgbd_images # RGBDImage objects
img_names = [v.camera_name for _, v in img_objects.items()]
imgs = [v.rotated_rgb for _, v in img_objects.items()]
import PIL
from PIL import ImageDraw, ImageFont
pil_imgs = [PIL.Image.fromarray(img) for img in imgs]
# Annotate images with detected objects (names + bounding box)
# and camera name.
object_detections_per_camera = observation.object_detections_per_camera
for i, camera_name in enumerate(img_names):
draw = ImageDraw.Draw(pil_imgs[i])
# Annotate with camera name.
font = utils.get_scaled_default_font(draw, 4)
_ = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[camera_name], font)
# Annotate with object detections.
detections = object_detections_per_camera[camera_name]
for obj_id, seg_bb in detections:
x0, y0, x1, y1 = seg_bb.bounding_box
draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2)
text = f"{obj_id.language_id}"
font = utils.get_scaled_default_font(draw, 3)
text_mask = font.getmask(text)
text_width, text_height = text_mask.size
text_bbox = [(x0, y0 - 1.5*text_height), (x0 + text_width + 1, y0)]
draw.rectangle(text_bbox, fill='green')
draw.text((x0 + 1, y0 - 1.5*text_height), text, fill='white', font=font)

import pdb; pdb.set_trace()
import PIL
from PIL import ImageDraw
annotated_pil_imgs = []
Expand All @@ -698,8 +726,6 @@ def step(self, observation: Observation) -> State:
draw, (0, 0), self.camera_name_to_annotation[img_name], font)
annotated_pil_imgs.append(pil_img)
annotated_imgs = [np.array(img) for img in annotated_pil_imgs]
import pdb
pdb.set_trace()
self._gripper_open_percentage = observation.gripper_open_percentage
self._curr_state = self._create_state()
self._curr_state.simulator_state["images"] = annotated_imgs
Expand Down
109 changes: 106 additions & 3 deletions predicators/spot_utils/perception/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from predicators.spot_utils.perception.perception_structs import \
AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \
LanguageObjectDetectionID, ObjectDetectionID, PythonicObjectDetectionID, \
RGBDImageWithContext, SegmentedBoundingBox
RGBDImage, RGBDImageWithContext, SegmentedBoundingBox
from predicators.spot_utils.utils import get_april_tag_transform, \
get_graph_nav_dir
from predicators.utils import rotate_point_in_image
Expand Down Expand Up @@ -351,6 +351,108 @@ def _query_detic_sam(

return object_id_to_img_detections

def _query_detic_sam2(
object_ids: Collection[LanguageObjectDetectionID],
rgbds: Dict[str, RGBDImage],
max_server_retries: int = 5,
detection_threshold: float = CFG.spot_vision_detection_threshold
) -> Dict[ObjectDetectionID, Dict[str, SegmentedBoundingBox]]:
"""Returns object ID to image ID (camera) to segmented bounding box."""

object_id_to_img_detections: Dict[ObjectDetectionID,
Dict[str, SegmentedBoundingBox]] = {
obj_id: {}
for obj_id in object_ids
}

# Create buffer dictionary to send to server.
buf_dict = {}
for camera_name, rgbd in rgbds.items():
pil_rotated_img = PIL.Image.fromarray(rgbd.rotated_rgb) # type: ignore
buf_dict[camera_name] = _image_to_bytes(pil_rotated_img)

# Extract all the classes that we want to detect.
classes = sorted(o.language_id for o in object_ids)

# Query server, retrying to handle possible wifi issues.
# import pdb; pdb.set_trace()
# imgs = [v.rotated_rgb for _, v in rgbds.items()]
# pil_img = PIL.Image.fromarray(imgs[0])
# import pdb; pdb.set_trace()

for _ in range(max_server_retries):
try:
r = requests.post("http://localhost:5550/batch_predict",
files=buf_dict,
data={"classes": ",".join(classes)})
break
except requests.exceptions.ConnectionError:
continue
else:
logging.warning("DETIC-SAM FAILED, POSSIBLE SERVER/WIFI ISSUE")
return object_id_to_img_detections

# If the status code is not 200, then fail.
if r.status_code != 200:
logging.warning(f"DETIC-SAM FAILED! STATUS CODE: {r.status_code}")
return object_id_to_img_detections

# Querying the server succeeded; unpack the contents.
with io.BytesIO(r.content) as f:
try:
server_results = np.load(f, allow_pickle=True)
# Corrupted results.
except pkl.UnpicklingError:
logging.warning("DETIC-SAM FAILED DURING UNPICKLING!")
return object_id_to_img_detections

# Process the results and save all detections per object ID.
for camera_name, rgbd in rgbds.items():
rot_boxes = server_results[f"{camera_name}_boxes"]
ret_classes = server_results[f"{camera_name}_classes"]
rot_masks = server_results[f"{camera_name}_masks"]
scores = server_results[f"{camera_name}_scores"]

# Invert the rotation immediately so we don't need to worry about
# them henceforth.
# h, w = rgbd.rgb.shape[:2]
# image_rot = rgbd.image_rot
# boxes = [
# _rotate_bounding_box(bb, -image_rot, h, w) for bb in rot_boxes
# ]
# masks = [
# ndimage.rotate(m.squeeze(), -image_rot, reshape=False)
# for m in rot_masks
# ]
boxes = rot_boxes
masks = rot_masks

# Filter out detections by confidence. We threshold detections
# at a set confidence level minimum, and if there are multiple,
# we only select the most confident one. This structure makes
# it easy for us to select multiple detections if that's ever
# necessary in the future.
for obj_id in object_ids:
# If there were no detections (which means all the
# returned values will be numpy arrays of shape (0, 0))
# then just skip this source.
if ret_classes.size == 0:
continue
obj_id_mask = (ret_classes == obj_id.language_id)
if not np.any(obj_id_mask):
continue
max_score = np.max(scores[obj_id_mask])
best_idx = np.where(scores == max_score)[0].item()
if scores[best_idx] < detection_threshold:
continue
# Save the detection.
seg_bb = SegmentedBoundingBox(boxes[best_idx], masks[best_idx],
scores[best_idx])
object_id_to_img_detections[obj_id][rgbd.camera_name] = seg_bb

# import pdb; pdb.set_trace()
return object_id_to_img_detections


def _image_to_bytes(img: PIL.Image.Image) -> io.BytesIO:
"""Helper function to convert from a PIL image into a bytes object."""
Expand Down Expand Up @@ -522,7 +624,8 @@ def visualize_all_artifacts(artifacts: Dict[str,
ax_row[2].imshow(rgbd.depth, cmap='Greys_r', vmin=0, vmax=10000)

# Bounding box.
ax_row[3].imshow(rgbd.rgb)
# ax_row[3].imshow(rgbd.rgb)
ax_row[3].imshow(rgbd.rotated_rgb)
box = seg_bb.bounding_box
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
Expand All @@ -534,7 +637,7 @@ def visualize_all_artifacts(artifacts: Dict[str,
facecolor=(0, 0, 0, 0),
lw=1))

ax_row[4].imshow(seg_bb.mask, cmap="binary_r", vmin=0, vmax=1)
# ax_row[4].imshow(seg_bb.mask, cmap="binary_r", vmin=0, vmax=1)

# Labels.
abbreviated_name = obj_id.language_id
Expand Down
Loading

0 comments on commit 9ce43d3

Please sign in to comment.