Skip to content

Commit

Permalink
debugging new goals via vlm
Browse files Browse the repository at this point in the history
  • Loading branch information
wmcclinton committed Oct 29, 2024
1 parent 8efea46 commit 974457e
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 6 deletions.
Binary file modified image_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified image_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified image_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified image_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified image_4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified image_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions predicators/cogman.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,34 @@ def step(self, observation: Observation) -> Optional[Action]:
self._exec_monitor.env_task = self._current_env_task
if self._exec_monitor.step(state):
logging.info("[CogMan] Replanning triggered.")
import jellyfish
def map_goal_to_state(goal_predicates, state_data):
goal_to_state_mapping = {}
state_usage_count = {state_obj: 0 for state_obj in state_data.keys()}
for pred in goal_predicates:
for goal_obj in pred.objects:
goal_obj_name = str(goal_obj)
closest_state_obj = None
min_distance = float('inf')
for state_obj in state_data.keys():
state_obj_name = str(state_obj)
distance = jellyfish.levenshtein_distance(goal_obj_name, state_obj_name)
if distance < min_distance:
min_distance = distance
closest_state_obj = state_obj
if state_usage_count[closest_state_obj] > 0:
virtual_obj = Object(f"{closest_state_obj}_{state_usage_count[closest_state_obj]}", closest_state_obj.type)
goal_to_state_mapping[goal_obj] = virtual_obj
else:
goal_to_state_mapping[goal_obj] = closest_state_obj
state_usage_count[closest_state_obj] += 1
return goal_to_state_mapping
mapping = map_goal_to_state(self._exec_monitor._curr_goal, state.data)
new_goal = set()
for pred in self._exec_monitor._curr_goal:
new_goal.add(GroundAtom(pred.predicate, [mapping[obj] for obj in pred.objects]))
import ipdb; ipdb.set_trace()
self._current_goal = new_goal
assert self._current_goal is not None
task = Task(state, self._current_goal)
self._reset_policy(task)
Expand Down
5 changes: 4 additions & 1 deletion predicators/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,10 @@ def _parse_goal_from_json(self, spec: Dict[str, List[List[str]]],
for pred, args in pred_to_args.items():
for id_args in args:
obj_args = [id_to_obj[a] for a in id_args]
goal_atom = GroundAtom(pred, obj_args)
try:
goal_atom = GroundAtom(pred, obj_args)
except AssertionError as e:
import ipdb; ipdb.set_trace()
goal.add(goal_atom)
return goal

Expand Down
1 change: 1 addition & 0 deletions predicators/execution_monitoring/base_execution_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class BaseExecutionMonitor(abc.ABC):
def __init__(self) -> None:
self._approach_info: List[Any] = []
self._curr_plan_timestep = 0
self._curr_goal = None

@classmethod
@abc.abstractmethod
Expand Down
11 changes: 8 additions & 3 deletions predicators/execution_monitoring/expected_atoms_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@ def step(self, state: State) -> bool:
unsat_atoms = {a for a in next_expected_atoms if not a.holds(state)}
# Check goal
assert self.perceiver is not None and self.env_task is not None
goal = self.perceiver._create_goal(state, self.env_task.goal_description)
import ipdb; ipdb.set_trace()
#
new_goal = self.perceiver._create_goal(state, self.env_task.goal_description)
if new_goal != self._curr_goal:
logging.info(
"Expected atoms execution monitor triggered replanning "
"because the goal has changed.")
logging.info(f"Old goal: {self._curr_goal}")
logging.info(f"New goal: {new_goal}")
self._curr_goal = new_goal
if not unsat_atoms:
return False
logging.info(
Expand Down
18 changes: 16 additions & 2 deletions predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,9 @@ def _parse_vlm_goal_from_state(
self, state: State, language_goal: str,
id_to_obj: Dict[str, Object]) -> Set[GroundAtom]:
"""Helper for parsing language-based goals from JSON task specs."""
###
language_goal = 'place empty cups into the plastic_bin'
###
object_names = set(id_to_obj)
prompt_prefix = self._get_language_goal_prompt_prefix(object_names)
prompt = prompt_prefix + f"\n# {language_goal}"
Expand All @@ -741,6 +744,17 @@ def _parse_vlm_goal_from_state(
goal_spec = json.loads(response)
except json.JSONDecodeError as e:
goal_spec = json.loads(response.replace('`', '').replace('json', ''))

return self._curr_env._parse_goal_from_json(goal_spec, id_to_obj)

for pred, args in goal_spec.items():
for arg in args:
for obj_name in arg:
if 'robot' in obj_name:
id_to_obj[obj_name] = Object(obj_name, _robot_type)
elif any([name in obj_name for name in ['cup', 'box', 'bin']]):
id_to_obj[obj_name] = Object(obj_name, _container_type)
else:
id_to_obj[obj_name] = Object(obj_name, _movable_object_type)

new_goal = self._curr_env._parse_goal_from_json(goal_spec, id_to_obj)
return new_goal

0 comments on commit 974457e

Please sign in to comment.