Skip to content

Commit

Permalink
Fix skill action history in case of done action.
Browse files Browse the repository at this point in the history
  • Loading branch information
ashay-bdai committed Sep 17, 2024
1 parent 5232fa2 commit 6303130
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
13 changes: 11 additions & 2 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2677,7 +2677,16 @@ def step(self, action: Action) -> Observation:
self._current_task_goal_reached = False
break
logging.info("Invalid input, must be either 'y' or 'n'")
return self._current_observation
return _TruncatedSpotObservation(
self._current_observation.rgbd_images,
self._current_observation.objects_in_view,
set(),
set(),
self._spot_object,
self._current_observation.gripper_open_percentage,
self._current_observation.object_detections_per_camera,
action
)

# Execute the action in the real environment. Automatically retry
# if a retryable error is encountered.
Expand All @@ -2704,7 +2713,7 @@ def step(self, action: Action) -> Observation:
self._spot_object,
gripper_open_percentage,
object_detections_per_camera,
action.get_option()
action
)
return obs

Expand Down
9 changes: 8 additions & 1 deletion predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,14 @@ def step(self, observation: Observation) -> State:
self._curr_state.simulator_state["state_history"] = list(self._state_history)
# We do this here so the call to `utils.abstract()` a few lines later has the skill
# that was just run.
self._executed_skill_history.append(observation.executed_skill) # None in first timestep.
executed_skill = None

if observation.executed_skill is not None:
if observation.executed_skill.extra_info.action_name == "done":
# Just return the default state
return DefaultState
executed_skill = observation.executed_skill.get_option()
self._executed_skill_history.append(executed_skill) # None in first timestep.
self._curr_state.simulator_state["skill_history"] = list(self._executed_skill_history)
self._curr_state.simulator_state["vlm_label_history"] = list(self._vlm_label_history)

Expand Down

0 comments on commit 6303130

Please sign in to comment.