From d70ed9bc1f932ef6172ead05b1f0109de56868bc Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 26 Sep 2023 15:42:26 -0500 Subject: [PATCH] refactor: BaseModel getattr logic sharing [APE-1415] (#1675) --- src/ape/api/config.py | 2 +- src/ape/api/networks.py | 47 ++-------- src/ape/api/projects.py | 61 +++++-------- src/ape/api/transactions.py | 12 ++- src/ape/types/__init__.py | 43 +++++---- src/ape/utils/__init__.py | 2 + src/ape/utils/basemodel.py | 149 ++++++++++++++++++++++++++++++- tests/functional/test_receipt.py | 5 ++ 8 files changed, 212 insertions(+), 109 deletions(-) diff --git a/src/ape/api/config.py b/src/ape/api/config.py index 6bceb1b3ab..d683da2aea 100644 --- a/src/ape/api/config.py +++ b/src/ape/api/config.py @@ -40,7 +40,7 @@ def update(root: Dict, value_map: Dict): return cls(**update(default_values, overrides)) def __getattr__(self, attr_name: str) -> Any: - # allow hyphens in plugin config files + # Allow hyphens in plugin config files. attr_name = attr_name.replace("-", "_") return super().__getattribute__(attr_name) diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index afcb5232e7..d6f9baa7a5 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -25,6 +25,7 @@ from ape.utils import ( DEFAULT_TRANSACTION_ACCEPTANCE_TIMEOUT, BaseInterfaceModel, + ExtraModelAttributes, ManagerAccessMixin, abstractmethod, cached_property, @@ -221,45 +222,13 @@ def __post_init__(self): if len(self.networks) == 0: raise NetworkError("Must define at least one network in ecosystem") - def __getitem__(self, network_name: str) -> "NetworkAPI": - """ - Get a network by name. - - Raises: - :class:`~ape.exceptions.NetworkNotFoundError`: - When there is no network with the given name. - - Args: - network_name (str): The name of the network to retrieve. - - Returns: - :class:`~ape.api.networks.NetworkAPI` - """ - return self.get_network(network_name) - - def __getattr__(self, network_name: str) -> "NetworkAPI": - """ - Get a network by name using ``.`` access. - - Usage example:: - - from ape import networks - mainnet = networks.ecosystem.mainnet - - Raises: - :class:`~ape.exceptions.NetworkNotFoundError`: - When there is no network with the given name. - - Args: - network_name (str): The name of the network to retrieve. - - Returns: - :class:`~ape.api.networks.NetworkAPI` - """ - try: - return self.get_network(network_name.replace("_", "-")) - except NetworkNotFoundError: - return self.__getattribute__(network_name) + def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]: + yield ExtraModelAttributes( + name="networks", + attributes=self.networks, + include_getattr=True, + include_getitem=True, + ) def add_network(self, network_name: str, network: "NetworkAPI"): """ diff --git a/src/ape/api/projects.py b/src/ape/api/projects.py index f41524858b..7ff8c899ae 100644 --- a/src/ape/api/projects.py +++ b/src/ape/api/projects.py @@ -2,7 +2,7 @@ import re import tempfile from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union from ethpm_types import Checksum, ContractType, PackageManifest, Source from ethpm_types.manifest import PackageName @@ -10,10 +10,10 @@ from packaging.version import InvalidVersion, Version from pydantic import ValidationError -from ape.exceptions import ApeAttributeError from ape.logging import logger from ape.utils import ( BaseInterfaceModel, + ExtraModelAttributes, abstractmethod, cached_property, get_all_files_in_directory, @@ -259,6 +259,14 @@ class DependencyAPI(BaseInterfaceModel): def __repr__(self): return f"<{self.__class__.__name__} name='{self.name}'>" + def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]: + yield ExtraModelAttributes( + name=self.name, + attributes=self.contracts, + include_getattr=True, + include_getitem=True, + ) + @property @abstractmethod def version_id(self) -> str: @@ -321,46 +329,19 @@ def cached_manifest(self) -> Optional[PackageManifest]: return self._cached_manifest - def __getitem__(self, contract_name: str) -> "ContractContainer": - try: - container = self.get(contract_name) - except Exception as err: - raise IndexError(str(err)) from err - - if not container: - raise IndexError(f"Contract '{contract_name}' not found.") - - return container - - def __getattr__(self, contract_name: str) -> "ContractContainer": - try: - return self.__getattribute__(contract_name) - except AttributeError: - pass - - try: - container = self.get(contract_name) - except Exception as err: - raise ApeAttributeError( - f"Dependency project '{self.name}' has no contract " - f"or attribute '{contract_name}'.\n{err}" - ) from err - - if not container: - raise ApeAttributeError( - f"Dependency project '{self.name}' has no contract '{contract_name}'." - ) - - return container + @cached_property + def contracts(self) -> Dict[str, "ContractContainer"]: + """ + A mapping of name to contract type of all the contracts + in this dependency. + """ + return { + n: self.chain_manager.contracts.get_container(c) + for n, c in (self.compile().contract_types or {}).items() + } def get(self, contract_name: str) -> Optional["ContractContainer"]: - manifest = self.compile() - options = (contract_name, contract_name.replace("-", "_")) - for name in options: - if contract_type := manifest.get_contract_type(name): - return self.chain_manager.contracts.get_container(contract_type) - - return None + return self.contracts.get(contract_name) def compile(self, use_cache: bool = True) -> PackageManifest: """ diff --git a/src/ape/api/transactions.py b/src/ape/api/transactions.py index eac25a2db9..1e3ef14074 100644 --- a/src/ape/api/transactions.py +++ b/src/ape/api/transactions.py @@ -27,7 +27,13 @@ TraceFrame, TransactionSignature, ) -from ape.utils import BaseInterfaceModel, abstractmethod, cached_property, raises_not_implemented +from ape.utils import ( + BaseInterfaceModel, + ExtraModelAttributes, + abstractmethod, + cached_property, + raises_not_implemented, +) if TYPE_CHECKING: from ape.api.providers import BlockAPI @@ -261,8 +267,8 @@ class ReceiptAPI(BaseInterfaceModel): def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.txn_hash}>" - def __getattr__(self, item: str) -> Any: - return getattr(self.transaction, item) + def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]: + yield ExtraModelAttributes(name="transaction", attributes=self.transaction) @validator("transaction", pre=True) def confirm_transaction(cls, value): diff --git a/src/ape/types/__init__.py b/src/ape/types/__init__.py index 3e3123328e..cf3bf73112 100644 --- a/src/ape/types/__init__.py +++ b/src/ape/types/__init__.py @@ -1,5 +1,16 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Union, + cast, +) from eth_abi.abi import encode from eth_abi.packed import encode_packed @@ -21,7 +32,6 @@ from pydantic import BaseModel, root_validator, validator from web3.types import FilterParams -from ape.exceptions import ApeAttributeError from ape.types.address import AddressType, RawAddress from ape.types.coverage import ( ContractCoverage, @@ -32,7 +42,7 @@ ) from ape.types.signatures import MessageSignature, SignableMessage, TransactionSignature from ape.types.trace import CallTreeNode, ControlFlow, GasReport, SourceTraceback, TraceFrame -from ape.utils import BaseInterfaceModel, cached_property +from ape.utils import BaseInterfaceModel, ExtraModelAttributes, cached_property from ape.utils.misc import ZERO_ADDRESS, to_int if TYPE_CHECKING: @@ -294,23 +304,13 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"<{self.event_name} {self._event_args_str}>" - def __getattr__(self, item: str) -> Any: - """ - Access properties from the log via ``.`` access. - - Args: - item (str): The name of the property. - """ - - try: - return self.__getattribute__(item) - except AttributeError: - pass - - if item not in self.event_arguments: - raise ApeAttributeError(f"{self.__class__.__name__} has no attribute '{item}'.") - - return self.event_arguments[item] + def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]: + yield ExtraModelAttributes( + name=self.event_name, + attributes=self.event_arguments, + include_getattr=True, + include_getitem=True, + ) def __contains__(self, item: str) -> bool: return item in self.event_arguments @@ -338,9 +338,6 @@ def __eq__(self, other: Any) -> bool: # call __eq__ on parent class return super().__eq__(other) - def __getitem__(self, item: str) -> Any: - return self.event_arguments[item] - def get(self, item: str, default: Optional[Any] = None) -> Any: return self.event_arguments.get(item, default) diff --git a/src/ape/utils/__init__.py b/src/ape/utils/__init__.py index 84bff51e82..3fccd224aa 100644 --- a/src/ape/utils/__init__.py +++ b/src/ape/utils/__init__.py @@ -13,6 +13,7 @@ from ape.utils.basemodel import ( BaseInterface, BaseInterfaceModel, + ExtraModelAttributes, ManagerAccessMixin, injected_before_use, ) @@ -75,6 +76,7 @@ "EMPTY_BYTES32", "expand_environment_variables", "extract_nested_value", + "ExtraModelAttributes", "get_relative_path", "gas_estimation_error_message", "get_package_version", diff --git a/src/ape/utils/basemodel.py b/src/ape/utils/basemodel.py index e10d8fe49f..361a9a0786 100644 --- a/src/ape/utils/basemodel.py +++ b/src/ape/utils/basemodel.py @@ -1,9 +1,9 @@ from abc import ABC -from typing import TYPE_CHECKING, ClassVar, List, Optional, cast +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, List, Optional, Union, cast -from ethpm_types import BaseModel +from ethpm_types import BaseModel as _BaseModel -from ape.exceptions import ProviderNotConnectedError +from ape.exceptions import ApeAttributeError, ProviderNotConnectedError from ape.logging import logger from ape.utils.misc import cached_property, singledispatchmethod @@ -100,6 +100,149 @@ class BaseInterface(ManagerAccessMixin, ABC): """ +def _get_alt(name: str) -> Optional[str]: + alt = None + if ("-" not in name and "_" not in name) or ("-" in name and "_" in name): + alt = None + elif "-" in name: + alt = name.replace("-", "_") + elif "_" in name: + alt = name.replace("_", "-") + + return alt + + +class ExtraModelAttributes(_BaseModel): + """ + A class for defining extra model attributes. + """ + + name: str + """ + The name of the attributes. This is important + in instances such as when an attribute is missing, + we can show a more accurate exception message. + """ + + attributes: Union[Dict[str, Any], "BaseModel"] + """The attributes.""" + + include_getattr: bool = True + """Whether to use these in ``__getattr__``.""" + + include_getitem: bool = False + """Whether to use these in ``__getitem__``.""" + + def __contains__(self, name: str) -> bool: + attr_dict = self.attributes if isinstance(self.attributes, dict) else self.attributes.dict() + if name in attr_dict: + return True + + elif alt := _get_alt(name): + return alt in attr_dict + + return False + + def get(self, name: str) -> Optional[Any]: + """ + Get an attribute. + + Args: + name (str): The name of the attribute. + + Returns: + Optional[Any]: The attribute if it exists, else ``None``. + """ + + res = self._get(name) + if res is not None: + return res + + if alt := _get_alt(name): + res = self._get(alt) + if res is not None: + return res + + return None + + def _get(self, name: str) -> Optional[Any]: + return ( + self.attributes.get(name) + if isinstance(self.attributes, dict) + else getattr(self.attributes, name, None) + ) + + +class BaseModel(_BaseModel): + """ + An ape-pydantic BaseModel. + """ + + def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]: + """ + Override this method to supply extra attributes + to a model in Ape; this allow more properties + to be available when invoking ``__getattr__``. + + Returns: + Iterator[:class:`~ape.utils.basemodel.ExtraModelAttributes`]: A + series of instances defining extra model attributes. + """ + return iter(()) + + def __getattr__(self, name: str) -> Any: + """ + An overridden ``__getattr__`` implementation that takes into + account :meth:`~ape.utils.basemodel.BaseModel.__ape_extra_attributes__`. + """ + + try: + return super().__getattr__(name) + except AttributeError: + extras_checked = set() + for ape_extra in self.__ape_extra_attributes__(): + if not ape_extra.include_getattr: + continue + + if name in ape_extra: + # Attribute was found in one of the supplied + # extra attributes mappings. + return ape_extra.get(name) + + extras_checked.add(ape_extra.name) + + # The error message mentions the alternative mappings, + # such as a contract-type map. + message = f"No attribute with name '{name}'." + if extras_checked: + extras_str = ", ".join(extras_checked) + message = f"{message} Also checked '{extras_str}'." + + raise ApeAttributeError(message) + + def __getitem__(self, name: Any) -> Any: + # For __getitem__, we first try the extra (unlike `__getattr__`). + extras_checked = set() + for extra in self.__ape_extra_attributes__(): + if not extra.include_getitem: + continue + + if name in extra: + return extra.get(name) + + extras_checked.add(extra.name) + + # NOTE: If extras were supplied, the user was expecting it to be + # there (unlike __getattr__). + if extras_checked: + extras_str = ", ".join(extras_checked) + raise IndexError(f"Unable to find '{name}' in any of '{extras_str}'.") + + # The user did not supply any extra __getitem__ attributes. + # Do what you would have normally done. + return super().__getitem__(name) + + class BaseInterfaceModel(BaseInterface, BaseModel): """ An abstract base-class with manager access on a pydantic base model. diff --git a/tests/functional/test_receipt.py b/tests/functional/test_receipt.py index 59beb4ae6f..4842e7b862 100644 --- a/tests/functional/test_receipt.py +++ b/tests/functional/test_receipt.py @@ -192,3 +192,8 @@ def test_track_coverage(deploy_receipt, mocker): assert mock_runner.track_coverage.call_count == 0 ManagerAccessMixin._test_runner = original + + +def test_access_from_tx(deploy_receipt): + actual = deploy_receipt.receiver + assert actual == ""