Skip to content

Commit

Permalink
finish up implementation!
Browse files Browse the repository at this point in the history
  • Loading branch information
NishanthJKumar committed Sep 13, 2024
1 parent 803683d commit 369b399
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
7 changes: 4 additions & 3 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
update_pbrspot_robot_conf, verify_estop
from predicators.structs import Action, EnvironmentTask, GoalDescription, \
GroundAtom, LiftedAtom, Object, Observation, Predicate, \
SpotActionExtraInfo, State, STRIPSOperator, Type, Variable
SpotActionExtraInfo, State, STRIPSOperator, Type, Variable, _Option

###############################################################################
# Base Class #
Expand Down Expand Up @@ -106,6 +106,7 @@ class _TruncatedSpotObservation:
# # A placeholder until all predicates have classifiers
# nonpercept_atoms: Set[GroundAtom]
# nonpercept_predicates: Set[Predicate]
executed_skill: Optional[_Option] = None


class _PartialPerceptionState(State):
Expand Down Expand Up @@ -2547,7 +2548,7 @@ def _actively_construct_env_task(self) -> EnvironmentTask:
objects_in_view = []
obs = _TruncatedSpotObservation(rgbd_images, set(objects_in_view),
set(), set(), self._spot_object,
gripper_open_percentage)
gripper_open_percentage, None)
goal_description = self._generate_goal_description()
task = EnvironmentTask(obs, goal_description)
return task
Expand Down Expand Up @@ -2611,7 +2612,7 @@ def step(self, action: Action) -> Observation:
objects_in_view = []
obs = _TruncatedSpotObservation(rgbd_images, set(objects_in_view),
set(), set(), self._spot_object,
gripper_open_percentage)
gripper_open_percentage, action.get_option())
return obs


Expand Down
5 changes: 4 additions & 1 deletion predicators/perception/spot_perceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
get_allowed_map_regions, load_spot_metadata, object_to_top_down_geom
from predicators.structs import Action, DefaultState, EnvironmentTask, \
GoalDescription, GroundAtom, Object, Observation, Predicate, \
SpotActionExtraInfo, State, Task, Video
SpotActionExtraInfo, State, Task, Video, _Option


class SpotPerceiver(BasePerceiver):
Expand Down Expand Up @@ -621,6 +621,7 @@ def __init__(self) -> None:
self._ordered_objects: List[Object] = [] # list of all known objects
self._state_history: Deque[State] = deque(
maxlen=5) # TODO: (njk) I just picked an arbitrary constant here! Didn't properly consider this.
self._executed_skill_history: Deque[_Option] = deque(maxlen=5)
# # Keep track of objects that are contained (out of view) in another
# # object, like a bag or bucket. This is important not only for gremlins
# # but also for small changes in the container's perceived pose.
Expand Down Expand Up @@ -707,6 +708,8 @@ def step(self, observation: Observation) -> State:
ret_state = self._curr_state.copy()
ret_state.simulator_state["state_history"] = list(self._state_history)
self._state_history.append(ret_state)
self._executed_skill_history.append(observation.executed_skill)
ret_state.simulator_state["skill_history"] = list(self._executed_skill_history)
return ret_state

def _create_state(self) -> State:
Expand Down
5 changes: 1 addition & 4 deletions predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2580,12 +2580,9 @@ def query_vlm_for_atom_vals(
if "state_history" in state.simulator_state:
previous_states = state.simulator_state["state_history"]
state_imgs_history = [state.simulator_state["images"] for state in previous_states]

# TODO: need to somehow get the history of skills executed; i'll think about this more and then implement.

vlm_atoms = sorted(vlm_atoms)
atom_queries_str = [atom.get_vlm_query_str() for atom in vlm_atoms]
vlm_query_str, imgs = get_prompt_for_vlm_state_labelling(CFG.vlm_test_time_atom_label_prompt_type, atom_queries_str, state.simulator_state["vlm_atoms_history"], state_imgs_history, [], skill_history)
vlm_query_str, imgs = get_prompt_for_vlm_state_labelling(CFG.vlm_test_time_atom_label_prompt_type, atom_queries_str, state.simulator_state["vlm_atoms_history"], state_imgs_history, [], state.simulator_state["skill_history"])
if vlm is None:
vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover.
vlm_input_imgs = \
Expand Down

0 comments on commit 369b399

Please sign in to comment.