Skip to content

Commit

Permalink
♻️ fix typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou committed Jul 10, 2024
1 parent a7592a6 commit 352e881
Show file tree
Hide file tree
Showing 27 changed files with 370 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ jobs:
poetry run pip install types-protobuf==4.24.0.4
poetry run mypy --install-types --non-interactive sotopia
poetry run mypy --install-types --non-interactive sotopia-chat
poetry run mypy --strict .
poetry run mypy --strict --exclude haicosystem/tools --exclude haicosystem/envs/llm_engine.py .
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pre-commit install
### New branch for each feature
`git checkout -b feature/feature-name` and PR to `main` branch.
### Before committing
Run `pytest` to make sure all tests pass (this will ensure dynamic typing passed with beartype) and `mypy --strict .` to check static typing.
Run `pytest` to make sure all tests pass (this will ensure dynamic typing passed with beartype) and `mypy --strict --exclude haicosystem/tools --exclude haicosystem/envs/llm_engine.py .` to check static typing.
(You can also run `pre-commit run --all-files` to run all checks)
### Check github action result
Check the github action result to make sure all tests pass. If not, fix the errors and push again.
17 changes: 10 additions & 7 deletions haicosystem/envs/hai_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
)

from beartype import beartype
from sotopia.envs.evaluators import ScriptEnvironmentResponse, _reduce

# TODO: fix the import error
from sotopia.envs.evaluators import ScriptEnvironmentResponse, _reduce # type: ignore
from collections import defaultdict
import logging
from haicosystem.envs.messages import SimulatedObservation

from haicosystem.envs.llm_engine import LlmGroundingEngine

from pydantic import Field

Expand Down Expand Up @@ -55,7 +57,7 @@ def unweighted_aggregate_response(
assert (
len(responses_dict["engine"]) == 0
) # TODO: allow multiple engine responses
responses_dict["engine"].append(response)
engine_response = ("engine", response)
else:
assert response[0] == "environment" or response[0].startswith("agent")
responses_dict[response[0]].append(response[1])
Expand All @@ -65,6 +67,7 @@ def unweighted_aggregate_response(
agent_2_responses: tuple[dict[str, float | int | bool], str] = ({}, "")
for k, v in responses_dict.items():
if k == "environment":
assert isinstance(v, list)
environment_responses = _reduce(v)
elif k == "engine":
continue
Expand Down Expand Up @@ -120,7 +123,7 @@ def unweighted_aggregate_response(
if agent_2_responses != ({}, "")
else None,
comments=comments,
observation=responses_dict["engine"][0],
observation=engine_response[1],
)


Expand All @@ -143,17 +146,17 @@ def __init__(
super().__init__(
env_profile=EnvironmentProfile(),
available_action_types=available_action_types,
action_order=action_order,
action_order=action_order, # type: ignore
model_name=model_name,
evaluators=evaluators,
terminal_evaluators=terminal_evaluators,
uuid_str=uuid_str,
)
self.profile = env_profile
self.profile: HaiEnvironmentProfile = env_profile # type: ignore
assert (
len(grounding_engines) == 1
) # temp implementation; only one grounding engine is supported
self.grounding_engine = grounding_engines[0]
self.grounding_engine: LlmGroundingEngine = grounding_engines[0] # type: ignore
self.engines_and_evaluators = grounding_engines + evaluators
self.profile.scenario = self.prepare_scenario(self.profile)

Expand Down
6 changes: 3 additions & 3 deletions haicosystem/envs/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ class SimulatedObservation(Message):
def to_natural_language(self) -> str:
return "Observation: \n" + self.observation

def __str__(self):
def __str__(self) -> str:
return self.observation

def __repr__(self):
def __repr__(self) -> str:
return self.observation


class LangchainAgentAction(AgentAction):
def action_api(self):
def action_api(self) -> None:
# TODO: Implement the action API
raise NotImplementedError
48 changes: 25 additions & 23 deletions haicosystem/envs/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
}


def insert_indent(s: str, indent: str = "\t", insert_first=True) -> str:
def insert_indent(s: str, indent: str = "\t", insert_first: bool = True) -> str:
prefix = indent if insert_first else ""
return prefix + s.rstrip("\n ").replace("\n", "\n" + indent) + "\n"

Expand Down Expand Up @@ -106,18 +106,20 @@ def _get_message(self, tool_name: str) -> str:
def special_fix_for_json_obj(s: str, pos: int) -> str:
"""Fix the string to make it a valid JSON object."""
if pos >= len(s):
return None
return ""
if s[pos] == "\\": # fix for {"xxx\_yyy": "zzz"}
broken = s[pos : pos + 2]
return s.replace(broken, broken[1])
elif s[pos : pos + 4] == "True":
return s.replace("True", "true")
elif s[pos : pos + 5] == "False":
return s.replace("False", "false")
return None
return ""


def get_first_json_object_str(s: str, enable_check=True, *args, **kwargs) -> str:
def get_first_json_object_str(
s: str, enable_check: bool = True, *args: Any, **kwargs: Any
) -> str:
"""Get the first JSON object from a string"""
# print("[DEBUG] raw action string:", s)
# allow ```json {JSON object} ``` format
Expand All @@ -136,7 +138,7 @@ def get_first_json_object_str(s: str, enable_check=True, *args, **kwargs) -> str
break
except json.JSONDecodeError as e:
fix_res = special_fix_for_json_obj(s, e.pos) # special fixes
if fix_res is not None:
if fix_res:
ret = s = fix_res
else:
s = s[: e.pos].rstrip() # take the first part
Expand All @@ -150,11 +152,11 @@ def get_first_json_object_str(s: str, enable_check=True, *args, **kwargs) -> str
return ret


def load_dict(tool_input: str) -> Dict[str, Any]:
def load_dict(tool_input: str) -> dict[str, Any]:
"""Load a dictionary from a string."""
# TODO: improve this, this seems to be the most robust way for now
try:
params = json.loads(tool_input, strict=False)
params: dict[str, Any] = json.loads(tool_input, strict=False)
# set strict=False to allow single quote, comments, escape characters, etc
return params
except Exception as e:
Expand All @@ -176,15 +178,15 @@ def match_type(value: Any, expected_type: type) -> bool:
return isinstance(value, int) and expected_type == float


def check_type(name: str, value: Any, expected_type: type):
def check_type(name: str, value: Any, expected_type: type) -> None:
if not match_type(value, expected_type):
raise ValueError(
f"Invalid type for {name}: {type(value).__name__}, "
f"expected {expected_type.__name__}."
)


def validate_inputs(parameters: List[ArgParameter], inputs: Dict[str, Any]):
def validate_inputs(parameters: List[ArgParameter], inputs: Dict[str, Any]) -> None:
"""Validate the inputs for a tool."""
remains = set(inputs.keys())
for param in parameters:
Expand All @@ -193,14 +195,14 @@ def validate_inputs(parameters: List[ArgParameter], inputs: Dict[str, Any]):
remains.remove(name)
if name in inputs:
expected_type = PRIMITIVE_TYPES[param["type"]]
check_type(name, inputs[name], expected_type)
check_type(name, inputs[name], expected_type) # type: ignore
elif param.get("required", False):
raise ValueError(f"Missing required parameter: {name}.")
if remains:
raise ValueError(f"Unexpected parameters: {remains}.")


def validate_outputs(returns: List[ArgReturn], outputs: Dict[str, Any]):
def validate_outputs(returns: List[ArgReturn], outputs: Dict[str, Any]) -> None:
"""Validate the outputs for a tool."""
if not isinstance(outputs, dict):
raise ValueError(
Expand All @@ -215,23 +217,23 @@ def validate_outputs(returns: List[ArgReturn], outputs: Dict[str, Any]):
remains.remove(name)
if name not in outputs:
raise ValueError(f"Missing return: {name}.")
expected_type = PRIMITIVE_TYPES[ret["type"]]
expected_type: type = PRIMITIVE_TYPES[ret["type"]] # type: ignore
check_type(name, outputs[name], expected_type)
if remains:
raise ValueError(f"Unexpected returns: {list(remains)}.")


def run_with_input_validation(
run_func: Callable,
inputs: Dict[str, str],
run_func: Callable[..., Any],
inputs: dict[str, str],
tool: BaseTool,
raw_inputs: dict,
**kwargs,
):
raw_inputs: dict[str, str],
**kwargs: Any,
) -> Any:
"""Run a tool with inputs, where the format of raw inputs is validated."""
try:
# params = load_dict(raw_inputs)
validate_inputs(tool.parameters, raw_inputs)
validate_inputs(tool.parameters, raw_inputs) # type: ignore
except Exception as e:
return DummyToolWithMessage().run(
f'{{"error": "InvalidRequestException: {e}"}}', **kwargs
Expand All @@ -240,16 +242,16 @@ def run_with_input_validation(


async def arun_with_input_validation(
run_func: Callable,
run_func: Callable[..., Any],
inputs: dict[str, str],
tool: BaseTool,
raw_inputs: dict[str, Any],
**kwargs,
):
**kwargs: Any,
) -> Any:
"""Run a tool with inputs, where the format of raw inputs is validated."""
try:
# params = load_dict(raw_inputs)
validate_inputs(tool.parameters, raw_inputs)
validate_inputs(tool.parameters, raw_inputs) # type: ignore
except Exception as e:
return await DummyToolWithMessage().arun(
f'{{"error": "InvalidRequestException: {e}"}}', **kwargs
Expand All @@ -271,7 +273,7 @@ def format_toolkit_dict(
add_risks: bool = True,
indent: str = "\t",
use_simple_tool_desc: bool = False,
):
) -> str:
"""Format a toolkit specified as a dictionary into a string."""
toolkit_name = toolkit[namekey]
desc = f"<{toolkit_name}> toolkit with following tool APIs:\n"
Expand Down
4 changes: 2 additions & 2 deletions haicosystem/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,13 @@ async def run_server(
n_agent=len(agents_model_dict),
env_params=env_params,
agents_params=[
{"model_name": model_name} if model_name != "bridge" else {}
{"model_name": model_name} if model_name != "bridge" else {} # type: ignore
for model_name in agents_model_dict.values()
],
)
episode_futures = [
arun_one_episode(
env=env_agent_combo[0],
env=env_agent_combo[0], # type: ignore
agent_list=env_agent_combo[1],
tag=tag,
push_to_db=push_to_db,
Expand Down
Empty file added stubs/absl/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions stubs/absl/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Any, Callable, Sequence


def parse_flags_with_usage(argv: Sequence[str]) -> Sequence[str]:
return argv


def run(
main: Callable[..., Any],
argv: list[str] | None = None,
flags_parser: Callable[[list[str]], Any] = parse_flags_with_usage,
) -> None:
pass
12 changes: 12 additions & 0 deletions stubs/absl/flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class FLAGS(object):
gin_search_paths: list[str]
gin_file: list[str]
gin_bindings: list[str]


def DEFINE_multi_string(name: str, default: list[str] | None, help: str) -> None:
pass


def DEFINE_list(name: str, default: list[str] | None, help: str) -> None:
pass
62 changes: 62 additions & 0 deletions stubs/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import enum
from dataclasses import dataclass
from typing import (
Any,
Dict,
Mapping,
Optional,
Sequence,
TypeVar,
Union,
)


class Split(enum.Enum): ...


class DownloadMode(enum.Enum): ...


class VerificationMode(enum.Enum): ...


@dataclass
class Version: ...


class DatasetDict(dict[Any, Any]): ...


Features = TypeVar("Features", bound=dict[str, Any])
DownloadConfig = TypeVar("DownloadConfig", bound=Any) # improve when used
Dataset = TypeVar("Dataset", bound=Any)
IterableDatasetDict = TypeVar("IterableDatasetDict", bound=dict[Any, Any])
IterableDataset = TypeVar("IterableDataset", bound=Any)


def load_dataset(
path: str,
name: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Optional[
Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
] = None,
split: Optional[Union[str, Split]] = None,
cache_dir: Optional[str] = None,
features: Optional[Features] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[Union[DownloadMode, str]] = None,
verification_mode: Optional[Union[VerificationMode, str]] = None,
ignore_verifications: str = "deprecated",
keep_in_memory: Optional[bool] = None,
save_infos: bool = False,
revision: Optional[Union[str, Version]] = None,
token: Optional[Union[bool, str]] = None,
use_auth_token: str = "deprecated",
task: str = "deprecated",
streaming: bool = False,
num_proc: Optional[int] = None,
storage_options: Optional[Dict[Any, Any]] = None,
**config_kwargs: Any,
) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]:
raise NotImplementedError
22 changes: 22 additions & 0 deletions stubs/gin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Callable, TypeVar

T = TypeVar("T", bound=Callable[..., Any])


def configurable(fn: T) -> T:
return fn


REQUIRED: object


def config_str() -> str:
return ""


def parse_config_files_and_bindings(*_: Any, **__: Any) -> None:
pass


def add_config_file_search_path(_: Any) -> None:
pass
5 changes: 5 additions & 0 deletions stubs/logzero/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class Logger(object):
def info(self, msg: str) -> None: ...


logger = Logger()
4 changes: 4 additions & 0 deletions stubs/logzero/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class Logger(object):
def info(self, msg: str) -> None: ...

logger = Logger()
2 changes: 2 additions & 0 deletions stubs/names/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def get_first_name() -> str:
raise NotImplementedError
Empty file added stubs/otree/__init__.pyi
Empty file.
Loading

0 comments on commit 352e881

Please sign in to comment.