diff --git a/predicators/structs.py b/predicators/structs.py index 8f113ce2af..f5363e9f3b 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -315,6 +315,7 @@ def _negated_classifier(self, state: State, return not self._classifier(state, objects) +@dataclass(frozen=True, order=True, repr=False) class VLMPredicate(Predicate): """Struct defining a predicate (a lifted classifier over states) that uses a VLM for evaluation. @@ -325,7 +326,10 @@ class VLMPredicate(Predicate): at once. """ + # A classifier is not needed for VLM predicates _classifier: Optional[Callable[[State, Sequence[Object]], bool]] = None + # An optional prompt additionally provided for each VLM predicate + prompt: Optional[str] = None def holds(self, state: State, objects: Sequence[Object]) -> bool: """Public method for getting predicate value. @@ -463,7 +467,6 @@ class VLMGroundAtom(GroundAtom): # NOTE: This subclasses GroundAtom to support VLM predicates and classifiers predicate: VLMPredicate - prompt: Optional[str] = None def get_query_str(self, without_type: bool = False) -> str: """Get a query string for this ground atom. @@ -478,8 +481,8 @@ def get_query_str(self, without_type: bool = False) -> str: else: string = str(self) - if self.prompt is not None: - string += f" [Prompt: {self.prompt}]" + if self.predicate.prompt is not None: + string += f" [Prompt: {self.predicate.prompt}]" return string def holds(self, state: State) -> bool: