diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 0886b29..5c74cbd 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -42,4 +42,4 @@ jobs: poetry run pip install types-protobuf==4.24.0.4 poetry run mypy --install-types --non-interactive sotopia poetry run mypy --install-types --non-interactive sotopia-chat - poetry run mypy --strict . + poetry run mypy --strict --exclude haicosystem/tools --exclude haicosystem/envs/llm_engine.py . diff --git a/README.md b/README.md index e8a52bb..977765d 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ pre-commit install ### New branch for each feature `git checkout -b feature/feature-name` and PR to `main` branch. ### Before committing -Run `pytest` to make sure all tests pass (this will ensure dynamic typing passed with beartype) and `mypy --strict .` to check static typing. +Run `pytest` to make sure all tests pass (this will ensure dynamic typing passed with beartype) and `mypy --strict --exclude haicosystem/tools --exclude haicosystem/envs/llm_engine.py .` to check static typing. (You can also run `pre-commit run --all-files` to run all checks) ### Check github action result Check the github action result to make sure all tests pass. If not, fix the errors and push again. diff --git a/haicosystem/envs/hai_env.py b/haicosystem/envs/hai_env.py index afda235..0da2000 100644 --- a/haicosystem/envs/hai_env.py +++ b/haicosystem/envs/hai_env.py @@ -17,11 +17,13 @@ ) from beartype import beartype -from sotopia.envs.evaluators import ScriptEnvironmentResponse, _reduce + +# TODO: fix the import error +from sotopia.envs.evaluators import ScriptEnvironmentResponse, _reduce # type: ignore from collections import defaultdict import logging from haicosystem.envs.messages import SimulatedObservation - +from haicosystem.envs.llm_engine import LlmGroundingEngine from pydantic import Field @@ -55,7 +57,7 @@ def unweighted_aggregate_response( assert ( len(responses_dict["engine"]) == 0 ) # TODO: allow multiple engine responses - responses_dict["engine"].append(response) + engine_response = ("engine", response) else: assert response[0] == "environment" or response[0].startswith("agent") responses_dict[response[0]].append(response[1]) @@ -65,6 +67,7 @@ def unweighted_aggregate_response( agent_2_responses: tuple[dict[str, float | int | bool], str] = ({}, "") for k, v in responses_dict.items(): if k == "environment": + assert isinstance(v, list) environment_responses = _reduce(v) elif k == "engine": continue @@ -120,7 +123,7 @@ def unweighted_aggregate_response( if agent_2_responses != ({}, "") else None, comments=comments, - observation=responses_dict["engine"][0], + observation=engine_response[1], ) @@ -143,17 +146,17 @@ def __init__( super().__init__( env_profile=EnvironmentProfile(), available_action_types=available_action_types, - action_order=action_order, + action_order=action_order, # type: ignore model_name=model_name, evaluators=evaluators, terminal_evaluators=terminal_evaluators, uuid_str=uuid_str, ) - self.profile = env_profile + self.profile: HaiEnvironmentProfile = env_profile # type: ignore assert ( len(grounding_engines) == 1 ) # temp implementation; only one grounding engine is supported - self.grounding_engine = grounding_engines[0] + self.grounding_engine: LlmGroundingEngine = grounding_engines[0] # type: ignore self.engines_and_evaluators = grounding_engines + evaluators self.profile.scenario = self.prepare_scenario(self.profile) diff --git a/haicosystem/envs/messages.py b/haicosystem/envs/messages.py index b7ebccd..fbcc89d 100644 --- a/haicosystem/envs/messages.py +++ b/haicosystem/envs/messages.py @@ -13,14 +13,14 @@ class SimulatedObservation(Message): def to_natural_language(self) -> str: return "Observation: \n" + self.observation - def __str__(self): + def __str__(self) -> str: return self.observation - def __repr__(self): + def __repr__(self) -> str: return self.observation class LangchainAgentAction(AgentAction): - def action_api(self): + def action_api(self) -> None: # TODO: Implement the action API raise NotImplementedError diff --git a/haicosystem/envs/tool.py b/haicosystem/envs/tool.py index 85fef3f..e09ae2f 100644 --- a/haicosystem/envs/tool.py +++ b/haicosystem/envs/tool.py @@ -24,7 +24,7 @@ } -def insert_indent(s: str, indent: str = "\t", insert_first=True) -> str: +def insert_indent(s: str, indent: str = "\t", insert_first: bool = True) -> str: prefix = indent if insert_first else "" return prefix + s.rstrip("\n ").replace("\n", "\n" + indent) + "\n" @@ -106,7 +106,7 @@ def _get_message(self, tool_name: str) -> str: def special_fix_for_json_obj(s: str, pos: int) -> str: """Fix the string to make it a valid JSON object.""" if pos >= len(s): - return None + return "" if s[pos] == "\\": # fix for {"xxx\_yyy": "zzz"} broken = s[pos : pos + 2] return s.replace(broken, broken[1]) @@ -114,10 +114,12 @@ def special_fix_for_json_obj(s: str, pos: int) -> str: return s.replace("True", "true") elif s[pos : pos + 5] == "False": return s.replace("False", "false") - return None + return "" -def get_first_json_object_str(s: str, enable_check=True, *args, **kwargs) -> str: +def get_first_json_object_str( + s: str, enable_check: bool = True, *args: Any, **kwargs: Any +) -> str: """Get the first JSON object from a string""" # print("[DEBUG] raw action string:", s) # allow ```json {JSON object} ``` format @@ -136,7 +138,7 @@ def get_first_json_object_str(s: str, enable_check=True, *args, **kwargs) -> str break except json.JSONDecodeError as e: fix_res = special_fix_for_json_obj(s, e.pos) # special fixes - if fix_res is not None: + if fix_res: ret = s = fix_res else: s = s[: e.pos].rstrip() # take the first part @@ -150,11 +152,11 @@ def get_first_json_object_str(s: str, enable_check=True, *args, **kwargs) -> str return ret -def load_dict(tool_input: str) -> Dict[str, Any]: +def load_dict(tool_input: str) -> dict[str, Any]: """Load a dictionary from a string.""" # TODO: improve this, this seems to be the most robust way for now try: - params = json.loads(tool_input, strict=False) + params: dict[str, Any] = json.loads(tool_input, strict=False) # set strict=False to allow single quote, comments, escape characters, etc return params except Exception as e: @@ -176,7 +178,7 @@ def match_type(value: Any, expected_type: type) -> bool: return isinstance(value, int) and expected_type == float -def check_type(name: str, value: Any, expected_type: type): +def check_type(name: str, value: Any, expected_type: type) -> None: if not match_type(value, expected_type): raise ValueError( f"Invalid type for {name}: {type(value).__name__}, " @@ -184,7 +186,7 @@ def check_type(name: str, value: Any, expected_type: type): ) -def validate_inputs(parameters: List[ArgParameter], inputs: Dict[str, Any]): +def validate_inputs(parameters: List[ArgParameter], inputs: Dict[str, Any]) -> None: """Validate the inputs for a tool.""" remains = set(inputs.keys()) for param in parameters: @@ -193,14 +195,14 @@ def validate_inputs(parameters: List[ArgParameter], inputs: Dict[str, Any]): remains.remove(name) if name in inputs: expected_type = PRIMITIVE_TYPES[param["type"]] - check_type(name, inputs[name], expected_type) + check_type(name, inputs[name], expected_type) # type: ignore elif param.get("required", False): raise ValueError(f"Missing required parameter: {name}.") if remains: raise ValueError(f"Unexpected parameters: {remains}.") -def validate_outputs(returns: List[ArgReturn], outputs: Dict[str, Any]): +def validate_outputs(returns: List[ArgReturn], outputs: Dict[str, Any]) -> None: """Validate the outputs for a tool.""" if not isinstance(outputs, dict): raise ValueError( @@ -215,23 +217,23 @@ def validate_outputs(returns: List[ArgReturn], outputs: Dict[str, Any]): remains.remove(name) if name not in outputs: raise ValueError(f"Missing return: {name}.") - expected_type = PRIMITIVE_TYPES[ret["type"]] + expected_type: type = PRIMITIVE_TYPES[ret["type"]] # type: ignore check_type(name, outputs[name], expected_type) if remains: raise ValueError(f"Unexpected returns: {list(remains)}.") def run_with_input_validation( - run_func: Callable, - inputs: Dict[str, str], + run_func: Callable[..., Any], + inputs: dict[str, str], tool: BaseTool, - raw_inputs: dict, - **kwargs, -): + raw_inputs: dict[str, str], + **kwargs: Any, +) -> Any: """Run a tool with inputs, where the format of raw inputs is validated.""" try: # params = load_dict(raw_inputs) - validate_inputs(tool.parameters, raw_inputs) + validate_inputs(tool.parameters, raw_inputs) # type: ignore except Exception as e: return DummyToolWithMessage().run( f'{{"error": "InvalidRequestException: {e}"}}', **kwargs @@ -240,16 +242,16 @@ def run_with_input_validation( async def arun_with_input_validation( - run_func: Callable, + run_func: Callable[..., Any], inputs: dict[str, str], tool: BaseTool, raw_inputs: dict[str, Any], - **kwargs, -): + **kwargs: Any, +) -> Any: """Run a tool with inputs, where the format of raw inputs is validated.""" try: # params = load_dict(raw_inputs) - validate_inputs(tool.parameters, raw_inputs) + validate_inputs(tool.parameters, raw_inputs) # type: ignore except Exception as e: return await DummyToolWithMessage().arun( f'{{"error": "InvalidRequestException: {e}"}}', **kwargs @@ -271,7 +273,7 @@ def format_toolkit_dict( add_risks: bool = True, indent: str = "\t", use_simple_tool_desc: bool = False, -): +) -> str: """Format a toolkit specified as a dictionary into a string.""" toolkit_name = toolkit[namekey] desc = f"<{toolkit_name}> toolkit with following tool APIs:\n" diff --git a/haicosystem/server.py b/haicosystem/server.py index d07fb1c..2274eba 100644 --- a/haicosystem/server.py +++ b/haicosystem/server.py @@ -300,13 +300,13 @@ async def run_server( n_agent=len(agents_model_dict), env_params=env_params, agents_params=[ - {"model_name": model_name} if model_name != "bridge" else {} + {"model_name": model_name} if model_name != "bridge" else {} # type: ignore for model_name in agents_model_dict.values() ], ) episode_futures = [ arun_one_episode( - env=env_agent_combo[0], + env=env_agent_combo[0], # type: ignore agent_list=env_agent_combo[1], tag=tag, push_to_db=push_to_db, diff --git a/stubs/absl/__init__.py b/stubs/absl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stubs/absl/app.py b/stubs/absl/app.py new file mode 100644 index 0000000..35132bb --- /dev/null +++ b/stubs/absl/app.py @@ -0,0 +1,13 @@ +from typing import Any, Callable, Sequence + + +def parse_flags_with_usage(argv: Sequence[str]) -> Sequence[str]: + return argv + + +def run( + main: Callable[..., Any], + argv: list[str] | None = None, + flags_parser: Callable[[list[str]], Any] = parse_flags_with_usage, +) -> None: + pass diff --git a/stubs/absl/flags.py b/stubs/absl/flags.py new file mode 100644 index 0000000..8d18b40 --- /dev/null +++ b/stubs/absl/flags.py @@ -0,0 +1,12 @@ +class FLAGS(object): + gin_search_paths: list[str] + gin_file: list[str] + gin_bindings: list[str] + + +def DEFINE_multi_string(name: str, default: list[str] | None, help: str) -> None: + pass + + +def DEFINE_list(name: str, default: list[str] | None, help: str) -> None: + pass diff --git a/stubs/datasets/__init__.py b/stubs/datasets/__init__.py new file mode 100644 index 0000000..e17cc66 --- /dev/null +++ b/stubs/datasets/__init__.py @@ -0,0 +1,62 @@ +import enum +from dataclasses import dataclass +from typing import ( + Any, + Dict, + Mapping, + Optional, + Sequence, + TypeVar, + Union, +) + + +class Split(enum.Enum): ... + + +class DownloadMode(enum.Enum): ... + + +class VerificationMode(enum.Enum): ... + + +@dataclass +class Version: ... + + +class DatasetDict(dict[Any, Any]): ... + + +Features = TypeVar("Features", bound=dict[str, Any]) +DownloadConfig = TypeVar("DownloadConfig", bound=Any) # improve when used +Dataset = TypeVar("Dataset", bound=Any) +IterableDatasetDict = TypeVar("IterableDatasetDict", bound=dict[Any, Any]) +IterableDataset = TypeVar("IterableDataset", bound=Any) + + +def load_dataset( + path: str, + name: Optional[str] = None, + data_dir: Optional[str] = None, + data_files: Optional[ + Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]] + ] = None, + split: Optional[Union[str, Split]] = None, + cache_dir: Optional[str] = None, + features: Optional[Features] = None, + download_config: Optional[DownloadConfig] = None, + download_mode: Optional[Union[DownloadMode, str]] = None, + verification_mode: Optional[Union[VerificationMode, str]] = None, + ignore_verifications: str = "deprecated", + keep_in_memory: Optional[bool] = None, + save_infos: bool = False, + revision: Optional[Union[str, Version]] = None, + token: Optional[Union[bool, str]] = None, + use_auth_token: str = "deprecated", + task: str = "deprecated", + streaming: bool = False, + num_proc: Optional[int] = None, + storage_options: Optional[Dict[Any, Any]] = None, + **config_kwargs: Any, +) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]: + raise NotImplementedError diff --git a/stubs/gin/__init__.py b/stubs/gin/__init__.py new file mode 100644 index 0000000..25005aa --- /dev/null +++ b/stubs/gin/__init__.py @@ -0,0 +1,22 @@ +from typing import Any, Callable, TypeVar + +T = TypeVar("T", bound=Callable[..., Any]) + + +def configurable(fn: T) -> T: + return fn + + +REQUIRED: object + + +def config_str() -> str: + return "" + + +def parse_config_files_and_bindings(*_: Any, **__: Any) -> None: + pass + + +def add_config_file_search_path(_: Any) -> None: + pass diff --git a/stubs/logzero/__init__.py b/stubs/logzero/__init__.py new file mode 100644 index 0000000..62500ec --- /dev/null +++ b/stubs/logzero/__init__.py @@ -0,0 +1,5 @@ +class Logger(object): + def info(self, msg: str) -> None: ... + + +logger = Logger() diff --git a/stubs/logzero/__init__.pyi b/stubs/logzero/__init__.pyi new file mode 100644 index 0000000..e3a9810 --- /dev/null +++ b/stubs/logzero/__init__.pyi @@ -0,0 +1,4 @@ +class Logger(object): + def info(self, msg: str) -> None: ... + +logger = Logger() diff --git a/stubs/names/__init__.py b/stubs/names/__init__.py new file mode 100644 index 0000000..73a9882 --- /dev/null +++ b/stubs/names/__init__.py @@ -0,0 +1,2 @@ +def get_first_name() -> str: + raise NotImplementedError diff --git a/stubs/otree/__init__.pyi b/stubs/otree/__init__.pyi new file mode 100644 index 0000000..e69de29 diff --git a/stubs/otree/api.pyi b/stubs/otree/api.pyi new file mode 100644 index 0000000..9ce7f9c --- /dev/null +++ b/stubs/otree/api.pyi @@ -0,0 +1,19 @@ +# stubs/otree/api.pyi + +from typing import Any + +class BaseConstants: + pass + +class BaseSubsession: + session: Any + +class BaseGroup: + pass + +class BasePlayer: + id: int + participant: Any + +class Page: + pass diff --git a/stubs/pettingzoo/__init__.pyi b/stubs/pettingzoo/__init__.pyi new file mode 100644 index 0000000..e69de29 diff --git a/stubs/pettingzoo/utils/__init__.pyi b/stubs/pettingzoo/utils/__init__.pyi new file mode 100644 index 0000000..e69de29 diff --git a/stubs/pettingzoo/utils/env.pyi b/stubs/pettingzoo/utils/env.pyi new file mode 100644 index 0000000..fdf450d --- /dev/null +++ b/stubs/pettingzoo/utils/env.pyi @@ -0,0 +1,43 @@ +from typing import Any, Generic, TypeVar + +import gymnasium + +ObsType = TypeVar("ObsType") +ActionType = TypeVar("ActionType") +AgentID = TypeVar("AgentID") + +class ParallelEnv(Generic[AgentID, ObsType, ActionType]): + agents: list[AgentID] + possible_agents: list[AgentID] + observation_spaces: dict[ + AgentID, gymnasium.spaces.Space[ObsType] + ] # Observation space for each agent + action_spaces: dict[AgentID, gymnasium.spaces.Space[Any]] + + def reset( + self, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> dict[AgentID, ObsType]: ... + def seed(self, seed: int | None = None) -> None: ... + def step( + self, actions: dict[AgentID, ActionType] + ) -> tuple[ + dict[AgentID, ObsType], + dict[str, float], + dict[str, bool], + dict[str, bool], + dict[str, dict[str, Any]], + ]: ... + def render(self) -> None | Any: ... + def close(self) -> None: ... + def state(self) -> Any: ... + def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space[ObsType]: ... + def action_space(self, agent: AgentID) -> gymnasium.spaces.Space[ActionType]: ... + @property + def num_agents(self) -> int: ... + @property + def max_num_agents(self) -> int: ... + def __str__(self) -> str: ... + @property + def unwrapped(self) -> ParallelEnv[AgentID, ObsType, ActionType]: ... diff --git a/stubs/redis_om/__init__.pyi b/stubs/redis_om/__init__.pyi new file mode 100644 index 0000000..230b0e4 --- /dev/null +++ b/stubs/redis_om/__init__.pyi @@ -0,0 +1,36 @@ +import abc +from typing import Any, Generator, TypeVar + +from pydantic import BaseModel, Field +from pydantic.main import ModelMetaclass +from redis_om.model.model import FindQuery + +InheritedJsonModel = TypeVar("InheritedJsonModel", bound="JsonModel") + +class ModelMeta(ModelMetaclass): ... + +class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): + pk: str | None = Field(default=None, primary_key=True) + + @classmethod + def delete(cls, pk: Any) -> None: ... + def expire(self, num_seconds: int) -> None: ... # pipeline arg can be added here + +class HashModel(RedisModel, abc.ABC): + @classmethod + def get(cls, pk: Any) -> "HashModel": ... + def save(self) -> None: ... + +class JsonModel(RedisModel, abc.ABC): + @classmethod + def get(cls: type[InheritedJsonModel], pk: Any) -> InheritedJsonModel: ... + @classmethod + def all_pks(cls) -> Generator[str, None, None]: ... + @classmethod + def find(cls, *args: Any, **kwargs: Any) -> FindQuery: ... + def save(self) -> None: ... + +class EmbeddedJsonModel(JsonModel): ... + +class Migrator: + def run(self) -> None: ... diff --git a/stubs/redis_om/model/__init__.py b/stubs/redis_om/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stubs/redis_om/model/model.pyi b/stubs/redis_om/model/model.pyi new file mode 100644 index 0000000..957c16e --- /dev/null +++ b/stubs/redis_om/model/model.pyi @@ -0,0 +1,43 @@ +from typing import AbstractSet, Any, Dict, Mapping, Optional, Union + +from pydantic.fields import Undefined, UndefinedType +from pydantic.typing import NoArgAnyCallable +from redis_om import JsonModel + +class NotFoundError(Exception): + """Raised when a query found no results.""" + +def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: Optional[str] = None, + gt: Optional[str] = None, + ge: Optional[str] = None, + lt: Optional[str] = None, + le: Optional[str] = None, + multiple_of: Optional[str] = None, + min_items: Optional[str] = None, + max_items: Optional[str] = None, + min_length: Optional[str] = None, + max_length: Optional[str] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + primary_key: bool = False, + sortable: Union[bool, UndefinedType] = Undefined, + index: Union[bool, UndefinedType] = Undefined, + full_text_search: Union[bool, UndefinedType] = Undefined, + schema_extra: Optional[Dict[str, Any]] = None, +) -> Any: ... + +class FindQuery: + def all(self) -> list[JsonModel]: ... diff --git a/stubs/scipy/__init__.pyi b/stubs/scipy/__init__.pyi new file mode 100644 index 0000000..e69de29 diff --git a/stubs/scipy/stats/__init__.pyi b/stubs/scipy/stats/__init__.pyi new file mode 100644 index 0000000..046204c --- /dev/null +++ b/stubs/scipy/stats/__init__.pyi @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from typing import Literal + +from numpy.typing import ArrayLike +from scipy.stats._resampling import ResamplingMethod +from scipy.stats._result_classes import PearsonRResult + +@dataclass +class SignificanceResult: + statistic: float + pvalue: float + correlation: float + +def pearsonr( + x: ArrayLike, + y: ArrayLike, + *, + alternative: str = ..., + method: ResamplingMethod | None = ..., +) -> PearsonRResult: ... +def spearmanr( + a: ArrayLike, + b: ArrayLike | None = ..., + axis: int | None = ..., + nan_policy: Literal["propagate", "raise", "omit"] | None = ..., + alternative: Literal["two-sided", "less", "greater"] | None = ..., +) -> SignificanceResult: ... diff --git a/stubs/scipy/stats/_resampling.pyi b/stubs/scipy/stats/_resampling.pyi new file mode 100644 index 0000000..078942a --- /dev/null +++ b/stubs/scipy/stats/_resampling.pyi @@ -0,0 +1,6 @@ +from dataclasses import dataclass + +@dataclass +class ResamplingMethod: + n_resamples: int + batch: int | None diff --git a/stubs/scipy/stats/_result_classes.pyi b/stubs/scipy/stats/_result_classes.pyi new file mode 100644 index 0000000..8379f7b --- /dev/null +++ b/stubs/scipy/stats/_result_classes.pyi @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import Any, NamedTuple + +@dataclass +class PearsonRResult: + statistic: float + pvalue: float + correlation: float + + def confidence_interval( + self, confidence_level: float = ..., method: Any = ... + ) -> NamedTuple: ... diff --git a/stubs/together/__init__.pyi b/stubs/together/__init__.pyi new file mode 100644 index 0000000..4587545 --- /dev/null +++ b/stubs/together/__init__.pyi @@ -0,0 +1,22 @@ +from typing import Any, Dict, List, Optional + +class Complete: + @classmethod + def create( + self, + prompt: str, + model: Optional[str] = "", + max_tokens: Optional[int] = 128, + stop: Optional[List[str]] = [], + temperature: Optional[float] = 0.7, + top_p: Optional[float] = 0.7, + top_k: Optional[int] = 50, + repetition_penalty: Optional[float] = None, + logprobs: Optional[int] = None, + ) -> Dict[str, Any]: ... + +api_key: str = ... + +class Models: + @classmethod + def start(cls, model_name: str) -> None: ...