Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lf-zhao committed May 4, 2024
1 parent 68ef57d commit 6560a94
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions predicators/envs/spot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, ClassVar, Collection, Dict, Iterator, List, \
Optional, Sequence, Set, Tuple, Any
Optional, Sequence, Set, Tuple

import PIL.Image
import matplotlib
Expand All @@ -22,6 +22,7 @@

from predicators import utils
from predicators.envs import BaseEnv
from predicators.pretrained_model_interface import OpenAIVLM
from predicators.settings import CFG
from predicators.spot_utils.perception.object_detection import \
AprilTagObjectDetectionID, KnownStaticObjectDetectionID, \
Expand Down Expand Up @@ -49,7 +50,6 @@
from predicators.structs import Action, EnvironmentTask, GoalDescription, \
GroundAtom, LiftedAtom, Object, Observation, Predicate, \
SpotActionExtraInfo, State, STRIPSOperator, Type, Variable
from predicators.pretrained_model_interface import OpenAIVLM

###############################################################################
# Base Class #
Expand Down Expand Up @@ -96,12 +96,7 @@ class _PartialPerceptionState(State):
in the classifier definitions for the dummy predicates
"""

# # DEBUG Add an additional field to store Spot images
# # This would be directly copied from the images in raw Observation
# # NOTE: This is only used when using VLM for predicate evaluation
# # NOTE: Performance aspect should be considered later
# cam_images: Optional[Dict[str, RGBDImageWithContext]] = None
# # TODO: it's still unclear how we select and store useful images!
# obs_images: Optional[Dict[str, RGBDImageWithContext]] = None

@property
def _simulator_state_predicates(self) -> Set[Predicate]:
Expand Down Expand Up @@ -1071,11 +1066,11 @@ def _generate_goal_description(self) -> GoalDescription:

def vlm_predicate_classify(question: str, state: State) -> bool:
"""Use VLM to evaluate (classify) a predicate in a given state."""
full_prompt = vlm_predicate_eval_prompt_prefix.format(
question=question
)
full_prompt = vlm_predicate_eval_prompt_prefix.format(question=question)
images_dict: Dict[str, RGBDImageWithContext] = state.camera_images
images = [PIL.Image.fromarray(v.rotated_rgb) for _, v in images_dict.items()]
images = [
PIL.Image.fromarray(v.rotated_rgb) for _, v in images_dict.items()
]

logging.info(f"VLM predicate evaluation for: {question}")
logging.info(f"Prompt: {full_prompt}")
Expand All @@ -1095,7 +1090,8 @@ def vlm_predicate_classify(question: str, state: State) -> bool:
elif vlm_response == "no":
return False
else:
logging.error(f"VLM response not understood: {vlm_response}. Treat as False.")
logging.error(
f"VLM response not understood: {vlm_response}. Treat as False.")
return False


Expand Down Expand Up @@ -1197,7 +1193,8 @@ def _on_classifier(state: State, objects: Sequence[Object]) -> bool:

else:
# Check that the bottom of the object is close to the top of the surface.
expect = state.get(obj_surface, "z") + state.get(obj_surface, "height") / 2
expect = state.get(obj_surface,
"z") + state.get(obj_surface, "height") / 2
actual = state.get(obj_on, "z") - state.get(obj_on, "height") / 2
classification_val = abs(actual - expect) < _ONTOP_Z_THRESHOLD

Expand Down Expand Up @@ -1575,7 +1572,8 @@ def _get_sweeping_surface_for_container(container: Object,
# is VLM perceptible or not.
# NOTE: candidates: on, inside, door opened, blocking, not blocked, ...
_VLM_EVAL_PREDICATES: {
_On, _Inside,
_On,
_Inside,
}


Expand Down

0 comments on commit 6560a94

Please sign in to comment.