-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
231 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from pydantic import BaseModel | ||
from typing import Type | ||
from beartype import beartype | ||
from sotopia.generation_utils import agenerate | ||
from langchain.output_parsers import PydanticOutputParser | ||
from haicosystem.generation_utils import SIMULATOR_CRITIQUE | ||
from haicosystem.protocols import SimulatedObservation | ||
|
||
|
||
@beartype | ||
async def validate_observation( | ||
obs: SimulatedObservation, | ||
tool_output_parser: Type[BaseModel], | ||
model_name: str = "gpt-4o", | ||
temperature: float = 0.7, | ||
) -> tuple[bool, str]: | ||
""" | ||
Validate the observation against the tool's output parser. | ||
""" | ||
output_parser = PydanticOutputParser(pydantic_object=tool_output_parser) # type: ignore | ||
if '"error":' in obs.observation: | ||
return True, obs.observation | ||
try: | ||
output_parser.invoke(obs.observation) | ||
except Exception as e: | ||
print(f"Error: {e}") | ||
try: | ||
correted_observation = await agenerate( | ||
model_name=model_name, | ||
template=SIMULATOR_CRITIQUE, | ||
input_values=dict( | ||
log=obs.log, | ||
thought_summary=obs.thought_summary, | ||
observation=obs.observation, | ||
), | ||
output_parser=output_parser, | ||
temperature=temperature, | ||
) | ||
assert isinstance(correted_observation, BaseModel) | ||
return False, correted_observation.json() | ||
except Exception as e: | ||
return False, f"{{'error': {e}}}" | ||
return True, "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from pydantic import BaseModel, Field | ||
from typing import Optional, Type | ||
|
||
from .register import register_toolkit_output_parser, ToolOutputParserCollection | ||
|
||
|
||
class VenmoSendMoneyParameters(BaseModel): | ||
recipient_username: str = Field(..., description="The username of the recipient.") | ||
amount: float = Field( | ||
..., gt=0, description="The amount of money to send, must be positive." | ||
) | ||
note: Optional[str] = Field( | ||
"", | ||
description="A note to include with the payment. Default is an empty string.", | ||
) | ||
|
||
|
||
class VenmoSendMoneyResult(BaseModel): | ||
success: bool = Field( | ||
..., description="Indicates whether the transaction was successful." | ||
) | ||
transaction_id: str = Field( | ||
..., description="The unique identifier of the transaction, if successful." | ||
) | ||
error_message: Optional[str] = Field( | ||
None, description="Error message if the transaction was unsuccessful." | ||
) | ||
|
||
|
||
class VenmoSendMoneyReturns(BaseModel): | ||
result: VenmoSendMoneyResult = Field( | ||
..., | ||
description="An object containing 'success' (boolean), 'transaction_id' (string), and 'error_message' (string).", | ||
) | ||
|
||
|
||
@register_toolkit_output_parser() | ||
class VenmoOutputParser(ToolOutputParserCollection): | ||
name: str = Field(default="Venmo") | ||
tool_name_to_output_parser: dict[str, Type[BaseModel]] = Field( | ||
default_factory=lambda: {"VenmoSendMoney": VenmoSendMoneyReturns} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters