diff --git a/predicators/approaches/base_approach.py b/predicators/approaches/base_approach.py index e780b822ff..2a1fb50ccc 100644 --- a/predicators/approaches/base_approach.py +++ b/predicators/approaches/base_approach.py @@ -11,7 +11,7 @@ from predicators.structs import Action, Dataset, InteractionRequest, \ InteractionResult, Metrics, ParameterizedOption, Predicate, State, Task, \ Type -from predicators.utils import ExceptionWithInfo +from predicators.utils import ExceptionWithInfo, create_vlm_by_name class BaseApproach(abc.ABC): @@ -29,6 +29,7 @@ def __init__(self, initial_predicates: Set[Predicate], self._train_tasks = train_tasks self._metrics: Metrics = defaultdict(float) self._set_seed(CFG.seed) + self._vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover @classmethod @abc.abstractmethod diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index cb8ca38347..e1c45b0838 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -58,12 +58,14 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: seed = self._seed + self._num_calls nsrts = self._get_current_nsrts() preds = self._get_current_predicates() - + # utils.abstract(task.init, preds, self._vlm) + # utils.abstract(task.init, preds, self._vlm) # Run task planning only and then greedily sample and execute in the # policy. if self._plan_without_sim: nsrt_plan, atoms_seq, metrics = self._run_task_plan( task, nsrts, preds, timeout, seed) + import pdb; pdb.set_trace() self._last_nsrt_plan = nsrt_plan self._last_atoms_seq = atoms_seq policy = utils.nsrt_plan_to_greedy_policy(nsrt_plan, task.goal, diff --git a/predicators/approaches/spot_wrapper_approach.py b/predicators/approaches/spot_wrapper_approach.py index d0c2aae1f1..00b5925044 100644 --- a/predicators/approaches/spot_wrapper_approach.py +++ b/predicators/approaches/spot_wrapper_approach.py @@ -55,7 +55,7 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: def _policy(state: State) -> Action: nonlocal base_approach_policy, need_stow # If we think that we're done, return the done action. - if task.goal_holds(state): + if task.goal_holds(state, self._vlm): extra_info = SpotActionExtraInfo("done", [], None, tuple(), None, tuple()) return utils.create_spot_env_action(extra_info) @@ -102,7 +102,8 @@ def _policy(state: State) -> Action: self._base_approach_has_control = True # Need to call this once here to fix off-by-one issue. atom_seq = self._base_approach.get_execution_monitoring_info() - assert all(a.holds(state) for a in atom_seq[0]) + # TODO: consider reinstating the line below. + # assert all(a.holds(state) for a in atom_seq[0]) # Use the base policy. return base_approach_policy(state) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 60ab8e33d2..7ea7fd0cc9 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -24,13 +24,14 @@ from predicators.settings import CFG from predicators.spot_utils.perception.object_detection import \ AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \ - LanguageObjectDetectionID, ObjectDetectionID, detect_objects, \ - visualize_all_artifacts + LanguageObjectDetectionID, ObjectDetectionID, _query_detic_sam2, \ + detect_objects, visualize_all_artifacts from predicators.spot_utils.perception.object_specific_grasp_selection import \ brush_prompt, bucket_prompt, football_prompt, train_toy_prompt -from predicators.spot_utils.perception.perception_structs import \ - RGBDImageWithContext -from predicators.spot_utils.perception.spot_cameras import capture_images +from predicators.spot_utils.perception.perception_structs import RGBDImage, \ + RGBDImageWithContext, SegmentedBoundingBox +from predicators.spot_utils.perception.spot_cameras import capture_images, \ + capture_images_without_context from predicators.spot_utils.skills.spot_find_objects import \ init_search_for_objects from predicators.spot_utils.skills.spot_hand_move import \ @@ -47,7 +48,7 @@ update_pbrspot_robot_conf, verify_estop from predicators.structs import Action, EnvironmentTask, GoalDescription, \ GroundAtom, LiftedAtom, Object, Observation, Predicate, \ - SpotActionExtraInfo, State, STRIPSOperator, Type, Variable + SpotActionExtraInfo, State, STRIPSOperator, Type, Variable, _Option ############################################################################### # Base Class # @@ -84,6 +85,32 @@ class _SpotObservation: nonpercept_predicates: Set[Predicate] +@dataclass(frozen=True) +class _TruncatedSpotObservation: + """An observation for a SpotEnv.""" + # Camera name to image + rgbd_images: Dict[str, RGBDImageWithContext] + # Objects in the environment + objects_in_view: Set[Object] + # Objects seen only by the hand camera + objects_in_hand_view: Set[Object] + # Objects seen by any camera except the back camera + objects_in_any_view_except_back: Set[Object] + # Expose the robot object. + robot: Object + # Status of the robot gripper. + gripper_open_percentage: float + # # Robot SE3 Pose + # robot_pos: math_helpers.SE3Pose + # # Ground atoms without ground-truth classifiers + # # A placeholder until all predicates have classifiers + # nonpercept_atoms: Set[GroundAtom] + # nonpercept_predicates: Set[Predicate] + # Object detections per camera in self.rgbd_images. + object_detections_per_camera: Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]] + executed_skill: Optional[_Option] = None + + class _PartialPerceptionState(State): """Some continuous object features, and ground atoms in simulator_state. @@ -158,9 +185,25 @@ def get_robot( return_at_exit=True) assert path.exists() localizer = SpotLocalizer(robot, path, lease_client, lease_keepalive) + # localizer = None return robot, localizer, lease_client +@functools.lru_cache(maxsize=None) +def get_robot_only() -> Tuple[Optional[Robot], Optional[LeaseClient]]: + hostname = CFG.spot_robot_ip + sdk = create_standard_sdk("PredicatorsClient-") + robot = sdk.create_robot(hostname) + robot.authenticate("user", "bbbdddaaaiii") + verify_estop(robot) + lease_client = robot.ensure_client(LeaseClient.default_service_name) + lease_client.take() + lease_keepalive = LeaseKeepAlive(lease_client, + must_acquire=True, + return_at_exit=True) + return robot, lease_client + + @functools.lru_cache(maxsize=None) def get_detection_id_for_object(obj: Object) -> ObjectDetectionID: """Exposed for wrapper and options.""" @@ -1443,12 +1486,38 @@ def _get_sweeping_surface_for_container(container: Object, _IsSemanticallyGreaterThan = Predicate( "IsSemanticallyGreaterThan", [_base_object_type, _base_object_type], _is_semantically_greater_than_classifier) + + +def _get_vlm_query_str(pred_name: str, objects: Sequence[Object]) -> str: + return pred_name + "(" + ", ".join( + str(obj.name) for obj in objects) + ")" # pragma: no cover + + +_VLMOn = utils.create_vlm_predicate("VLMOn", + [_movable_object_type, _base_object_type], + lambda o: _get_vlm_query_str("VLMOn", o)) +_Upright = utils.create_vlm_predicate( + "Upright", [_movable_object_type], + lambda o: _get_vlm_query_str("Upright", o)) +_Toasted = utils.create_vlm_predicate( + "Toasted", [_movable_object_type], + lambda o: _get_vlm_query_str("Toasted", o)) +_VLMIn = utils.create_vlm_predicate( + "VLMIn", [_movable_object_type, _immovable_object_type], + lambda o: _get_vlm_query_str("In", o)) +_Open = utils.create_vlm_predicate("Open", [_movable_object_type], + lambda o: _get_vlm_query_str("Open", o)) +_Stained = utils.create_vlm_predicate( + "Stained", [_movable_object_type], + lambda o: _get_vlm_query_str("Stained", o)) + _ALL_PREDICATES = { _NEq, _On, _TopAbove, _Inside, _NotInsideAnyContainer, _FitsInXY, _HandEmpty, _Holding, _NotHolding, _InHandView, _InView, _Reachable, _Blocking, _NotBlocked, _ContainerReadyForSweeping, _IsPlaceable, _IsNotPlaceable, _IsSweeper, _HasFlatTopSurface, _RobotReadyForSweeping, - _IsSemanticallyGreaterThan + _IsSemanticallyGreaterThan, _VLMOn, _Upright, _Toasted, _VLMIn, _Open, + _Stained } _NONPERCEPT_PREDICATES: Set[Predicate] = set() @@ -2372,6 +2441,283 @@ def _dry_simulate_pick_and_dump_container( return next_obs +############################################################################### +# VLM Test Env # +############################################################################### +class VLMTestEnv(SpotRearrangementEnv): + """An environment to start testing the VLM pipeline.""" + + @property + def predicates(self) -> Set[Predicate]: + # return set(p for p in _ALL_PREDICATES if p.name in ["VLMOn", "Holding", "HandEmpty", "Pourable", "Toasted", "VLMIn", "Open"]) + return set(p for p in _ALL_PREDICATES + if p.name in ["VLMOn", "Holding", "HandEmpty"]) + + @property + def goal_predicates(self) -> Set[Predicate]: + return self.predicates + + @classmethod + def get_name(cls) -> str: + return "spot_vlm_test_env" + + def _get_dry_task(self, train_or_test: str, + task_idx: int) -> EnvironmentTask: + raise NotImplementedError("No dry task for VLMTestEnv.") + + @property + def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]: + + detection_id_to_obj: Dict[ObjectDetectionID, Object] = {} + objects = { + Object("pan", _movable_object_type), + Object("cup", _movable_object_type), + Object("chair", _movable_object_type), + Object("bowl", _movable_object_type), + Object("table", _movable_object_type), + } + for o in objects: + detection_id = LanguageObjectDetectionID(o.name) + detection_id_to_obj[detection_id] = o + return detection_id_to_obj + + def _create_operators(self) -> Iterator[STRIPSOperator]: + # Pick object + robot = Variable("?robot", _robot_type) + obj = Variable("?object", _movable_object_type) + table = Variable("?table", _movable_object_type) + parameters = [robot, obj, table] + preconds: Set[LiftedAtom] = { + LiftedAtom(_HandEmpty, [robot]), + LiftedAtom(_NotHolding, [robot, obj]), + LiftedAtom(_VLMOn, [obj, table]) + } + add_effs: Set[LiftedAtom] = {LiftedAtom(_Holding, [robot, obj])} + del_effs: Set[LiftedAtom] = { + LiftedAtom(_HandEmpty, [robot]), + LiftedAtom(_NotHolding, [robot, obj]), + LiftedAtom(_VLMOn, [obj, table]) + } + ignore_effs: Set[LiftedAtom] = set() + yield STRIPSOperator("Pick", parameters, preconds, add_effs, del_effs, + ignore_effs) + + # Place object + robot = Variable("?robot", _robot_type) + obj = Variable("?object", _movable_object_type) + pan = Variable("?pan", _container_type) + parameters = [robot, obj, pan] + preconds: Set[LiftedAtom] = {LiftedAtom(_Holding, [robot, obj])} + add_effs: Set[LiftedAtom] = { + LiftedAtom(_HandEmpty, [robot]), + LiftedAtom(_NotHolding, [robot, obj]), + LiftedAtom(_VLMOn, [obj, pan]) + } + del_effs: Set[LiftedAtom] = {LiftedAtom(_Holding, [robot, obj])} + ignore_effs: Set[LiftedAtom] = set() + yield STRIPSOperator("Place", parameters, preconds, add_effs, del_effs, + ignore_effs) + + # def _generate_train_tasks(self) -> List[EnvironmentTask]: + # goal = self._generate_goal_description() # currently just one goal + # return [ + # EnvironmentTask(None, goal) for _ in range(CFG.num_train_tasks) + # ] + + def _generate_test_tasks(self) -> List[EnvironmentTask]: + goal = self._generate_goal_description() # currently just one goal + return [EnvironmentTask(None, goal) for _ in range(CFG.num_test_tasks)] + + def _generate_train_tasks(self) -> List[EnvironmentTask]: + goal = self._generate_goal_description() # currently just one goal + return [ + EnvironmentTask(None, goal) for _ in range(CFG.num_train_tasks) + ] + + def __init__(self, use_gui: bool = True) -> None: + robot, lease_client = get_robot_only() + self._robot = robot + self._lease_client = lease_client + self._strips_operators: Set[STRIPSOperator] = set() + # Used to do [something] when the agent thinks the goal is reached + # but the human says it is not. + self._current_task_goal_reached = False + # Used when we want to doa special check for a specific + # action. + self._last_action: Optional[Action] = None + # Create constant objects. + self._spot_object = Object("robot", _robot_type) + op_to_name = {o.name: o for o in self._create_operators()} + op_names_to_keep = {"Pick", "Place"} + self._strips_operators = {op_to_name[o] for o in op_names_to_keep} + self._train_tasks = [] + self._test_tasks = [] + + def detect_objects(self, rgbd_images: Dict[str, RGBDImage]) -> Dict[str, List[Tuple[ObjectDetectionID, SegmentedBoundingBox]]]: + object_ids = self._detection_id_to_obj.keys() + object_id_to_img_detections = _query_detic_sam2(object_ids, rgbd_images) + # This ^ is currently a mapping of object_id -> camera_name -> SegmentedBoundingBox. + # We want to do our annotations by camera image, so let's turn this into a + # mapping of camera_name -> object_id -> SegmentedBoundingBox. + detections = {k: [] for k in rgbd_images.keys()} + for object_id, d in object_id_to_img_detections.items(): + for camera_name, seg_bb in d.items(): + detections[camera_name].append((object_id, seg_bb)) + return detections + + def _actively_construct_env_task(self) -> EnvironmentTask: + assert self._robot is not None + rgbd_images = capture_images_without_context(self._robot) + # import PIL + # imgs = [v.rgb for _, v in rgbd_images.items()] + # rot_imgs = [v.rotated_rgb for _, v in rgbd_images.items()] + # ex1 = PIL.Image.fromarray(imgs[0]) + # ex2 = PIL.Image.fromarray(rot_imgs[0]) + # import pdb; pdb.set_trace() + gripper_open_percentage = get_robot_gripper_open_percentage(self._robot) + objects_in_view = [] + + # Perform object detection. + object_detections_per_camera = self.detect_objects(rgbd_images) + + + # artifacts = {"language": {"rgbds": rgbd_images, "object_id_to_img_detections": ret}} + # detections_outfile = Path(".") / "object_detection_artifacts.png" + # no_detections_outfile = Path(".") / "no_detection_artifacts.png" + # visualize_all_artifacts(artifacts, detections_outfile, no_detections_outfile) + + # # Draw object bounding box on images. + # rgbds = artifacts["language"]["rgbds"] + # detections = artifacts["language"]["object_id_to_img_detections"] + # flat_detections: List[Tuple[RGBDImage, + # LanguageObjectDetectionID, + # SegmentedBoundingBox]] = [] + # for obj_id, img_detections in detections.items(): + # for camera, seg_bb in img_detections.items(): + # rgbd = rgbds[camera] + # flat_detections.append((rgbd, obj_id, seg_bb)) + + # # For now assume we only have 1 image, front-left. + # import pdb; pdb.set_trace() + # import PIL + # from PIL import ImageDraw, ImageFont + # bb_pil_imgs = [] + # img = list(rgbd_images.values())[0].rotated_rgb + # pil_img = PIL.Image.fromarray(img) + # draw = ImageDraw.Draw(pil_img) + # for i, (rgbd, obj_id, seg_bb) in enumerate(flat_detections): + # # img = rgbd.rotated_rgb + # # pil_img = PIL.Image.fromarray(img) + # x0, y0, x1, y1 = seg_bb.bounding_box + # draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2) + # text = f"{obj_id.language_id}" + # font = ImageFont.load_default() + # # font = utils.get_scaled_default_font(draw, 4) + # # text_width, text_height = draw.textsize(text, font) + # # text_width = draw.textlength(text, font) + # # text_height = font.getsize("hg")[1] + # text_mask = font.getmask(text) + # text_width, text_height = text_mask.size + # text_bbox = [(x0, y0 - text_height - 2), (x0 + text_width + 2, y0)] + # draw.rectangle(text_bbox, fill='green') + # draw.text((x0 + 1, y0 - text_height - 1), text, fill='white', font=font) + + # import pdb; pdb.set_trace() + + obs = _TruncatedSpotObservation( + rgbd_images, + set(objects_in_view), + set(), + set(), + self._spot_object, + gripper_open_percentage, + object_detections_per_camera, + None + ) + goal_description = self._generate_goal_description() + task = EnvironmentTask(obs, goal_description) + return task + + def _generate_goal_description(self) -> GoalDescription: + return "put the cup in the pan" + + def reset(self, train_or_test: str, task_idx: int) -> Observation: + prompt = f"Please set up {train_or_test} task {task_idx}!" + utils.prompt_user(prompt) + assert self._lease_client is not None + # Automatically retry if a retryable error is encountered. + while True: + try: + self._lease_client.take() + self._current_task = self._actively_construct_env_task() + break + except RetryableRpcError as e: + logging.warning("WARNING: the following retryable error " + f"was encountered. Trying again.\n{e}") + self._current_observation = self._current_task.init_obs + self._current_task_goal_reached = False + self._last_action = None + return self._current_task.init_obs + + def step(self, action: Action) -> Observation: + assert self._robot is not None + action_name = action.extra_info.action_name + # Special case: the action is "done", indicating that the robot + # believes it has finished the task. Used for goal checking. + if action_name == "done": + while True: + goal_description = self._current_task.goal_description + logging.info(f"The goal is: {goal_description}") + prompt = "Is the goal accomplished? Answer y or n. " + response = utils.prompt_user(prompt).strip() + if response == "y": + self._current_task_goal_reached = True + break + if response == "n": + self._current_task_goal_reached = False + break + logging.info("Invalid input, must be either 'y' or 'n'") + return _TruncatedSpotObservation( + self._current_observation.rgbd_images, + self._current_observation.objects_in_view, + set(), + set(), + self._spot_object, + self._current_observation.gripper_open_percentage, + self._current_observation.object_detections_per_camera, + action + ) + + # Execute the action in the real environment. Automatically retry + # if a retryable error is encountered. + action_fn = action.extra_info.real_world_fn + action_fn_args = action.extra_info.real_world_fn_args + while True: + try: + action_fn(*action_fn_args) # type: ignore + break + except RetryableRpcError as e: + logging.warning("WARNING: the following retryable error " + f"was encountered. Trying again.\n{e}") + rgbd_images = capture_images_without_context(self._robot) + gripper_open_percentage = get_robot_gripper_open_percentage(self._robot) + objects_in_view = [] + # Perform object detection. + object_detections_per_camera = self.detect_objects(rgbd_images) + + obs = _TruncatedSpotObservation( + rgbd_images, + set(objects_in_view), + set(), + set(), + self._spot_object, + gripper_open_percentage, + object_detections_per_camera, + action + ) + return obs + + ############################################################################### # Cube Table Env # ############################################################################### diff --git a/predicators/ground_truth_models/spot_env/nsrts.py b/predicators/ground_truth_models/spot_env/nsrts.py index 8ab36470a6..849854dbcd 100644 --- a/predicators/ground_truth_models/spot_env/nsrts.py +++ b/predicators/ground_truth_models/spot_env/nsrts.py @@ -285,10 +285,11 @@ class SpotEnvsGroundTruthNSRTFactory(GroundTruthNSRTFactory): @classmethod def get_env_names(cls) -> Set[str]: return { - "spot_cube_env", "spot_soda_floor_env", "spot_soda_table_env", - "spot_soda_bucket_env", "spot_soda_chair_env", - "spot_main_sweep_env", "spot_ball_and_cup_sticky_table_env", - "spot_brush_shelf_env", "lis_spot_block_floor_env" + "spot_vlm_test_env", "spot_cube_env", "spot_soda_floor_env", + "spot_soda_table_env", "spot_soda_bucket_env", + "spot_soda_chair_env", "spot_main_sweep_env", + "spot_ball_and_cup_sticky_table_env", "spot_brush_shelf_env", + "lis_spot_block_floor_env" } @staticmethod @@ -320,6 +321,8 @@ def get_nsrts(env_name: str, types: Dict[str, Type], "PrepareContainerForSweeping": _prepare_sweeping_sampler, "DropNotPlaceableObject": utils.null_sampler, "MoveToReadySweep": utils.null_sampler, + "Pick": utils.null_sampler, + "Place": utils.null_sampler } # If we're doing proper bilevel planning with a simulator, then diff --git a/predicators/ground_truth_models/spot_env/options.py b/predicators/ground_truth_models/spot_env/options.py index 4729a50493..6e26a7ccf6 100644 --- a/predicators/ground_truth_models/spot_env/options.py +++ b/predicators/ground_truth_models/spot_env/options.py @@ -1,11 +1,13 @@ """Ground-truth options for Spot environments.""" +import logging import time from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple import numpy as np import pbrspot from bosdyn.client import math_helpers +from bosdyn.client.lease import LeaseClient from bosdyn.client.sdk import Robot from gym.spaces import Box @@ -14,7 +16,7 @@ from predicators.envs.spot_env import HANDEMPTY_GRIPPER_THRESHOLD, \ SpotRearrangementEnv, _get_sweeping_surface_for_container, \ get_detection_id_for_object, get_robot, \ - get_robot_gripper_open_percentage, get_simulated_object, \ + get_robot_gripper_open_percentage, get_robot_only, get_simulated_object, \ get_simulated_robot from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.settings import CFG @@ -898,6 +900,33 @@ def _move_to_ready_sweep_policy(state: State, memory: Dict, state, memory, objects, params) +def _teleop_policy(state: State, memory: Dict, objects: Sequence[Object], + params: Array) -> Action: + del state, memory, params + + robot, lease_client = get_robot_only() + + def _teleop(robot: Robot, lease_client: LeaseClient): + prompt = "Press (y) when you are done with teleop." + while True: + response = utils.prompt_user(prompt).strip() + if response == "y": + break + logging.info("Invalid input. Press (y) when y") + # Take back control. + robot, lease_client = get_robot_only() + lease_client.take() + + fn = _teleop + fn_args = (robot, lease_client) + sim_fn = lambda _: None + sim_fn_args = () + name = "teleop" + action_extra_info = SpotActionExtraInfo(name, objects, fn, fn_args, sim_fn, + sim_fn_args) + return utils.create_spot_env_action(action_extra_info) + + ############################################################################### # Parameterized option factory # ############################################################################### @@ -928,6 +957,8 @@ def _move_to_ready_sweep_policy(state: State, memory: Dict, "PrepareContainerForSweeping": Box(-np.inf, np.inf, (3, )), # dx, dy, dyaw "DropNotPlaceableObject": Box(0, 1, (0, )), # empty "MoveToReadySweep": Box(0, 1, (0, )), # empty + "Pick": Box(0, 1, (0, )), # empty + "Place": Box(0, 1, (0, )) # empty } # NOTE: the policies MUST be unique because they output actions with extra info @@ -951,6 +982,8 @@ def _move_to_ready_sweep_policy(state: State, memory: Dict, "PrepareContainerForSweeping": _prepare_container_for_sweeping_policy, "DropNotPlaceableObject": _drop_not_placeable_object_policy, "MoveToReadySweep": _move_to_ready_sweep_policy, + "Pick": _teleop_policy, + "Place": _teleop_policy } @@ -987,6 +1020,7 @@ class SpotEnvsGroundTruthOptionFactory(GroundTruthOptionFactory): @classmethod def get_env_names(cls) -> Set[str]: return { + "spot_vlm_test_env", "spot_cube_env", "spot_soda_floor_env", "spot_soda_table_env", diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index d39107a060..3f908130f2 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -17,12 +17,16 @@ _PartialPerceptionState, _SpotObservation, in_general_view_classifier from predicators.perception.base_perceiver import BasePerceiver from predicators.settings import CFG +from predicators.spot_utils.perception.object_detection import \ + AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \ + LanguageObjectDetectionID, ObjectDetectionID, _query_detic_sam2, \ + detect_objects, visualize_all_artifacts from predicators.spot_utils.utils import _container_type, \ _immovable_object_type, _movable_object_type, _robot_type, \ get_allowed_map_regions, load_spot_metadata, object_to_top_down_geom from predicators.structs import Action, DefaultState, EnvironmentTask, \ GoalDescription, GroundAtom, Object, Observation, Predicate, \ - SpotActionExtraInfo, State, Task, Video + SpotActionExtraInfo, State, Task, Video, _Option, VLMPredicate class SpotPerceiver(BasePerceiver): @@ -577,3 +581,376 @@ def render_mental_images(self, observation: Observation, logging.info(f"Wrote out to {outfile}") plt.close() return [img] + + +class SpotMinimalPerceiver(BasePerceiver): + """A perceiver for spot envs with minimal functionality.""" + + camera_name_to_annotation = { + 'hand_color_image': "Hand Camera Image", + 'back_fisheye_image': "Back Camera Image", + 'frontleft_fisheye_image': "Front Left Camera Image", + 'frontright_fisheye_image': "Front Right Camera Image", + 'left_fisheye_image': "Left Camera Image", + 'right_fisheye_image': "Right Camera Image" + } + + def render_mental_images(self, observation: Observation, + env_task: EnvironmentTask) -> Video: + raise NotImplementedError() + + @classmethod + def get_name(cls) -> str: + return "spot_minimal_perceiver" + + def __init__(self) -> None: + super().__init__() + # self._known_object_poses: Dict[Object, math_helpers.SE3Pose] = {} + # self._objects_in_view: Set[Object] = set() + # self._objects_in_hand_view: Set[Object] = set() + # self._objects_in_any_view_except_back: Set[Object] = set() + self._robot: Optional[Object] = None + # self._nonpercept_atoms: Set[GroundAtom] = set() + # self._nonpercept_predicates: Set[Predicate] = set() + # self._percept_predicates: Set[Predicate] = set() + self._prev_action: Optional[Action] = None + self._held_object: Optional[Object] = None + self._gripper_open_percentage = 0.0 + self._robot_pos: math_helpers.SE3Pose = math_helpers.SE3Pose( + 0, 0, 0, math_helpers.Quat()) + # self._lost_objects: Set[Object] = set() + self._curr_env: Optional[BaseEnv] = None + self._waiting_for_observation = True + self._ordered_objects: List[Object] = [] # list of all known objects + self._state_history: List[State] = [] + self._executed_skill_history: List[_Option] = [] + self._vlm_label_history: List[str] = [] + self._curr_state = None + # # Keep track of objects that are contained (out of view) in another + # # object, like a bag or bucket. This is important not only for gremlins + # # but also for small changes in the container's perceived pose. + # self._container_to_contained_objects: Dict[Object, Set[Object]] = {} + # Load static, hard-coded features of objects, like their shapes. + # meta = load_spot_metadata() + # self._static_object_features = meta.get("static-object-features", {}) + + + def _create_goal(self, state: State, + goal_description: GoalDescription) -> Set[GroundAtom]: + del state # not used + # Unfortunate hack to deal with the fact that the state is actually + # not yet set. Hopefully one day other cleanups will enable cleaning. + assert self._curr_env is not None + pred_name_to_pred = {p.name: p for p in self._curr_env.predicates} + VLMOn = pred_name_to_pred["VLMOn"] + HandEmpty = pred_name_to_pred["HandEmpty"] + if goal_description == "put the cup in the pan": + robot = Object("robot", _robot_type) + cup = Object("cup", _movable_object_type) + pan = Object("pan", _container_type) + goal = { + GroundAtom(HandEmpty, [robot]), + GroundAtom(VLMOn, [cup, pan]) + } + return goal + raise NotImplementedError("Unrecognized goal description") + + def update_perceiver_with_action(self, action: Action) -> None: + # NOTE: we need to keep track of the previous action + # because the step function (where we need knowledge + # of the previous action) occurs *after* the action + # has already been taken. + self._prev_action = action + + def reset(self, env_task: EnvironmentTask) -> Task: + # import pdb; pdb.set_trace() + # init_obs = env_task.init_obs + # imgs = init_obs.rgbd_images + # self._robot = init_obs.robot + # state = self._create_state() + # state.simulator_state["images"] = [imgs] + # state.set(self._robot, "gripper_open_percentage", init_obs.gripper_open_percentage) + # self._curr_state = state + self._curr_env = get_or_create_env(CFG.env) + state = self._create_state() + # state.simulator_state = {} + # state.simulator_state["images"] = [] + # state.simulator_state["state_history"] = [] + # state.simulator_state["skill_history"] = [] + # state.simulator_state["vlm_atoms_history"] = [] + self._curr_state = state + goal = self._create_goal(state, env_task.goal_description) + + # Reset run-specific things. + self._state_history = [] + self._executed_skill_history = [] + self._vlm_label_history = [] + self._prev_action = None + + return Task(state, goal) + + def step(self, observation: Observation) -> State: + # import pdb; pdb.set_trace() + self._waiting_for_observation = False + self._robot = observation.robot + img_objects = observation.rgbd_images # RGBDImage objects + img_names = [v.camera_name for _, v in img_objects.items()] + imgs = [v.rotated_rgb for _, v in img_objects.items()] + import PIL + from PIL import ImageDraw, ImageFont + pil_imgs = [PIL.Image.fromarray(img) for img in imgs] + # Annotate images with detected objects (names + bounding box) + # and camera name. + object_detections_per_camera = observation.object_detections_per_camera + for i, camera_name in enumerate(img_names): + draw = ImageDraw.Draw(pil_imgs[i]) + # Annotate with camera name. + font = utils.get_scaled_default_font(draw, 4) + _ = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[camera_name], font) + # Annotate with object detections. + detections = object_detections_per_camera[camera_name] + for obj_id, seg_bb in detections: + x0, y0, x1, y1 = seg_bb.bounding_box + draw.rectangle([(x0, y0), (x1, y1)], outline='green', width=2) + text = f"{obj_id.language_id}" + font = utils.get_scaled_default_font(draw, 3) + text_mask = font.getmask(text) + text_width, text_height = text_mask.size + text_bbox = [(x0, y0 - 1.5*text_height), (x0 + text_width + 1, y0)] + draw.rectangle(text_bbox, fill='green') + draw.text((x0 + 1, y0 - 1.5*text_height), text, fill='white', font=font) + + # import PIL + # from PIL import ImageDraw + # annotated_pil_imgs = [] + # for img, img_name in zip(imgs, img_names): + # pil_img = PIL.Image.fromarray(img) + # draw = ImageDraw.Draw(pil_img) + # font = utils.get_scaled_default_font(draw, 4) + # annotated_pil_img = utils.add_text_to_draw_img(draw, (0, 0), self.camera_name_to_annotation[img_name], font) + # annotated_pil_imgs.append(pil_img) + annotated_imgs = [np.array(img) for img in pil_imgs] + self._gripper_open_percentage = observation.gripper_open_percentage + + self._curr_state = self._create_state() + # This state is a default/empty. We have to set the attributes + # of the objects and set the simulator state properly. + self._curr_state.simulator_state["images"] = annotated_imgs + # At the first timestep, these histories will be empty due to self.reset(). + # But at every timestep that isn't the first one, they will be non-empty. + self._curr_state.simulator_state["state_history"] = list(self._state_history) + # We do this here so the call to `utils.abstract()` a few lines later has the skill + # that was just run. + executed_skill = None + + if observation.executed_skill is not None: + if observation.executed_skill.extra_info.action_name == "done": + # Just return the default state + return DefaultState + executed_skill = observation.executed_skill.get_option() + self._executed_skill_history.append(executed_skill) # None in first timestep. + self._curr_state.simulator_state["skill_history"] = list(self._executed_skill_history) + self._curr_state.simulator_state["vlm_label_history"] = list(self._vlm_label_history) + + # Add to histories. + # A bit of extra work is required to build the VLM label history. + # We want to keep `utils.abstract()` as straightforward as possible, + # so we'll "rebuild" the VLM labels from the abstract state + # returned by `utils.abstract()`. And since we call this function, + # we might as well store the abstract state as a part of the simulator + # state so that we don't need to recompute it later in the approach or + # in planning. + assert self._curr_env is not None + preds = self._curr_env.predicates + state_copy = self._curr_state.copy() + print(f"Right before abstract state, skill in obs: {observation.executed_skill}") + abstract_state = utils.abstract(state_copy, preds) + self._curr_state.simulator_state["abstract_state"] = abstract_state + # Compute all the VLM atoms. `utils.abstract()` only returns the ones that + # are True. The remaining ones are the ones that are False. + vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate)) + vlm_atoms = set() + for pred in vlm_preds: + for choice in utils.get_object_combinations(list(state_copy), pred.types): + vlm_atoms.add(GroundAtom(pred, choice)) + vlm_atoms = sorted(vlm_atoms) + reconstructed_all_vlm_responses = [] + for atom in vlm_atoms: + if atom in abstract_state: + truth_value = 'True' + else: + truth_value = 'False' + atom_label = f"* {atom.get_vlm_query_str()}: {truth_value}" + reconstructed_all_vlm_responses.append(atom_label) + str_vlm_response = '\n'.join(reconstructed_all_vlm_responses) + self._vlm_label_history.append(str_vlm_response) + self._state_history.append(self._curr_state.copy()) + return self._curr_state.copy() + + def _create_state(self) -> State: + if self._waiting_for_observation: + return DefaultState + # Build the continuous part of the state. + assert self._robot is not None + table = Object("table", _movable_object_type) + cup = Object("cup", _movable_object_type) + pan = Object("pan", _container_type) + # bread = Object("bread", _movable_object_type) + # toaster = Object("toaster", _immovable_object_type) + # microwave = Object("microwave", _movable_object_type) + # napkin = Object("napkin", _movable_object_type) + state_dict = { + self._robot: { + "gripper_open_percentage": self._gripper_open_percentage, + "x": 0, + "y": 0, + "z": 0, + "qw": 0, + "qx": 0, + "qy": 0, + "qz": 0, + }, + table: { + "x": 0, + "y": 0, + "z": 0, + "qw": 0, + "qx": 0, + "qy": 0, + "qz": 0, + "shape": 0, + "height": 0, + "width" : 0, + "length": 0, + "object_id": 0, + "placeable": 1, + "held": 0, + "lost": 0, + "in_hand_view": 0, + "in_view": 1, + "is_sweeper": 0 + }, + cup: { + "x": 0, + "y": 0, + "z": 0, + "qw": 0, + "qx": 0, + "qy": 0, + "qz": 0, + "shape": 0, + "height": 0, + "width": 0, + "length": 0, + "object_id": 1, + "placeable": 1, + "held": 0, + "lost": 0, + "in_hand_view": 0, + "in_view": 1, + "is_sweeper": 0 + }, + # napkin: { + # "x": 0, + # "y": 0, + # "z": 0, + # "qw": 0, + # "qx": 0, + # "qy": 0, + # "qz": 0, + # "shape": 0, + # "height": 0, + # "width" : 0, + # "length": 0, + # "object_id": 2, + # "placeable": 1, + # "held": 0, + # "lost": 0, + # "in_hand_view": 0, + # "in_view": 1, + # "is_sweeper": 0 + # }, + # microwave: { + # "x": 0, + # "y": 0, + # "z": 0, + # "qw": 0, + # "qx": 0, + # "qy": 0, + # "qz": 0, + # "shape": 0, + # "height": 0, + # "width" : 0, + # "length": 0, + # "object_id": 2, + # "placeable": 1, + # "held": 0, + # "lost": 0, + # "in_hand_view": 0, + # "in_view": 1, + # "is_sweeper": 0 + # }, + # bread: { + # "x": 0, + # "y": 0, + # "z": 0, + # "qw": 0, + # "qx": 0, + # "qy": 0, + # "qz": 0, + # "shape": 0, + # "height": 0, + # "width" : 0, + # "length": 0, + # "object_id": 2, + # "placeable": 1, + # "held": 0, + # "lost": 0, + # "in_hand_view": 0, + # "in_view": 1, + # "is_sweeper": 0 + # }, + # toaster: { + # "x": 0, + # "y": 0, + # "z": 0, + # "qw": 0, + # "qx": 0, + # "qy": 0, + # "qz": 0, + # "shape": 0, + # "height": 0, + # "width" : 0, + # "length": 0, + # "object_id": 1, + # "flat_top_surface": 1 + # }, + pan: { + "x": 0, + "y": 0, + "z": 0, + "qw": 0, + "qx": 0, + "qy": 0, + "qz": 0, + "shape": 0, + "height": 0, + "width" : 0, + "length": 0, + "object_id": 2, + "placeable": 1, + "held": 0, + "lost": 0, + "in_hand_view": 0, + "in_view": 1, + "is_sweeper": 0 + } + } + state_dict = {k: list(v.values()) for k, v in state_dict.items()} + state = State(state_dict) + state.simulator_state = {} + state.simulator_state["images"] = [] + state.simulator_state["state_history"] = [] + state.simulator_state["skill_history"] = [] + state.simulator_state["vlm_atoms_history"] = [] + return state diff --git a/predicators/pretrained_model_interface.py b/predicators/pretrained_model_interface.py index dfd779b8f6..609d5ee900 100644 --- a/predicators/pretrained_model_interface.py +++ b/predicators/pretrained_model_interface.py @@ -102,7 +102,7 @@ def sample_completions(self, if not os.path.exists(cache_filepath): if CFG.llm_use_cache_only: raise ValueError("No cached response found for prompt.") - logging.debug(f"Querying model {model_id} with new prompt.") + print(f"Querying model {model_id} with new prompt.") # Query the model. completions = self._sample_completions(prompt, imgs, temperature, seed, stop_token, @@ -118,11 +118,11 @@ def sample_completions(self, for i, img in enumerate(imgs): filename_suffix = str(i) + ".jpg" img.save(os.path.join(imgs_folderpath, filename_suffix)) - logging.debug(f"Saved model response to {cache_filepath}.") + print(f"Saved model response to {cache_filepath}.") # Load the saved completion. with open(cache_filepath, 'r', encoding='utf-8') as f: cache_str = f.read() - logging.debug(f"Loaded model response from {cache_filepath}.") + print(f"Loaded model response from {cache_filepath}.") assert cache_str.count(_CACHE_SEP) == num_completions cached_prompt, completion_strs = cache_str.split(_CACHE_SEP, 1) assert cached_prompt == prompt @@ -289,8 +289,8 @@ class GoogleGeminiVLM(VisionLanguageModel, GoogleGeminiModel): necessary API key to query the particular model name. """ - @retry(wait=wait_random_exponential(min=1, max=60), - stop=stop_after_attempt(10)) + # @retry(wait=wait_random_exponential(min=1, max=60), + # stop=stop_after_attempt(10)) def _sample_completions( self, prompt: str, diff --git a/predicators/settings.py b/predicators/settings.py index 8cac7a2f0b..f2c4d9ed27 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -434,7 +434,7 @@ class GlobalSettings: # parameters for vision language models # gemini-1.5-pro-latest, gpt-4-turbo, gpt-4o - vlm_model_name = "gemini-pro-vision" + vlm_model_name = "gemini-1.5-flash" #"gemini-1.5-pro-latest" vlm_temperature = 0.0 vlm_num_completions = 1 vlm_include_cropped_images = False @@ -719,6 +719,9 @@ class GlobalSettings: # saved_vlm_img_demos_folder vlm_trajs_folder_name = "" vlm_predicate_vision_api_generate_ground_atoms = False + # At test-time, we will use the below number of states + # as part of labelling the current state's VLM atoms. + vlm_test_time_atom_label_prompt_type = "per_scene_naive" @classmethod def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]: diff --git a/predicators/spot_utils/perception/object_detection.py b/predicators/spot_utils/perception/object_detection.py index 6757e8f324..6b627a90ef 100644 --- a/predicators/spot_utils/perception/object_detection.py +++ b/predicators/spot_utils/perception/object_detection.py @@ -41,7 +41,7 @@ from predicators.spot_utils.perception.perception_structs import \ AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \ LanguageObjectDetectionID, ObjectDetectionID, PythonicObjectDetectionID, \ - RGBDImageWithContext, SegmentedBoundingBox + RGBDImage, RGBDImageWithContext, SegmentedBoundingBox from predicators.spot_utils.utils import get_april_tag_transform, \ get_graph_nav_dir from predicators.utils import rotate_point_in_image @@ -351,6 +351,108 @@ def _query_detic_sam( return object_id_to_img_detections +def _query_detic_sam2( + object_ids: Collection[LanguageObjectDetectionID], + rgbds: Dict[str, RGBDImage], + max_server_retries: int = 5, + detection_threshold: float = CFG.spot_vision_detection_threshold +) -> Dict[ObjectDetectionID, Dict[str, SegmentedBoundingBox]]: + """Returns object ID to image ID (camera) to segmented bounding box.""" + + object_id_to_img_detections: Dict[ObjectDetectionID, + Dict[str, SegmentedBoundingBox]] = { + obj_id: {} + for obj_id in object_ids + } + + # Create buffer dictionary to send to server. + buf_dict = {} + for camera_name, rgbd in rgbds.items(): + pil_rotated_img = PIL.Image.fromarray(rgbd.rotated_rgb) # type: ignore + buf_dict[camera_name] = _image_to_bytes(pil_rotated_img) + + # Extract all the classes that we want to detect. + classes = sorted(o.language_id for o in object_ids) + + # Query server, retrying to handle possible wifi issues. + # import pdb; pdb.set_trace() + # imgs = [v.rotated_rgb for _, v in rgbds.items()] + # pil_img = PIL.Image.fromarray(imgs[0]) + # import pdb; pdb.set_trace() + + for _ in range(max_server_retries): + try: + r = requests.post("http://localhost:5550/batch_predict", + files=buf_dict, + data={"classes": ",".join(classes)}) + break + except requests.exceptions.ConnectionError: + continue + else: + logging.warning("DETIC-SAM FAILED, POSSIBLE SERVER/WIFI ISSUE") + return object_id_to_img_detections + + # If the status code is not 200, then fail. + if r.status_code != 200: + logging.warning(f"DETIC-SAM FAILED! STATUS CODE: {r.status_code}") + return object_id_to_img_detections + + # Querying the server succeeded; unpack the contents. + with io.BytesIO(r.content) as f: + try: + server_results = np.load(f, allow_pickle=True) + # Corrupted results. + except pkl.UnpicklingError: + logging.warning("DETIC-SAM FAILED DURING UNPICKLING!") + return object_id_to_img_detections + + # Process the results and save all detections per object ID. + for camera_name, rgbd in rgbds.items(): + rot_boxes = server_results[f"{camera_name}_boxes"] + ret_classes = server_results[f"{camera_name}_classes"] + rot_masks = server_results[f"{camera_name}_masks"] + scores = server_results[f"{camera_name}_scores"] + + # Invert the rotation immediately so we don't need to worry about + # them henceforth. + # h, w = rgbd.rgb.shape[:2] + # image_rot = rgbd.image_rot + # boxes = [ + # _rotate_bounding_box(bb, -image_rot, h, w) for bb in rot_boxes + # ] + # masks = [ + # ndimage.rotate(m.squeeze(), -image_rot, reshape=False) + # for m in rot_masks + # ] + boxes = rot_boxes + masks = rot_masks + + # Filter out detections by confidence. We threshold detections + # at a set confidence level minimum, and if there are multiple, + # we only select the most confident one. This structure makes + # it easy for us to select multiple detections if that's ever + # necessary in the future. + for obj_id in object_ids: + # If there were no detections (which means all the + # returned values will be numpy arrays of shape (0, 0)) + # then just skip this source. + if ret_classes.size == 0: + continue + obj_id_mask = (ret_classes == obj_id.language_id) + if not np.any(obj_id_mask): + continue + max_score = np.max(scores[obj_id_mask]) + best_idx = np.where(scores == max_score)[0].item() + if scores[best_idx] < detection_threshold: + continue + # Save the detection. + seg_bb = SegmentedBoundingBox(boxes[best_idx], masks[best_idx], + scores[best_idx]) + object_id_to_img_detections[obj_id][rgbd.camera_name] = seg_bb + + # import pdb; pdb.set_trace() + return object_id_to_img_detections + def _image_to_bytes(img: PIL.Image.Image) -> io.BytesIO: """Helper function to convert from a PIL image into a bytes object.""" @@ -522,7 +624,8 @@ def visualize_all_artifacts(artifacts: Dict[str, ax_row[2].imshow(rgbd.depth, cmap='Greys_r', vmin=0, vmax=10000) # Bounding box. - ax_row[3].imshow(rgbd.rgb) + # ax_row[3].imshow(rgbd.rgb) + ax_row[3].imshow(rgbd.rotated_rgb) box = seg_bb.bounding_box x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] @@ -534,7 +637,7 @@ def visualize_all_artifacts(artifacts: Dict[str, facecolor=(0, 0, 0, 0), lw=1)) - ax_row[4].imshow(seg_bb.mask, cmap="binary_r", vmin=0, vmax=1) + # ax_row[4].imshow(seg_bb.mask, cmap="binary_r", vmin=0, vmax=1) # Labels. abbreviated_name = obj_id.language_id diff --git a/predicators/spot_utils/perception/perception_structs.py b/predicators/spot_utils/perception/perception_structs.py index 321907c3b6..a39f78f380 100644 --- a/predicators/spot_utils/perception/perception_structs.py +++ b/predicators/spot_utils/perception/perception_structs.py @@ -30,6 +30,22 @@ def rotated_rgb(self) -> NDArray[np.uint8]: return ndimage.rotate(self.rgb, self.image_rot, reshape=False) +@dataclass +class RGBDImage: + """An RGBD image.""" + rgb: NDArray[np.uint8] + depth: NDArray[np.uint16] + image_rot: float + camera_name: str + depth_scale: float + camera_model: Any # bosdyn.api.image_pb2.PinholeModel, but not available + + @property + def rotated_rgb(self) -> NDArray[np.uint8]: + """The image rotated to be upright.""" + return ndimage.rotate(self.rgb, self.image_rot, reshape=False) + + @dataclass(frozen=True) class ObjectDetectionID: """A unique identifier for an object that is to be detected.""" diff --git a/predicators/spot_utils/perception/spot_cameras.py b/predicators/spot_utils/perception/spot_cameras.py index cbc8d4dff3..a1a9780e94 100644 --- a/predicators/spot_utils/perception/spot_cameras.py +++ b/predicators/spot_utils/perception/spot_cameras.py @@ -8,8 +8,9 @@ from bosdyn.client.image import ImageClient, build_image_request from bosdyn.client.sdk import Robot from numpy.typing import NDArray +from scipy import ndimage -from predicators.spot_utils.perception.perception_structs import \ +from predicators.spot_utils.perception.perception_structs import RGBDImage, \ RGBDImageWithContext from predicators.spot_utils.spot_localization import SpotLocalizer @@ -22,12 +23,12 @@ 'right_fisheye_image': 180 } RGB_TO_DEPTH_CAMERAS = { - "hand_color_image": "hand_depth_in_hand_color_frame", + # "hand_color_image": "hand_depth_in_hand_color_frame", "left_fisheye_image": "left_depth_in_visual_frame", - "right_fisheye_image": "right_depth_in_visual_frame", - "frontleft_fisheye_image": "frontleft_depth_in_visual_frame", + # "right_fisheye_image": "right_depth_in_visual_frame", + # "frontleft_fisheye_image": "frontleft_depth_in_visual_frame", "frontright_fisheye_image": "frontright_depth_in_visual_frame", - "back_fisheye_image": "back_depth_in_visual_frame" + # "back_fisheye_image": "back_depth_in_visual_frame" } # Hack to avoid double image capturing when we want to (1) get object states @@ -120,6 +121,87 @@ def capture_images( return rgbds +def capture_images_without_context( + robot: Robot, + camera_names: Optional[Collection[str]] = None, + quality_percent: int = 100, +) -> Dict[str, RGBDImage]: + """Build an image request and get the responses. + + If no camera names are provided, all RGB cameras are used. + """ + # global _LAST_CAPTURED_IMAGES # pylint: disable=global-statement + + if camera_names is None: + camera_names = set(RGB_TO_DEPTH_CAMERAS) + + image_client = robot.ensure_client(ImageClient.default_service_name) + + rgbds: Dict[str, RGBDImage] = {} + + # # Get the world->robot transform so we can store world->camera transforms + # # in the RGBDWithContexts. + # if relocalize: + # localizer.localize() + # world_tform_body = localizer.get_last_robot_pose() + # body_tform_world = world_tform_body.inverse() + + # Package all the requests together. + img_reqs: image_pb2.ImageRequest = [] + for camera_name in camera_names: + # Build RGB image request. + if "hand" in camera_name: + rgb_pixel_format = None + else: + rgb_pixel_format = image_pb2.Image.PIXEL_FORMAT_RGB_U8 # pylint: disable=no-member + rgb_img_req = build_image_request(camera_name, + quality_percent=quality_percent, + pixel_format=rgb_pixel_format) + img_reqs.append(rgb_img_req) + # Build depth image request. + depth_camera_name = RGB_TO_DEPTH_CAMERAS[camera_name] + depth_img_req = build_image_request(depth_camera_name, + quality_percent=quality_percent, + pixel_format=None) + img_reqs.append(depth_img_req) + + # Send the request. + responses = image_client.get_image(img_reqs) + name_to_response = {r.source.name: r for r in responses} + + # Build RGBDImageWithContexts. + for camera_name in camera_names: + rgb_img_resp = name_to_response[camera_name] + depth_img_resp = name_to_response[RGB_TO_DEPTH_CAMERAS[camera_name]] + rgb_img = _image_response_to_image(rgb_img_resp) + # rgb_img = ndimage.rotate(rgb_img, ROTATION_ANGLE[camera_name], reshape=False) + depth_img = _image_response_to_image(depth_img_resp) + # # Create transform. + # camera_tform_body = get_a_tform_b( + # rgb_img_resp.shot.transforms_snapshot, + # rgb_img_resp.shot.frame_name_image_sensor, BODY_FRAME_NAME) + # camera_tform_world = camera_tform_body * body_tform_world + # world_tform_camera = camera_tform_world.inverse() + # Extract other context. + rot = ROTATION_ANGLE[camera_name] + depth_scale = depth_img_resp.source.depth_scale + # transforms_snapshot = rgb_img_resp.shot.transforms_snapshot + # frame_name_image_sensor = rgb_img_resp.shot.frame_name_image_sensor + camera_model = rgb_img_resp.source.pinhole + # Finish RGBDImageWithContext. + # rgbd = RGBDImageWithContext(rgb_img, depth_img, rot, camera_name, + # world_tform_camera, depth_scale, + # transforms_snapshot, + # frame_name_image_sensor, camera_model) + rgbd = RGBDImage(rgb_img, depth_img, rot, camera_name, depth_scale, + camera_model) + rgbds[camera_name] = rgbd + + # _LAST_CAPTURED_IMAGES = rgbds + + return rgbds + + def _image_response_to_image( image_response: image_pb2.ImageResponse, ) -> NDArray: """Extract an image from an image response. diff --git a/predicators/spot_utils/spot_localization.py b/predicators/spot_utils/spot_localization.py index 3c4c557b6d..a853992700 100644 --- a/predicators/spot_utils/spot_localization.py +++ b/predicators/spot_utils/spot_localization.py @@ -55,26 +55,29 @@ def __init__(self, robot: Robot, upload_path: Path, self._robot_pose = math_helpers.SE3Pose(0, 0, 0, math_helpers.Quat()) # Initialize the robot's position in the map. robot_state = get_robot_state(self._robot) - current_odom_tform_body = get_odom_tform_body( - robot_state.kinematic_state.transforms_snapshot).to_proto() - localization = nav_pb2.Localization() - for r in range(NUM_LOCALIZATION_RETRIES + 1): - try: - self.graph_nav_client.set_localization( - initial_guess_localization=localization, - ko_tform_body=current_odom_tform_body) - break - except (ResponseError, TimedOutError) as e: - # Retry or fail. - if r == NUM_LOCALIZATION_RETRIES: - msg = f"Localization failed permanently: {e}." - logging.warning(msg) - raise LocalizationFailure(msg) - logging.warning("Localization failed once, retrying.") - time.sleep(LOCALIZATION_RETRY_WAIT_TIME) - - # Run localize once to start. - self.localize() + z_position = robot_state.kinematic_state.transforms_snapshot.child_to_parent_edge_map[ + "gpe"].parent_tform_child.position.z + import pdb + pdb.set_trace() + # current_odom_tform_body = get_odom_tform_body( + # robot_state.kinematic_state.transforms_snapshot).to_proto() + # localization = nav_pb2.Localization() + # for r in range(NUM_LOCALIZATION_RETRIES + 1): + # try: + # self.graph_nav_client.set_localization( + # initial_guess_localization=localization, + # ko_tform_body=current_odom_tform_body) + # break + # except (ResponseError, TimedOutError) as e: + # # Retry or fail. + # if r == NUM_LOCALIZATION_RETRIES: + # msg = f"Localization failed permanently: {e}." + # logging.warning(msg) + # raise LocalizationFailure(msg) + # logging.warning("Localization failed once, retrying.") + # time.sleep(LOCALIZATION_RETRY_WAIT_TIME) + # # Run localize once to start. + # self.localize() def _upload_graph_and_snapshots(self) -> None: """Upload the graph and snapshots to the robot.""" @@ -133,23 +136,23 @@ def localize(self, It's good practice to call this periodically to avoid drift issues. April tags need to be in view. """ - try: - localization_state = self.graph_nav_client.get_localization_state() - transform = localization_state.localization.seed_tform_body - if str(transform) == "": - raise LocalizationFailure("Received empty localization state.") - except (ResponseError, TimedOutError, LocalizationFailure) as e: - # Retry or fail. - if num_retries <= 0: - msg = f"Localization failed permanently: {e}." - logging.warning(msg) - raise LocalizationFailure(msg) - logging.warning("Localization failed once, retrying.") - time.sleep(retry_wait_time) - return self.localize(num_retries=num_retries - 1, - retry_wait_time=retry_wait_time) - logging.info("Localization succeeded.") - self._robot_pose = math_helpers.SE3Pose.from_proto(transform) + # try: + # localization_state = self.graph_nav_client.get_localization_state() + # transform = localization_state.localization.seed_tform_body + # if str(transform) == "": + # raise LocalizationFailure("Received empty localization state.") + # except (ResponseError, TimedOutError, LocalizationFailure) as e: + # # Retry or fail. + # if num_retries <= 0: + # msg = f"Localization failed permanently: {e}." + # logging.warning(msg) + # raise LocalizationFailure(msg) + # logging.warning("Localization failed once, retrying.") + # time.sleep(retry_wait_time) + # return self.localize(num_retries=num_retries - 1, + # retry_wait_time=retry_wait_time) + # logging.info("Localization succeeded.") + # self._robot_pose = math_helpers.SE3Pose.from_proto(transform) return None diff --git a/predicators/structs.py b/predicators/structs.py index 1ad0c05b2c..d1c8e1880f 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -489,9 +489,20 @@ def __post_init__(self) -> None: for atom in self.goal: assert isinstance(atom, GroundAtom) - def goal_holds(self, state: State) -> bool: + def goal_holds(self, state: State, vlm: Optional[Any] = None) -> bool: """Return whether the goal of this task holds in the given state.""" - return all(goal_atom.holds(state) for goal_atom in self.goal) + if "abstract_state" in state.simulator_state: + abstract_state = state.simulator_state["abstract_state"] + return self.goal.issubset(abstract_state) + from predicators.utils import query_vlm_for_atom_vals + vlm_atoms = set(atom for atom in self.goal + if isinstance(atom.predicate, VLMPredicate)) + for atom in self.goal: + if atom not in vlm_atoms: + if not atom.holds(state): + return False + true_vlm_atoms = query_vlm_for_atom_vals(vlm_atoms, state, vlm) + return len(true_vlm_atoms) == len(vlm_atoms) def replace_goal_with_alt_goal(self) -> Task: """Return a Task with the goal replaced with the alternative goal if it diff --git a/predicators/utils.py b/predicators/utils.py index 5478bc9f97..b8ea1e97e3 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -2494,6 +2494,73 @@ def parse_model_output_into_option_plan( return option_plan +def get_prompt_for_vlm_state_labelling( + prompt_type: str, atoms_list: List[str], label_history: List[str], + imgs_history: List[List[PIL.Image.Image]], + cropped_imgs_history: List[List[PIL.Image.Image]], + skill_history: List[Action]) -> Tuple[str, List[PIL.Image.Image]]: + """Prompt for generating labels for an entire trajectory. Similar to the + above prompting method, this outputs a list of prompts to label the state + at each timestep of traj with atom values). + + Note that all our prompts are saved as separate txt files under the + 'vlm_input_data_prompts/atom_labelling' folder. + """ + # Load the pre-specified prompt. + filepath_prefix = get_path_to_predicators_root() + \ + "/predicators/datasets/vlm_input_data_prompts/atom_labelling/" + try: + with open(filepath_prefix + + CFG.grammar_search_vlm_atom_label_prompt_type + ".txt", + "r", + encoding="utf-8") as f: + prompt = f.read() + except FileNotFoundError: + raise ValueError("Unknown VLM prompting option " + + f"{CFG.grammar_search_vlm_atom_label_prompt_type}") + # The prompt ends with a section for 'Predicates', so list these. + for atom_str in atoms_list: + prompt += f"\n{atom_str}" + + if "img_option_diffs" in prompt_type: + # In this case, we need to load the 'per_scene_naive' prompt as well + # for the first timestep. + with open(filepath_prefix + "per_scene_naive.txt", + "r", + encoding="utf-8") as f: + init_prompt = f.read() + for atom_str in atoms_list: + init_prompt += f"\n{atom_str}" + if len(label_history) == 0: + return (init_prompt, imgs_history[0]) + # Now, we use actual difference-based prompting for the second timestep + # and beyond. + curr_prompt = prompt[:] + curr_prompt_imgs = [ + imgs_timestep for imgs_timestep in imgs_history[-1] + ] + if CFG.vlm_include_cropped_images: + if CFG.env in ["burger", "burger_no_move"]: # pragma: no cover + curr_prompt_imgs.extend( + [cropped_imgs_history[-1][1], cropped_imgs_history[-1][0]]) + else: + raise NotImplementedError( + f"Cropped images not implemented for {CFG.env}.") + curr_prompt += "\n\nSkill executed between states: " + skill_name = skill_history[-1].name + str(skill_history[-1].objects) + curr_prompt += skill_name + if "label_history" in prompt_type: + curr_prompt += "\n\nPredicate values in the first scene, " \ + "before the skill was executed: \n" + curr_prompt += label_history[-1] + return (curr_prompt, curr_prompt_imgs) + else: + # NOTE: we rip out only the first image from each trajectory + # which is fine for most domains, but will be problematic for + # situations in which there is more than one image per state. + return (prompt, [imgs_history[-1][0]]) + + def query_vlm_for_atom_vals( vlm_atoms: Collection[GroundAtom], state: State, @@ -2505,17 +2572,41 @@ def query_vlm_for_atom_vals( # vlm can be called on. assert state.simulator_state is not None assert isinstance(state.simulator_state["images"], List) - imgs = state.simulator_state["images"] + # if "vlm_atoms_history" not in state.simulator_state: + # state.simulator_state["vlm_atoms_history"] = [] + # imgs = state.simulator_state["images"] + # previous_states = [] + # # We assume the state.simulator_state contains a list of previous states. + # if "state_history" in state.simulator_state: + # previous_states = state.simulator_state["state_history"] + # state_imgs_history = [ + # state.simulator_state["images"] for state in previous_states + # ] vlm_atoms = sorted(vlm_atoms) - atom_queries_str = "\n* " - atom_queries_str += "\n* ".join(atom.get_vlm_query_str() - for atom in vlm_atoms) - filepath_to_vlm_prompt = get_path_to_predicators_root() + \ - "/predicators/datasets/vlm_input_data_prompts/atom_labelling/" + \ - "per_scene_naive.txt" - with open(filepath_to_vlm_prompt, "r", encoding="utf-8") as f: - vlm_query_str = f.read() - vlm_query_str += atom_queries_str + atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms] + # All "history" fields in the simulator state contain things from + # previous states -- not the current state. + # We want the image history to include the images from the current state. + curr_state_images = state.simulator_state["images"] + if "state_history" in state.simulator_state: + prev_states = state.simulator_state["state_history"] + prev_states_imgs_history = [s.simulator_state["images"] for s in prev_states] + images_history = prev_states_imgs_history + [curr_state_images] + skill_history = [] + if "skill_history" in state.simulator_state: + skill_history = state.simulator_state["skill_history"] + label_history = [] + if "vlm_label_history" in state.simulator_state: + label_history = state.simulator_state["vlm_label_history"] + + # vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( + # CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list, + # state.simulator_state["vlm_atoms_history"], state_imgs_history, [], + # state.simulator_state["skill_history"]) + vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( + CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list, + label_history, images_history, [], skill_history) + import pdb; pdb.set_trace() if vlm is None: vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover. vlm_input_imgs = \ @@ -2527,20 +2618,30 @@ def query_vlm_for_atom_vals( num_completions=1) assert len(vlm_output) == 1 vlm_output_str = vlm_output[0] - all_atom_queries = atom_queries_str.strip().split("\n") + print(f"VLM output: {vlm_output_str}") all_vlm_responses = vlm_output_str.strip().split("\n") - # NOTE: this assumption is likely too brittle; if this is breaking, feel # free to remove/adjust this and change the below parsing loop accordingly! - assert len(all_atom_queries) == len(all_vlm_responses) + assert len(atom_queries_list) == len(all_vlm_responses) for i, (atom_query, curr_vlm_output_line) in enumerate( - zip(all_atom_queries, all_vlm_responses)): + zip(atom_queries_list, all_vlm_responses)): assert atom_query + ":" in curr_vlm_output_line assert "." in curr_vlm_output_line period_idx = curr_vlm_output_line.find(".") - if curr_vlm_output_line[len(atom_query + - ":"):period_idx].lower().strip() == "true": + # value = curr_vlm_output_line[len(atom_query + ":"):period_idx].lower().strip() + value = curr_vlm_output_line.split(': ')[-1].strip('.').lower() + if value == "true": true_atoms.add(vlm_atoms[i]) + + # breakpoint() + # Add the text of the VLM's response to the state, to be used in the future! + # REMOVE THIS -> AND PUT IT IN THE PERCEIVER + # Perceiver calls utils.abstract once, and puts it in the state history. + # According to a flag, anywhere else we normally call utils.abstract, we + # instead just pull the abstract state from the state simulator state field that has it already. + # The appending of vlm atom history is currently done in query_vlm_for_atom_vals() in utils.py, + # and utils.ground calls that. + # state.simulator_state["vlm_atoms_history"].append(all_vlm_responses) return true_atoms @@ -2552,6 +2653,8 @@ def abstract(state: State, Duplicate arguments in predicates are allowed. """ + if "abstract_state" in state.simulator_state: + return state.simulator_state["abstract_state"] # Start by pulling out all VLM predicates. vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate)) # Next, classify all non-VLM predicates. @@ -2568,6 +2671,7 @@ def abstract(state: State, for pred in vlm_preds: for choice in get_object_combinations(list(state), pred.types): vlm_atoms.add(GroundAtom(pred, choice)) + # import pdb; pdb.set_trace() true_vlm_atoms = query_vlm_for_atom_vals(vlm_atoms, state, vlm) atoms |= true_vlm_atoms return atoms diff --git a/tests/envs/test_spot_envs.py b/tests/envs/test_spot_envs.py index 6f9fdbffba..f723f4f138 100644 --- a/tests/envs/test_spot_envs.py +++ b/tests/envs/test_spot_envs.py @@ -25,6 +25,13 @@ from predicators.structs import Action, GroundAtom, _GroundNSRT +def test_spot_vlm_debug(): + utils.reset_config({ + "env": "spot_vlm_test_env", + }) + pass + + def test_spot_env_dry_run(): """Dry run tests (do not require access to robot).""" utils.reset_config({ @@ -265,8 +272,15 @@ def real_robot_cube_env_test() -> None: "test_task_json_dir": args.get("test_task_json_dir", None), }) + import pdb + pdb.set_trace() rng = np.random.default_rng(123) + import os + os.environ["BOSDYN_CLIENT_USERNAME"] = "user" + os.environ["BOSDYN_CLIENT_PASSWORD"] = "bbbdddaaaiii" env = SpotCubeEnv() + import pdb + pdb.set_trace() perceiver = SpotPerceiver() nsrts = get_gt_nsrts(env.get_name(), env.predicates, get_gt_options(env.get_name())) @@ -275,6 +289,8 @@ def real_robot_cube_env_test() -> None: task = env.get_test_tasks()[0] obs = env.reset("test", 0) perceiver.reset(task) + import pdb + pdb.set_trace() assert len(obs.objects_in_view) == 4 cube, floor, table1, table2 = sorted(obs.objects_in_view) assert cube.name == "cube" @@ -646,6 +662,6 @@ def real_robot_sweeping_nsrt_test() -> None: if __name__ == "__main__": - # real_robot_cube_env_test() + real_robot_cube_env_test() # real_robot_drafting_table_placement_test() - real_robot_sweeping_nsrt_test() + # real_robot_sweeping_nsrt_test()