Skip to content

Commit

Permalink
clean up previous legacy version of individual VLM classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
lf-zhao committed May 18, 2024
1 parent dc6ae7b commit 0e0d6cf
Showing 1 changed file with 22 additions and 56 deletions.
78 changes: 22 additions & 56 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,36 +1196,23 @@ def _object_in_xy_classifier(state: State,
def _on_classifier(state: State, objects: Sequence[Object]) -> bool:
obj_on, obj_surface = objects

currently_visible = all([o in state.visible_objects for o in objects])
# If object not all visible and choose to use VLM,
# then use predicate values of previous time step
if CFG.spot_vlm_eval_predicate and not currently_visible:
# TODO: add all previous atoms to the state
raise NotImplementedError

# Call VLM to evaluate predicate value
elif CFG.spot_vlm_eval_predicate and currently_visible:
# predicate_str = f"""
# On({obj_on}, {obj_surface})
# (Whether {obj_on} is on {obj_surface} in the image?)
# """
# return vlm_predicate_classify(predicate_str, state)
# NOTE: Legacy version evaluate predicates individually
if CFG.spot_vlm_eval_predicate:
raise RuntimeError(
"VLM predicate classifier should be evaluated in batch!")

else:
# Check that the bottom of the object is close to the top of the surface.
expect = state.get(obj_surface,
"z") + state.get(obj_surface, "height") / 2
actual = state.get(obj_on, "z") - state.get(obj_on, "height") / 2
classification_val = abs(actual - expect) < _ONTOP_Z_THRESHOLD

# If so, check that the object is within the bounds of the surface.
if not _object_in_xy_classifier(
state, obj_on, obj_surface, buffer=_ONTOP_SURFACE_BUFFER):
return False
# Check that the bottom of the object is close to the top of the surface.
expect = state.get(obj_surface,
"z") + state.get(obj_surface, "height") / 2
actual = state.get(obj_on, "z") - state.get(obj_on, "height") / 2
classification_val = abs(actual - expect) < _ONTOP_Z_THRESHOLD

# If so, check that the object is within the bounds of the surface.
if not _object_in_xy_classifier(
state, obj_on, obj_surface, buffer=_ONTOP_SURFACE_BUFFER):
return False

return classification_val
return classification_val


def _top_above_classifier(state: State, objects: Sequence[Object]) -> bool:
Expand All @@ -1240,22 +1227,11 @@ def _top_above_classifier(state: State, objects: Sequence[Object]) -> bool:
def _inside_classifier(state: State, objects: Sequence[Object]) -> bool:
obj_in, obj_container = objects

# currently_visible = all([o in state.visible_objects for o in objects])
# # If object not all visible and choose to use VLM,
# # then use predicate values of previous time step
# if CFG.spot_vlm_eval_predicate and not currently_visible:
# # TODO: add all previous atoms to the state
# raise NotImplementedError
#
# # Call VLM to evaluate predicate value
# elif CFG.spot_vlm_eval_predicate and currently_visible:
# predicate_str = f"""
# Inside({obj_in}, {obj_container})
# (Whether {obj_in} is inside {obj_container} in the image?)
# """
# return vlm_predicate_classify(predicate_str, state)
#
# else:
# NOTE: Legacy version evaluate predicates individually
if CFG.spot_vlm_eval_predicate:
raise RuntimeError(
"VLM predicate classifier should be evaluated in batch!")

if not _object_in_xy_classifier(
state, obj_in, obj_container, buffer=_INSIDE_SURFACE_BUFFER):
return False
Expand Down Expand Up @@ -1369,20 +1345,10 @@ def _blocking_classifier(state: State, objects: Sequence[Object]) -> bool:
if blocker_obj == blocked_obj:
return False

# currently_visible = all([o in state.visible_objects for o in objects])
# # If object not all visible and choose to use VLM,
# # then use predicate values of previous time step
# if CFG.spot_vlm_eval_predicate and not currently_visible:
# # TODO: add all previous atoms to the state
# raise NotImplementedError
#
# # Call VLM to evaluate predicate value
# elif CFG.spot_vlm_eval_predicate and currently_visible:
# predicate_str = f"""
# (Whether {blocker_obj} is blocking {blocked_obj} for further manipulation in the image?)
# Blocking({blocker_obj}, {blocked_obj})
# """
# return vlm_predicate_classify(predicate_str, state)
# NOTE: Legacy version evaluate predicates individually
if CFG.spot_vlm_eval_predicate:
raise RuntimeError(
"VLM predicate classifier should be evaluated in batch!")

# Only consider draggable (non-placeable, movable) objects to be blockers.
if not blocker_obj.is_instance(_movable_object_type):
Expand Down

0 comments on commit 0e0d6cf

Please sign in to comment.