Skip to content

Commit

Permalink
fix: issue with custom errors on estimate gas and static fee txns [AP…
Browse files Browse the repository at this point in the history
…E-1421] (#1680)
  • Loading branch information
antazoey authored Sep 29, 2023
1 parent db49185 commit 0eb2478
Show file tree
Hide file tree
Showing 20 changed files with 249 additions and 175 deletions.
41 changes: 33 additions & 8 deletions src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class EcosystemAPI(BaseInterfaceModel):
fee_token_decimals: int = 18
"""The number of the decimals the fee token has."""

_default_network: str = LOCAL_NETWORK_NAME
_default_network: Optional[str] = None

def __repr__(self) -> str:
return f"<{self.name}>"
Expand Down Expand Up @@ -254,7 +254,25 @@ def default_network(self) -> str:
Returns:
str
"""
return self._default_network

if network := self._default_network:
# Was set programatically.
return network

elif network := self.config.get("default_network"):
# Default found in config.
return network

elif LOCAL_NETWORK_NAME in self.networks:
# Default to the LOCAL_NETWORK_NAME, at last resort.
return LOCAL_NETWORK_NAME

elif len(self.networks) >= 1:
# Use the first network.
return self.networks[0]

# Very unlikely scenario.
raise ValueError("No networks found.")

def set_default_network(self, network_name: str):
"""
Expand Down Expand Up @@ -425,7 +443,7 @@ def get_network_data(self, network_name: str) -> Dict:
Returns:
dict: A dictionary containing the providers in a network.
"""
data: Dict[str, Any] = {"name": network_name}
data: Dict[str, Any] = {"name": str(network_name)}

# Only add isDefault key when True
if network_name == self.default_network:
Expand All @@ -435,10 +453,10 @@ def get_network_data(self, network_name: str) -> Dict:
network = self[network_name]

if network.explorer:
data["explorer"] = network.explorer.name
data["explorer"] = str(network.explorer.name)

for provider_name in network.providers:
provider_data = {"name": provider_name}
provider_data: Dict = {"name": str(provider_name)}

# Only add isDefault key when True
if provider_name == network.default_provider:
Expand Down Expand Up @@ -906,12 +924,19 @@ def default_provider(self) -> Optional[str]:
Optional[str]
"""

if self._default_provider:
return self._default_provider
if provider := self._default_provider:
# Was set programatically.
return provider

elif provider_from_config := self._network_config.get("default_provider"):
# The default is found in the Network's config class.
return provider_from_config

if len(self.providers) > 0:
elif len(self.providers) > 0:
# No default set anywhere - use the first installed.
return list(self.providers)[0]

# There are no providers at all for this network.
return None

@property
Expand Down
15 changes: 13 additions & 2 deletions src/ape/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,15 +1420,26 @@ def prepare_transaction(self, txn: TransactionAPI) -> TransactionAPI:
txn.max_fee = int(self.base_fee * multiplier + txn.max_priority_fee)
# else: Assume user specified the correct amount or txn will fail and waste gas

if txn.gas_limit is None:
multiplier = self.network.auto_gas_multiplier
gas_limit = self.network.gas_limit if txn.gas_limit is None else txn.gas_limit
if gas_limit in (None, "auto") or isinstance(gas_limit, AutoGasLimit):
multiplier = (
gas_limit.multiplier
if isinstance(gas_limit, AutoGasLimit)
else self.network.auto_gas_multiplier
)
if multiplier != 1.0:
gas = min(int(self.estimate_gas_cost(txn) * multiplier), self.max_gas)
else:
gas = self.estimate_gas_cost(txn)

txn.gas_limit = gas

elif gas_limit == "max":
txn.gas_limit = self.max_gas

elif gas_limit is not None and isinstance(gas_limit, int):
txn.gas_limit = gas_limit

if txn.required_confirmations is None:
txn.required_confirmations = self.network.required_confirmations
elif not isinstance(txn.required_confirmations, int) or txn.required_confirmations < 0:
Expand Down
7 changes: 1 addition & 6 deletions src/ape/managers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ape._pydantic_compat import root_validator
from ape.api import ConfigDict, DependencyAPI, PluginConfig
from ape.exceptions import ConfigError, NetworkError
from ape.exceptions import ConfigError
from ape.logging import logger
from ape.utils import BaseInterfaceModel, load_config

Expand Down Expand Up @@ -182,11 +182,6 @@ def _plugin_configs(self) -> Dict[str, PluginConfig]:
configs["compiler"] = compiler_dict
self.compiler = CompilerConfig(**compiler_dict)

try:
self.network_manager.set_default_ecosystem(self.default_ecosystem)
except NetworkError as err:
logger.warning(str(err))

dependencies = user_config.pop("dependencies", []) or []
if not isinstance(dependencies, list):
raise ConfigError("'dependencies' config item must be a list of dicts.")
Expand Down
89 changes: 38 additions & 51 deletions src/ape/managers/networks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
from functools import cached_property
from typing import Dict, Iterator, List, Optional, Set, Union

import yaml

from ape.api import EcosystemAPI, ProviderAPI, ProviderContextManager
from ape.api.networks import LOCAL_NETWORK_NAME, NetworkAPI
from ape.api.networks import NetworkAPI
from ape.exceptions import ApeAttributeError, NetworkError
from ape.logging import logger

from .base import BaseManager
from ape.managers.base import BaseManager


class NetworkManager(BaseManager):
Expand All @@ -27,7 +27,6 @@ class NetworkManager(BaseManager):

_active_provider: Optional[ProviderAPI] = None
_default: Optional[str] = None
_ecosystems_by_project: Dict[str, Dict[str, EcosystemAPI]] = {}

def __repr__(self):
provider = self.active_provider
Expand Down Expand Up @@ -137,54 +136,21 @@ def provider_names(self) -> Set[str]:

return names

@property
@cached_property
def ecosystems(self) -> Dict[str, EcosystemAPI]:
"""
All the registered ecosystems in ``ape``, such as ``ethereum``.
"""

project_name = self.config_manager.PROJECT_FOLDER.stem
if project_name in self._ecosystems_by_project:
return self._ecosystems_by_project[project_name]

ecosystem_dict = {}
for plugin_name, ecosystem_class in self.plugin_manager.ecosystems:
ecosystem = ecosystem_class( # type: ignore
name=plugin_name,
data_folder=self.config_manager.DATA_FOLDER / plugin_name,
request_header=self.config_manager.REQUEST_HEADER,
)
ecosystem_config = self.config_manager.get_config(plugin_name).dict()
default_network = ecosystem_config.get("default_network", LOCAL_NETWORK_NAME)
def to_kwargs(name: str) -> Dict:
return {
"name": name,
"data_folder": self.config_manager.DATA_FOLDER / name,
"request_header": self.config_manager.REQUEST_HEADER,
}

try:
ecosystem.set_default_network(default_network)
except NetworkError as err:
message = f"Failed setting default network: {err}"
logger.error(message)

if ecosystem_config:
for network_name, network in ecosystem.networks.items():
network_name = network_name.replace("-", "_")
if network_name not in ecosystem_config:
continue

network_config = ecosystem_config[network_name]
if "default_provider" not in network_config:
continue

default_provider = network_config["default_provider"]
if default_provider:
try:
network.set_default_provider(default_provider)
except NetworkError as err:
message = f"Failed setting default provider: {err}"
logger.error(message)

ecosystem_dict[plugin_name] = ecosystem

self._ecosystems_by_project[project_name] = ecosystem_dict
return ecosystem_dict
ecosystems = self.plugin_manager.ecosystems
return {n: cls(**to_kwargs(n)) for n, cls in ecosystems} # type: ignore

def create_adhoc_geth_provider(self, uri: str) -> ProviderAPI:
"""
Expand Down Expand Up @@ -485,6 +451,9 @@ def default_ecosystem(self) -> EcosystemAPI:
if self._default:
return ecosystems[self._default]

elif self.config_manager.default_ecosystem:
return ecosystems[self.config_manager.default_ecosystem]

# If explicit default is not set, use first registered ecosystem
elif len(ecosystems) > 0:
return list(ecosystems.values())[0]
Expand Down Expand Up @@ -529,16 +498,17 @@ def network_data(self) -> Dict:

return data

def _get_ecosystem_data(self, ecosystem_name) -> Dict:
def _get_ecosystem_data(self, ecosystem_name: str) -> Dict:
ecosystem = self[ecosystem_name]
ecosystem_data = {"name": ecosystem_name}
ecosystem_data: Dict = {"name": str(ecosystem_name)}

# Only add isDefault key when True
if ecosystem_name == self.default_ecosystem.name:
ecosystem_data["isDefault"] = True

ecosystem_data["networks"] = []
for network_name in getattr(self, ecosystem_name).networks.keys():

for network_name in getattr(self, ecosystem_name).networks:
network_data = ecosystem.get_network_data(network_name)
ecosystem_data["networks"].append(network_data)

Expand All @@ -556,7 +526,24 @@ def networks_yaml(self) -> str:
str
"""

return yaml.dump(self.network_data, sort_keys=False)
data = self.network_data
if not isinstance(data, dict):
raise TypeError(
f"Unexpected network data type: {type(data)}. "
f"Expecting dict. YAML dump will fail."
)

try:
return yaml.dump(data, sort_keys=False)
except ValueError as err:
try:
data_str = json.dumps(data)
except Exception:
data_str = str(data)

raise NetworkError(
f"Network data did not dump to YAML: {data_str}\nAcual err: {err}"
) from err


def _validate_filter(arg: Optional[Union[List[str], str]], options: Set[str]):
Expand Down
9 changes: 7 additions & 2 deletions src/ape/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ def _warn_not_fully_implemented_error(self, results, plugin_name):
# Likely only ever a single class in a registration, but just in case.
api_name = " - ".join([p.__name__ for p in classes])
for api_cls in classes:
if hasattr(api_cls, "__abstractmethods__") and api_cls.__abstractmethods__:
if (
abstract_methods := getattr(api_cls, "__abstractmethods__", None)
) and isinstance(abstract_methods, dict):
unimplemented_methods.extend(api_cls.__abstractmethods__)

else:
Expand All @@ -204,8 +206,11 @@ def _warn_not_fully_implemented_error(self, results, plugin_name):

elif hasattr(results, "__name__"):
api_name = results.__name__
if hasattr(results, "__abstractmethods__") and results.__abstractmethods__:
if (abstract_methods := getattr(results, "__abstractmethods__", None)) and isinstance(
abstract_methods, dict
):
unimplemented_methods.extend(results.__abstractmethods__)

else:
api_name = results

Expand Down
9 changes: 6 additions & 3 deletions src/ape_ethereum/ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,10 +602,13 @@ def create_transaction(self, **kwargs) -> TransactionAPI:
if "type" in kwargs:
if kwargs["type"] is None:
version = TransactionType.DYNAMIC
elif not isinstance(kwargs["type"], int):
version = TransactionType(self.conversion_manager.convert(kwargs["type"], int))
else:
elif isinstance(kwargs["type"], TransactionType):
version = kwargs["type"]
elif isinstance(kwargs["type"], int):
version = TransactionType(kwargs["type"])
else:
# Using hex values or alike.
version = TransactionType(self.conversion_manager.convert(kwargs["type"], int))

elif "gas_price" in kwargs:
version = TransactionType.STATIC
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def ethereum(networks):


@pytest.fixture(autouse=True)
def eth_tester_provider():
def eth_tester_provider(ethereum):
if not ape.networks.active_provider or ape.networks.provider.name != "test":
with ape.networks.ethereum.local.use_provider("test") as provider:
with ethereum.local.use_provider("test") as provider:
yield provider
else:
yield ape.networks.provider
Expand Down
22 changes: 10 additions & 12 deletions tests/functional/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from ethpm_types import ContractType, HexBytes

import ape
from ape.api import EcosystemAPI, NetworkAPI, TransactionAPI
from ape.api.networks import LOCAL_NETWORK_NAME
from ape.api import TransactionAPI
from ape.contracts import ContractContainer, ContractInstance
from ape.exceptions import ChainError, ContractLogicError
from ape.logging import LogLevel
Expand Down Expand Up @@ -48,15 +47,6 @@ def pytest_collection_finish(session):
yield


@pytest.fixture
def mock_network_api(mocker):
mock = mocker.MagicMock(spec=NetworkAPI)
mock_ecosystem = mocker.MagicMock(spec=EcosystemAPI)
mock_ecosystem.virtual_machine_error_class = _ContractLogicError
mock.ecosystem = mock_ecosystem
return mock


@pytest.fixture
def mock_web3(mocker):
return mocker.MagicMock()
Expand Down Expand Up @@ -416,9 +406,10 @@ def use_debug(logger):

@pytest.fixture
def dummy_live_network(chain):
original_network = chain.provider.network.name
chain.provider.network.name = "goerli"
yield chain.provider.network
chain.provider.network.name = LOCAL_NETWORK_NAME
chain.provider.network.name = original_network


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -514,3 +505,10 @@ def minimal_proxy_container():
@pytest.fixture
def minimal_proxy(owner, minimal_proxy_container):
return owner.deploy(minimal_proxy_container)


@pytest.fixture
def mock_explorer(mocker):
explorer = mocker.MagicMock()
explorer.name = "mock" # Needed for network data serialization.
return explorer
1 change: 1 addition & 0 deletions tests/functional/geth/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def test_chain_id_live_network_not_connected(networks):
def test_chain_id_live_network_connected_uses_web3_chain_id(mocker, geth_provider):
mock_network = mocker.MagicMock()
mock_network.chain_id = 999999999 # Shouldn't use hardcoded network
mock_network.name = "mock"
orig_network = geth_provider.network

try:
Expand Down
Loading

0 comments on commit 0eb2478

Please sign in to comment.