Skip to content

Commit

Permalink
refactor: BaseModel getattr logic sharing [APE-1415] (#1675)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Sep 26, 2023
1 parent 803d952 commit d70ed9b
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 109 deletions.
2 changes: 1 addition & 1 deletion src/ape/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def update(root: Dict, value_map: Dict):
return cls(**update(default_values, overrides))

def __getattr__(self, attr_name: str) -> Any:
# allow hyphens in plugin config files
# Allow hyphens in plugin config files.
attr_name = attr_name.replace("-", "_")
return super().__getattribute__(attr_name)

Expand Down
47 changes: 8 additions & 39 deletions src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ape.utils import (
DEFAULT_TRANSACTION_ACCEPTANCE_TIMEOUT,
BaseInterfaceModel,
ExtraModelAttributes,
ManagerAccessMixin,
abstractmethod,
cached_property,
Expand Down Expand Up @@ -221,45 +222,13 @@ def __post_init__(self):
if len(self.networks) == 0:
raise NetworkError("Must define at least one network in ecosystem")

def __getitem__(self, network_name: str) -> "NetworkAPI":
"""
Get a network by name.
Raises:
:class:`~ape.exceptions.NetworkNotFoundError`:
When there is no network with the given name.
Args:
network_name (str): The name of the network to retrieve.
Returns:
:class:`~ape.api.networks.NetworkAPI`
"""
return self.get_network(network_name)

def __getattr__(self, network_name: str) -> "NetworkAPI":
"""
Get a network by name using ``.`` access.
Usage example::
from ape import networks
mainnet = networks.ecosystem.mainnet
Raises:
:class:`~ape.exceptions.NetworkNotFoundError`:
When there is no network with the given name.
Args:
network_name (str): The name of the network to retrieve.
Returns:
:class:`~ape.api.networks.NetworkAPI`
"""
try:
return self.get_network(network_name.replace("_", "-"))
except NetworkNotFoundError:
return self.__getattribute__(network_name)
def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]:
yield ExtraModelAttributes(
name="networks",
attributes=self.networks,
include_getattr=True,
include_getitem=True,
)

def add_network(self, network_name: str, network: "NetworkAPI"):
"""
Expand Down
61 changes: 21 additions & 40 deletions src/ape/api/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
import re
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union

from ethpm_types import Checksum, ContractType, PackageManifest, Source
from ethpm_types.manifest import PackageName
from ethpm_types.utils import Algorithm, AnyUrl, compute_checksum
from packaging.version import InvalidVersion, Version
from pydantic import ValidationError

from ape.exceptions import ApeAttributeError
from ape.logging import logger
from ape.utils import (
BaseInterfaceModel,
ExtraModelAttributes,
abstractmethod,
cached_property,
get_all_files_in_directory,
Expand Down Expand Up @@ -259,6 +259,14 @@ class DependencyAPI(BaseInterfaceModel):
def __repr__(self):
return f"<{self.__class__.__name__} name='{self.name}'>"

def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]:
yield ExtraModelAttributes(
name=self.name,
attributes=self.contracts,
include_getattr=True,
include_getitem=True,
)

@property
@abstractmethod
def version_id(self) -> str:
Expand Down Expand Up @@ -321,46 +329,19 @@ def cached_manifest(self) -> Optional[PackageManifest]:

return self._cached_manifest

def __getitem__(self, contract_name: str) -> "ContractContainer":
try:
container = self.get(contract_name)
except Exception as err:
raise IndexError(str(err)) from err

if not container:
raise IndexError(f"Contract '{contract_name}' not found.")

return container

def __getattr__(self, contract_name: str) -> "ContractContainer":
try:
return self.__getattribute__(contract_name)
except AttributeError:
pass

try:
container = self.get(contract_name)
except Exception as err:
raise ApeAttributeError(
f"Dependency project '{self.name}' has no contract "
f"or attribute '{contract_name}'.\n{err}"
) from err

if not container:
raise ApeAttributeError(
f"Dependency project '{self.name}' has no contract '{contract_name}'."
)

return container
@cached_property
def contracts(self) -> Dict[str, "ContractContainer"]:
"""
A mapping of name to contract type of all the contracts
in this dependency.
"""
return {
n: self.chain_manager.contracts.get_container(c)
for n, c in (self.compile().contract_types or {}).items()
}

def get(self, contract_name: str) -> Optional["ContractContainer"]:
manifest = self.compile()
options = (contract_name, contract_name.replace("-", "_"))
for name in options:
if contract_type := manifest.get_contract_type(name):
return self.chain_manager.contracts.get_container(contract_type)

return None
return self.contracts.get(contract_name)

def compile(self, use_cache: bool = True) -> PackageManifest:
"""
Expand Down
12 changes: 9 additions & 3 deletions src/ape/api/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
TraceFrame,
TransactionSignature,
)
from ape.utils import BaseInterfaceModel, abstractmethod, cached_property, raises_not_implemented
from ape.utils import (
BaseInterfaceModel,
ExtraModelAttributes,
abstractmethod,
cached_property,
raises_not_implemented,
)

if TYPE_CHECKING:
from ape.api.providers import BlockAPI
Expand Down Expand Up @@ -261,8 +267,8 @@ class ReceiptAPI(BaseInterfaceModel):
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self.txn_hash}>"

def __getattr__(self, item: str) -> Any:
return getattr(self.transaction, item)
def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]:
yield ExtraModelAttributes(name="transaction", attributes=self.transaction)

@validator("transaction", pre=True)
def confirm_transaction(cls, value):
Expand Down
43 changes: 20 additions & 23 deletions src/ape/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)

from eth_abi.abi import encode
from eth_abi.packed import encode_packed
Expand All @@ -21,7 +32,6 @@
from pydantic import BaseModel, root_validator, validator
from web3.types import FilterParams

from ape.exceptions import ApeAttributeError
from ape.types.address import AddressType, RawAddress
from ape.types.coverage import (
ContractCoverage,
Expand All @@ -32,7 +42,7 @@
)
from ape.types.signatures import MessageSignature, SignableMessage, TransactionSignature
from ape.types.trace import CallTreeNode, ControlFlow, GasReport, SourceTraceback, TraceFrame
from ape.utils import BaseInterfaceModel, cached_property
from ape.utils import BaseInterfaceModel, ExtraModelAttributes, cached_property
from ape.utils.misc import ZERO_ADDRESS, to_int

if TYPE_CHECKING:
Expand Down Expand Up @@ -294,23 +304,13 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"<{self.event_name} {self._event_args_str}>"

def __getattr__(self, item: str) -> Any:
"""
Access properties from the log via ``.`` access.
Args:
item (str): The name of the property.
"""

try:
return self.__getattribute__(item)
except AttributeError:
pass

if item not in self.event_arguments:
raise ApeAttributeError(f"{self.__class__.__name__} has no attribute '{item}'.")

return self.event_arguments[item]
def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]:
yield ExtraModelAttributes(
name=self.event_name,
attributes=self.event_arguments,
include_getattr=True,
include_getitem=True,
)

def __contains__(self, item: str) -> bool:
return item in self.event_arguments
Expand Down Expand Up @@ -338,9 +338,6 @@ def __eq__(self, other: Any) -> bool:
# call __eq__ on parent class
return super().__eq__(other)

def __getitem__(self, item: str) -> Any:
return self.event_arguments[item]

def get(self, item: str, default: Optional[Any] = None) -> Any:
return self.event_arguments.get(item, default)

Expand Down
2 changes: 2 additions & 0 deletions src/ape/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ape.utils.basemodel import (
BaseInterface,
BaseInterfaceModel,
ExtraModelAttributes,
ManagerAccessMixin,
injected_before_use,
)
Expand Down Expand Up @@ -75,6 +76,7 @@
"EMPTY_BYTES32",
"expand_environment_variables",
"extract_nested_value",
"ExtraModelAttributes",
"get_relative_path",
"gas_estimation_error_message",
"get_package_version",
Expand Down
Loading

0 comments on commit d70ed9b

Please sign in to comment.