From 4e1383676a80cdc651919d7b470cdbc17a28a08f Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Mon, 9 Sep 2024 18:20:50 -0400 Subject: [PATCH 01/24] Initial commit. --- predicators/envs/spot_env.py | 174 +++++++++++++++++- .../ground_truth_models/spot_env/nsrts.py | 3 + .../ground_truth_models/spot_env/options.py | 37 +++- .../perception/perception_structs.py | 15 ++ .../spot_utils/perception/spot_cameras.py | 81 +++++++- predicators/spot_utils/spot_localization.py | 76 ++++---- tests/envs/test_spot_envs.py | 16 +- 7 files changed, 358 insertions(+), 44 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 60ab8e33d2..fb95a4f5a2 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -30,7 +30,7 @@ brush_prompt, bucket_prompt, football_prompt, train_toy_prompt from predicators.spot_utils.perception.perception_structs import \ RGBDImageWithContext -from predicators.spot_utils.perception.spot_cameras import capture_images +from predicators.spot_utils.perception.spot_cameras import capture_images, capture_images_without_context from predicators.spot_utils.skills.spot_find_objects import \ init_search_for_objects from predicators.spot_utils.skills.spot_hand_move import \ @@ -84,6 +84,29 @@ class _SpotObservation: nonpercept_predicates: Set[Predicate] +@dataclass(frozen=True) +class _TruncatedSpotObservation: + """An observation for a SpotEnv.""" + # Camera name to image + images: Dict[str, RGBDImageWithContext] + # Objects that are seen in the current image and their positions in world + objects_in_view: Dict[Object, math_helpers.SE3Pose] + # Objects seen only by the hand camera + objects_in_hand_view: Set[Object] + # Objects seen by any camera except the back camera + objects_in_any_view_except_back: Set[Object] + # Expose the robot object. + robot: Object + # Status of the robot gripper. + gripper_open_percentage: float + # # Robot SE3 Pose + # robot_pos: math_helpers.SE3Pose + # # Ground atoms without ground-truth classifiers + # # A placeholder until all predicates have classifiers + # nonpercept_atoms: Set[GroundAtom] + # nonpercept_predicates: Set[Predicate] + + class _PartialPerceptionState(State): """Some continuous object features, and ground atoms in simulator_state. @@ -158,9 +181,23 @@ def get_robot( return_at_exit=True) assert path.exists() localizer = SpotLocalizer(robot, path, lease_client, lease_keepalive) + # localizer = None return robot, localizer, lease_client +@functools.lru_cache(maxsize=None) +def get_robot_only(self) -> Tuple[Optional[Robot], Optional[LeaseClient]]: + hostname = CFG.spot_robot_ip + sdk = create_standard_sdk("PredicatorsClient-") + robot = sdk.create_robot(hostname) + robot.authenticate("user", "bbbdddaaaiii") + verify_estop(robot) + lease_client = robot.ensure_client(LeaseClient.default_service_name) + lease_client.take() + lease_keepalive = LeaseKeepAlive(lease_client, must_acquire=True, return_at_exit=True) + return robot, lease_client + + @functools.lru_cache(maxsize=None) def get_detection_id_for_object(obj: Object) -> ObjectDetectionID: """Exposed for wrapper and options.""" @@ -228,6 +265,7 @@ def __init__(self, use_gui: bool = True) -> None: if not CFG.bilevel_plan_without_sim: self._initialize_pybullet() _SIMULATED_SPOT_ROBOT = self._sim_robot + import pdb; pdb.set_trace() robot, localizer, lease_client = get_robot() self._robot = robot self._localizer = localizer @@ -1443,12 +1481,21 @@ def _get_sweeping_surface_for_container(container: Object, _IsSemanticallyGreaterThan = Predicate( "IsSemanticallyGreaterThan", [_base_object_type, _base_object_type], _is_semantically_greater_than_classifier) + +def _get_vlm_query_str(pred_name: str, objects: Sequence[Object]) -> str: + return pred_name + "(" + ", ".join(str(obj.name) for obj in objects) + ")" # pragma: no cover +_VLMOn = utils.create_vlm_predicate( + "On" + [_movable_object_type, _immovable_object_type], + _get_vlm_query_str +) + _ALL_PREDICATES = { _NEq, _On, _TopAbove, _Inside, _NotInsideAnyContainer, _FitsInXY, _HandEmpty, _Holding, _NotHolding, _InHandView, _InView, _Reachable, _Blocking, _NotBlocked, _ContainerReadyForSweeping, _IsPlaceable, _IsNotPlaceable, _IsSweeper, _HasFlatTopSurface, _RobotReadyForSweeping, - _IsSemanticallyGreaterThan + _IsSemanticallyGreaterThan, _VLMOn } _NONPERCEPT_PREDICATES: Set[Predicate] = set() @@ -2372,6 +2419,128 @@ def _dry_simulate_pick_and_dump_container( return next_obs +############################################################################### +# VLM Test Env # +############################################################################### +class VLMTestEnv(SpotRearrangementEnv): + """An environment to start testing the VLM pipeline.""" + + @classmethod + def get_name(cls) -> str: + return "spot_vlm_test_env" + + def _create_operators() -> Iterator[STRIPSOperator]: + # Pick object + robot = Variable("?robot", _robot_type) + obj = Variable("?object", _movable_object_type) + table = Variable("?table", _immovable_object_type) + parameters = [robot, obj, table] + preconds: Set[LiftedAtom] = { + LiftedAtom(_HandEmpty, [robot]), + LiftedAtom(_NotHolding, [robot, obj]), + LiftedAtom(_VLMOn, [obj, table]) + } + add_effs: Set[LiftedAtom] = { + LiftedAtom(_Holding, [robot, obj]) + } + del_effs: Set[LiftedAtom] = { + LiftedAtom(_HandEmpty, [robot]), + LiftedAtom(_NotHolding, [robot, obj]), + LiftedAtom(_VLMOn, [obj, table]) + } + ignore_effs: Set[LiftedAtom] = set() + yield STRIPSOperator("Pick", parameters, preconds, add_effs, del_effs, ignore_effs) + + # Place object + robot = Variable("?robot", _robot_type) + obj = Variable("?object", _movable_object_type) + pan = Variable("?pan", _container_type) + parameters = [robot, obj, pan] + preconds: Set[LiftedAtom] = { + LiftedAtom(_Holding, [robot, obj]) + } + add_effs: Set[LiftedAtom] = { + LiftedAtom(_HandEmpty, [robot]), + LiftedAtom(_NotHolding, [robot, obj]), + LiftedAtom(_VLMOn, [obj, pan]) + } + del_effs: Set[LiftedAtom] = { + LiftedAtom(_Holding, [robot, obj]) + } + ignore_effs: Set[LiftedAtom] = set() + yield STRIPSOperator("Place", parameters, preconds, add_effs, del_effs, ignore_effs) + + # def _generate_train_tasks(self) -> List[EnvironmentTask]: + # goal = self._generate_goal_description() # currently just one goal + # return [ + # EnvironmentTask(None, goal) for _ in range(CFG.num_train_tasks) + # ] + + def _generate_test_tasks(self) -> List[EnvironmentTask]: + goal = self._generate_goal_description() # currently just one goal + return [EnvironmentTask(None, goal) for _ in range(CFG.num_test_tasks)] + + def __init__(self, use_ui: bool = True) -> None: + super().__init__(use_gui) + robot, lease_client = get_robot_only() + self._robot = robot + self._lease_cient = lease_client + self._strips_operators: Set[STRIPSOperator] = set() + # Used to do [something] when the agent thinks the goal is reached + # but the human says it is not. + self._current_task_goal_reached = False + # Used when we want to doa special check for a specific + # action. + self._last_action: Optional[Action] = None + # Create constant objects. + self._spot_object = Object("robot", _robot_type) + op_to_name = {o.name for o in _create_operators()} + op_names_to_keep = { + "Pick", + "Place" + } + self._strips_operators = {op_to_name[o] for o in op_names_to_keep} + + def _actively_construct_env_task(self) -> EnvironmentTask: + assert self._robot is not None + rgbd_images = capture_images_without_context(self._robot) + gripper_open_percentage = get_robot_gripper_open_percentage(self._robot) + obs = _TruncatedSpotObservation( + rgbd_images, + dict(), + set(), + set(), + self._spot_object, + gripper_open_percentage + ) + goal_description = self._generate_goal_description() + task = EnvironmentTask(obs, goal_description) + return task + + def _generate_goal_description(self) -> GoalDescription: + return "put the cup in the pan." + + def reset(self, train_or_test: str, task_idx: int) -> Observation: + prompt = f"Please set up {train_or_test} task {task_idx}!" + utils.prompt_user(prompt) + assert self._lease_client is not None + # Automatically retry if a retryable error is encountered. + while True: + try: + self._lease_client.take() + self._current_task = self._actively_construct_env_task() + break + except RetryableRpcError as e: + logging.warning("WARNING: the following retryable error " + f"was encountered. Trying again.\n{e}") + self._current_observation = self._current_task.init_obs + self._current_task_goal_reached = False + self._last_action = None + return self._current_task.init_obs + + def step(self, action: Action) -> Observation: + pass + ############################################################################### # Cube Table Env # ############################################################################### @@ -2382,6 +2551,7 @@ class SpotCubeEnv(SpotRearrangementEnv): attempts to place an April Tag cube onto a particular table.""" def __init__(self, use_gui: bool = True) -> None: + import pdb; pdb.set_trace() super().__init__(use_gui) op_to_name = {o.name: o for o in _create_operators()} diff --git a/predicators/ground_truth_models/spot_env/nsrts.py b/predicators/ground_truth_models/spot_env/nsrts.py index 8ab36470a6..7d4e5dad45 100644 --- a/predicators/ground_truth_models/spot_env/nsrts.py +++ b/predicators/ground_truth_models/spot_env/nsrts.py @@ -285,6 +285,7 @@ class SpotEnvsGroundTruthNSRTFactory(GroundTruthNSRTFactory): @classmethod def get_env_names(cls) -> Set[str]: return { + "spot_vlm_test_env", "spot_cube_env", "spot_soda_floor_env", "spot_soda_table_env", "spot_soda_bucket_env", "spot_soda_chair_env", "spot_main_sweep_env", "spot_ball_and_cup_sticky_table_env", @@ -320,6 +321,8 @@ def get_nsrts(env_name: str, types: Dict[str, Type], "PrepareContainerForSweeping": _prepare_sweeping_sampler, "DropNotPlaceableObject": utils.null_sampler, "MoveToReadySweep": utils.null_sampler, + "Pick": utils.null_sampler, + "Place": utils.null_sampler } # If we're doing proper bilevel planning with a simulator, then diff --git a/predicators/ground_truth_models/spot_env/options.py b/predicators/ground_truth_models/spot_env/options.py index 4729a50493..f28372b753 100644 --- a/predicators/ground_truth_models/spot_env/options.py +++ b/predicators/ground_truth_models/spot_env/options.py @@ -2,18 +2,20 @@ import time from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple +import logging import numpy as np import pbrspot from bosdyn.client import math_helpers from bosdyn.client.sdk import Robot +from bosdyn.client.lease import LeaseClient from gym.spaces import Box from predicators import utils from predicators.envs import get_or_create_env from predicators.envs.spot_env import HANDEMPTY_GRIPPER_THRESHOLD, \ SpotRearrangementEnv, _get_sweeping_surface_for_container, \ - get_detection_id_for_object, get_robot, \ + get_detection_id_for_object, get_robot, get_robot_only, \ get_robot_gripper_open_percentage, get_simulated_object, \ get_simulated_robot from predicators.ground_truth_models import GroundTruthOptionFactory @@ -897,6 +899,32 @@ def _move_to_ready_sweep_policy(state: State, memory: Dict, robot_obj_idx, target_obj_idx, do_gaze, state, memory, objects, params) +def _teleop_policy(state: State, memory: Dict, objects: Sequence[Object], params: Array) -> Action: + del state, memory, params + + robot, lease_client = get_robot_only() + + def _teleop(robot: Robot, lease_client: LeaseClient): + prompt = "Press (y) when you are done with teleop." + while True: + response = utils.prompt_user(prompt).strip() + if response == "y": + break + logging.info("Invalid input. Press (y) when y") + # Take back control. + robot, lease_client = get_robot_only() + lease_client.take() + + fn = _teleop + fn_args = (robot, lease_client) + sim_fn = lambda _: None + sim_fn_args = () + name = "teleop" + action_extra_info = SpotActionExtraInfo( + name, objects, fn, fn_args, sim_fn, sim_fn_args + ) + return utils.create_spot_env_action(action_extra_info) + ############################################################################### # Parameterized option factory # @@ -928,7 +956,9 @@ def _move_to_ready_sweep_policy(state: State, memory: Dict, "PrepareContainerForSweeping": Box(-np.inf, np.inf, (3, )), # dx, dy, dyaw "DropNotPlaceableObject": Box(0, 1, (0, )), # empty "MoveToReadySweep": Box(0, 1, (0, )), # empty -} + "Pick": Box(0, 1, (0, )), # empty + "Place": Box(0, 1, (0, )) # empty +} # NOTE: the policies MUST be unique because they output actions with extra info # that includes the name of the operators. @@ -951,6 +981,8 @@ def _move_to_ready_sweep_policy(state: State, memory: Dict, "PrepareContainerForSweeping": _prepare_container_for_sweeping_policy, "DropNotPlaceableObject": _drop_not_placeable_object_policy, "MoveToReadySweep": _move_to_ready_sweep_policy, + "Pick": _teleop_policy, + "Place": _teleop_policy } @@ -987,6 +1019,7 @@ class SpotEnvsGroundTruthOptionFactory(GroundTruthOptionFactory): @classmethod def get_env_names(cls) -> Set[str]: return { + "spot_vlm_test_env" "spot_cube_env", "spot_soda_floor_env", "spot_soda_table_env", diff --git a/predicators/spot_utils/perception/perception_structs.py b/predicators/spot_utils/perception/perception_structs.py index 321907c3b6..d676d20cf2 100644 --- a/predicators/spot_utils/perception/perception_structs.py +++ b/predicators/spot_utils/perception/perception_structs.py @@ -29,6 +29,21 @@ def rotated_rgb(self) -> NDArray[np.uint8]: """The image rotated to be upright.""" return ndimage.rotate(self.rgb, self.image_rot, reshape=False) +@dataclass +class RGBDImage: + """An RGBD image""" + rgb: NDArray[np.uint8] + depth: NDArray[np.uint16] + image_rot: float + camera_name: str + depth_scale: float + camera_model: Any # bosdyn.api.image_pb2.PinholeModel, but not available + + @property + def rotated_rgb(self) -> NDArray[np.uint8]: + """The image rotated to be upright.""" + return ndimage.rotate(self.rgb, self.image_rot, reshape=False) + @dataclass(frozen=True) class ObjectDetectionID: diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index cbc8d4dff3..7bf5a1fae5 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -10,7 +10,7 @@ from numpy.typing import NDArray from predicators.spot_utils.perception.perception_structs import \ - RGBDImageWithContext + RGBDImageWithContext, RGBDImage from predicators.spot_utils.spot_localization import SpotLocalizer ROTATION_ANGLE = { @@ -120,6 +120,85 @@ def capture_images( return rgbds +def capture_images_without_context( + robot: Robot, + camera_names: Optional[Collection[str]] = None, + quality_percent: int = 100, +) -> Dict[str, RGBDImageWithContext]: + """Build an image request and get the responses. + + If no camera names are provided, all RGB cameras are used. + """ + # global _LAST_CAPTURED_IMAGES # pylint: disable=global-statement + + if camera_names is None: + camera_names = set(RGB_TO_DEPTH_CAMERAS) + + image_client = robot.ensure_client(ImageClient.default_service_name) + + rgbds: Dict[str, RGBDImage] = {} + + # # Get the world->robot transform so we can store world->camera transforms + # # in the RGBDWithContexts. + # if relocalize: + # localizer.localize() + # world_tform_body = localizer.get_last_robot_pose() + # body_tform_world = world_tform_body.inverse() + + # Package all the requests together. + img_reqs: image_pb2.ImageRequest = [] + for camera_name in camera_names: + # Build RGB image request. + if "hand" in camera_name: + rgb_pixel_format = None + else: + rgb_pixel_format = image_pb2.Image.PIXEL_FORMAT_RGB_U8 # pylint: disable=no-member + rgb_img_req = build_image_request(camera_name, + quality_percent=quality_percent, + pixel_format=rgb_pixel_format) + img_reqs.append(rgb_img_req) + # Build depth image request. + depth_camera_name = RGB_TO_DEPTH_CAMERAS[camera_name] + depth_img_req = build_image_request(depth_camera_name, + quality_percent=quality_percent, + pixel_format=None) + img_reqs.append(depth_img_req) + + # Send the request. + responses = image_client.get_image(img_reqs) + name_to_response = {r.source.name: r for r in responses} + + # Build RGBDImageWithContexts. + for camera_name in camera_names: + rgb_img_resp = name_to_response[camera_name] + depth_img_resp = name_to_response[RGB_TO_DEPTH_CAMERAS[camera_name]] + rgb_img = _image_response_to_image(rgb_img_resp) + depth_img = _image_response_to_image(depth_img_resp) + # # Create transform. + # camera_tform_body = get_a_tform_b( + # rgb_img_resp.shot.transforms_snapshot, + # rgb_img_resp.shot.frame_name_image_sensor, BODY_FRAME_NAME) + # camera_tform_world = camera_tform_body * body_tform_world + # world_tform_camera = camera_tform_world.inverse() + # Extract other context. + rot = ROTATION_ANGLE[camera_name] + depth_scale = depth_img_resp.source.depth_scale + # transforms_snapshot = rgb_img_resp.shot.transforms_snapshot + # frame_name_image_sensor = rgb_img_resp.shot.frame_name_image_sensor + camera_model = rgb_img_resp.source.pinhole + # Finish RGBDImageWithContext. + # rgbd = RGBDImageWithContext(rgb_img, depth_img, rot, camera_name, + # world_tform_camera, depth_scale, + # transforms_snapshot, + # frame_name_image_sensor, camera_model) + rgbd = RGBDImage(rgb_img, depth_img, rot, camera_name, depth_scale, camera_model) + rgbds[camera_name] = rgbd + + # _LAST_CAPTURED_IMAGES = rgbds + + return rgbds + + def _image_response_to_image( image_response: image_pb2.ImageResponse, ) -> NDArray: """Extract an image from an image response. diff --git a/predicators/spot_utils/spot_localization.py b/predicators/spot_utils/spot_localization.py index 3c4c557b6d..27283b248e 100644 --- a/predicators/spot_utils/spot_localization.py +++ b/predicators/spot_utils/spot_localization.py @@ -55,26 +55,28 @@ def __init__(self, robot: Robot, upload_path: Path, self._robot_pose = math_helpers.SE3Pose(0, 0, 0, math_helpers.Quat()) # Initialize the robot's position in the map. robot_state = get_robot_state(self._robot) - current_odom_tform_body = get_odom_tform_body( - robot_state.kinematic_state.transforms_snapshot).to_proto() - localization = nav_pb2.Localization() - for r in range(NUM_LOCALIZATION_RETRIES + 1): - try: - self.graph_nav_client.set_localization( - initial_guess_localization=localization, - ko_tform_body=current_odom_tform_body) - break - except (ResponseError, TimedOutError) as e: - # Retry or fail. - if r == NUM_LOCALIZATION_RETRIES: - msg = f"Localization failed permanently: {e}." - logging.warning(msg) - raise LocalizationFailure(msg) - logging.warning("Localization failed once, retrying.") - time.sleep(LOCALIZATION_RETRY_WAIT_TIME) - - # Run localize once to start. - self.localize() + z_position = robot_state.kinematic_state.transforms_snapshot.child_to_parent_edge_map["gpe"].parent_tform_child.position.z + import pdb; pdb.set_trace() + # current_odom_tform_body = get_odom_tform_body( + # robot_state.kinematic_state.transforms_snapshot).to_proto() + # localization = nav_pb2.Localization() + # for r in range(NUM_LOCALIZATION_RETRIES + 1): + # try: + # self.graph_nav_client.set_localization( + # initial_guess_localization=localization, + # ko_tform_body=current_odom_tform_body) + # break + # except (ResponseError, TimedOutError) as e: + # # Retry or fail. + # if r == NUM_LOCALIZATION_RETRIES: + # msg = f"Localization failed permanently: {e}." + # logging.warning(msg) + # raise LocalizationFailure(msg) + # logging.warning("Localization failed once, retrying.") + # time.sleep(LOCALIZATION_RETRY_WAIT_TIME) + + # # Run localize once to start. + # self.localize() def _upload_graph_and_snapshots(self) -> None: """Upload the graph and snapshots to the robot.""" @@ -133,23 +135,23 @@ def localize(self, It's good practice to call this periodically to avoid drift issues. April tags need to be in view. """ - try: - localization_state = self.graph_nav_client.get_localization_state() - transform = localization_state.localization.seed_tform_body - if str(transform) == "": - raise LocalizationFailure("Received empty localization state.") - except (ResponseError, TimedOutError, LocalizationFailure) as e: - # Retry or fail. - if num_retries <= 0: - msg = f"Localization failed permanently: {e}." - logging.warning(msg) - raise LocalizationFailure(msg) - logging.warning("Localization failed once, retrying.") - time.sleep(retry_wait_time) - return self.localize(num_retries=num_retries - 1, - retry_wait_time=retry_wait_time) - logging.info("Localization succeeded.") - self._robot_pose = math_helpers.SE3Pose.from_proto(transform) + # try: + # localization_state = self.graph_nav_client.get_localization_state() + # transform = localization_state.localization.seed_tform_body + # if str(transform) == "": + # raise LocalizationFailure("Received empty localization state.") + # except (ResponseError, TimedOutError, LocalizationFailure) as e: + # # Retry or fail. + # if num_retries <= 0: + # msg = f"Localization failed permanently: {e}." + # logging.warning(msg) + # raise LocalizationFailure(msg) + # logging.warning("Localization failed once, retrying.") + # time.sleep(retry_wait_time) + # return self.localize(num_retries=num_retries - 1, + # retry_wait_time=retry_wait_time) + # logging.info("Localization succeeded.") + # self._robot_pose = math_helpers.SE3Pose.from_proto(transform) return None diff --git a/tests/envs/test_spot_envs.py b/tests/envs/test_spot_envs.py index 6f9fdbffba..f05deceff3 100644 --- a/tests/envs/test_spot_envs.py +++ b/tests/envs/test_spot_envs.py @@ -25,6 +25,12 @@ from predicators.structs import Action, GroundAtom, _GroundNSRT +def test_spot_vlm_debug(): + utils.reset_config({ + "env": "spot_vlm_test_env", + }) + pass + def test_spot_env_dry_run(): """Dry run tests (do not require access to robot).""" utils.reset_config({ @@ -265,8 +271,13 @@ def real_robot_cube_env_test() -> None: "test_task_json_dir": args.get("test_task_json_dir", None), }) + import pdb; pdb.set_trace() rng = np.random.default_rng(123) + import os + os.environ["BOSDYN_CLIENT_USERNAME"] = "user" + os.environ["BOSDYN_CLIENT_PASSWORD"] = "bbbdddaaaiii" env = SpotCubeEnv() + import pdb; pdb.set_trace() perceiver = SpotPerceiver() nsrts = get_gt_nsrts(env.get_name(), env.predicates, get_gt_options(env.get_name())) @@ -275,6 +286,7 @@ def real_robot_cube_env_test() -> None: task = env.get_test_tasks()[0] obs = env.reset("test", 0) perceiver.reset(task) + import pdb; pdb.set_trace() assert len(obs.objects_in_view) == 4 cube, floor, table1, table2 = sorted(obs.objects_in_view) assert cube.name == "cube" @@ -646,6 +658,6 @@ def real_robot_sweeping_nsrt_test() -> None: if __name__ == "__main__": - # real_robot_cube_env_test() + real_robot_cube_env_test() # real_robot_drafting_table_placement_test() - real_robot_sweeping_nsrt_test() + # real_robot_sweeping_nsrt_test() From 0a25b2d83e29e86a94bcd08acb0e94b3d47557b1 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Tue, 10 Sep 2024 15:00:09 -0400 Subject: [PATCH 02/24] Some more progress. --- predicators/envs/spot_env.py | 11 +- predicators/perception/spot_perceiver.py | 139 +++++++++++++++++++++++ 2 files changed, 145 insertions(+), 5 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index fb95a4f5a2..d204fadb9e 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -89,8 +89,8 @@ class _TruncatedSpotObservation: """An observation for a SpotEnv.""" # Camera name to image images: Dict[str, RGBDImageWithContext] - # Objects that are seen in the current image and their positions in world - objects_in_view: Dict[Object, math_helpers.SE3Pose] + # Objects in the environment + objects_in_view: Set[Object] # Objects seen only by the hand camera objects_in_hand_view: Set[Object] # Objects seen by any camera except the back camera @@ -1485,7 +1485,7 @@ def _get_sweeping_surface_for_container(container: Object, def _get_vlm_query_str(pred_name: str, objects: Sequence[Object]) -> str: return pred_name + "(" + ", ".join(str(obj.name) for obj in objects) + ")" # pragma: no cover _VLMOn = utils.create_vlm_predicate( - "On" + "VLMOn" [_movable_object_type, _immovable_object_type], _get_vlm_query_str ) @@ -2505,9 +2505,10 @@ def _actively_construct_env_task(self) -> EnvironmentTask: assert self._robot is not None rgbd_images = capture_images_without_context(self._robot) gripper_open_percentage = get_robot_gripper_open_percentage(self._robot) + objects_in_view = [] obs = _TruncatedSpotObservation( rgbd_images, - dict(), + set(objects_in_view), set(), set(), self._spot_object, @@ -2518,7 +2519,7 @@ def _actively_construct_env_task(self) -> EnvironmentTask: return task def _generate_goal_description(self) -> GoalDescription: - return "put the cup in the pan." + return "put the cup in the pan" def reset(self, train_or_test: str, task_idx: int) -> Observation: prompt = f"Please set up {train_or_test} task {task_idx}!" diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index d39107a060..c33a480f8e 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -577,3 +577,142 @@ def render_mental_images(self, observation: Observation, logging.info(f"Wrote out to {outfile}") plt.close() return [img] + + +class SpotMinimumPerceiver(BasePerceiver): + """A perceiver for spot envs with minimal functionality.""" + + @classmethod + def get_name(cls) -> str: + return "spot_minimal_perceiver" + + def __init__(self) -> None: + super().__init__() + # self._known_object_poses: Dict[Object, math_helpers.SE3Pose] = {} + # self._objects_in_view: Set[Object] = set() + # self._objects_in_hand_view: Set[Object] = set() + # self._objects_in_any_view_except_back: Set[Object] = set() + self._robot: Optional[Object] = None + # self._nonpercept_atoms: Set[GroundAtom] = set() + # self._nonpercept_predicates: Set[Predicate] = set() + # self._percept_predicates: Set[Predicate] = set() + self._prev_action: Optional[Action] = None + self._held_object: Optional[Object] = None + self._gripper_open_percentage = 0.0 + self._robot_pos: math_helpers.SE3Pose = math_helpers.SE3Pose( + 0, 0, 0, math_helpers.Quat()) + # self._lost_objects: Set[Object] = set() + self._curr_env: Optional[BaseEnv] = None + self._waiting_for_observation = True + self._ordered_objects: List[Object] = [] # list of all known objects + # # Keep track of objects that are contained (out of view) in another + # # object, like a bag or bucket. This is important not only for gremlins + # # but also for small changes in the container's perceived pose. + # self._container_to_contained_objects: Dict[Object, Set[Object]] = {} + # Load static, hard-coded features of objects, like their shapes. + # meta = load_spot_metadata() + # self._static_object_features = meta.get("static-object-features", {}) + + def update_perceiver_with_action(self, action: Action) -> None: + # NOTE: we need to keep track of the previous action + # because the step function (where we need knowledge + # of the previous action) occurs *after* the action + # has already been taken. + self._prev_action = action + + def _create_goal(self, state: State, + goal_description: GoalDescription) -> Set[GroundAtom]: + del state # not used + # Unfortunate hack to deal with the fact that the state is actually + # not yet set. Hopefully one day other cleanups will enable cleaning. + assert self._curr_env is not None + pred_name_to_pred = {p.name: p for p in self._curr_env.predicates} + VLMOn = pred_name_to_pred["VLMOn"] + HandEmpty = pred_name_to_pred["HandEmpty"] + if goal_description == "put the cup in the pan": + robot = Object("robot", _robot_type) + cup = Object("cup", _movable_object_type) + pan = Object("pan", _container_type) + goal = { + GroundAtom(HandEmpty, [robot]), + GroundAtom(VLMOn, [cup, pan]) + } + return goal + raise NotImplementedError("Unrecognized goal description") + + def _create_state(self) -> State: + if self._waiting_for_observation: + return DefaultState + # Build the continuous part of the state. + assert self._robot is not None + table = Object("talbe", _immovable_object_type) + cup = Object("cup", _movable_object_type) + pan = Object("pan", _container_type) + state_dict = { + self._robot: { + "gripper_open_percentage": self._gripper_open_percentage, + "x": self._robot_pos.x, + "y": self._robot_pos.y, + "z": self._robot_pos.z, + "qw": self._robot_pos.rot.w, + "qx": self._robot_pos.rot.x, + "qy": self._robot_pos.rot.y, + "qz": self._robot_pos.rot.z, + }, + table: { + "x": 0, + "y": 0, + "z": 0, + "qw": 0, + "qx": 0, + "qy": 0, + "qz": 0, + "shape": 0, + "height": 0, + "width" : 0, + "length": 0, + "object_id": 1, + "flat_top_surface": 1 + }, + cup: { + "x": 0, + "y": 0, + "z": 0, + "qw": 0, + "qx": 0, + "qy": 0, + "qz": 0, + "shape": 0, + "height": 0, + "width" : 0, + "length": 0, + "object_id": 2, + "placeable": 1, + "held": 0, + "lost": 0, + "in_hand_view": 0, + "in_view": 1, + "is_sweeper": 0 + }, + pan: { + "x": 0, + "y": 0, + "z": 0, + "qw": 0, + "qx": 0, + "qy": 0, + "qz": 0, + "shape": 0, + "height": 0, + "width" : 0, + "length": 0, + "object_id": 3, + "placeable": 1, + "held": 0, + "lost": 0, + "in_hand_view": 0, + "in_view": 1, + "is_sweeper": 0 + } + } + From e62e8cf07af20824c9d0c2b611ede8676b84cadb Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Tue, 10 Sep 2024 17:27:12 -0400 Subject: [PATCH 03/24] Some more progress. --- predicators/envs/spot_env.py | 14 ++++++++++++- predicators/perception/spot_perceiver.py | 26 +++++++++++++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index d204fadb9e..ade6d2e548 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -2540,7 +2540,19 @@ def reset(self, train_or_test: str, task_idx: int) -> Observation: return self._current_task.init_obs def step(self, action: Action) -> Observation: - pass + assert self._robot is not None + rgbd_images = capture_images_without_context(self._robot) + gripper_open_percentage = get_robot_gripper_open_percentage(self._robot) + objects_in_view = [] + obs = _TruncatedSpotObservation( + rgbd_images, + set(objects_in_view), + set(), + set(), + self._spot_object, + gripper_open_percentage + ) + return obs ############################################################################### # Cube Table Env # diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index c33a480f8e..6f037f7f1e 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -640,6 +640,30 @@ def _create_goal(self, state: State, return goal raise NotImplementedError("Unrecognized goal description") + def update_perceiver_with_action(self, action: Action) -> None: + # NOTE: we need to keep track of the previous action + # because the step function (where we need knowledge + # of the previous action) occurs *after* the action + # has already been taken. + self._prev_action = action + + def reset(self, env_task: EnvironmentTask) -> Task: + init_obs = env_task.init_bos + imgs = init_obs.rgbd_images + self._robot = init_obs.robot + state = self._create_state() + state.simulator_state["images"] = [imgs] + state.set(self._robot, "gripper_open_percentage") = init_obs.gripper_open_percentage + self._curr_state = state + goal = self._create_goal(state, env_task.goal_description) + return Task(state, goal) + + def step(self, observation: Observation) -> State: + imgs = observation.rgbd_images + self._curr_state.simulator_state["images"].append([imgs]) + self._curr_state.set(self._robot, "gripper_open_percentage") = observation.gripper_open_percentage + return self._curr_state.copy() + def _create_state(self) -> State: if self._waiting_for_observation: return DefaultState @@ -715,4 +739,4 @@ def _create_state(self) -> State: "is_sweeper": 0 } } - + return State(state_dict) From aed57089aefcad81aa62056321187c80279eaa77 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Tue, 10 Sep 2024 19:51:25 -0400 Subject: [PATCH 04/24] More progress after pair programming. --- predicators/approaches/base_approach.py | 3 +- .../approaches/bilevel_planning_approach.py | 1 + .../approaches/spot_wrapper_approach.py | 2 +- predicators/envs/spot_env.py | 38 +++++++++---- .../ground_truth_models/spot_env/options.py | 2 +- predicators/perception/spot_perceiver.py | 55 +++++++++++++------ predicators/pretrained_model_interface.py | 6 +- predicators/structs.py | 11 +++- 8 files changed, 80 insertions(+), 38 deletions(-) diff --git a/predicators/approaches/base_approach.py b/predicators/approaches/base_approach.py index e780b822ff..2a1fb50ccc 100644 --- a/predicators/approaches/base_approach.py +++ b/predicators/approaches/base_approach.py @@ -11,7 +11,7 @@ from predicators.structs import Action, Dataset, InteractionRequest, \ InteractionResult, Metrics, ParameterizedOption, Predicate, State, Task, \ Type -from predicators.utils import ExceptionWithInfo +from predicators.utils import ExceptionWithInfo, create_vlm_by_name class BaseApproach(abc.ABC): @@ -29,6 +29,7 @@ def __init__(self, initial_predicates: Set[Predicate], self._train_tasks = train_tasks self._metrics: Metrics = defaultdict(float) self._set_seed(CFG.seed) + self._vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover @classmethod @abc.abstractmethod diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index cb8ca38347..79db49699c 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -58,6 +58,7 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: seed = self._seed + self._num_calls nsrts = self._get_current_nsrts() preds = self._get_current_predicates() + import pdb; pdb.set_trace() # Run task planning only and then greedily sample and execute in the # policy. diff --git a/predicators/approaches/spot_wrapper_approach.py b/predicators/approaches/spot_wrapper_approach.py index d0c2aae1f1..eda1ed5a1d 100644 --- a/predicators/approaches/spot_wrapper_approach.py +++ b/predicators/approaches/spot_wrapper_approach.py @@ -55,7 +55,7 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: def _policy(state: State) -> Action: nonlocal base_approach_policy, need_stow # If we think that we're done, return the done action. - if task.goal_holds(state): + if task.goal_holds(state, self._vlm): extra_info = SpotActionExtraInfo("done", [], None, tuple(), None, tuple()) return utils.create_spot_env_action(extra_info) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index ade6d2e548..17ce0623f6 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -88,7 +88,7 @@ class _SpotObservation: class _TruncatedSpotObservation: """An observation for a SpotEnv.""" # Camera name to image - images: Dict[str, RGBDImageWithContext] + rgbd_images: Dict[str, RGBDImageWithContext] # Objects in the environment objects_in_view: Set[Object] # Objects seen only by the hand camera @@ -186,7 +186,7 @@ def get_robot( @functools.lru_cache(maxsize=None) -def get_robot_only(self) -> Tuple[Optional[Robot], Optional[LeaseClient]]: +def get_robot_only() -> Tuple[Optional[Robot], Optional[LeaseClient]]: hostname = CFG.spot_robot_ip sdk = create_standard_sdk("PredicatorsClient-") robot = sdk.create_robot(hostname) @@ -265,7 +265,6 @@ def __init__(self, use_gui: bool = True) -> None: if not CFG.bilevel_plan_without_sim: self._initialize_pybullet() _SIMULATED_SPOT_ROBOT = self._sim_robot - import pdb; pdb.set_trace() robot, localizer, lease_client = get_robot() self._robot = robot self._localizer = localizer @@ -1485,9 +1484,9 @@ def _get_sweeping_surface_for_container(container: Object, def _get_vlm_query_str(pred_name: str, objects: Sequence[Object]) -> str: return pred_name + "(" + ", ".join(str(obj.name) for obj in objects) + ")" # pragma: no cover _VLMOn = utils.create_vlm_predicate( - "VLMOn" - [_movable_object_type, _immovable_object_type], - _get_vlm_query_str + "VLMOn", + [_movable_object_type, _base_object_type], + lambda o: _get_vlm_query_str("VLMOn", o) ) _ALL_PREDICATES = { @@ -2428,8 +2427,17 @@ class VLMTestEnv(SpotRearrangementEnv): @classmethod def get_name(cls) -> str: return "spot_vlm_test_env" + + def _get_dry_task(self, train_or_test: str, + task_idx: int) -> EnvironmentTask: + raise NotImplementedError("No dry task for VLMTestEnv.") - def _create_operators() -> Iterator[STRIPSOperator]: + @property + def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]: + """Get an object from a perception detection ID.""" + raise NotImplementedError("No dry task for VLMTestEnv.") + + def _create_operators(self) -> Iterator[STRIPSOperator]: # Pick object robot = Variable("?robot", _robot_type) obj = Variable("?object", _movable_object_type) @@ -2480,11 +2488,16 @@ def _generate_test_tasks(self) -> List[EnvironmentTask]: goal = self._generate_goal_description() # currently just one goal return [EnvironmentTask(None, goal) for _ in range(CFG.num_test_tasks)] - def __init__(self, use_ui: bool = True) -> None: - super().__init__(use_gui) + def _generate_train_tasks(self) -> List[EnvironmentTask]: + goal = self._generate_goal_description() # currently just one goal + return [ + EnvironmentTask(None, goal) for _ in range(CFG.num_train_tasks) + ] + + def __init__(self, use_gui: bool = True) -> None: robot, lease_client = get_robot_only() self._robot = robot - self._lease_cient = lease_client + self._lease_client = lease_client self._strips_operators: Set[STRIPSOperator] = set() # Used to do [something] when the agent thinks the goal is reached # but the human says it is not. @@ -2494,12 +2507,14 @@ def __init__(self, use_ui: bool = True) -> None: self._last_action: Optional[Action] = None # Create constant objects. self._spot_object = Object("robot", _robot_type) - op_to_name = {o.name for o in _create_operators()} + op_to_name = {o.name: o for o in self._create_operators()} op_names_to_keep = { "Pick", "Place" } self._strips_operators = {op_to_name[o] for o in op_names_to_keep} + self._train_tasks = [] + self._test_tasks = [] def _actively_construct_env_task(self) -> EnvironmentTask: assert self._robot is not None @@ -2564,7 +2579,6 @@ class SpotCubeEnv(SpotRearrangementEnv): attempts to place an April Tag cube onto a particular table.""" def __init__(self, use_gui: bool = True) -> None: - import pdb; pdb.set_trace() super().__init__(use_gui) op_to_name = {o.name: o for o in _create_operators()} diff --git a/predicators/ground_truth_models/spot_env/options.py b/predicators/ground_truth_models/spot_env/options.py index f28372b753..9f786eab5b 100644 --- a/predicators/ground_truth_models/spot_env/options.py +++ b/predicators/ground_truth_models/spot_env/options.py @@ -1019,7 +1019,7 @@ class SpotEnvsGroundTruthOptionFactory(GroundTruthOptionFactory): @classmethod def get_env_names(cls) -> Set[str]: return { - "spot_vlm_test_env" + "spot_vlm_test_env", "spot_cube_env", "spot_soda_floor_env", "spot_soda_table_env", diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 6f037f7f1e..fdcb53e35f 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -579,9 +579,13 @@ def render_mental_images(self, observation: Observation, return [img] -class SpotMinimumPerceiver(BasePerceiver): +class SpotMinimalPerceiver(BasePerceiver): """A perceiver for spot envs with minimal functionality.""" + def render_mental_images(self, observation: Observation, + env_task: EnvironmentTask) -> Video: + raise NotImplementedError() + @classmethod def get_name(cls) -> str: return "spot_minimal_perceiver" @@ -648,40 +652,51 @@ def update_perceiver_with_action(self, action: Action) -> None: self._prev_action = action def reset(self, env_task: EnvironmentTask) -> Task: - init_obs = env_task.init_bos - imgs = init_obs.rgbd_images - self._robot = init_obs.robot + # import pdb; pdb.set_trace() + # init_obs = env_task.init_obs + # imgs = init_obs.rgbd_images + # self._robot = init_obs.robot + # state = self._create_state() + # state.simulator_state["images"] = [imgs] + # state.set(self._robot, "gripper_open_percentage", init_obs.gripper_open_percentage) + # self._curr_state = state + self._curr_env = get_or_create_env(CFG.env) state = self._create_state() - state.simulator_state["images"] = [imgs] - state.set(self._robot, "gripper_open_percentage") = init_obs.gripper_open_percentage + state.simulator_state = {} + state.simulator_state["images"] = [] self._curr_state = state goal = self._create_goal(state, env_task.goal_description) return Task(state, goal) def step(self, observation: Observation) -> State: + self._waiting_for_observation = False + self._robot = observation.robot imgs = observation.rgbd_images - self._curr_state.simulator_state["images"].append([imgs]) - self._curr_state.set(self._robot, "gripper_open_percentage") = observation.gripper_open_percentage - return self._curr_state.copy() + imgs = [v.rgb for _, v in imgs.items()] + self._curr_state = self._create_state() + self._curr_state.simulator_state["images"] = imgs + self._gripper_open_percentage = observation.gripper_open_percentage + ret_state = self._curr_state.copy() + return ret_state def _create_state(self) -> State: if self._waiting_for_observation: return DefaultState # Build the continuous part of the state. assert self._robot is not None - table = Object("talbe", _immovable_object_type) + table = Object("table", _immovable_object_type) cup = Object("cup", _movable_object_type) pan = Object("pan", _container_type) state_dict = { self._robot: { "gripper_open_percentage": self._gripper_open_percentage, - "x": self._robot_pos.x, - "y": self._robot_pos.y, - "z": self._robot_pos.z, - "qw": self._robot_pos.rot.w, - "qx": self._robot_pos.rot.x, - "qy": self._robot_pos.rot.y, - "qz": self._robot_pos.rot.z, + "x": 0, + "y": 0, + "z": 0, + "qw": 0, + "qx": 0, + "qy": 0, + "qz": 0, }, table: { "x": 0, @@ -739,4 +754,8 @@ def _create_state(self) -> State: "is_sweeper": 0 } } - return State(state_dict) + state_dict = {k: list(v.values()) for k, v in state_dict.items()} + ret_state = State(state_dict) + ret_state.simulator_state = {} + ret_state.simulator_state["images"] = [] + return ret_state diff --git a/predicators/pretrained_model_interface.py b/predicators/pretrained_model_interface.py index dfd779b8f6..9fc07a6004 100644 --- a/predicators/pretrained_model_interface.py +++ b/predicators/pretrained_model_interface.py @@ -102,7 +102,7 @@ def sample_completions(self, if not os.path.exists(cache_filepath): if CFG.llm_use_cache_only: raise ValueError("No cached response found for prompt.") - logging.debug(f"Querying model {model_id} with new prompt.") + print(f"Querying model {model_id} with new prompt.") # Query the model. completions = self._sample_completions(prompt, imgs, temperature, seed, stop_token, @@ -118,11 +118,11 @@ def sample_completions(self, for i, img in enumerate(imgs): filename_suffix = str(i) + ".jpg" img.save(os.path.join(imgs_folderpath, filename_suffix)) - logging.debug(f"Saved model response to {cache_filepath}.") + print(f"Saved model response to {cache_filepath}.") # Load the saved completion. with open(cache_filepath, 'r', encoding='utf-8') as f: cache_str = f.read() - logging.debug(f"Loaded model response from {cache_filepath}.") + print(f"Loaded model response from {cache_filepath}.") assert cache_str.count(_CACHE_SEP) == num_completions cached_prompt, completion_strs = cache_str.split(_CACHE_SEP, 1) assert cached_prompt == prompt diff --git a/predicators/structs.py b/predicators/structs.py index 1ad0c05b2c..185deecdf5 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -489,9 +489,16 @@ def __post_init__(self) -> None: for atom in self.goal: assert isinstance(atom, GroundAtom) - def goal_holds(self, state: State) -> bool: + def goal_holds(self, state: State, vlm: Optional[Any] = None) -> bool: """Return whether the goal of this task holds in the given state.""" - return all(goal_atom.holds(state) for goal_atom in self.goal) + from predicators.utils import query_vlm_for_atom_vals + vlm_atoms = set(atom for atom in self.goal if isinstance(atom.predicate, VLMPredicate)) + for atom in self.goal: + if atom not in vlm_atoms: + if not atom.holds(state): + return False + true_vlm_atoms = query_vlm_for_atom_vals(vlm_atoms, state, vlm) + return len(true_vlm_atoms) == len(vlm_atoms) def replace_goal_with_alt_goal(self) -> Task: """Return a Task with the goal replaced with the alternative goal if it From c832fb0126aeb727c44aa9b068596170c4e87493 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Tue, 10 Sep 2024 20:02:46 -0400 Subject: [PATCH 05/24] Get plumbing to work. --- predicators/envs/spot_env.py | 8 ++++++++ predicators/pretrained_model_interface.py | 4 ++-- predicators/settings.py | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 17ce0623f6..d3610c8c63 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -2424,6 +2424,14 @@ def _dry_simulate_pick_and_dump_container( class VLMTestEnv(SpotRearrangementEnv): """An environment to start testing the VLM pipeline.""" + @property + def predicates(self) -> Set[Predicate]: + return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty"]) + + @property + def goal_predicates(self) -> Set[Predicate]: + return self.predicates + @classmethod def get_name(cls) -> str: return "spot_vlm_test_env" diff --git a/predicators/pretrained_model_interface.py b/predicators/pretrained_model_interface.py index 9fc07a6004..609d5ee900 100644 --- a/predicators/pretrained_model_interface.py +++ b/predicators/pretrained_model_interface.py @@ -289,8 +289,8 @@ class GoogleGeminiVLM(VisionLanguageModel, GoogleGeminiModel): necessary API key to query the particular model name. """ - @retry(wait=wait_random_exponential(min=1, max=60), - stop=stop_after_attempt(10)) + # @retry(wait=wait_random_exponential(min=1, max=60), + # stop=stop_after_attempt(10)) def _sample_completions( self, prompt: str, diff --git a/predicators/settings.py b/predicators/settings.py index 8cac7a2f0b..357cd7e1e1 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -434,7 +434,7 @@ class GlobalSettings: # parameters for vision language models # gemini-1.5-pro-latest, gpt-4-turbo, gpt-4o - vlm_model_name = "gemini-pro-vision" + vlm_model_name = "gemini-1.5-pro-latest" vlm_temperature = 0.0 vlm_num_completions = 1 vlm_include_cropped_images = False From ea5f1bf263a9d423b00536e34ca6e8ae4b5d6c32 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Wed, 11 Sep 2024 15:17:04 -0400 Subject: [PATCH 06/24] initial test working --- .../approaches/bilevel_planning_approach.py | 3 +- .../approaches/spot_wrapper_approach.py | 3 +- predicators/envs/spot_env.py | 30 +++++++++++++++++++ .../ground_truth_models/spot_env/options.py | 2 +- predicators/perception/spot_perceiver.py | 5 +++- predicators/settings.py | 2 +- .../spot_utils/perception/spot_cameras.py | 10 +++---- predicators/utils.py | 1 + 8 files changed, 46 insertions(+), 10 deletions(-) diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index 79db49699c..7878b143f4 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -58,13 +58,14 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: seed = self._seed + self._num_calls nsrts = self._get_current_nsrts() preds = self._get_current_predicates() - import pdb; pdb.set_trace() + # utils.abstract(task.init, preds, self._vlm) # Run task planning only and then greedily sample and execute in the # policy. if self._plan_without_sim: nsrt_plan, atoms_seq, metrics = self._run_task_plan( task, nsrts, preds, timeout, seed) + # import pdb; pdb.set_trace() self._last_nsrt_plan = nsrt_plan self._last_atoms_seq = atoms_seq policy = utils.nsrt_plan_to_greedy_policy(nsrt_plan, task.goal, diff --git a/predicators/approaches/spot_wrapper_approach.py b/predicators/approaches/spot_wrapper_approach.py index eda1ed5a1d..00b5925044 100644 --- a/predicators/approaches/spot_wrapper_approach.py +++ b/predicators/approaches/spot_wrapper_approach.py @@ -102,7 +102,8 @@ def _policy(state: State) -> Action: self._base_approach_has_control = True # Need to call this once here to fix off-by-one issue. atom_seq = self._base_approach.get_execution_monitoring_info() - assert all(a.holds(state) for a in atom_seq[0]) + # TODO: consider reinstating the line below. + # assert all(a.holds(state) for a in atom_seq[0]) # Use the base policy. return base_approach_policy(state) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index d3610c8c63..76396d038c 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -2564,8 +2564,38 @@ def reset(self, train_or_test: str, task_idx: int) -> Observation: def step(self, action: Action) -> Observation: assert self._robot is not None + action_name = action.extra_info.action_name + # Special case: the action is "done", indicating that the robot + # believes it has finished the task. Used for goal checking. + if action_name == "done": + while True: + goal_description = self._current_task.goal_description + logging.info(f"The goal is: {goal_description}") + prompt = "Is the goal accomplished? Answer y or n. " + response = utils.prompt_user(prompt).strip() + if response == "y": + self._current_task_goal_reached = True + break + if response == "n": + self._current_task_goal_reached = False + break + logging.info("Invalid input, must be either 'y' or 'n'") + return self._current_observation + + # Execute the action in the real environment. Automatically retry + # if a retryable error is encountered. + action_fn = action.extra_info.real_world_fn + action_fn_args = action.extra_info.real_world_fn_args + while True: + try: + action_fn(*action_fn_args) # type: ignore + break + except RetryableRpcError as e: + logging.warning("WARNING: the following retryable error " + f"was encountered. Trying again.\n{e}") rgbd_images = capture_images_without_context(self._robot) gripper_open_percentage = get_robot_gripper_open_percentage(self._robot) + print(gripper_open_percentage) objects_in_view = [] obs = _TruncatedSpotObservation( rgbd_images, diff --git a/predicators/ground_truth_models/spot_env/options.py b/predicators/ground_truth_models/spot_env/options.py index 9f786eab5b..4e9b7293bf 100644 --- a/predicators/ground_truth_models/spot_env/options.py +++ b/predicators/ground_truth_models/spot_env/options.py @@ -901,7 +901,7 @@ def _move_to_ready_sweep_policy(state: State, memory: Dict, def _teleop_policy(state: State, memory: Dict, objects: Sequence[Object], params: Array) -> Action: del state, memory, params - + robot, lease_client = get_robot_only() def _teleop(robot: Robot, lease_client: LeaseClient): diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index fdcb53e35f..88d164f178 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -673,9 +673,12 @@ def step(self, observation: Observation) -> State: self._robot = observation.robot imgs = observation.rgbd_images imgs = [v.rgb for _, v in imgs.items()] + # import PIL + # PIL.Image.fromarray(imgs[0]).show() + # import pdb; pdb.set_trace() + self._gripper_open_percentage = observation.gripper_open_percentage self._curr_state = self._create_state() self._curr_state.simulator_state["images"] = imgs - self._gripper_open_percentage = observation.gripper_open_percentage ret_state = self._curr_state.copy() return ret_state diff --git a/predicators/settings.py b/predicators/settings.py index 357cd7e1e1..1af1ae0b62 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -434,7 +434,7 @@ class GlobalSettings: # parameters for vision language models # gemini-1.5-pro-latest, gpt-4-turbo, gpt-4o - vlm_model_name = "gemini-1.5-pro-latest" + vlm_model_name = "gemini-1.5-flash" #"gemini-1.5-pro-latest" vlm_temperature = 0.0 vlm_num_completions = 1 vlm_include_cropped_images = False diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index 7bf5a1fae5..f6ead0248e 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -22,12 +22,12 @@ 'right_fisheye_image': 180 } RGB_TO_DEPTH_CAMERAS = { - "hand_color_image": "hand_depth_in_hand_color_frame", - "left_fisheye_image": "left_depth_in_visual_frame", - "right_fisheye_image": "right_depth_in_visual_frame", + # "hand_color_image": "hand_depth_in_hand_color_frame", + # "left_fisheye_image": "left_depth_in_visual_frame", + # "right_fisheye_image": "right_depth_in_visual_frame", "frontleft_fisheye_image": "frontleft_depth_in_visual_frame", - "frontright_fisheye_image": "frontright_depth_in_visual_frame", - "back_fisheye_image": "back_depth_in_visual_frame" + # "frontright_fisheye_image": "frontright_depth_in_visual_frame", + # "back_fisheye_image": "back_depth_in_visual_frame" } # Hack to avoid double image capturing when we want to (1) get object states diff --git a/predicators/utils.py b/predicators/utils.py index 5478bc9f97..1aed539f3d 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2527,6 +2527,7 @@ def query_vlm_for_atom_vals( num_completions=1) assert len(vlm_output) == 1 vlm_output_str = vlm_output[0] + print(f"VLM output: {vlm_output_str}") all_atom_queries = atom_queries_str.strip().split("\n") all_vlm_responses = vlm_output_str.strip().split("\n") From 0b8d125f0f9b77051dd6e6514da73d13c0186365 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Wed, 11 Sep 2024 17:40:49 -0400 Subject: [PATCH 07/24] Rotate spot camera images. --- predicators/perception/spot_perceiver.py | 1 - predicators/spot_utils/perception/spot_cameras.py | 12 +++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 88d164f178..cf3cfa7dc8 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -675,7 +675,6 @@ def step(self, observation: Observation) -> State: imgs = [v.rgb for _, v in imgs.items()] # import PIL # PIL.Image.fromarray(imgs[0]).show() - # import pdb; pdb.set_trace() self._gripper_open_percentage = observation.gripper_open_percentage self._curr_state = self._create_state() self._curr_state.simulator_state["images"] = imgs diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index f6ead0248e..0c249a89f4 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -2,6 +2,7 @@ from typing import Collection, Dict, Optional, Type import cv2 +from scipy import ndimage import numpy as np from bosdyn.api import image_pb2 from bosdyn.client.frame_helpers import BODY_FRAME_NAME, get_a_tform_b @@ -22,12 +23,12 @@ 'right_fisheye_image': 180 } RGB_TO_DEPTH_CAMERAS = { - # "hand_color_image": "hand_depth_in_hand_color_frame", - # "left_fisheye_image": "left_depth_in_visual_frame", - # "right_fisheye_image": "right_depth_in_visual_frame", + "hand_color_image": "hand_depth_in_hand_color_frame", + "left_fisheye_image": "left_depth_in_visual_frame", + "right_fisheye_image": "right_depth_in_visual_frame", "frontleft_fisheye_image": "frontleft_depth_in_visual_frame", - # "frontright_fisheye_image": "frontright_depth_in_visual_frame", - # "back_fisheye_image": "back_depth_in_visual_frame" + "frontright_fisheye_image": "frontright_depth_in_visual_frame", + "back_fisheye_image": "back_depth_in_visual_frame" } # Hack to avoid double image capturing when we want to (1) get object states @@ -173,6 +174,7 @@ def capture_images_without_context( rgb_img_resp = name_to_response[camera_name] depth_img_resp = name_to_response[RGB_TO_DEPTH_CAMERAS[camera_name]] rgb_img = _image_response_to_image(rgb_img_resp) + rgb_img = ndimage.rotate(rgb_img, ROTATION_ANGLE[camera_name]) depth_img = _image_response_to_image(depth_img_resp) # # Create transform. # camera_tform_body = get_a_tform_b( From d255810299cce18985bf35ab2a0cc07b40ca3f08 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Thu, 12 Sep 2024 14:31:54 -0400 Subject: [PATCH 08/24] Play-test other VLM predicates and update annotation in spot images. --- .../approaches/bilevel_planning_approach.py | 2 + predicators/envs/spot_env.py | 30 ++- predicators/perception/spot_perceiver.py | 179 ++++++++++++++---- 3 files changed, 169 insertions(+), 42 deletions(-) diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index 7878b143f4..3c02e7dbc5 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -58,6 +58,8 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: seed = self._seed + self._num_calls nsrts = self._get_current_nsrts() preds = self._get_current_predicates() + utils.abstract(task.init, preds, self._vlm) + import pdb; pdb.set_trace() # utils.abstract(task.init, preds, self._vlm) # Run task planning only and then greedily sample and execute in the diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 76396d038c..affadd59a9 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -1488,13 +1488,38 @@ def _get_vlm_query_str(pred_name: str, objects: Sequence[Object]) -> str: [_movable_object_type, _base_object_type], lambda o: _get_vlm_query_str("VLMOn", o) ) +_Upright = utils.create_vlm_predicate( + "Upright", + [_movable_object_type], + lambda o: _get_vlm_query_str("Upright", o) +) +_Toasted = utils.create_vlm_predicate( + "Toasted", + [_movable_object_type], + lambda o: _get_vlm_query_str("Toasted", o) +) +_VLMIn = utils.create_vlm_predicate( + "VLMIn", + [_movable_object_type, _immovable_object_type], + lambda o: _get_vlm_query_str("In", o) +) +_Open = utils.create_vlm_predicate( + "Open", + [_movable_object_type], + lambda o: _get_vlm_query_str("Open", o) +) +_Stained = utils.create_vlm_predicate( + "Stained", + [_movable_object_type], + lambda o: _get_vlm_query_str("Stained", o) +) _ALL_PREDICATES = { _NEq, _On, _TopAbove, _Inside, _NotInsideAnyContainer, _FitsInXY, _HandEmpty, _Holding, _NotHolding, _InHandView, _InView, _Reachable, _Blocking, _NotBlocked, _ContainerReadyForSweeping, _IsPlaceable, _IsNotPlaceable, _IsSweeper, _HasFlatTopSurface, _RobotReadyForSweeping, - _IsSemanticallyGreaterThan, _VLMOn + _IsSemanticallyGreaterThan, _VLMOn, _Upright, _Toasted, _VLMIn, _Open, _Stained } _NONPERCEPT_PREDICATES: Set[Predicate] = set() @@ -2426,7 +2451,8 @@ class VLMTestEnv(SpotRearrangementEnv): @property def predicates(self) -> Set[Predicate]: - return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty"]) + # return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty", "Pourable", "Toasted", "VLMIn", "Open"]) + return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty", "Upright"]) @property def goal_predicates(self) -> Set[Predicate]: diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index cf3cfa7dc8..14c4b5550c 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -582,6 +582,15 @@ def render_mental_images(self, observation: Observation, class SpotMinimalPerceiver(BasePerceiver): """A perceiver for spot envs with minimal functionality.""" + camera_name_to_annotation = { + 'hand_color_image': "Hand Camera Image", + 'back_fisheye_image': "Back Camera Image", + 'frontleft_fisheye_image': "Front Left Camera Image", + 'frontright_fisheye_image': "Front Right Camera Image", + 'left_fisheye_image': "Left Camera Image", + 'right_fisheye_image': "Right Camera Image" + } + def render_mental_images(self, observation: Observation, env_task: EnvironmentTask) -> Video: raise NotImplementedError() @@ -672,12 +681,23 @@ def step(self, observation: Observation) -> State: self._waiting_for_observation = False self._robot = observation.robot imgs = observation.rgbd_images + img_names = [v.camera_name for _, v in imgs.items()] imgs = [v.rgb for _, v in imgs.items()] - # import PIL - # PIL.Image.fromarray(imgs[0]).show() + import pdb; pdb.set_trace() + import PIL + from PIL import ImageDraw + annotated_pil_imgs = [] + for img, img_name in zip(imgs, img_names): + pil_img = PIL.Image.fromarray(img) + draw = ImageDraw.Draw(pil_img) + font = utils.get_scaled_default_font(draw, 4) + annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) + annotated_pil_imgs.append(pil_img) + annotated_imgs = [np.array(img) for img in annotated_pil_imgs] + import pdb; pdb.set_trace() self._gripper_open_percentage = observation.gripper_open_percentage self._curr_state = self._create_state() - self._curr_state.simulator_state["images"] = imgs + self._curr_state.simulator_state["images"] = annotated_imgs ret_state = self._curr_state.copy() return ret_state @@ -686,9 +706,13 @@ def _create_state(self) -> State: return DefaultState # Build the continuous part of the state. assert self._robot is not None - table = Object("table", _immovable_object_type) + # table = Object("table", _immovable_object_type) cup = Object("cup", _movable_object_type) - pan = Object("pan", _container_type) + # pan = Object("pan", _container_type) + # bread = Object("bread", _movable_object_type) + # toaster = Object("toaster", _immovable_object_type) + # microwave = Object("microwave", _movable_object_type) + # napkin = Object("napkin", _movable_object_type) state_dict = { self._robot: { "gripper_open_percentage": self._gripper_open_percentage, @@ -700,21 +724,21 @@ def _create_state(self) -> State: "qy": 0, "qz": 0, }, - table: { - "x": 0, - "y": 0, - "z": 0, - "qw": 0, - "qx": 0, - "qy": 0, - "qz": 0, - "shape": 0, - "height": 0, - "width" : 0, - "length": 0, - "object_id": 1, - "flat_top_surface": 1 - }, + # table: { + # "x": 0, + # "y": 0, + # "z": 0, + # "qw": 0, + # "qx": 0, + # "qy": 0, + # "qz": 0, + # "shape": 0, + # "height": 0, + # "width" : 0, + # "length": 0, + # "object_id": 1, + # "flat_top_surface": 1 + # }, cup: { "x": 0, "y": 0, @@ -735,26 +759,101 @@ def _create_state(self) -> State: "in_view": 1, "is_sweeper": 0 }, - pan: { - "x": 0, - "y": 0, - "z": 0, - "qw": 0, - "qx": 0, - "qy": 0, - "qz": 0, - "shape": 0, - "height": 0, - "width" : 0, - "length": 0, - "object_id": 3, - "placeable": 1, - "held": 0, - "lost": 0, - "in_hand_view": 0, - "in_view": 1, - "is_sweeper": 0 - } + # napkin: { + # "x": 0, + # "y": 0, + # "z": 0, + # "qw": 0, + # "qx": 0, + # "qy": 0, + # "qz": 0, + # "shape": 0, + # "height": 0, + # "width" : 0, + # "length": 0, + # "object_id": 2, + # "placeable": 1, + # "held": 0, + # "lost": 0, + # "in_hand_view": 0, + # "in_view": 1, + # "is_sweeper": 0 + # }, + # microwave: { + # "x": 0, + # "y": 0, + # "z": 0, + # "qw": 0, + # "qx": 0, + # "qy": 0, + # "qz": 0, + # "shape": 0, + # "height": 0, + # "width" : 0, + # "length": 0, + # "object_id": 2, + # "placeable": 1, + # "held": 0, + # "lost": 0, + # "in_hand_view": 0, + # "in_view": 1, + # "is_sweeper": 0 + # }, + # bread: { + # "x": 0, + # "y": 0, + # "z": 0, + # "qw": 0, + # "qx": 0, + # "qy": 0, + # "qz": 0, + # "shape": 0, + # "height": 0, + # "width" : 0, + # "length": 0, + # "object_id": 2, + # "placeable": 1, + # "held": 0, + # "lost": 0, + # "in_hand_view": 0, + # "in_view": 1, + # "is_sweeper": 0 + # }, + # toaster: { + # "x": 0, + # "y": 0, + # "z": 0, + # "qw": 0, + # "qx": 0, + # "qy": 0, + # "qz": 0, + # "shape": 0, + # "height": 0, + # "width" : 0, + # "length": 0, + # "object_id": 1, + # "flat_top_surface": 1 + # }, + # pan: { + # "x": 0, + # "y": 0, + # "z": 0, + # "qw": 0, + # "qx": 0, + # "qy": 0, + # "qz": 0, + # "shape": 0, + # "height": 0, + # "width" : 0, + # "length": 0, + # "object_id": 3, + # "placeable": 1, + # "held": 0, + # "lost": 0, + # "in_hand_view": 0, + # "in_view": 1, + # "is_sweeper": 0 + # } } state_dict = {k: list(v.values()) for k, v in state_dict.items()} ret_state = State(state_dict) From 803683ddf66e6fac1b2006c4f5e7854e18c7a235 Mon Sep 17 00:00:00 2001 From: NishanthJKumar Date: Thu, 12 Sep 2024 19:36:03 -0400 Subject: [PATCH 09/24] making some progress; took way longer than expected though... --- .../approaches/bilevel_planning_approach.py | 3 +- predicators/envs/spot_env.py | 136 ++++++++---------- .../ground_truth_models/spot_env/nsrts.py | 10 +- .../ground_truth_models/spot_env/options.py | 23 +-- predicators/perception/spot_perceiver.py | 26 ++-- predicators/settings.py | 5 +- .../perception/perception_structs.py | 3 +- .../spot_utils/perception/spot_cameras.py | 9 +- predicators/spot_utils/spot_localization.py | 6 +- predicators/structs.py | 3 +- predicators/utils.py | 91 ++++++++++-- tests/envs/test_spot_envs.py | 10 +- 12 files changed, 200 insertions(+), 125 deletions(-) diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index 3c02e7dbc5..92fa468fe9 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -59,7 +59,8 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: nsrts = self._get_current_nsrts() preds = self._get_current_predicates() utils.abstract(task.init, preds, self._vlm) - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() # utils.abstract(task.init, preds, self._vlm) # Run task planning only and then greedily sample and execute in the diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index affadd59a9..e4cbc042fd 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -30,7 +30,8 @@ brush_prompt, bucket_prompt, football_prompt, train_toy_prompt from predicators.spot_utils.perception.perception_structs import \ RGBDImageWithContext -from predicators.spot_utils.perception.spot_cameras import capture_images, capture_images_without_context +from predicators.spot_utils.perception.spot_cameras import capture_images, \ + capture_images_without_context from predicators.spot_utils.skills.spot_find_objects import \ init_search_for_objects from predicators.spot_utils.skills.spot_hand_move import \ @@ -187,15 +188,17 @@ def get_robot( @functools.lru_cache(maxsize=None) def get_robot_only() -> Tuple[Optional[Robot], Optional[LeaseClient]]: - hostname = CFG.spot_robot_ip - sdk = create_standard_sdk("PredicatorsClient-") - robot = sdk.create_robot(hostname) - robot.authenticate("user", "bbbdddaaaiii") - verify_estop(robot) - lease_client = robot.ensure_client(LeaseClient.default_service_name) - lease_client.take() - lease_keepalive = LeaseKeepAlive(lease_client, must_acquire=True, return_at_exit=True) - return robot, lease_client + hostname = CFG.spot_robot_ip + sdk = create_standard_sdk("PredicatorsClient-") + robot = sdk.create_robot(hostname) + robot.authenticate("user", "bbbdddaaaiii") + verify_estop(robot) + lease_client = robot.ensure_client(LeaseClient.default_service_name) + lease_client.take() + lease_keepalive = LeaseKeepAlive(lease_client, + must_acquire=True, + return_at_exit=True) + return robot, lease_client @functools.lru_cache(maxsize=None) @@ -1481,45 +1484,37 @@ def _get_sweeping_surface_for_container(container: Object, "IsSemanticallyGreaterThan", [_base_object_type, _base_object_type], _is_semantically_greater_than_classifier) + def _get_vlm_query_str(pred_name: str, objects: Sequence[Object]) -> str: - return pred_name + "(" + ", ".join(str(obj.name) for obj in objects) + ")" # pragma: no cover -_VLMOn = utils.create_vlm_predicate( - "VLMOn", - [_movable_object_type, _base_object_type], - lambda o: _get_vlm_query_str("VLMOn", o) -) + return pred_name + "(" + ", ".join( + str(obj.name) for obj in objects) + ")" # pragma: no cover + + +_VLMOn = utils.create_vlm_predicate("VLMOn", + [_movable_object_type, _base_object_type], + lambda o: _get_vlm_query_str("VLMOn", o)) _Upright = utils.create_vlm_predicate( - "Upright", - [_movable_object_type], - lambda o: _get_vlm_query_str("Upright", o) -) + "Upright", [_movable_object_type], + lambda o: _get_vlm_query_str("Upright", o)) _Toasted = utils.create_vlm_predicate( - "Toasted", - [_movable_object_type], - lambda o: _get_vlm_query_str("Toasted", o) -) + "Toasted", [_movable_object_type], + lambda o: _get_vlm_query_str("Toasted", o)) _VLMIn = utils.create_vlm_predicate( - "VLMIn", - [_movable_object_type, _immovable_object_type], - lambda o: _get_vlm_query_str("In", o) -) -_Open = utils.create_vlm_predicate( - "Open", - [_movable_object_type], - lambda o: _get_vlm_query_str("Open", o) -) + "VLMIn", [_movable_object_type, _immovable_object_type], + lambda o: _get_vlm_query_str("In", o)) +_Open = utils.create_vlm_predicate("Open", [_movable_object_type], + lambda o: _get_vlm_query_str("Open", o)) _Stained = utils.create_vlm_predicate( - "Stained", - [_movable_object_type], - lambda o: _get_vlm_query_str("Stained", o) -) + "Stained", [_movable_object_type], + lambda o: _get_vlm_query_str("Stained", o)) _ALL_PREDICATES = { _NEq, _On, _TopAbove, _Inside, _NotInsideAnyContainer, _FitsInXY, _HandEmpty, _Holding, _NotHolding, _InHandView, _InView, _Reachable, _Blocking, _NotBlocked, _ContainerReadyForSweeping, _IsPlaceable, _IsNotPlaceable, _IsSweeper, _HasFlatTopSurface, _RobotReadyForSweeping, - _IsSemanticallyGreaterThan, _VLMOn, _Upright, _Toasted, _VLMIn, _Open, _Stained + _IsSemanticallyGreaterThan, _VLMOn, _Upright, _Toasted, _VLMIn, _Open, + _Stained } _NONPERCEPT_PREDICATES: Set[Predicate] = set() @@ -2452,8 +2447,9 @@ class VLMTestEnv(SpotRearrangementEnv): @property def predicates(self) -> Set[Predicate]: # return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty", "Pourable", "Toasted", "VLMIn", "Open"]) - return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty", "Upright"]) - + return set(p for p in _ALL_PREDICATES + if p.name in ["VLMOn", "Holding", "HandEmpty", "Upright"]) + @property def goal_predicates(self) -> Set[Predicate]: return self.predicates @@ -2461,7 +2457,7 @@ def goal_predicates(self) -> Set[Predicate]: @classmethod def get_name(cls) -> str: return "spot_vlm_test_env" - + def _get_dry_task(self, train_or_test: str, task_idx: int) -> EnvironmentTask: raise NotImplementedError("No dry task for VLMTestEnv.") @@ -2482,35 +2478,31 @@ def _create_operators(self) -> Iterator[STRIPSOperator]: LiftedAtom(_NotHolding, [robot, obj]), LiftedAtom(_VLMOn, [obj, table]) } - add_effs: Set[LiftedAtom] = { - LiftedAtom(_Holding, [robot, obj]) - } + add_effs: Set[LiftedAtom] = {LiftedAtom(_Holding, [robot, obj])} del_effs: Set[LiftedAtom] = { LiftedAtom(_HandEmpty, [robot]), LiftedAtom(_NotHolding, [robot, obj]), LiftedAtom(_VLMOn, [obj, table]) } ignore_effs: Set[LiftedAtom] = set() - yield STRIPSOperator("Pick", parameters, preconds, add_effs, del_effs, ignore_effs) + yield STRIPSOperator("Pick", parameters, preconds, add_effs, del_effs, + ignore_effs) # Place object robot = Variable("?robot", _robot_type) obj = Variable("?object", _movable_object_type) pan = Variable("?pan", _container_type) parameters = [robot, obj, pan] - preconds: Set[LiftedAtom] = { - LiftedAtom(_Holding, [robot, obj]) - } + preconds: Set[LiftedAtom] = {LiftedAtom(_Holding, [robot, obj])} add_effs: Set[LiftedAtom] = { LiftedAtom(_HandEmpty, [robot]), LiftedAtom(_NotHolding, [robot, obj]), LiftedAtom(_VLMOn, [obj, pan]) } - del_effs: Set[LiftedAtom] = { - LiftedAtom(_Holding, [robot, obj]) - } + del_effs: Set[LiftedAtom] = {LiftedAtom(_Holding, [robot, obj])} ignore_effs: Set[LiftedAtom] = set() - yield STRIPSOperator("Place", parameters, preconds, add_effs, del_effs, ignore_effs) + yield STRIPSOperator("Place", parameters, preconds, add_effs, del_effs, + ignore_effs) # def _generate_train_tasks(self) -> List[EnvironmentTask]: # goal = self._generate_goal_description() # currently just one goal @@ -2542,34 +2534,27 @@ def __init__(self, use_gui: bool = True) -> None: # Create constant objects. self._spot_object = Object("robot", _robot_type) op_to_name = {o.name: o for o in self._create_operators()} - op_names_to_keep = { - "Pick", - "Place" - } + op_names_to_keep = {"Pick", "Place"} self._strips_operators = {op_to_name[o] for o in op_names_to_keep} self._train_tasks = [] self._test_tasks = [] - + def _actively_construct_env_task(self) -> EnvironmentTask: assert self._robot is not None rgbd_images = capture_images_without_context(self._robot) - gripper_open_percentage = get_robot_gripper_open_percentage(self._robot) + gripper_open_percentage = get_robot_gripper_open_percentage( + self._robot) objects_in_view = [] - obs = _TruncatedSpotObservation( - rgbd_images, - set(objects_in_view), - set(), - set(), - self._spot_object, - gripper_open_percentage - ) + obs = _TruncatedSpotObservation(rgbd_images, set(objects_in_view), + set(), set(), self._spot_object, + gripper_open_percentage) goal_description = self._generate_goal_description() task = EnvironmentTask(obs, goal_description) return task def _generate_goal_description(self) -> GoalDescription: return "put the cup in the pan" - + def reset(self, train_or_test: str, task_idx: int) -> Observation: prompt = f"Please set up {train_or_test} task {task_idx}!" utils.prompt_user(prompt) @@ -2587,7 +2572,7 @@ def reset(self, train_or_test: str, task_idx: int) -> Observation: self._current_task_goal_reached = False self._last_action = None return self._current_task.init_obs - + def step(self, action: Action) -> Observation: assert self._robot is not None action_name = action.extra_info.action_name @@ -2620,19 +2605,16 @@ def step(self, action: Action) -> Observation: logging.warning("WARNING: the following retryable error " f"was encountered. Trying again.\n{e}") rgbd_images = capture_images_without_context(self._robot) - gripper_open_percentage = get_robot_gripper_open_percentage(self._robot) + gripper_open_percentage = get_robot_gripper_open_percentage( + self._robot) print(gripper_open_percentage) objects_in_view = [] - obs = _TruncatedSpotObservation( - rgbd_images, - set(objects_in_view), - set(), - set(), - self._spot_object, - gripper_open_percentage - ) + obs = _TruncatedSpotObservation(rgbd_images, set(objects_in_view), + set(), set(), self._spot_object, + gripper_open_percentage) return obs + ############################################################################### # Cube Table Env # ############################################################################### diff --git a/predicators/ground_truth_models/spot_env/nsrts.py b/predicators/ground_truth_models/spot_env/nsrts.py index 7d4e5dad45..849854dbcd 100644 --- a/predicators/ground_truth_models/spot_env/nsrts.py +++ b/predicators/ground_truth_models/spot_env/nsrts.py @@ -285,11 +285,11 @@ class SpotEnvsGroundTruthNSRTFactory(GroundTruthNSRTFactory): @classmethod def get_env_names(cls) -> Set[str]: return { - "spot_vlm_test_env", - "spot_cube_env", "spot_soda_floor_env", "spot_soda_table_env", - "spot_soda_bucket_env", "spot_soda_chair_env", - "spot_main_sweep_env", "spot_ball_and_cup_sticky_table_env", - "spot_brush_shelf_env", "lis_spot_block_floor_env" + "spot_vlm_test_env", "spot_cube_env", "spot_soda_floor_env", + "spot_soda_table_env", "spot_soda_bucket_env", + "spot_soda_chair_env", "spot_main_sweep_env", + "spot_ball_and_cup_sticky_table_env", "spot_brush_shelf_env", + "lis_spot_block_floor_env" } @staticmethod diff --git a/predicators/ground_truth_models/spot_env/options.py b/predicators/ground_truth_models/spot_env/options.py index 4e9b7293bf..6e26a7ccf6 100644 --- a/predicators/ground_truth_models/spot_env/options.py +++ b/predicators/ground_truth_models/spot_env/options.py @@ -1,22 +1,22 @@ """Ground-truth options for Spot environments.""" +import logging import time from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple -import logging import numpy as np import pbrspot from bosdyn.client import math_helpers -from bosdyn.client.sdk import Robot from bosdyn.client.lease import LeaseClient +from bosdyn.client.sdk import Robot from gym.spaces import Box from predicators import utils from predicators.envs import get_or_create_env from predicators.envs.spot_env import HANDEMPTY_GRIPPER_THRESHOLD, \ SpotRearrangementEnv, _get_sweeping_surface_for_container, \ - get_detection_id_for_object, get_robot, get_robot_only, \ - get_robot_gripper_open_percentage, get_simulated_object, \ + get_detection_id_for_object, get_robot, \ + get_robot_gripper_open_percentage, get_robot_only, get_simulated_object, \ get_simulated_robot from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.settings import CFG @@ -899,9 +899,11 @@ def _move_to_ready_sweep_policy(state: State, memory: Dict, robot_obj_idx, target_obj_idx, do_gaze, state, memory, objects, params) -def _teleop_policy(state: State, memory: Dict, objects: Sequence[Object], params: Array) -> Action: + +def _teleop_policy(state: State, memory: Dict, objects: Sequence[Object], + params: Array) -> Action: del state, memory, params - + robot, lease_client = get_robot_only() def _teleop(robot: Robot, lease_client: LeaseClient): @@ -914,15 +916,14 @@ def _teleop(robot: Robot, lease_client: LeaseClient): # Take back control. robot, lease_client = get_robot_only() lease_client.take() - + fn = _teleop fn_args = (robot, lease_client) sim_fn = lambda _: None sim_fn_args = () name = "teleop" - action_extra_info = SpotActionExtraInfo( - name, objects, fn, fn_args, sim_fn, sim_fn_args - ) + action_extra_info = SpotActionExtraInfo(name, objects, fn, fn_args, sim_fn, + sim_fn_args) return utils.create_spot_env_action(action_extra_info) @@ -958,7 +959,7 @@ def _teleop(robot: Robot, lease_client: LeaseClient): "MoveToReadySweep": Box(0, 1, (0, )), # empty "Pick": Box(0, 1, (0, )), # empty "Place": Box(0, 1, (0, )) # empty -} +} # NOTE: the policies MUST be unique because they output actions with extra info # that includes the name of the operators. diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 14c4b5550c..a9d123fec8 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -2,8 +2,9 @@ import logging import time +from collections import deque from pathlib import Path -from typing import Dict, List, Optional, Set +from typing import Deque, Dict, List, Optional, Set import imageio.v2 as iio import numpy as np @@ -577,7 +578,7 @@ def render_mental_images(self, observation: Observation, logging.info(f"Wrote out to {outfile}") plt.close() return [img] - + class SpotMinimalPerceiver(BasePerceiver): """A perceiver for spot envs with minimal functionality.""" @@ -618,6 +619,8 @@ def __init__(self) -> None: self._curr_env: Optional[BaseEnv] = None self._waiting_for_observation = True self._ordered_objects: List[Object] = [] # list of all known objects + self._state_history: Deque[State] = deque( + maxlen=5) # TODO: (njk) I just picked an arbitrary constant here! Didn't properly consider this. # # Keep track of objects that are contained (out of view) in another # # object, like a bag or bucket. This is important not only for gremlins # # but also for small changes in the container's perceived pose. @@ -625,14 +628,14 @@ def __init__(self) -> None: # Load static, hard-coded features of objects, like their shapes. # meta = load_spot_metadata() # self._static_object_features = meta.get("static-object-features", {}) - + def update_perceiver_with_action(self, action: Action) -> None: # NOTE: we need to keep track of the previous action # because the step function (where we need knowledge # of the previous action) occurs *after* the action # has already been taken. self._prev_action = action - + def _create_goal(self, state: State, goal_description: GoalDescription) -> Set[GroundAtom]: del state # not used @@ -683,7 +686,8 @@ def step(self, observation: Observation) -> State: imgs = observation.rgbd_images img_names = [v.camera_name for _, v in imgs.items()] imgs = [v.rgb for _, v in imgs.items()] - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() import PIL from PIL import ImageDraw annotated_pil_imgs = [] @@ -691,14 +695,18 @@ def step(self, observation: Observation) -> State: pil_img = PIL.Image.fromarray(img) draw = ImageDraw.Draw(pil_img) font = utils.get_scaled_default_font(draw, 4) - annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) + annotated_pil_img = utils.add_text_to_draw_img( + draw, (0, 0), self.camera_name_to_annotation[img_name], font) annotated_pil_imgs.append(pil_img) annotated_imgs = [np.array(img) for img in annotated_pil_imgs] - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() self._gripper_open_percentage = observation.gripper_open_percentage self._curr_state = self._create_state() self._curr_state.simulator_state["images"] = annotated_imgs ret_state = self._curr_state.copy() + ret_state.simulator_state["state_history"] = list(self._state_history) + self._state_history.append(ret_state) return ret_state def _create_state(self) -> State: @@ -749,7 +757,7 @@ def _create_state(self) -> State: "qz": 0, "shape": 0, "height": 0, - "width" : 0, + "width": 0, "length": 0, "object_id": 2, "placeable": 1, @@ -856,7 +864,7 @@ def _create_state(self) -> State: # } } state_dict = {k: list(v.values()) for k, v in state_dict.items()} - ret_state = State(state_dict) + ret_state = State(state_dict) ret_state.simulator_state = {} ret_state.simulator_state["images"] = [] return ret_state diff --git a/predicators/settings.py b/predicators/settings.py index 1af1ae0b62..f2c4d9ed27 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -434,7 +434,7 @@ class GlobalSettings: # parameters for vision language models # gemini-1.5-pro-latest, gpt-4-turbo, gpt-4o - vlm_model_name = "gemini-1.5-flash" #"gemini-1.5-pro-latest" + vlm_model_name = "gemini-1.5-flash" #"gemini-1.5-pro-latest" vlm_temperature = 0.0 vlm_num_completions = 1 vlm_include_cropped_images = False @@ -719,6 +719,9 @@ class GlobalSettings: # saved_vlm_img_demos_folder vlm_trajs_folder_name = "" vlm_predicate_vision_api_generate_ground_atoms = False + # At test-time, we will use the below number of states + # as part of labelling the current state's VLM atoms. + vlm_test_time_atom_label_prompt_type = "per_scene_naive" @classmethod def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]: diff --git a/predicators/spot_utils/perception/perception_structs.py b/predicators/spot_utils/perception/perception_structs.py index d676d20cf2..a39f78f380 100644 --- a/predicators/spot_utils/perception/perception_structs.py +++ b/predicators/spot_utils/perception/perception_structs.py @@ -29,9 +29,10 @@ def rotated_rgb(self) -> NDArray[np.uint8]: """The image rotated to be upright.""" return ndimage.rotate(self.rgb, self.image_rot, reshape=False) + @dataclass class RGBDImage: - """An RGBD image""" + """An RGBD image.""" rgb: NDArray[np.uint8] depth: NDArray[np.uint16] image_rot: float diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index 0c249a89f4..896a8f5e6a 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -2,16 +2,16 @@ from typing import Collection, Dict, Optional, Type import cv2 -from scipy import ndimage import numpy as np from bosdyn.api import image_pb2 from bosdyn.client.frame_helpers import BODY_FRAME_NAME, get_a_tform_b from bosdyn.client.image import ImageClient, build_image_request from bosdyn.client.sdk import Robot from numpy.typing import NDArray +from scipy import ndimage -from predicators.spot_utils.perception.perception_structs import \ - RGBDImageWithContext, RGBDImage +from predicators.spot_utils.perception.perception_structs import RGBDImage, \ + RGBDImageWithContext from predicators.spot_utils.spot_localization import SpotLocalizer ROTATION_ANGLE = { @@ -193,7 +193,8 @@ def capture_images_without_context( # world_tform_camera, depth_scale, # transforms_snapshot, # frame_name_image_sensor, camera_model) - rgbd = RGBDImage(rgb_img, depth_img, rot, camera_name, depth_scale, camera_model) + rgbd = RGBDImage(rgb_img, depth_img, rot, camera_name, depth_scale, + camera_model) rgbds[camera_name] = rgbd # _LAST_CAPTURED_IMAGES = rgbds diff --git a/predicators/spot_utils/spot_localization.py b/predicators/spot_utils/spot_localization.py index 27283b248e..e3178b6dbb 100644 --- a/predicators/spot_utils/spot_localization.py +++ b/predicators/spot_utils/spot_localization.py @@ -55,8 +55,10 @@ def __init__(self, robot: Robot, upload_path: Path, self._robot_pose = math_helpers.SE3Pose(0, 0, 0, math_helpers.Quat()) # Initialize the robot's position in the map. robot_state = get_robot_state(self._robot) - z_position = robot_state.kinematic_state.transforms_snapshot.child_to_parent_edge_map["gpe"].parent_tform_child.position.z - import pdb; pdb.set_trace() + z_position = robot_state.kinematic_state.transforms_snapshot.child_to_parent_edge_map[ + "gpe"].parent_tform_child.position.z + import pdb + pdb.set_trace() # current_odom_tform_body = get_odom_tform_body( # robot_state.kinematic_state.transforms_snapshot).to_proto() # localization = nav_pb2.Localization() diff --git a/predicators/structs.py b/predicators/structs.py index 185deecdf5..5309fc4ee3 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -492,7 +492,8 @@ def __post_init__(self) -> None: def goal_holds(self, state: State, vlm: Optional[Any] = None) -> bool: """Return whether the goal of this task holds in the given state.""" from predicators.utils import query_vlm_for_atom_vals - vlm_atoms = set(atom for atom in self.goal if isinstance(atom.predicate, VLMPredicate)) + vlm_atoms = set(atom for atom in self.goal + if isinstance(atom.predicate, VLMPredicate)) for atom in self.goal: if atom not in vlm_atoms: if not atom.holds(state): diff --git a/predicators/utils.py b/predicators/utils.py index 1aed539f3d..bd2472f3f0 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2494,6 +2494,73 @@ def parse_model_output_into_option_plan( return option_plan +def get_prompt_for_vlm_state_labelling( + prompt_type: str, atoms_list: List[str], label_history: List[str], + imgs_history: List[List[PIL.Image.Image]], + cropped_imgs_history: List[List[PIL.Image.Image]], + skill_history: List[Action]) -> Tuple[str, List[PIL.Image.Image]]: + """Prompt for generating labels for an entire trajectory. Similar to the + above prompting method, this outputs a list of prompts to label the state + at each timestep of traj with atom values). + + Note that all our prompts are saved as separate txt files under the + 'vlm_input_data_prompts/atom_labelling' folder. + """ + # Load the pre-specified prompt. + filepath_prefix = get_path_to_predicators_root() + \ + "/predicators/datasets/vlm_input_data_prompts/atom_proposal/" + try: + with open(filepath_prefix + + CFG.grammar_search_vlm_atom_label_prompt_type + ".txt", + "r", + encoding="utf-8") as f: + prompt = f.read() + except FileNotFoundError: + raise ValueError("Unknown VLM prompting option " + + f"{CFG.grammar_search_vlm_atom_label_prompt_type}") + # The prompt ends with a section for 'Predicates', so list these. + for atom_str in atoms_list: + prompt += f"\n{atom_str}" + + if "img_option_diffs" in prompt_type: + # In this case, we need to load the 'per_scene_naive' prompt as well + # for the first timestep. + with open(filepath_prefix + "per_scene_naive.txt", + "r", + encoding="utf-8") as f: + init_prompt = f.read() + for atom_str in atoms_list: + init_prompt += f"\n{atom_str}" + if len(label_history) == 0: + return (init_prompt, imgs_history[0]) + # Now, we use actual difference-based prompting for the second timestep + # and beyond. + curr_prompt = prompt[:] + curr_prompt_imgs = [ + imgs_timestep[0] for imgs_timestep in imgs_history[-1] + ] + if CFG.vlm_include_cropped_images: + if CFG.env in ["burger", "burger_no_move"]: # pragma: no cover + curr_prompt_imgs.extend( + [cropped_imgs_history[-1][1], cropped_imgs_history[-1][0]]) + else: + raise NotImplementedError( + f"Cropped images not implemented for {CFG.env}.") + curr_prompt += "\n\nSkill executed between states: " + skill_name = skill_history[-1].name + str(skill_history[-1].objects) + curr_prompt += skill_name + if "label_history" in prompt_type: + curr_prompt += "\n\nPredicate values in the first scene, " \ + "before the skill was executed: \n" + curr_prompt += label_history[-1] + return (curr_prompt, curr_prompt_imgs) + else: + # NOTE: we rip out only the first image from each trajectory + # which is fine for most domains, but will be problematic for + # situations in which there is more than one image per state. + return (prompt, [imgs_history[-1][0]]) + + def query_vlm_for_atom_vals( vlm_atoms: Collection[GroundAtom], state: State, @@ -2505,17 +2572,20 @@ def query_vlm_for_atom_vals( # vlm can be called on. assert state.simulator_state is not None assert isinstance(state.simulator_state["images"], List) + if "vlm_atoms_history" not in state.simulator_state: + state.simulator_state["vlm_atoms_history"] = [] imgs = state.simulator_state["images"] + previous_states = [] + # We assume the state.simulator_state contains a list of previous states. + if "state_history" in state.simulator_state: + previous_states = state.simulator_state["state_history"] + state_imgs_history = [state.simulator_state["images"] for state in previous_states] + + # TODO: need to somehow get the history of skills executed; i'll think about this more and then implement. + vlm_atoms = sorted(vlm_atoms) - atom_queries_str = "\n* " - atom_queries_str += "\n* ".join(atom.get_vlm_query_str() - for atom in vlm_atoms) - filepath_to_vlm_prompt = get_path_to_predicators_root() + \ - "/predicators/datasets/vlm_input_data_prompts/atom_labelling/" + \ - "per_scene_naive.txt" - with open(filepath_to_vlm_prompt, "r", encoding="utf-8") as f: - vlm_query_str = f.read() - vlm_query_str += atom_queries_str + atom_queries_str = [atom.get_vlm_query_str() for atom in vlm_atoms] + vlm_query_str, imgs = get_prompt_for_vlm_state_labelling(CFG.vlm_test_time_atom_label_prompt_type, atom_queries_str, state.simulator_state["vlm_atoms_history"], state_imgs_history, [], skill_history) if vlm is None: vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover. vlm_input_imgs = \ @@ -2530,7 +2600,6 @@ def query_vlm_for_atom_vals( print(f"VLM output: {vlm_output_str}") all_atom_queries = atom_queries_str.strip().split("\n") all_vlm_responses = vlm_output_str.strip().split("\n") - # NOTE: this assumption is likely too brittle; if this is breaking, feel # free to remove/adjust this and change the below parsing loop accordingly! assert len(all_atom_queries) == len(all_vlm_responses) @@ -2542,6 +2611,8 @@ def query_vlm_for_atom_vals( if curr_vlm_output_line[len(atom_query + ":"):period_idx].lower().strip() == "true": true_atoms.add(vlm_atoms[i]) + # Add the text of the VLM's response to the state, to be used in the future! + state.simulator_state["vlm_atoms_history"].append(all_vlm_responses) return true_atoms diff --git a/tests/envs/test_spot_envs.py b/tests/envs/test_spot_envs.py index f05deceff3..f723f4f138 100644 --- a/tests/envs/test_spot_envs.py +++ b/tests/envs/test_spot_envs.py @@ -31,6 +31,7 @@ def test_spot_vlm_debug(): }) pass + def test_spot_env_dry_run(): """Dry run tests (do not require access to robot).""" utils.reset_config({ @@ -271,13 +272,15 @@ def real_robot_cube_env_test() -> None: "test_task_json_dir": args.get("test_task_json_dir", None), }) - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() rng = np.random.default_rng(123) import os os.environ["BOSDYN_CLIENT_USERNAME"] = "user" os.environ["BOSDYN_CLIENT_PASSWORD"] = "bbbdddaaaiii" env = SpotCubeEnv() - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() perceiver = SpotPerceiver() nsrts = get_gt_nsrts(env.get_name(), env.predicates, get_gt_options(env.get_name())) @@ -286,7 +289,8 @@ def real_robot_cube_env_test() -> None: task = env.get_test_tasks()[0] obs = env.reset("test", 0) perceiver.reset(task) - import pdb; pdb.set_trace() + import pdb + pdb.set_trace() assert len(obs.objects_in_view) == 4 cube, floor, table1, table2 = sorted(obs.objects_in_view) assert cube.name == "cube" From cf4033494dd7d3e883acd3f118971f8373af5bf1 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Fri, 13 Sep 2024 16:17:53 -0400 Subject: [PATCH 10/24] Progress towards object detection result. --- predicators/envs/spot_env.py | 116 +++++++++++++++++- predicators/perception/spot_perceiver.py | 9 +- .../spot_utils/perception/object_detection.py | 109 +++++++++++++++- .../spot_utils/perception/spot_cameras.py | 12 +- 4 files changed, 230 insertions(+), 16 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index affadd59a9..d0b5f5f3b9 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -25,11 +25,11 @@ from predicators.spot_utils.perception.object_detection import \ AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \ LanguageObjectDetectionID, ObjectDetectionID, detect_objects, \ - visualize_all_artifacts + visualize_all_artifacts, _query_detic_sam2 from predicators.spot_utils.perception.object_specific_grasp_selection import \ brush_prompt, bucket_prompt, football_prompt, train_toy_prompt from predicators.spot_utils.perception.perception_structs import \ - RGBDImageWithContext + RGBDImageWithContext, RGBDImage, SegmentedBoundingBox from predicators.spot_utils.perception.spot_cameras import capture_images, capture_images_without_context from predicators.spot_utils.skills.spot_find_objects import \ init_search_for_objects @@ -2468,8 +2468,16 @@ def _get_dry_task(self, train_or_test: str, @property def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]: - """Get an object from a perception detection ID.""" - raise NotImplementedError("No dry task for VLMTestEnv.") + + detection_id_to_obj: Dict[ObjectDetectionID, Object] = {} + objects = { + Object("pan", _movable_object_type), + Object("cup", _movable_object_type) + } + for o in objects: + detection_id = LanguageObjectDetectionID(o.name) + detection_id_to_obj[detection_id] = o + return detection_id_to_obj def _create_operators(self) -> Iterator[STRIPSOperator]: # Pick object @@ -2553,8 +2561,106 @@ def __init__(self, use_gui: bool = True) -> None: def _actively_construct_env_task(self) -> EnvironmentTask: assert self._robot is not None rgbd_images = capture_images_without_context(self._robot) + # import PIL + # imgs = [v.rgb for _, v in rgbd_images.items()] + # rot_imgs = [v.rotated_rgb for _, v in rgbd_images.items()] + # ex1 = PIL.Image.fromarray(imgs[0]) + # ex2 = PIL.Image.fromarray(rot_imgs[0]) + # import pdb; pdb.set_trace() gripper_open_percentage = get_robot_gripper_open_percentage(self._robot) objects_in_view = [] + + # Perform object detection. + object_ids = self._detection_id_to_obj.keys() + ret = _query_detic_sam2(object_ids, rgbd_images) + artifacts = {"language": {"rgbds": rgbd_images, "object_id_to_img_detections": ret}} + detections_outfile = Path(".") / "object_detection_artifacts.png" + no_detections_outfile = Path(".") / "no_detection_artifacts.png" + visualize_all_artifacts(artifacts, detections_outfile, no_detections_outfile) + + # Draw object bounding box on images. + rgbds = artifacts["language"]["rgbds"] + detections = artifacts["language"]["object_id_to_img_detections"] + flat_detections: List[Tuple[RGBDImage, + LanguageObjectDetectionID, + SegmentedBoundingBox]] = [] + for obj_id, img_detections in detections.items(): + for camera, seg_bb in img_detections.items(): + rgbd = rgbds[camera] + flat_detections.append((rgbd, obj_id, seg_bb)) + + # For now assume we only have 1 image, front-left. + import pdb; pdb.set_trace() + import PIL + from PIL import ImageDraw, ImageFont + bb_pil_imgs = [] + img = list(rgbd_images.values())[0].rotated_rgb + pil_img = PIL.Image.fromarray(img) + draw = ImageDraw.Draw(pil_img) + for i, (rgbd, obj_id, seg_bb) in enumerate(flat_detections): + # img = rgbd.rotated_rgb + # pil_img = PIL.Image.fromarray(img) + x0, y0, x1, y1 = seg_bb.bounding_box + draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2) + text = f"{obj_id.language_id}" + font = ImageFont.load_default() + # font = utils.get_scaled_default_font(draw, 4) + # text_width, text_height = draw.textsize(text, font) + # text_width = draw.textlength(text, font) + # text_height = font.getsize("hg")[1] + text_mask = font.getmask(text) + text_width, text_height = text_mask.size + text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)] + draw.rectangle(text_bbox, fill='green') + draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font) + + import pdb; pdb.set_trace() + + + + # box = seg_bb.bounding_box + # x0, y0 = box[0], box[1] + # w, h = box[2] - box[0], box[3] - box[1] + # ax_row[3].add_patch( + # plt.Rectangle((x0, y0), + # w, + # h, + # edgecolor='green', + # facecolor=(0, 0, 0, 0), + # lw=1)) + + # import PIL + # from PIL import ImageDraw + # annotated_pil_imgs = [] + # for img, img_name in zip(imgs, img_names): + # pil_img = PIL.Image.fromarray(img) + # draw = ImageDraw.Draw(pil_img) + # font = utils.get_scaled_default_font(draw, 4) + # annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) + # annotated_pil_imgs.append(pil_img) + # annotated_imgs = [np.array(img) for img in annotated_pil_imgs] + + # im = Image.open(image_path) + # draw = ImageDraw.Draw(im) + # font = ImageFont.load_default() # You can use a specific font if needed + + # for mask in masks: + # # Assuming you have a function to convert the mask to a PIL Image or polygon + # mask_image = convert_mask_to_pil(mask) + # im.paste(mask_image, (0, 0), mask_image) + + # for box, class_name, score in zip(input_boxes, classes, scores): + # x0, y0, x1, y1 = box + # draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2) + # text = f"{class_name}: {score:.2f}" + # text_width, text_height = draw.textsize(text, font) + # text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)] + # draw.rectangle(text_bbox, fill='green') + # draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font) + + # im.show() # Or save it: im.save("output.jpg") + # import pdb; pdb.set_trace() + obs = _TruncatedSpotObservation( rgbd_images, set(objects_in_view), @@ -2621,8 +2727,8 @@ def step(self, action: Action) -> Observation: f"was encountered. Trying again.\n{e}") rgbd_images = capture_images_without_context(self._robot) gripper_open_percentage = get_robot_gripper_open_percentage(self._robot) - print(gripper_open_percentage) objects_in_view = [] + obs = _TruncatedSpotObservation( rgbd_images, set(objects_in_view), diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 14c4b5550c..eef7e53c29 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -24,6 +24,11 @@ GoalDescription, GroundAtom, Object, Observation, Predicate, \ SpotActionExtraInfo, State, Task, Video +from predicators.spot_utils.perception.object_detection import \ + AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \ + LanguageObjectDetectionID, ObjectDetectionID, detect_objects, \ + visualize_all_artifacts, _query_detic_sam2 + class SpotPerceiver(BasePerceiver): """A perceiver specific to spot envs.""" @@ -683,7 +688,7 @@ def step(self, observation: Observation) -> State: imgs = observation.rgbd_images img_names = [v.camera_name for _, v in imgs.items()] imgs = [v.rgb for _, v in imgs.items()] - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() import PIL from PIL import ImageDraw annotated_pil_imgs = [] @@ -694,7 +699,7 @@ def step(self, observation: Observation) -> State: annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) annotated_pil_imgs.append(pil_img) annotated_imgs = [np.array(img) for img in annotated_pil_imgs] - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() self._gripper_open_percentage = observation.gripper_open_percentage self._curr_state = self._create_state() self._curr_state.simulator_state["images"] = annotated_imgs diff --git a/predicators/spot_utils/perception/object_detection.py b/predicators/spot_utils/perception/object_detection.py index 6757e8f324..94b7581d66 100644 --- a/predicators/spot_utils/perception/object_detection.py +++ b/predicators/spot_utils/perception/object_detection.py @@ -41,7 +41,7 @@ from predicators.spot_utils.perception.perception_structs import \ AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \ LanguageObjectDetectionID, ObjectDetectionID, PythonicObjectDetectionID, \ - RGBDImageWithContext, SegmentedBoundingBox + RGBDImageWithContext, RGBDImage, SegmentedBoundingBox from predicators.spot_utils.utils import get_april_tag_transform, \ get_graph_nav_dir from predicators.utils import rotate_point_in_image @@ -351,6 +351,108 @@ def _query_detic_sam( return object_id_to_img_detections +def _query_detic_sam2( + object_ids: Collection[LanguageObjectDetectionID], + rgbds: Dict[str, RGBDImage], + max_server_retries: int = 5, + detection_threshold: float = CFG.spot_vision_detection_threshold +) -> Dict[ObjectDetectionID, Dict[str, SegmentedBoundingBox]]: + """Returns object ID to image ID (camera) to segmented bounding box.""" + + object_id_to_img_detections: Dict[ObjectDetectionID, + Dict[str, SegmentedBoundingBox]] = { + obj_id: {} + for obj_id in object_ids + } + + # Create buffer dictionary to send to server. + buf_dict = {} + for camera_name, rgbd in rgbds.items(): + pil_rotated_img = PIL.Image.fromarray(rgbd.rotated_rgb) # type: ignore + buf_dict[camera_name] = _image_to_bytes(pil_rotated_img) + + # Extract all the classes that we want to detect. + classes = sorted(o.language_id for o in object_ids) + + # Query server, retrying to handle possible wifi issues. + # import pdb; pdb.set_trace() + # imgs = [v.rotated_rgb for _, v in rgbds.items()] + # pil_img = PIL.Image.fromarray(imgs[0]) + # import pdb; pdb.set_trace() + + for _ in range(max_server_retries): + try: + r = requests.post("http://localhost:5550/batch_predict", + files=buf_dict, + data={"classes": ",".join(classes)}) + break + except requests.exceptions.ConnectionError: + continue + else: + logging.warning("DETIC-SAM FAILED, POSSIBLE SERVER/WIFI ISSUE") + return object_id_to_img_detections + + # If the status code is not 200, then fail. + if r.status_code != 200: + logging.warning(f"DETIC-SAM FAILED! STATUS CODE: {r.status_code}") + return object_id_to_img_detections + + # Querying the server succeeded; unpack the contents. + with io.BytesIO(r.content) as f: + try: + server_results = np.load(f, allow_pickle=True) + # Corrupted results. + except pkl.UnpicklingError: + logging.warning("DETIC-SAM FAILED DURING UNPICKLING!") + return object_id_to_img_detections + + # Process the results and save all detections per object ID. + for camera_name, rgbd in rgbds.items(): + rot_boxes = server_results[f"{camera_name}_boxes"] + ret_classes = server_results[f"{camera_name}_classes"] + rot_masks = server_results[f"{camera_name}_masks"] + scores = server_results[f"{camera_name}_scores"] + + # Invert the rotation immediately so we don't need to worry about + # them henceforth. + # h, w = rgbd.rgb.shape[:2] + # image_rot = rgbd.image_rot + # boxes = [ + # _rotate_bounding_box(bb, -image_rot, h, w) for bb in rot_boxes + # ] + # masks = [ + # ndimage.rotate(m.squeeze(), -image_rot, reshape=False) + # for m in rot_masks + # ] + boxes = rot_boxes + masks = rot_masks + + # Filter out detections by confidence. We threshold detections + # at a set confidence level minimum, and if there are multiple, + # we only select the most confident one. This structure makes + # it easy for us to select multiple detections if that's ever + # necessary in the future. + for obj_id in object_ids: + # If there were no detections (which means all the + # returned values will be numpy arrays of shape (0, 0)) + # then just skip this source. + if ret_classes.size == 0: + continue + obj_id_mask = (ret_classes == obj_id.language_id) + if not np.any(obj_id_mask): + continue + max_score = np.max(scores[obj_id_mask]) + best_idx = np.where(scores == max_score)[0].item() + if scores[best_idx] < detection_threshold: + continue + # Save the detection. + seg_bb = SegmentedBoundingBox(boxes[best_idx], masks[best_idx], + scores[best_idx]) + object_id_to_img_detections[obj_id][rgbd.camera_name] = seg_bb + + import pdb; pdb.set_trace() + return object_id_to_img_detections + def _image_to_bytes(img: PIL.Image.Image) -> io.BytesIO: """Helper function to convert from a PIL image into a bytes object.""" @@ -522,7 +624,8 @@ def visualize_all_artifacts(artifacts: Dict[str, ax_row[2].imshow(rgbd.depth, cmap='Greys_r', vmin=0, vmax=10000) # Bounding box. - ax_row[3].imshow(rgbd.rgb) + # ax_row[3].imshow(rgbd.rgb) + ax_row[3].imshow(rgbd.rotated_rgb) box = seg_bb.bounding_box x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] @@ -534,7 +637,7 @@ def visualize_all_artifacts(artifacts: Dict[str, facecolor=(0, 0, 0, 0), lw=1)) - ax_row[4].imshow(seg_bb.mask, cmap="binary_r", vmin=0, vmax=1) + # ax_row[4].imshow(seg_bb.mask, cmap="binary_r", vmin=0, vmax=1) # Labels. abbreviated_name = obj_id.language_id diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index 0c249a89f4..299fd8db0f 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -23,12 +23,12 @@ 'right_fisheye_image': 180 } RGB_TO_DEPTH_CAMERAS = { - "hand_color_image": "hand_depth_in_hand_color_frame", - "left_fisheye_image": "left_depth_in_visual_frame", - "right_fisheye_image": "right_depth_in_visual_frame", + # "hand_color_image": "hand_depth_in_hand_color_frame", + # "left_fisheye_image": "left_depth_in_visual_frame", + # "right_fisheye_image": "right_depth_in_visual_frame", "frontleft_fisheye_image": "frontleft_depth_in_visual_frame", - "frontright_fisheye_image": "frontright_depth_in_visual_frame", - "back_fisheye_image": "back_depth_in_visual_frame" + # "frontright_fisheye_image": "frontright_depth_in_visual_frame", + # "back_fisheye_image": "back_depth_in_visual_frame" } # Hack to avoid double image capturing when we want to (1) get object states @@ -174,7 +174,7 @@ def capture_images_without_context( rgb_img_resp = name_to_response[camera_name] depth_img_resp = name_to_response[RGB_TO_DEPTH_CAMERAS[camera_name]] rgb_img = _image_response_to_image(rgb_img_resp) - rgb_img = ndimage.rotate(rgb_img, ROTATION_ANGLE[camera_name]) + # rgb_img = ndimage.rotate(rgb_img, ROTATION_ANGLE[camera_name], reshape=False) depth_img = _image_response_to_image(depth_img_resp) # # Create transform. # camera_tform_body = get_a_tform_b( From 369b399c497aa8c19d89f8be421471b8f8a1fec2 Mon Sep 17 00:00:00 2001 From: NishanthJKumar Date: Fri, 13 Sep 2024 17:36:04 -0400 Subject: [PATCH 11/24] finish up implementation! --- predicators/envs/spot_env.py | 7 ++++--- predicators/perception/spot_perceiver.py | 5 ++++- predicators/utils.py | 5 +---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index e4cbc042fd..1a88c59028 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -48,7 +48,7 @@ update_pbrspot_robot_conf, verify_estop from predicators.structs import Action, EnvironmentTask, GoalDescription, \ GroundAtom, LiftedAtom, Object, Observation, Predicate, \ - SpotActionExtraInfo, State, STRIPSOperator, Type, Variable + SpotActionExtraInfo, State, STRIPSOperator, Type, Variable, _Option ############################################################################### # Base Class # @@ -106,6 +106,7 @@ class _TruncatedSpotObservation: # # A placeholder until all predicates have classifiers # nonpercept_atoms: Set[GroundAtom] # nonpercept_predicates: Set[Predicate] + executed_skill: Optional[_Option] = None class _PartialPerceptionState(State): @@ -2547,7 +2548,7 @@ def _actively_construct_env_task(self) -> EnvironmentTask: objects_in_view = [] obs = _TruncatedSpotObservation(rgbd_images, set(objects_in_view), set(), set(), self._spot_object, - gripper_open_percentage) + gripper_open_percentage, None) goal_description = self._generate_goal_description() task = EnvironmentTask(obs, goal_description) return task @@ -2611,7 +2612,7 @@ def step(self, action: Action) -> Observation: objects_in_view = [] obs = _TruncatedSpotObservation(rgbd_images, set(objects_in_view), set(), set(), self._spot_object, - gripper_open_percentage) + gripper_open_percentage, action.get_option()) return obs diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index a9d123fec8..137edae86a 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -23,7 +23,7 @@ get_allowed_map_regions, load_spot_metadata, object_to_top_down_geom from predicators.structs import Action, DefaultState, EnvironmentTask, \ GoalDescription, GroundAtom, Object, Observation, Predicate, \ - SpotActionExtraInfo, State, Task, Video + SpotActionExtraInfo, State, Task, Video, _Option class SpotPerceiver(BasePerceiver): @@ -621,6 +621,7 @@ def __init__(self) -> None: self._ordered_objects: List[Object] = [] # list of all known objects self._state_history: Deque[State] = deque( maxlen=5) # TODO: (njk) I just picked an arbitrary constant here! Didn't properly consider this. + self._executed_skill_history: Deque[_Option] = deque(maxlen=5) # # Keep track of objects that are contained (out of view) in another # # object, like a bag or bucket. This is important not only for gremlins # # but also for small changes in the container's perceived pose. @@ -707,6 +708,8 @@ def step(self, observation: Observation) -> State: ret_state = self._curr_state.copy() ret_state.simulator_state["state_history"] = list(self._state_history) self._state_history.append(ret_state) + self._executed_skill_history.append(observation.executed_skill) + ret_state.simulator_state["skill_history"] = list(self._executed_skill_history) return ret_state def _create_state(self) -> State: diff --git a/predicators/utils.py b/predicators/utils.py index bd2472f3f0..4980c50850 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2580,12 +2580,9 @@ def query_vlm_for_atom_vals( if "state_history" in state.simulator_state: previous_states = state.simulator_state["state_history"] state_imgs_history = [state.simulator_state["images"] for state in previous_states] - - # TODO: need to somehow get the history of skills executed; i'll think about this more and then implement. - vlm_atoms = sorted(vlm_atoms) atom_queries_str = [atom.get_vlm_query_str() for atom in vlm_atoms] - vlm_query_str, imgs = get_prompt_for_vlm_state_labelling(CFG.vlm_test_time_atom_label_prompt_type, atom_queries_str, state.simulator_state["vlm_atoms_history"], state_imgs_history, [], skill_history) + vlm_query_str, imgs = get_prompt_for_vlm_state_labelling(CFG.vlm_test_time_atom_label_prompt_type, atom_queries_str, state.simulator_state["vlm_atoms_history"], state_imgs_history, [], state.simulator_state["skill_history"]) if vlm is None: vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover. vlm_input_imgs = \ From 94dff09968ad3fc417ea0ae6ef555887d22b59d7 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Sat, 14 Sep 2024 18:16:16 -0400 Subject: [PATCH 12/24] Get camera annotation + object detection annotation to work on all images. --- predicators/envs/spot_env.py | 158 ++++++++---------- predicators/perception/spot_perceiver.py | 36 +++- .../spot_utils/perception/object_detection.py | 2 +- .../spot_utils/perception/spot_cameras.py | 12 +- 4 files changed, 106 insertions(+), 102 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index d0b5f5f3b9..3700100e0c 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -105,6 +105,8 @@ class _TruncatedSpotObservation: # # A placeholder until all predicates have classifiers # nonpercept_atoms: Set[GroundAtom] # nonpercept_predicates: Set[Predicate] + # Object detections per camera in self.rgbd_images. + object_detections_per_camera: Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]] class _PartialPerceptionState(State): @@ -2472,7 +2474,9 @@ def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]: detection_id_to_obj: Dict[ObjectDetectionID, Object] = {} objects = { Object("pan", _movable_object_type), - Object("cup", _movable_object_type) + Object("cup", _movable_object_type), + Object("chair", _movable_object_type), + Object("bowl", _movable_object_type), } for o in objects: detection_id = LanguageObjectDetectionID(o.name) @@ -2558,6 +2562,18 @@ def __init__(self, use_gui: bool = True) -> None: self._train_tasks = [] self._test_tasks = [] + def detect_objects(self, rgbd_images: Dict[str, RGBDImage]) -> Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]]: + object_ids = self._detection_id_to_obj.keys() + object_id_to_img_detections = _query_detic_sam2(object_ids, rgbd_images) + # This ^ is currently a mapping of object_id -> camera_name -> SegmentedBoundingBox. + # We want to do our annotations by camera image, so let's turn this into a + # mapping of camera_name -> object_id -> SegmentedBoundingBox. + detections = {k: [] for k in rgbd_images.keys()} + for object_id, d in object_id_to_img_detections.items(): + for camera_name, seg_bb in d.items(): + detections[camera_name].append((object_id, seg_bb)) + return detections + def _actively_construct_env_task(self) -> EnvironmentTask: assert self._robot is not None rgbd_images = capture_images_without_context(self._robot) @@ -2571,103 +2587,60 @@ def _actively_construct_env_task(self) -> EnvironmentTask: objects_in_view = [] # Perform object detection. - object_ids = self._detection_id_to_obj.keys() - ret = _query_detic_sam2(object_ids, rgbd_images) - artifacts = {"language": {"rgbds": rgbd_images, "object_id_to_img_detections": ret}} - detections_outfile = Path(".") / "object_detection_artifacts.png" - no_detections_outfile = Path(".") / "no_detection_artifacts.png" - visualize_all_artifacts(artifacts, detections_outfile, no_detections_outfile) - - # Draw object bounding box on images. - rgbds = artifacts["language"]["rgbds"] - detections = artifacts["language"]["object_id_to_img_detections"] - flat_detections: List[Tuple[RGBDImage, - LanguageObjectDetectionID, - SegmentedBoundingBox]] = [] - for obj_id, img_detections in detections.items(): - for camera, seg_bb in img_detections.items(): - rgbd = rgbds[camera] - flat_detections.append((rgbd, obj_id, seg_bb)) + object_detections_per_camera = self.detect_objects(rgbd_images) + + + # artifacts = {"language": {"rgbds": rgbd_images, "object_id_to_img_detections": ret}} + # detections_outfile = Path(".") / "object_detection_artifacts.png" + # no_detections_outfile = Path(".") / "no_detection_artifacts.png" + # visualize_all_artifacts(artifacts, detections_outfile, no_detections_outfile) + + # # Draw object bounding box on images. + # rgbds = artifacts["language"]["rgbds"] + # detections = artifacts["language"]["object_id_to_img_detections"] + # flat_detections: List[Tuple[RGBDImage, + # LanguageObjectDetectionID, + # SegmentedBoundingBox]] = [] + # for obj_id, img_detections in detections.items(): + # for camera, seg_bb in img_detections.items(): + # rgbd = rgbds[camera] + # flat_detections.append((rgbd, obj_id, seg_bb)) - # For now assume we only have 1 image, front-left. - import pdb; pdb.set_trace() - import PIL - from PIL import ImageDraw, ImageFont - bb_pil_imgs = [] - img = list(rgbd_images.values())[0].rotated_rgb - pil_img = PIL.Image.fromarray(img) - draw = ImageDraw.Draw(pil_img) - for i, (rgbd, obj_id, seg_bb) in enumerate(flat_detections): - # img = rgbd.rotated_rgb - # pil_img = PIL.Image.fromarray(img) - x0, y0, x1, y1 = seg_bb.bounding_box - draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2) - text = f"{obj_id.language_id}" - font = ImageFont.load_default() - # font = utils.get_scaled_default_font(draw, 4) - # text_width, text_height = draw.textsize(text, font) - # text_width = draw.textlength(text, font) - # text_height = font.getsize("hg")[1] - text_mask = font.getmask(text) - text_width, text_height = text_mask.size - text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)] - draw.rectangle(text_bbox, fill='green') - draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font) - - import pdb; pdb.set_trace() - - - - # box = seg_bb.bounding_box - # x0, y0 = box[0], box[1] - # w, h = box[2] - box[0], box[3] - box[1] - # ax_row[3].add_patch( - # plt.Rectangle((x0, y0), - # w, - # h, - # edgecolor='green', - # facecolor=(0, 0, 0, 0), - # lw=1)) - - # import PIL - # from PIL import ImageDraw - # annotated_pil_imgs = [] - # for img, img_name in zip(imgs, img_names): - # pil_img = PIL.Image.fromarray(img) - # draw = ImageDraw.Draw(pil_img) - # font = utils.get_scaled_default_font(draw, 4) - # annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) - # annotated_pil_imgs.append(pil_img) - # annotated_imgs = [np.array(img) for img in annotated_pil_imgs] - - # im = Image.open(image_path) - # draw = ImageDraw.Draw(im) - # font = ImageFont.load_default() # You can use a specific font if needed - - # for mask in masks: - # # Assuming you have a function to convert the mask to a PIL Image or polygon - # mask_image = convert_mask_to_pil(mask) - # im.paste(mask_image, (0, 0), mask_image) - - # for box, class_name, score in zip(input_boxes, classes, scores): - # x0, y0, x1, y1 = box - # draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2) - # text = f"{class_name}: {score:.2f}" - # text_width, text_height = draw.textsize(text, font) - # text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)] - # draw.rectangle(text_bbox, fill='green') - # draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font) - - # im.show() # Or save it: im.save("output.jpg") - # import pdb; pdb.set_trace() + # # For now assume we only have 1 image, front-left. + # import pdb; pdb.set_trace() + # import PIL + # from PIL import ImageDraw, ImageFont + # bb_pil_imgs = [] + # img = list(rgbd_images.values())[0].rotated_rgb + # pil_img = PIL.Image.fromarray(img) + # draw = ImageDraw.Draw(pil_img) + # for i, (rgbd, obj_id, seg_bb) in enumerate(flat_detections): + # # img = rgbd.rotated_rgb + # # pil_img = PIL.Image.fromarray(img) + # x0, y0, x1, y1 = seg_bb.bounding_box + # draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2) + # text = f"{obj_id.language_id}" + # font = ImageFont.load_default() + # # font = utils.get_scaled_default_font(draw, 4) + # # text_width, text_height = draw.textsize(text, font) + # # text_width = draw.textlength(text, font) + # # text_height = font.getsize("hg")[1] + # text_mask = font.getmask(text) + # text_width, text_height = text_mask.size + # text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)] + # draw.rectangle(text_bbox, fill='green') + # draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font) + # import pdb; pdb.set_trace() + obs = _TruncatedSpotObservation( rgbd_images, set(objects_in_view), set(), set(), self._spot_object, - gripper_open_percentage + gripper_open_percentage, + object_detections_per_camera ) goal_description = self._generate_goal_description() task = EnvironmentTask(obs, goal_description) @@ -2728,6 +2701,8 @@ def step(self, action: Action) -> Observation: rgbd_images = capture_images_without_context(self._robot) gripper_open_percentage = get_robot_gripper_open_percentage(self._robot) objects_in_view = [] + # Perform object detection. + object_detections_per_camera = self.detect_objects(rgbd_images) obs = _TruncatedSpotObservation( rgbd_images, @@ -2735,7 +2710,8 @@ def step(self, action: Action) -> Observation: set(), set(), self._spot_object, - gripper_open_percentage + gripper_open_percentage, + object_detections_per_camera ) return obs diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index eef7e53c29..1f86189b71 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -683,12 +683,40 @@ def reset(self, env_task: EnvironmentTask) -> Task: return Task(state, goal) def step(self, observation: Observation) -> State: + import pdb; pdb.set_trace() self._waiting_for_observation = False self._robot = observation.robot - imgs = observation.rgbd_images - img_names = [v.camera_name for _, v in imgs.items()] - imgs = [v.rgb for _, v in imgs.items()] - # import pdb; pdb.set_trace() + img_objects = observation.rgbd_images # RGBDImage objects + img_names = [v.camera_name for _, v in img_objects.items()] + imgs = [v.rotated_rgb for _, v in img_objects.items()] + import PIL + from PIL import ImageDraw, ImageFont + pil_imgs = [PIL.Image.fromarray(img) for img in imgs] + # Annotate images with detected objects (names + bounding box) + # and camera name. + object_detections_per_camera = observation.object_detections_per_camera + imgs_with_objects_annotated = [] # These are PIL images. + for i, camera_name in enumerate(img_names): + draw = ImageDraw.Draw(pil_imgs[i]) + # Annotate with camera name. + font = utils.get_scaled_default_font(draw, 4) + _ = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[camera_name], font) + # Annotate with object detections. + detections = object_detections_per_camera[camera_name] + for obj_id, seg_bb in detections: + x0, y0, x1, y1 = seg_bb.bounding_box + draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2) + text = f"{obj_id.language_id}" + font = ImageFont.load_default() + text_mask = font.getmask(text) + text_width, text_height = text_mask.size + text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)] + draw.rectangle(text_bbox, fill='green') + draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font) + + import pdb; pdb.set_trace() + + import PIL from PIL import ImageDraw annotated_pil_imgs = [] diff --git a/predicators/spot_utils/perception/object_detection.py b/predicators/spot_utils/perception/object_detection.py index 94b7581d66..90365f2b02 100644 --- a/predicators/spot_utils/perception/object_detection.py +++ b/predicators/spot_utils/perception/object_detection.py @@ -450,7 +450,7 @@ def _query_detic_sam2( scores[best_idx]) object_id_to_img_detections[obj_id][rgbd.camera_name] = seg_bb - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() return object_id_to_img_detections diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index 299fd8db0f..be5131690a 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -23,12 +23,12 @@ 'right_fisheye_image': 180 } RGB_TO_DEPTH_CAMERAS = { - # "hand_color_image": "hand_depth_in_hand_color_frame", - # "left_fisheye_image": "left_depth_in_visual_frame", - # "right_fisheye_image": "right_depth_in_visual_frame", + "hand_color_image": "hand_depth_in_hand_color_frame", + "left_fisheye_image": "left_depth_in_visual_frame", + "right_fisheye_image": "right_depth_in_visual_frame", "frontleft_fisheye_image": "frontleft_depth_in_visual_frame", - # "frontright_fisheye_image": "frontright_depth_in_visual_frame", - # "back_fisheye_image": "back_depth_in_visual_frame" + "frontright_fisheye_image": "frontright_depth_in_visual_frame", + "back_fisheye_image": "back_depth_in_visual_frame" } # Hack to avoid double image capturing when we want to (1) get object states @@ -125,7 +125,7 @@ def capture_images_without_context( robot: Robot, camera_names: Optional[Collection[str]] = None, quality_percent: int = 100, -) -> Dict[str, RGBDImageWithContext]: +) -> Dict[str, RGBDImage]: """Build an image request and get the responses. If no camera names are provided, all RGB cameras are used. From aec3f863604361410d0342a934a6e215b7b6b1d7 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Sat, 14 Sep 2024 21:00:21 -0400 Subject: [PATCH 13/24] Update annotation text font and background. --- predicators/perception/spot_perceiver.py | 7 +++---- predicators/spot_utils/perception/spot_cameras.py | 10 +++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 1f86189b71..6b834521c7 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -695,7 +695,6 @@ def step(self, observation: Observation) -> State: # Annotate images with detected objects (names + bounding box) # and camera name. object_detections_per_camera = observation.object_detections_per_camera - imgs_with_objects_annotated = [] # These are PIL images. for i, camera_name in enumerate(img_names): draw = ImageDraw.Draw(pil_imgs[i]) # Annotate with camera name. @@ -707,12 +706,12 @@ def step(self, observation: Observation) -> State: x0, y0, x1, y1 = seg_bb.bounding_box draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2) text = f"{obj_id.language_id}" - font = ImageFont.load_default() + font = utils.get_scaled_default_font(draw, 3) text_mask = font.getmask(text) text_width, text_height = text_mask.size - text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)] + text_bbox = [(x0, y0 - 1.5*text_height), (x0 + text_width + 1, y0)] draw.rectangle(text_bbox, fill='green') - draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font) + draw.text((x0 + 1, y0 - 1.5*text_height), text, fill='white', font=font) import pdb; pdb.set_trace() diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index be5131690a..15349535b8 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -23,12 +23,12 @@ 'right_fisheye_image': 180 } RGB_TO_DEPTH_CAMERAS = { - "hand_color_image": "hand_depth_in_hand_color_frame", - "left_fisheye_image": "left_depth_in_visual_frame", - "right_fisheye_image": "right_depth_in_visual_frame", + # "hand_color_image": "hand_depth_in_hand_color_frame", + # "left_fisheye_image": "left_depth_in_visual_frame", + # "right_fisheye_image": "right_depth_in_visual_frame", "frontleft_fisheye_image": "frontleft_depth_in_visual_frame", - "frontright_fisheye_image": "frontright_depth_in_visual_frame", - "back_fisheye_image": "back_depth_in_visual_frame" + # "frontright_fisheye_image": "frontright_depth_in_visual_frame", + # "back_fisheye_image": "back_depth_in_visual_frame" } # Hack to avoid double image capturing when we want to (1) get object states From ac9fc2f69b6ff01832e6c202feebfd359903f211 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Sat, 14 Sep 2024 21:07:05 -0400 Subject: [PATCH 14/24] Run autoformatter. --- predicators/approaches/bilevel_planning_approach.py | 1 - predicators/envs/spot_env.py | 11 ++++++----- predicators/ground_truth_models/spot_env/options.py | 8 ++++---- predicators/perception/spot_perceiver.py | 11 ++++------- predicators/spot_utils/perception/object_detection.py | 2 +- predicators/spot_utils/perception/spot_cameras.py | 6 +++--- predicators/spot_utils/spot_localization.py | 1 - 7 files changed, 18 insertions(+), 22 deletions(-) diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index 3c02e7dbc5..dbf9eb8a33 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -61,7 +61,6 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: utils.abstract(task.init, preds, self._vlm) import pdb; pdb.set_trace() # utils.abstract(task.init, preds, self._vlm) - # Run task planning only and then greedily sample and execute in the # policy. if self._plan_without_sim: diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 3700100e0c..f7a6798a53 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -24,13 +24,14 @@ from predicators.settings import CFG from predicators.spot_utils.perception.object_detection import \ AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \ - LanguageObjectDetectionID, ObjectDetectionID, detect_objects, \ - visualize_all_artifacts, _query_detic_sam2 + LanguageObjectDetectionID, ObjectDetectionID, _query_detic_sam2, \ + detect_objects, visualize_all_artifacts from predicators.spot_utils.perception.object_specific_grasp_selection import \ brush_prompt, bucket_prompt, football_prompt, train_toy_prompt -from predicators.spot_utils.perception.perception_structs import \ - RGBDImageWithContext, RGBDImage, SegmentedBoundingBox -from predicators.spot_utils.perception.spot_cameras import capture_images, capture_images_without_context +from predicators.spot_utils.perception.perception_structs import RGBDImage, \ + RGBDImageWithContext, SegmentedBoundingBox +from predicators.spot_utils.perception.spot_cameras import capture_images, \ + capture_images_without_context from predicators.spot_utils.skills.spot_find_objects import \ init_search_for_objects from predicators.spot_utils.skills.spot_hand_move import \ diff --git a/predicators/ground_truth_models/spot_env/options.py b/predicators/ground_truth_models/spot_env/options.py index 4e9b7293bf..59749ab03f 100644 --- a/predicators/ground_truth_models/spot_env/options.py +++ b/predicators/ground_truth_models/spot_env/options.py @@ -1,22 +1,22 @@ """Ground-truth options for Spot environments.""" +import logging import time from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple -import logging import numpy as np import pbrspot from bosdyn.client import math_helpers -from bosdyn.client.sdk import Robot from bosdyn.client.lease import LeaseClient +from bosdyn.client.sdk import Robot from gym.spaces import Box from predicators import utils from predicators.envs import get_or_create_env from predicators.envs.spot_env import HANDEMPTY_GRIPPER_THRESHOLD, \ SpotRearrangementEnv, _get_sweeping_surface_for_container, \ - get_detection_id_for_object, get_robot, get_robot_only, \ - get_robot_gripper_open_percentage, get_simulated_object, \ + get_detection_id_for_object, get_robot, \ + get_robot_gripper_open_percentage, get_robot_only, get_simulated_object, \ get_simulated_robot from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.settings import CFG diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 6b834521c7..b9e56d7199 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -17,6 +17,10 @@ _PartialPerceptionState, _SpotObservation, in_general_view_classifier from predicators.perception.base_perceiver import BasePerceiver from predicators.settings import CFG +from predicators.spot_utils.perception.object_detection import \ + AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \ + LanguageObjectDetectionID, ObjectDetectionID, _query_detic_sam2, \ + detect_objects, visualize_all_artifacts from predicators.spot_utils.utils import _container_type, \ _immovable_object_type, _movable_object_type, _robot_type, \ get_allowed_map_regions, load_spot_metadata, object_to_top_down_geom @@ -24,11 +28,6 @@ GoalDescription, GroundAtom, Object, Observation, Predicate, \ SpotActionExtraInfo, State, Task, Video -from predicators.spot_utils.perception.object_detection import \ - AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \ - LanguageObjectDetectionID, ObjectDetectionID, detect_objects, \ - visualize_all_artifacts, _query_detic_sam2 - class SpotPerceiver(BasePerceiver): """A perceiver specific to spot envs.""" @@ -714,8 +713,6 @@ def step(self, observation: Observation) -> State: draw.text((x0 + 1, y0 - 1.5*text_height), text, fill='white', font=font) import pdb; pdb.set_trace() - - import PIL from PIL import ImageDraw annotated_pil_imgs = [] diff --git a/predicators/spot_utils/perception/object_detection.py b/predicators/spot_utils/perception/object_detection.py index 90365f2b02..6b627a90ef 100644 --- a/predicators/spot_utils/perception/object_detection.py +++ b/predicators/spot_utils/perception/object_detection.py @@ -41,7 +41,7 @@ from predicators.spot_utils.perception.perception_structs import \ AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \ LanguageObjectDetectionID, ObjectDetectionID, PythonicObjectDetectionID, \ - RGBDImageWithContext, RGBDImage, SegmentedBoundingBox + RGBDImage, RGBDImageWithContext, SegmentedBoundingBox from predicators.spot_utils.utils import get_april_tag_transform, \ get_graph_nav_dir from predicators.utils import rotate_point_in_image diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index 15349535b8..036dc089ec 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -2,16 +2,16 @@ from typing import Collection, Dict, Optional, Type import cv2 -from scipy import ndimage import numpy as np from bosdyn.api import image_pb2 from bosdyn.client.frame_helpers import BODY_FRAME_NAME, get_a_tform_b from bosdyn.client.image import ImageClient, build_image_request from bosdyn.client.sdk import Robot from numpy.typing import NDArray +from scipy import ndimage -from predicators.spot_utils.perception.perception_structs import \ - RGBDImageWithContext, RGBDImage +from predicators.spot_utils.perception.perception_structs import RGBDImage, \ + RGBDImageWithContext from predicators.spot_utils.spot_localization import SpotLocalizer ROTATION_ANGLE = { diff --git a/predicators/spot_utils/spot_localization.py b/predicators/spot_utils/spot_localization.py index 27283b248e..145178cd7a 100644 --- a/predicators/spot_utils/spot_localization.py +++ b/predicators/spot_utils/spot_localization.py @@ -74,7 +74,6 @@ def __init__(self, robot: Robot, upload_path: Path, # raise LocalizationFailure(msg) # logging.warning("Localization failed once, retrying.") # time.sleep(LOCALIZATION_RETRY_WAIT_TIME) - # # Run localize once to start. # self.localize() From 5a461cc7fb24cfd9a09940453dc6e30969c04bbd Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Mon, 16 Sep 2024 14:30:31 -0400 Subject: [PATCH 15/24] Fix indent. --- predicators/envs/spot_env.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index f7a6798a53..374c0114a6 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -190,15 +190,15 @@ def get_robot( @functools.lru_cache(maxsize=None) def get_robot_only() -> Tuple[Optional[Robot], Optional[LeaseClient]]: - hostname = CFG.spot_robot_ip - sdk = create_standard_sdk("PredicatorsClient-") - robot = sdk.create_robot(hostname) - robot.authenticate("user", "bbbdddaaaiii") - verify_estop(robot) - lease_client = robot.ensure_client(LeaseClient.default_service_name) - lease_client.take() - lease_keepalive = LeaseKeepAlive(lease_client, must_acquire=True, return_at_exit=True) - return robot, lease_client + hostname = CFG.spot_robot_ip + sdk = create_standard_sdk("PredicatorsClient-") + robot = sdk.create_robot(hostname) + robot.authenticate("user", "bbbdddaaaiii") + verify_estop(robot) + lease_client = robot.ensure_client(LeaseClient.default_service_name) + lease_client.take() + lease_keepalive = LeaseKeepAlive(lease_client, must_acquire=True, return_at_exit=True) + return robot, lease_client @functools.lru_cache(maxsize=None) From c4c978f0bb31966fc76d1335995bdcfeb3bd7534 Mon Sep 17 00:00:00 2001 From: NishanthJKumar Date: Mon, 16 Sep 2024 14:45:21 -0400 Subject: [PATCH 16/24] fixes for comments --- predicators/perception/spot_perceiver.py | 8 +++----- predicators/utils.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 137edae86a..e7b61d386e 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -2,9 +2,8 @@ import logging import time -from collections import deque from pathlib import Path -from typing import Deque, Dict, List, Optional, Set +from typing import Dict, List, Optional, Set import imageio.v2 as iio import numpy as np @@ -619,9 +618,8 @@ def __init__(self) -> None: self._curr_env: Optional[BaseEnv] = None self._waiting_for_observation = True self._ordered_objects: List[Object] = [] # list of all known objects - self._state_history: Deque[State] = deque( - maxlen=5) # TODO: (njk) I just picked an arbitrary constant here! Didn't properly consider this. - self._executed_skill_history: Deque[_Option] = deque(maxlen=5) + self._state_history: List[State] = [] + self._executed_skill_history: List[_Option] = [] # # Keep track of objects that are contained (out of view) in another # # object, like a bag or bucket. This is important not only for gremlins # # but also for small changes in the container's perceived pose. diff --git a/predicators/utils.py b/predicators/utils.py index 4980c50850..f8a50a7950 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2537,7 +2537,7 @@ def get_prompt_for_vlm_state_labelling( # and beyond. curr_prompt = prompt[:] curr_prompt_imgs = [ - imgs_timestep[0] for imgs_timestep in imgs_history[-1] + imgs_timestep for imgs_timestep in imgs_history[-1] ] if CFG.vlm_include_cropped_images: if CFG.env in ["burger", "burger_no_move"]: # pragma: no cover From 199698eee81007c5cace76f7428606c901dbb7d2 Mon Sep 17 00:00:00 2001 From: NishanthJKumar Date: Mon, 16 Sep 2024 14:48:47 -0400 Subject: [PATCH 17/24] autoformat --- predicators/utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/predicators/utils.py b/predicators/utils.py index f8a50a7950..b13201b3c7 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2579,10 +2579,15 @@ def query_vlm_for_atom_vals( # We assume the state.simulator_state contains a list of previous states. if "state_history" in state.simulator_state: previous_states = state.simulator_state["state_history"] - state_imgs_history = [state.simulator_state["images"] for state in previous_states] + state_imgs_history = [ + state.simulator_state["images"] for state in previous_states + ] vlm_atoms = sorted(vlm_atoms) atom_queries_str = [atom.get_vlm_query_str() for atom in vlm_atoms] - vlm_query_str, imgs = get_prompt_for_vlm_state_labelling(CFG.vlm_test_time_atom_label_prompt_type, atom_queries_str, state.simulator_state["vlm_atoms_history"], state_imgs_history, [], state.simulator_state["skill_history"]) + vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( + CFG.vlm_test_time_atom_label_prompt_type, atom_queries_str, + state.simulator_state["vlm_atoms_history"], state_imgs_history, [], + state.simulator_state["skill_history"]) if vlm is None: vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover. vlm_input_imgs = \ From 6e25b726e2ebf83de78cec0a6b14cb05f45b0dba Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Mon, 16 Sep 2024 14:57:46 -0400 Subject: [PATCH 18/24] Add last skill execution to _TruncatedSpotObservation. --- predicators/envs/spot_env.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index b702638bf1..05e6eaf162 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -109,6 +109,8 @@ class _TruncatedSpotObservation: executed_skill: Optional[_Option] = None # Object detections per camera in self.rgbd_images. object_detections_per_camera: Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]] + # Last skill + executed_skill: Optional[_Option] = None class _PartialPerceptionState(State): @@ -2630,7 +2632,8 @@ def _actively_construct_env_task(self) -> EnvironmentTask: set(), self._spot_object, gripper_open_percentage, - object_detections_per_camera + object_detections_per_camera, + None ) goal_description = self._generate_goal_description() task = EnvironmentTask(obs, goal_description) @@ -2701,7 +2704,8 @@ def step(self, action: Action) -> Observation: set(), self._spot_object, gripper_open_percentage, - object_detections_per_camera + object_detections_per_camera, + action.get_option() ) return obs From fe7da3016916a5dc9b57ea98d29edc2bb3bb49b5 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Mon, 16 Sep 2024 15:57:13 -0400 Subject: [PATCH 19/24] Debugging progressm. --- .../approaches/bilevel_planning_approach.py | 4 +--- predicators/envs/spot_env.py | 2 -- predicators/perception/spot_perceiver.py | 6 +++--- predicators/utils.py | 17 +++++++++++------ 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index 5cadda34ce..fc45a70de9 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -58,9 +58,7 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: seed = self._seed + self._num_calls nsrts = self._get_current_nsrts() preds = self._get_current_predicates() - utils.abstract(task.init, preds, self._vlm) - import pdb - pdb.set_trace() + # utils.abstract(task.init, preds, self._vlm) # utils.abstract(task.init, preds, self._vlm) # Run task planning only and then greedily sample and execute in the # policy. diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 05e6eaf162..e84aa7ae79 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -106,10 +106,8 @@ class _TruncatedSpotObservation: # # A placeholder until all predicates have classifiers # nonpercept_atoms: Set[GroundAtom] # nonpercept_predicates: Set[Predicate] - executed_skill: Optional[_Option] = None # Object detections per camera in self.rgbd_images. object_detections_per_camera: Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]] - # Last skill executed_skill: Optional[_Option] = None diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 62394889bb..201df39f21 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -684,7 +684,7 @@ def reset(self, env_task: EnvironmentTask) -> Task: return Task(state, goal) def step(self, observation: Observation) -> State: - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() self._waiting_for_observation = False self._robot = observation.robot img_objects = observation.rgbd_images # RGBDImage objects @@ -714,7 +714,7 @@ def step(self, observation: Observation) -> State: draw.rectangle(text_bbox, fill='green') draw.text((x0 + 1, y0 - 1.5*text_height), text, fill='white', font=font) - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() import PIL from PIL import ImageDraw annotated_pil_imgs = [] @@ -730,8 +730,8 @@ def step(self, observation: Observation) -> State: self._curr_state = self._create_state() self._curr_state.simulator_state["images"] = annotated_imgs ret_state = self._curr_state.copy() - ret_state.simulator_state["state_history"] = list(self._state_history) self._state_history.append(ret_state) + ret_state.simulator_state["state_history"] = list(self._state_history) self._executed_skill_history.append(observation.executed_skill) ret_state.simulator_state["skill_history"] = list(self._executed_skill_history) return ret_state diff --git a/predicators/utils.py b/predicators/utils.py index b13201b3c7..f4e32a3b7a 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2499,6 +2499,7 @@ def get_prompt_for_vlm_state_labelling( imgs_history: List[List[PIL.Image.Image]], cropped_imgs_history: List[List[PIL.Image.Image]], skill_history: List[Action]) -> Tuple[str, List[PIL.Image.Image]]: + # import pdb; pdb.set_trace() """Prompt for generating labels for an entire trajectory. Similar to the above prompting method, this outputs a list of prompts to label the state at each timestep of traj with atom values). @@ -2508,7 +2509,7 @@ def get_prompt_for_vlm_state_labelling( """ # Load the pre-specified prompt. filepath_prefix = get_path_to_predicators_root() + \ - "/predicators/datasets/vlm_input_data_prompts/atom_proposal/" + "/predicators/datasets/vlm_input_data_prompts/atom_labelling/" try: with open(filepath_prefix + CFG.grammar_search_vlm_atom_label_prompt_type + ".txt", @@ -2516,6 +2517,7 @@ def get_prompt_for_vlm_state_labelling( encoding="utf-8") as f: prompt = f.read() except FileNotFoundError: + import pdb; pdb.set_trace() raise ValueError("Unknown VLM prompting option " + f"{CFG.grammar_search_vlm_atom_label_prompt_type}") # The prompt ends with a section for 'Predicates', so list these. @@ -2583,9 +2585,9 @@ def query_vlm_for_atom_vals( state.simulator_state["images"] for state in previous_states ] vlm_atoms = sorted(vlm_atoms) - atom_queries_str = [atom.get_vlm_query_str() for atom in vlm_atoms] + atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms] vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( - CFG.vlm_test_time_atom_label_prompt_type, atom_queries_str, + CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list, state.simulator_state["vlm_atoms_history"], state_imgs_history, [], state.simulator_state["skill_history"]) if vlm is None: @@ -2600,21 +2602,24 @@ def query_vlm_for_atom_vals( assert len(vlm_output) == 1 vlm_output_str = vlm_output[0] print(f"VLM output: {vlm_output_str}") - all_atom_queries = atom_queries_str.strip().split("\n") all_vlm_responses = vlm_output_str.strip().split("\n") # NOTE: this assumption is likely too brittle; if this is breaking, feel # free to remove/adjust this and change the below parsing loop accordingly! - assert len(all_atom_queries) == len(all_vlm_responses) + assert len(atom_queries_list) == len(all_vlm_responses) for i, (atom_query, curr_vlm_output_line) in enumerate( - zip(all_atom_queries, all_vlm_responses)): + zip(atom_queries_list, all_vlm_responses)): assert atom_query + ":" in curr_vlm_output_line assert "." in curr_vlm_output_line period_idx = curr_vlm_output_line.find(".") if curr_vlm_output_line[len(atom_query + ":"):period_idx].lower().strip() == "true": true_atoms.add(vlm_atoms[i]) + + breakpoint() # Add the text of the VLM's response to the state, to be used in the future! + # REMOVE THIS -> AND PUT IT IN THE PERCEIVER state.simulator_state["vlm_atoms_history"].append(all_vlm_responses) + return true_atoms From 4cc03364e7e8d407a52a7dfd3fb677aa7a0dbcec Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Mon, 16 Sep 2024 22:24:48 -0400 Subject: [PATCH 20/24] Some progress towards fixing vlm atom history. --- predicators/perception/spot_perceiver.py | 48 +++++++++++++++++++++--- predicators/utils.py | 5 +++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 201df39f21..54e0f513bc 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -26,7 +26,7 @@ get_allowed_map_regions, load_spot_metadata, object_to_top_down_geom from predicators.structs import Action, DefaultState, EnvironmentTask, \ GoalDescription, GroundAtom, Object, Observation, Predicate, \ - SpotActionExtraInfo, State, Task, Video, _Option + SpotActionExtraInfo, State, Task, Video, _Option, VLMPredicate class SpotPerceiver(BasePerceiver): @@ -624,6 +624,8 @@ def __init__(self) -> None: self._ordered_objects: List[Object] = [] # list of all known objects self._state_history: List[State] = [] self._executed_skill_history: List[_Option] = [] + self._vlm_label_history: List[str] = [] + self._curr_state = None # # Keep track of objects that are contained (out of view) in another # # object, like a bag or bucket. This is important not only for gremlins # # but also for small changes in the container's perceived pose. @@ -631,6 +633,7 @@ def __init__(self) -> None: # Load static, hard-coded features of objects, like their shapes. # meta = load_spot_metadata() # self._static_object_features = meta.get("static-object-features", {}) + def update_perceiver_with_action(self, action: Action) -> None: # NOTE: we need to keep track of the previous action @@ -679,7 +682,8 @@ def reset(self, env_task: EnvironmentTask) -> Task: state = self._create_state() state.simulator_state = {} state.simulator_state["images"] = [] - self._curr_state = state + # self._curr_state = state + self._curr_state = None # this will get set by self.step() goal = self._create_goal(state, env_task.goal_description) return Task(state, goal) @@ -713,8 +717,7 @@ def step(self, observation: Observation) -> State: text_bbox = [(x0, y0 - 1.5*text_height), (x0 + text_width + 1, y0)] draw.rectangle(text_bbox, fill='green') draw.text((x0 + 1, y0 - 1.5*text_height), text, fill='white', font=font) - - # import pdb; pdb.set_trace() + import PIL from PIL import ImageDraw annotated_pil_imgs = [] @@ -722,11 +725,14 @@ def step(self, observation: Observation) -> State: pil_img = PIL.Image.fromarray(img) draw = ImageDraw.Draw(pil_img) font = utils.get_scaled_default_font(draw, 4) - annotated_pil_img = utils.add_text_to_draw_img( - draw, (0, 0), self.camera_name_to_annotation[img_name], font) + annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) annotated_pil_imgs.append(pil_img) annotated_imgs = [np.array(img) for img in annotated_pil_imgs] + self._gripper_open_percentage = observation.gripper_open_percentage + + curr_state = self._create_state + self._curr_state = self._create_state() self._curr_state.simulator_state["images"] = annotated_imgs ret_state = self._curr_state.copy() @@ -734,6 +740,36 @@ def step(self, observation: Observation) -> State: ret_state.simulator_state["state_history"] = list(self._state_history) self._executed_skill_history.append(observation.executed_skill) ret_state.simulator_state["skill_history"] = list(self._executed_skill_history) + + # Save "all_vlm_responses" towards building vlm atom history. + # Any time utils.abstract() is called, e.g. approach or planner, + # we may (depending on flags) want to pass in the vlm atom history + # into the prompt to the VLM. + # We could save `all_vlm_responses` computed internally by + # utils.query_vlm_for_aotm_vals(), but that would require us to + # change how utils.abstract() works. Instead, we'll re-compute the + # `all_vlm_responses` based on the true atoms returned by utils.abstract(). + assert self._curr_env is not None + preds = self._curr_env.predicates + state_copy = ret_state.copy() # temporary, to ease debugging + abstract_state = utils.abstract(state_copy, preds) + # We should avoid recomputing the abstract state (VLM noise?) so let's store it in + # the state. + ret_state.simulator_state["abstract_state"] = abstract_state + # Re-compute the VLM labeling for the VLM atoms in this state to store in our + # vlm atom history. + # This code also appears in utils.abstract() + if self._curr_state is not None: + vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate)) + vlm_atoms = set() + for pred in vlm_preds: + for choice in utils.get_object_combinations(list(state_copy), pred.types): + vlm_atoms.add(GroundAtom(pred, choice)) + vlm_atoms = sorted(vlm_atoms) + import pdb; pdb.set_trace() + ret_state.simulator_state["vlm_atoms_history"].append(abstract_state) + else: + self._curr_state = ret_state.copy() return ret_state def _create_state(self) -> State: diff --git a/predicators/utils.py b/predicators/utils.py index f4e32a3b7a..9d3f3f723b 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2618,6 +2618,11 @@ def query_vlm_for_atom_vals( breakpoint() # Add the text of the VLM's response to the state, to be used in the future! # REMOVE THIS -> AND PUT IT IN THE PERCEIVER + # Perceiver calls utils.abstract once, and puts it in the state history. + # According to a flag, anywhere else we normally call utils.abstract, we + # instead just pull the abstract state from the state simulator state field that has it already. + # The appending of vlm atom history is currently done in query_vlm_for_atom_vals() in utils.py, + # and utils.ground calls that. state.simulator_state["vlm_atoms_history"].append(all_vlm_responses) return true_atoms From 41b58ea42dcfde3693dab7842847c40174f09154 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Tue, 17 Sep 2024 14:59:32 -0400 Subject: [PATCH 21/24] Progress towards passing history into VLM at test time. --- predicators/envs/spot_env.py | 3 +- predicators/main.py | 1 + predicators/perception/spot_perceiver.py | 176 +++++++++++------- .../spot_utils/perception/spot_cameras.py | 4 +- predicators/utils.py | 6 +- 5 files changed, 122 insertions(+), 68 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index e84aa7ae79..800c44fd48 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -2451,7 +2451,7 @@ class VLMTestEnv(SpotRearrangementEnv): def predicates(self) -> Set[Predicate]: # return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty", "Pourable", "Toasted", "VLMIn", "Open"]) return set(p for p in _ALL_PREDICATES - if p.name in ["VLMOn", "Holding", "HandEmpty", "Upright"]) + if p.name in ["VLMOn", "Holding", "HandEmpty"]) @property def goal_predicates(self) -> Set[Predicate]: @@ -2474,6 +2474,7 @@ def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]: Object("cup", _movable_object_type), Object("chair", _movable_object_type), Object("bowl", _movable_object_type), + Object("table", _movable_object_type), } for o in objects: detection_id = LanguageObjectDetectionID(o.name) diff --git a/predicators/main.py b/predicators/main.py index 173b660d8f..08748d681c 100644 --- a/predicators/main.py +++ b/predicators/main.py @@ -362,6 +362,7 @@ def _run_testing(env: BaseEnv, cogman: CogMan) -> Metrics: metrics: Metrics = defaultdict(float) curr_num_nodes_created = 0.0 curr_num_nodes_expanded = 0.0 + import pdb; pdb.set_trace() for test_task_idx, env_task in enumerate(test_tasks): solve_start = time.perf_counter() try: diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 54e0f513bc..b9325aadfe 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -635,13 +635,6 @@ def __init__(self) -> None: # self._static_object_features = meta.get("static-object-features", {}) - def update_perceiver_with_action(self, action: Action) -> None: - # NOTE: we need to keep track of the previous action - # because the step function (where we need knowledge - # of the previous action) occurs *after* the action - # has already been taken. - self._prev_action = action - def _create_goal(self, state: State, goal_description: GoalDescription) -> Set[GroundAtom]: del state # not used @@ -680,11 +673,20 @@ def reset(self, env_task: EnvironmentTask) -> Task: # self._curr_state = state self._curr_env = get_or_create_env(CFG.env) state = self._create_state() - state.simulator_state = {} - state.simulator_state["images"] = [] - # self._curr_state = state - self._curr_state = None # this will get set by self.step() + # state.simulator_state = {} + # state.simulator_state["images"] = [] + # state.simulator_state["state_history"] = [] + # state.simulator_state["skill_history"] = [] + # state.simulator_state["vlm_atoms_history"] = [] + self._curr_state = state goal = self._create_goal(state, env_task.goal_description) + + # Reset run-specific things. + self._state_history = [] + self._executed_skill_history = [] + self._vlm_label_history = [] + self._prev_action = None + return Task(state, goal) def step(self, observation: Observation) -> State: @@ -718,21 +720,66 @@ def step(self, observation: Observation) -> State: draw.rectangle(text_bbox, fill='green') draw.text((x0 + 1, y0 - 1.5*text_height), text, fill='white', font=font) - import PIL - from PIL import ImageDraw - annotated_pil_imgs = [] - for img, img_name in zip(imgs, img_names): - pil_img = PIL.Image.fromarray(img) - draw = ImageDraw.Draw(pil_img) - font = utils.get_scaled_default_font(draw, 4) - annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) - annotated_pil_imgs.append(pil_img) - annotated_imgs = [np.array(img) for img in annotated_pil_imgs] + # import PIL + # from PIL import ImageDraw + # annotated_pil_imgs = [] + # for img, img_name in zip(imgs, img_names): + # pil_img = PIL.Image.fromarray(img) + # draw = ImageDraw.Draw(pil_img) + # font = utils.get_scaled_default_font(draw, 4) + # annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) + # annotated_pil_imgs.append(pil_img) + annotated_imgs = [np.array(img) for img in pil_imgs] self._gripper_open_percentage = observation.gripper_open_percentage - curr_state = self._create_state + # check if self._curr_state is what we expect it to be. + import pdb; pdb.set_trace() + self._curr_state = self._create_state() + # This state is a default/empty. We have to set the attributes + # of the objects and set the simulator state properly. + self._curr_state.simulator_state["images"] = annotated_imgs + # At the first timestep, these histories will be empty due to self.reset(). + # But at every timestep that isn't the first one, they will be non-empty. + self._curr_state.simulator_state["state_history"] = list(self._state_history) + self._curr_state.simulator_state["skill_history"] = list(self._executed_skill_history) + self._curr_state.simulator_state["vlm_label_history"] = list(self._vlm_label_history) + + # Add to histories. + # A bit of extra work is required to build the VLM label history. + # We want to keep `utils.abstract()` as straightforward as possible, + # so we'll "rebuild" the VLM labels from the abstract state + # returned by `utils.abstract()`. And since we call this function, + # we might as well store the abstract state as a part of the simulator + # state so that we don't need to recompute it later in the approach or + # in planning. + assert self._curr_env is not None + preds = self._curr_env.predicates + state_copy = self._curr_env.copy() + abstract_state = utils.abstract(state_copy, preds) + self._curr_state.simulator_state["abstract_state"] = abstract_state + # Compute all the VLM atoms. `utils.abstract()` only returns the ones that + # are True. The remaining ones are the ones that are False. + vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate)) + vlm_atoms = set() + for pred in vlm_preds: + for choice in utils.get_object_combinations(list(state_copy), pred.types): + vlm_atoms.add(GroundAtom(pred, choice)) + vlm_atoms = sorted(vlm_atoms) + import pdb; pdb.set_trace() + + self._state_history.append(self._curr_state.copy()) + # The executed skill will be `None` in the first timestep. + # This should be handled in the function that processes the + # history when passing it to the VLM. + self._executed_skill_history.append(observation.executed_skill) + + ############################# + + + + curr_state = self._create_state self._curr_state = self._create_state() self._curr_state.simulator_state["images"] = annotated_imgs ret_state = self._curr_state.copy() @@ -777,9 +824,9 @@ def _create_state(self) -> State: return DefaultState # Build the continuous part of the state. assert self._robot is not None - # table = Object("table", _immovable_object_type) + table = Object("table", _immovable_object_type) cup = Object("cup", _movable_object_type) - # pan = Object("pan", _container_type) + pan = Object("pan", _container_type) # bread = Object("bread", _movable_object_type) # toaster = Object("toaster", _immovable_object_type) # microwave = Object("microwave", _movable_object_type) @@ -795,21 +842,21 @@ def _create_state(self) -> State: "qy": 0, "qz": 0, }, - # table: { - # "x": 0, - # "y": 0, - # "z": 0, - # "qw": 0, - # "qx": 0, - # "qy": 0, - # "qz": 0, - # "shape": 0, - # "height": 0, - # "width" : 0, - # "length": 0, - # "object_id": 1, - # "flat_top_surface": 1 - # }, + table: { + "x": 0, + "y": 0, + "z": 0, + "qw": 0, + "qx": 0, + "qy": 0, + "qz": 0, + "shape": 0, + "height": 0, + "width" : 0, + "length": 0, + "object_id": 1, + "flat_top_surface": 1 + }, cup: { "x": 0, "y": 0, @@ -905,29 +952,32 @@ def _create_state(self) -> State: # "object_id": 1, # "flat_top_surface": 1 # }, - # pan: { - # "x": 0, - # "y": 0, - # "z": 0, - # "qw": 0, - # "qx": 0, - # "qy": 0, - # "qz": 0, - # "shape": 0, - # "height": 0, - # "width" : 0, - # "length": 0, - # "object_id": 3, - # "placeable": 1, - # "held": 0, - # "lost": 0, - # "in_hand_view": 0, - # "in_view": 1, - # "is_sweeper": 0 - # } + pan: { + "x": 0, + "y": 0, + "z": 0, + "qw": 0, + "qx": 0, + "qy": 0, + "qz": 0, + "shape": 0, + "height": 0, + "width" : 0, + "length": 0, + "object_id": 3, + "placeable": 1, + "held": 0, + "lost": 0, + "in_hand_view": 0, + "in_view": 1, + "is_sweeper": 0 + } } state_dict = {k: list(v.values()) for k, v in state_dict.items()} - ret_state = State(state_dict) - ret_state.simulator_state = {} - ret_state.simulator_state["images"] = [] - return ret_state + state = State(state_dict) + state.simulator_state = {} + state.simulator_state["images"] = [] + state.simulator_state["state_history"] = [] + state.simulator_state["skill_history"] = [] + state.simulator_state["vlm_atoms_history"] = [] + return state diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index 9f85bcc83d..f39dbc0cc4 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -26,8 +26,8 @@ # "hand_color_image": "hand_depth_in_hand_color_frame", # "left_fisheye_image": "left_depth_in_visual_frame", # "right_fisheye_image": "right_depth_in_visual_frame", - "frontleft_fisheye_image": "frontleft_depth_in_visual_frame", - # "frontright_fisheye_image": "frontright_depth_in_visual_frame", + # "frontleft_fisheye_image": "frontleft_depth_in_visual_frame", + "frontright_fisheye_image": "frontright_depth_in_visual_frame", # "back_fisheye_image": "back_depth_in_visual_frame" } diff --git a/predicators/utils.py b/predicators/utils.py index 9d3f3f723b..d41f2c09a2 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2603,6 +2603,7 @@ def query_vlm_for_atom_vals( vlm_output_str = vlm_output[0] print(f"VLM output: {vlm_output_str}") all_vlm_responses = vlm_output_str.strip().split("\n") + # import pdb; pdb.set_trace() # NOTE: this assumption is likely too brittle; if this is breaking, feel # free to remove/adjust this and change the below parsing loop accordingly! assert len(atom_queries_list) == len(all_vlm_responses) @@ -2615,7 +2616,7 @@ def query_vlm_for_atom_vals( ":"):period_idx].lower().strip() == "true": true_atoms.add(vlm_atoms[i]) - breakpoint() + # breakpoint() # Add the text of the VLM's response to the state, to be used in the future! # REMOVE THIS -> AND PUT IT IN THE PERCEIVER # Perceiver calls utils.abstract once, and puts it in the state history. @@ -2623,7 +2624,7 @@ def query_vlm_for_atom_vals( # instead just pull the abstract state from the state simulator state field that has it already. # The appending of vlm atom history is currently done in query_vlm_for_atom_vals() in utils.py, # and utils.ground calls that. - state.simulator_state["vlm_atoms_history"].append(all_vlm_responses) + # state.simulator_state["vlm_atoms_history"].append(all_vlm_responses) return true_atoms @@ -2652,6 +2653,7 @@ def abstract(state: State, for pred in vlm_preds: for choice in get_object_combinations(list(state), pred.types): vlm_atoms.add(GroundAtom(pred, choice)) + # import pdb; pdb.set_trace() true_vlm_atoms = query_vlm_for_atom_vals(vlm_atoms, state, vlm) atoms |= true_vlm_atoms return atoms From 71fe6d390db9afc5b044d11ef5246917be349849 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Tue, 17 Sep 2024 16:21:04 -0400 Subject: [PATCH 22/24] Update and fix computation of history and how it gets passed into the VLM query. --- predicators/main.py | 1 - predicators/perception/spot_perceiver.py | 68 +++++------------------- predicators/utils.py | 51 +++++++++++------- 3 files changed, 46 insertions(+), 74 deletions(-) diff --git a/predicators/main.py b/predicators/main.py index 08748d681c..173b660d8f 100644 --- a/predicators/main.py +++ b/predicators/main.py @@ -362,7 +362,6 @@ def _run_testing(env: BaseEnv, cogman: CogMan) -> Metrics: metrics: Metrics = defaultdict(float) curr_num_nodes_created = 0.0 curr_num_nodes_expanded = 0.0 - import pdb; pdb.set_trace() for test_task_idx, env_task in enumerate(test_tasks): solve_start = time.perf_counter() try: diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index b9325aadfe..db177f3fba 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -730,12 +730,8 @@ def step(self, observation: Observation) -> State: # annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) # annotated_pil_imgs.append(pil_img) annotated_imgs = [np.array(img) for img in pil_imgs] - self._gripper_open_percentage = observation.gripper_open_percentage - # check if self._curr_state is what we expect it to be. - import pdb; pdb.set_trace() - self._curr_state = self._create_state() # This state is a default/empty. We have to set the attributes # of the objects and set the simulator state properly. @@ -756,7 +752,7 @@ def step(self, observation: Observation) -> State: # in planning. assert self._curr_env is not None preds = self._curr_env.predicates - state_copy = self._curr_env.copy() + state_copy = self._curr_state.copy() abstract_state = utils.abstract(state_copy, preds) self._curr_state.simulator_state["abstract_state"] = abstract_state # Compute all the VLM atoms. `utils.abstract()` only returns the ones that @@ -767,57 +763,19 @@ def step(self, observation: Observation) -> State: for choice in utils.get_object_combinations(list(state_copy), pred.types): vlm_atoms.add(GroundAtom(pred, choice)) vlm_atoms = sorted(vlm_atoms) - import pdb; pdb.set_trace() - + atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms] + reconstructed_all_vlm_responses = [] + for atom in vlm_atoms: + if atom in abstract_state: + truth_value = 'True' + else: + truth_value = 'False' + atom_label = f"* {atom.get_vlm_query_str()}: {truth_value}" + reconstructed_all_vlm_responses.append(atom_label) + self._vlm_label_history.append(reconstructed_all_vlm_responses) self._state_history.append(self._curr_state.copy()) - # The executed skill will be `None` in the first timestep. - # This should be handled in the function that processes the - # history when passing it to the VLM. - self._executed_skill_history.append(observation.executed_skill) - - ############################# - - - - curr_state = self._create_state - self._curr_state = self._create_state() - self._curr_state.simulator_state["images"] = annotated_imgs - ret_state = self._curr_state.copy() - self._state_history.append(ret_state) - ret_state.simulator_state["state_history"] = list(self._state_history) - self._executed_skill_history.append(observation.executed_skill) - ret_state.simulator_state["skill_history"] = list(self._executed_skill_history) - - # Save "all_vlm_responses" towards building vlm atom history. - # Any time utils.abstract() is called, e.g. approach or planner, - # we may (depending on flags) want to pass in the vlm atom history - # into the prompt to the VLM. - # We could save `all_vlm_responses` computed internally by - # utils.query_vlm_for_aotm_vals(), but that would require us to - # change how utils.abstract() works. Instead, we'll re-compute the - # `all_vlm_responses` based on the true atoms returned by utils.abstract(). - assert self._curr_env is not None - preds = self._curr_env.predicates - state_copy = ret_state.copy() # temporary, to ease debugging - abstract_state = utils.abstract(state_copy, preds) - # We should avoid recomputing the abstract state (VLM noise?) so let's store it in - # the state. - ret_state.simulator_state["abstract_state"] = abstract_state - # Re-compute the VLM labeling for the VLM atoms in this state to store in our - # vlm atom history. - # This code also appears in utils.abstract() - if self._curr_state is not None: - vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate)) - vlm_atoms = set() - for pred in vlm_preds: - for choice in utils.get_object_combinations(list(state_copy), pred.types): - vlm_atoms.add(GroundAtom(pred, choice)) - vlm_atoms = sorted(vlm_atoms) - import pdb; pdb.set_trace() - ret_state.simulator_state["vlm_atoms_history"].append(abstract_state) - else: - self._curr_state = ret_state.copy() - return ret_state + self._executed_skill_history.append(observation.executed_skill) # None in first timestep. + return self._curr_state.copy() def _create_state(self) -> State: if self._waiting_for_observation: diff --git a/predicators/utils.py b/predicators/utils.py index d41f2c09a2..18172d4e8d 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2499,7 +2499,6 @@ def get_prompt_for_vlm_state_labelling( imgs_history: List[List[PIL.Image.Image]], cropped_imgs_history: List[List[PIL.Image.Image]], skill_history: List[Action]) -> Tuple[str, List[PIL.Image.Image]]: - # import pdb; pdb.set_trace() """Prompt for generating labels for an entire trajectory. Similar to the above prompting method, this outputs a list of prompts to label the state at each timestep of traj with atom values). @@ -2517,7 +2516,6 @@ def get_prompt_for_vlm_state_labelling( encoding="utf-8") as f: prompt = f.read() except FileNotFoundError: - import pdb; pdb.set_trace() raise ValueError("Unknown VLM prompting option " + f"{CFG.grammar_search_vlm_atom_label_prompt_type}") # The prompt ends with a section for 'Predicates', so list these. @@ -2574,22 +2572,40 @@ def query_vlm_for_atom_vals( # vlm can be called on. assert state.simulator_state is not None assert isinstance(state.simulator_state["images"], List) - if "vlm_atoms_history" not in state.simulator_state: - state.simulator_state["vlm_atoms_history"] = [] - imgs = state.simulator_state["images"] - previous_states = [] - # We assume the state.simulator_state contains a list of previous states. - if "state_history" in state.simulator_state: - previous_states = state.simulator_state["state_history"] - state_imgs_history = [ - state.simulator_state["images"] for state in previous_states - ] + # if "vlm_atoms_history" not in state.simulator_state: + # state.simulator_state["vlm_atoms_history"] = [] + # imgs = state.simulator_state["images"] + # previous_states = [] + # # We assume the state.simulator_state contains a list of previous states. + # if "state_history" in state.simulator_state: + # previous_states = state.simulator_state["state_history"] + # state_imgs_history = [ + # state.simulator_state["images"] for state in previous_states + # ] vlm_atoms = sorted(vlm_atoms) atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms] + # All "history" fields in the simulator state contain things from + # previous states -- not the current state. + # We want the image history to include the images from the current state. + curr_state_images = state.simulator_state["images"] + if "state_history" in state.simulator_state: + prev_states = state.simulator_state["state_history"] + prev_states_imgs_history = [s.simulator_state["images"] for s in prev_states] + images_history = [curr_state_images] + prev_states_imgs_history + skill_history = [] + if "skill_history" in state.simulator_state: + skill_history = state.simulator_state["skill_history"] + label_history = [] + if "vlm_label_history" in state.simulator_state: + label_history = state.simulator_state["vlm_label_history"] + + # vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( + # CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list, + # state.simulator_state["vlm_atoms_history"], state_imgs_history, [], + # state.simulator_state["skill_history"]) vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list, - state.simulator_state["vlm_atoms_history"], state_imgs_history, [], - state.simulator_state["skill_history"]) + label_history, images_history, [], skill_history) if vlm is None: vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover. vlm_input_imgs = \ @@ -2603,7 +2619,6 @@ def query_vlm_for_atom_vals( vlm_output_str = vlm_output[0] print(f"VLM output: {vlm_output_str}") all_vlm_responses = vlm_output_str.strip().split("\n") - # import pdb; pdb.set_trace() # NOTE: this assumption is likely too brittle; if this is breaking, feel # free to remove/adjust this and change the below parsing loop accordingly! assert len(atom_queries_list) == len(all_vlm_responses) @@ -2612,8 +2627,9 @@ def query_vlm_for_atom_vals( assert atom_query + ":" in curr_vlm_output_line assert "." in curr_vlm_output_line period_idx = curr_vlm_output_line.find(".") - if curr_vlm_output_line[len(atom_query + - ":"):period_idx].lower().strip() == "true": + # value = curr_vlm_output_line[len(atom_query + ":"):period_idx].lower().strip() + value = curr_vlm_output_line.split(': ')[-1].strip('.').lower() + if value == "true": true_atoms.add(vlm_atoms[i]) # breakpoint() @@ -2625,7 +2641,6 @@ def query_vlm_for_atom_vals( # The appending of vlm atom history is currently done in query_vlm_for_atom_vals() in utils.py, # and utils.ground calls that. # state.simulator_state["vlm_atoms_history"].append(all_vlm_responses) - return true_atoms From 5232fa2b26d31f23c16deb0cd5ec5128e3cc4446 Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Tue, 17 Sep 2024 17:43:30 -0400 Subject: [PATCH 23/24] Fix bugs related to image history to VLM and abstracting the state. --- .../approaches/bilevel_planning_approach.py | 2 +- predicators/envs/spot_env.py | 2 +- predicators/perception/spot_perceiver.py | 24 ++++++++++++------- .../spot_utils/perception/spot_cameras.py | 2 +- predicators/structs.py | 3 +++ predicators/utils.py | 5 +++- 6 files changed, 26 insertions(+), 12 deletions(-) diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index fc45a70de9..e1c45b0838 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -65,7 +65,7 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: if self._plan_without_sim: nsrt_plan, atoms_seq, metrics = self._run_task_plan( task, nsrts, preds, timeout, seed) - # import pdb; pdb.set_trace() + import pdb; pdb.set_trace() self._last_nsrt_plan = nsrt_plan self._last_atoms_seq = atoms_seq policy = utils.nsrt_plan_to_greedy_policy(nsrt_plan, task.goal, diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 800c44fd48..60cfeab20f 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -2485,7 +2485,7 @@ def _create_operators(self) -> Iterator[STRIPSOperator]: # Pick object robot = Variable("?robot", _robot_type) obj = Variable("?object", _movable_object_type) - table = Variable("?table", _immovable_object_type) + table = Variable("?table", _movable_object_type) parameters = [robot, obj, table] preconds: Set[LiftedAtom] = { LiftedAtom(_HandEmpty, [robot]), diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index db177f3fba..44c1b3922d 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -739,6 +739,9 @@ def step(self, observation: Observation) -> State: # At the first timestep, these histories will be empty due to self.reset(). # But at every timestep that isn't the first one, they will be non-empty. self._curr_state.simulator_state["state_history"] = list(self._state_history) + # We do this here so the call to `utils.abstract()` a few lines later has the skill + # that was just run. + self._executed_skill_history.append(observation.executed_skill) # None in first timestep. self._curr_state.simulator_state["skill_history"] = list(self._executed_skill_history) self._curr_state.simulator_state["vlm_label_history"] = list(self._vlm_label_history) @@ -753,6 +756,7 @@ def step(self, observation: Observation) -> State: assert self._curr_env is not None preds = self._curr_env.predicates state_copy = self._curr_state.copy() + print(f"Right before abstract state, skill in obs: {observation.executed_skill}") abstract_state = utils.abstract(state_copy, preds) self._curr_state.simulator_state["abstract_state"] = abstract_state # Compute all the VLM atoms. `utils.abstract()` only returns the ones that @@ -763,7 +767,6 @@ def step(self, observation: Observation) -> State: for choice in utils.get_object_combinations(list(state_copy), pred.types): vlm_atoms.add(GroundAtom(pred, choice)) vlm_atoms = sorted(vlm_atoms) - atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms] reconstructed_all_vlm_responses = [] for atom in vlm_atoms: if atom in abstract_state: @@ -772,9 +775,9 @@ def step(self, observation: Observation) -> State: truth_value = 'False' atom_label = f"* {atom.get_vlm_query_str()}: {truth_value}" reconstructed_all_vlm_responses.append(atom_label) - self._vlm_label_history.append(reconstructed_all_vlm_responses) + str_vlm_response = '\n'.join(reconstructed_all_vlm_responses) + self._vlm_label_history.append(str_vlm_response) self._state_history.append(self._curr_state.copy()) - self._executed_skill_history.append(observation.executed_skill) # None in first timestep. return self._curr_state.copy() def _create_state(self) -> State: @@ -782,7 +785,7 @@ def _create_state(self) -> State: return DefaultState # Build the continuous part of the state. assert self._robot is not None - table = Object("table", _immovable_object_type) + table = Object("table", _movable_object_type) cup = Object("cup", _movable_object_type) pan = Object("pan", _container_type) # bread = Object("bread", _movable_object_type) @@ -812,8 +815,13 @@ def _create_state(self) -> State: "height": 0, "width" : 0, "length": 0, - "object_id": 1, - "flat_top_surface": 1 + "object_id": 0, + "placeable": 1, + "held": 0, + "lost": 0, + "in_hand_view": 0, + "in_view": 1, + "is_sweeper": 0 }, cup: { "x": 0, @@ -827,7 +835,7 @@ def _create_state(self) -> State: "height": 0, "width": 0, "length": 0, - "object_id": 2, + "object_id": 1, "placeable": 1, "held": 0, "lost": 0, @@ -922,7 +930,7 @@ def _create_state(self) -> State: "height": 0, "width" : 0, "length": 0, - "object_id": 3, + "object_id": 2, "placeable": 1, "held": 0, "lost": 0, diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index f39dbc0cc4..a1a9780e94 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -24,7 +24,7 @@ } RGB_TO_DEPTH_CAMERAS = { # "hand_color_image": "hand_depth_in_hand_color_frame", - # "left_fisheye_image": "left_depth_in_visual_frame", + "left_fisheye_image": "left_depth_in_visual_frame", # "right_fisheye_image": "right_depth_in_visual_frame", # "frontleft_fisheye_image": "frontleft_depth_in_visual_frame", "frontright_fisheye_image": "frontright_depth_in_visual_frame", diff --git a/predicators/structs.py b/predicators/structs.py index 5309fc4ee3..d1c8e1880f 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -491,6 +491,9 @@ def __post_init__(self) -> None: def goal_holds(self, state: State, vlm: Optional[Any] = None) -> bool: """Return whether the goal of this task holds in the given state.""" + if "abstract_state" in state.simulator_state: + abstract_state = state.simulator_state["abstract_state"] + return self.goal.issubset(abstract_state) from predicators.utils import query_vlm_for_atom_vals vlm_atoms = set(atom for atom in self.goal if isinstance(atom.predicate, VLMPredicate)) diff --git a/predicators/utils.py b/predicators/utils.py index 18172d4e8d..b8ea1e97e3 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2591,7 +2591,7 @@ def query_vlm_for_atom_vals( if "state_history" in state.simulator_state: prev_states = state.simulator_state["state_history"] prev_states_imgs_history = [s.simulator_state["images"] for s in prev_states] - images_history = [curr_state_images] + prev_states_imgs_history + images_history = prev_states_imgs_history + [curr_state_images] skill_history = [] if "skill_history" in state.simulator_state: skill_history = state.simulator_state["skill_history"] @@ -2606,6 +2606,7 @@ def query_vlm_for_atom_vals( vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list, label_history, images_history, [], skill_history) + import pdb; pdb.set_trace() if vlm is None: vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover. vlm_input_imgs = \ @@ -2652,6 +2653,8 @@ def abstract(state: State, Duplicate arguments in predicates are allowed. """ + if "abstract_state" in state.simulator_state: + return state.simulator_state["abstract_state"] # Start by pulling out all VLM predicates. vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate)) # Next, classify all non-VLM predicates. From 63031302f68acedf75ec51990c8ccd30904ce6cb Mon Sep 17 00:00:00 2001 From: Ashay Athalye Date: Tue, 17 Sep 2024 19:43:01 -0400 Subject: [PATCH 24/24] Fix skill action history in case of done action. --- predicators/envs/spot_env.py | 13 +++++++++++-- predicators/perception/spot_perceiver.py | 9 ++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 60cfeab20f..7ea7fd0cc9 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -2677,7 +2677,16 @@ def step(self, action: Action) -> Observation: self._current_task_goal_reached = False break logging.info("Invalid input, must be either 'y' or 'n'") - return self._current_observation + return _TruncatedSpotObservation( + self._current_observation.rgbd_images, + self._current_observation.objects_in_view, + set(), + set(), + self._spot_object, + self._current_observation.gripper_open_percentage, + self._current_observation.object_detections_per_camera, + action + ) # Execute the action in the real environment. Automatically retry # if a retryable error is encountered. @@ -2704,7 +2713,7 @@ def step(self, action: Action) -> Observation: self._spot_object, gripper_open_percentage, object_detections_per_camera, - action.get_option() + action ) return obs diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 44c1b3922d..3f908130f2 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -741,7 +741,14 @@ def step(self, observation: Observation) -> State: self._curr_state.simulator_state["state_history"] = list(self._state_history) # We do this here so the call to `utils.abstract()` a few lines later has the skill # that was just run. - self._executed_skill_history.append(observation.executed_skill) # None in first timestep. + executed_skill = None + + if observation.executed_skill is not None: + if observation.executed_skill.extra_info.action_name == "done": + # Just return the default state + return DefaultState + executed_skill = observation.executed_skill.get_option() + self._executed_skill_history.append(executed_skill) # None in first timestep. self._curr_state.simulator_state["skill_history"] = list(self._executed_skill_history) self._curr_state.simulator_state["vlm_label_history"] = list(self._vlm_label_history)