Skip to content

Commit

Permalink
add prompt field to VLM predicate / atom
Browse files Browse the repository at this point in the history
  • Loading branch information
lf-zhao committed May 10, 2024
1 parent d6daa39 commit 7a4d20e
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions predicators/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def _negated_classifier(self, state: State,
return not self._classifier(state, objects)


@dataclass(frozen=True, order=True, repr=False)
class VLMPredicate(Predicate):
"""Struct defining a predicate (a lifted classifier over states) that uses
a VLM for evaluation.
Expand All @@ -325,7 +326,10 @@ class VLMPredicate(Predicate):
at once.
"""

# A classifier is not needed for VLM predicates
_classifier: Optional[Callable[[State, Sequence[Object]], bool]] = None
# An optional prompt additionally provided for each VLM predicate
prompt: Optional[str] = None

def holds(self, state: State, objects: Sequence[Object]) -> bool:
"""Public method for getting predicate value.
Expand Down Expand Up @@ -463,7 +467,6 @@ class VLMGroundAtom(GroundAtom):

# NOTE: This subclasses GroundAtom to support VLM predicates and classifiers
predicate: VLMPredicate
prompt: Optional[str] = None

def get_query_str(self, without_type: bool = False) -> str:
"""Get a query string for this ground atom.
Expand All @@ -478,8 +481,8 @@ def get_query_str(self, without_type: bool = False) -> str:
else:
string = str(self)

if self.prompt is not None:
string += f" [Prompt: {self.prompt}]"
if self.predicate.prompt is not None:
string += f" [Prompt: {self.predicate.prompt}]"
return string

def holds(self, state: State) -> bool:
Expand Down

0 comments on commit 7a4d20e

Please sign in to comment.