diff --git a/image_0.png b/image_0.png index 874f8f32f..2400498c8 100644 Binary files a/image_0.png and b/image_0.png differ diff --git a/image_1.png b/image_1.png index 303729689..97c8fba53 100644 Binary files a/image_1.png and b/image_1.png differ diff --git a/image_2.png b/image_2.png index 7d386f85f..bd1fe28e0 100644 Binary files a/image_2.png and b/image_2.png differ diff --git a/image_3.png b/image_3.png index 07385ca6d..43cf25782 100644 Binary files a/image_3.png and b/image_3.png differ diff --git a/image_4.png b/image_4.png index 40b36bfea..d2a967e9c 100644 Binary files a/image_4.png and b/image_4.png differ diff --git a/image_5.png b/image_5.png index de471a98a..1c5c78b3b 100644 Binary files a/image_5.png and b/image_5.png differ diff --git a/predicators/cogman.py b/predicators/cogman.py index f0d00e11e..d2e5af69e 100644 --- a/predicators/cogman.py +++ b/predicators/cogman.py @@ -84,6 +84,34 @@ def step(self, observation: Observation) -> Optional[Action]: self._exec_monitor.env_task = self._current_env_task if self._exec_monitor.step(state): logging.info("[CogMan] Replanning triggered.") + import jellyfish + def map_goal_to_state(goal_predicates, state_data): + goal_to_state_mapping = {} + state_usage_count = {state_obj: 0 for state_obj in state_data.keys()} + for pred in goal_predicates: + for goal_obj in pred.objects: + goal_obj_name = str(goal_obj) + closest_state_obj = None + min_distance = float('inf') + for state_obj in state_data.keys(): + state_obj_name = str(state_obj) + distance = jellyfish.levenshtein_distance(goal_obj_name, state_obj_name) + if distance < min_distance: + min_distance = distance + closest_state_obj = state_obj + if state_usage_count[closest_state_obj] > 0: + virtual_obj = Object(f"{closest_state_obj}_{state_usage_count[closest_state_obj]}", closest_state_obj.type) + goal_to_state_mapping[goal_obj] = virtual_obj + else: + goal_to_state_mapping[goal_obj] = closest_state_obj + state_usage_count[closest_state_obj] += 1 + return goal_to_state_mapping + mapping = map_goal_to_state(self._exec_monitor._curr_goal, state.data) + new_goal = set() + for pred in self._exec_monitor._curr_goal: + new_goal.add(GroundAtom(pred.predicate, [mapping[obj] for obj in pred.objects])) + import ipdb; ipdb.set_trace() + self._current_goal = new_goal assert self._current_goal is not None task = Task(state, self._current_goal) self._reset_policy(task) diff --git a/predicators/envs/base_env.py b/predicators/envs/base_env.py index 931c71824..a32aa645a 100644 --- a/predicators/envs/base_env.py +++ b/predicators/envs/base_env.py @@ -315,7 +315,10 @@ def _parse_goal_from_json(self, spec: Dict[str, List[List[str]]], for pred, args in pred_to_args.items(): for id_args in args: obj_args = [id_to_obj[a] for a in id_args] - goal_atom = GroundAtom(pred, obj_args) + try: + goal_atom = GroundAtom(pred, obj_args) + except AssertionError as e: + import ipdb; ipdb.set_trace() goal.add(goal_atom) return goal diff --git a/predicators/execution_monitoring/base_execution_monitor.py b/predicators/execution_monitoring/base_execution_monitor.py index 631b97915..d65ad7ad2 100644 --- a/predicators/execution_monitoring/base_execution_monitor.py +++ b/predicators/execution_monitoring/base_execution_monitor.py @@ -12,6 +12,7 @@ class BaseExecutionMonitor(abc.ABC): def __init__(self) -> None: self._approach_info: List[Any] = [] self._curr_plan_timestep = 0 + self._curr_goal = None @classmethod @abc.abstractmethod diff --git a/predicators/execution_monitoring/expected_atoms_monitor.py b/predicators/execution_monitoring/expected_atoms_monitor.py index 246d29b3c..8caec7ecc 100644 --- a/predicators/execution_monitoring/expected_atoms_monitor.py +++ b/predicators/execution_monitoring/expected_atoms_monitor.py @@ -33,9 +33,14 @@ def step(self, state: State) -> bool: unsat_atoms = {a for a in next_expected_atoms if not a.holds(state)} # Check goal assert self.perceiver is not None and self.env_task is not None - goal = self.perceiver._create_goal(state, self.env_task.goal_description) - import ipdb; ipdb.set_trace() - # + new_goal = self.perceiver._create_goal(state, self.env_task.goal_description) + if new_goal != self._curr_goal: + logging.info( + "Expected atoms execution monitor triggered replanning " + "because the goal has changed.") + logging.info(f"Old goal: {self._curr_goal}") + logging.info(f"New goal: {new_goal}") + self._curr_goal = new_goal if not unsat_atoms: return False logging.info( diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index 1b498e4ab..4658e4b61 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -720,6 +720,9 @@ def _parse_vlm_goal_from_state( self, state: State, language_goal: str, id_to_obj: Dict[str, Object]) -> Set[GroundAtom]: """Helper for parsing language-based goals from JSON task specs.""" + ### + language_goal = 'place empty cups into the plastic_bin' + ### object_names = set(id_to_obj) prompt_prefix = self._get_language_goal_prompt_prefix(object_names) prompt = prompt_prefix + f"\n# {language_goal}" @@ -741,6 +744,17 @@ def _parse_vlm_goal_from_state( goal_spec = json.loads(response) except json.JSONDecodeError as e: goal_spec = json.loads(response.replace('`', '').replace('json', '')) - - return self._curr_env._parse_goal_from_json(goal_spec, id_to_obj) + + for pred, args in goal_spec.items(): + for arg in args: + for obj_name in arg: + if 'robot' in obj_name: + id_to_obj[obj_name] = Object(obj_name, _robot_type) + elif any([name in obj_name for name in ['cup', 'box', 'bin']]): + id_to_obj[obj_name] = Object(obj_name, _container_type) + else: + id_to_obj[obj_name] = Object(obj_name, _movable_object_type) + + new_goal = self._curr_env._parse_goal_from_json(goal_spec, id_to_obj) + return new_goal