From f720add44d732b6e98f264610cef1b99bee7a15c Mon Sep 17 00:00:00 2001 From: antazoey Date: Sun, 10 Dec 2023 12:43:14 -0600 Subject: [PATCH] refactor!: a collection of breaking changes [APE-1417] (#1677) --- docs/userguides/clis.md | 6 +- src/ape/_cli.py | 4 +- src/ape/api/__init__.py | 10 +- src/ape/api/networks.py | 5 +- src/ape/api/projects.py | 10 - src/ape/api/providers.py | 1191 +------------------- src/ape/cli/__init__.py | 6 +- src/ape/cli/arguments.py | 2 +- src/ape/cli/choices.py | 53 +- src/ape/cli/options.py | 11 +- src/ape/cli/utils.py | 5 - src/ape/contracts/base.py | 47 +- src/ape/exceptions.py | 19 +- src/ape/logging.py | 3 - src/ape/managers/chain.py | 10 - src/ape/managers/converters.py | 68 +- src/ape/managers/project/dependency.py | 27 +- src/ape/managers/project/manager.py | 27 +- src/ape/types/trace.py | 3 - src/ape_compile/__init__.py | 18 +- src/ape_ethereum/ecosystem.py | 16 +- src/ape_ethereum/provider.py | 1019 +++++++++++++++++ src/ape_geth/provider.py | 20 +- src/ape_test/provider.py | 25 +- tests/functional/conversion/test_ether.py | 7 +- tests/functional/test_cli.py | 42 +- tests/functional/test_compilers.py | 2 +- tests/functional/test_contract_instance.py | 29 +- tests/functional/test_exceptions.py | 3 +- tests/functional/test_provider.py | 2 +- 30 files changed, 1235 insertions(+), 1455 deletions(-) create mode 100644 src/ape_ethereum/provider.py diff --git a/docs/userguides/clis.md b/docs/userguides/clis.md index 0ef9628e14..09c63ffc99 100644 --- a/docs/userguides/clis.md +++ b/docs/userguides/clis.md @@ -127,13 +127,13 @@ Alternatively, you can call the [get_user_selected_account()](../methoddocs/cli. ```python import click -from ape.cli import get_user_selected_account +from ape.cli import select_account @click.command() def cmd(): - account = get_user_selected_account("Select an account to use") - click.echo(f"You selected {account.address}.") + account = select_account("Select an account to use") + click.echo(f"You selected {account.address}.") ``` Similarly, there are a couple custom arguments for aliases alone that are useful when making CLIs for account creation. diff --git a/src/ape/_cli.py b/src/ape/_cli.py index 82c9985bbb..0b7c2e340a 100644 --- a/src/ape/_cli.py +++ b/src/ape/_cli.py @@ -7,8 +7,8 @@ import importlib_metadata as metadata import yaml -from ape.cli import Abort, ape_cli_context -from ape.exceptions import ApeException, handle_ape_exception +from ape.cli import ape_cli_context +from ape.exceptions import Abort, ApeException, handle_ape_exception from ape.logging import logger from ape.plugins import clean_plugin_name diff --git a/src/ape/api/__init__.py b/src/ape/api/__init__.py index fc54436c5e..caba392f01 100644 --- a/src/ape/api/__init__.py +++ b/src/ape/api/__init__.py @@ -18,14 +18,7 @@ create_network_type, ) from .projects import DependencyAPI, ProjectAPI -from .providers import ( - BlockAPI, - ProviderAPI, - SubprocessProvider, - TestProviderAPI, - UpstreamProvider, - Web3Provider, -) +from .providers import BlockAPI, ProviderAPI, SubprocessProvider, TestProviderAPI, UpstreamProvider from .query import QueryAPI, QueryType from .transactions import ReceiptAPI, TransactionAPI @@ -58,5 +51,4 @@ "TestProviderAPI", "TransactionAPI", "UpstreamProvider", - "Web3Provider", ] diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index 2c797bbe9a..13616a371b 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -140,13 +140,10 @@ def encode_contract_blueprint( # type: ignore[empty-body] :class:`~ape.ape.transactions.TransactionAPI` """ - def serialize_transaction(self, transaction: "TransactionAPI") -> bytes: + def serialize_transaction(self) -> bytes: """ Serialize a transaction to bytes. - Args: - transaction (:class:`~ape.api.transactions.TransactionAPI`): The transaction to encode. - Returns: bytes """ diff --git a/src/ape/api/projects.py b/src/ape/api/projects.py index ec027667c9..aa2a7a26b5 100644 --- a/src/ape/api/projects.py +++ b/src/ape/api/projects.py @@ -104,16 +104,6 @@ def cached_manifest(self) -> Optional[PackageManifest]: manifest = self._cached_manifest if manifest.contract_types and not self.contracts: - # Extract contract types from cached manifest. - # This helps migrate to >= 0.6.3. - # TODO: Remove once Ape 0.7 is released. - for contract_type in manifest.contract_types.values(): - if not contract_type.name: - continue - - path = self._cache_folder / f"{contract_type.name}.json" - path.write_text(contract_type.model_dump_json()) - # Rely on individual cache files. self._contracts = manifest.contract_types manifest.contract_types = {} diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index f592527f21..57874edefd 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -6,10 +6,6 @@ import shutil import sys import time -from abc import ABC -from concurrent.futures import ThreadPoolExecutor -from copy import copy -from itertools import tee from logging import FileHandler, Formatter, Logger, getLogger from pathlib import Path from signal import SIGINT, SIGTERM, signal @@ -17,48 +13,30 @@ from typing import Any, Dict, Iterator, List, Optional, Union, cast from eth_pydantic_types import HexBytes -from eth_typing import BlockNumber, HexStr -from eth_utils import add_0x_prefix, to_hex from ethpm_types.abi import EventABI -from evm_trace import CallTreeNode as EvmCallTreeNode -from evm_trace import TraceFrame as EvmTraceFrame from pydantic import Field, computed_field, model_validator -from pydantic.dataclasses import dataclass -from web3 import Web3 -from web3.exceptions import ContractLogicError as Web3ContractLogicError -from web3.exceptions import MethodUnavailable, TimeExhausted, TransactionNotFound -from web3.types import RPCEndpoint, TxParams from ape.api.config import PluginConfig -from ape.api.networks import LOCAL_NETWORK_NAME, NetworkAPI +from ape.api.networks import NetworkAPI from ape.api.query import BlockTransactionQuery from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.exceptions import ( - ApeException, APINotImplementedError, - BlockNotFoundError, - ContractLogicError, - OutOfGasError, ProviderError, - ProviderNotConnectedError, RPCTimeoutError, SubprocessError, SubprocessTimeoutError, - TransactionError, - TransactionNotFoundError, VirtualMachineError, ) -from ape.logging import LogLevel, logger, sanitize_url +from ape.logging import LogLevel, logger from ape.types import ( AddressType, - AutoGasLimit, BlockID, CallTreeNode, ContractCode, ContractLog, LogFilter, SnapshotID, - SourceTraceback, TraceFrame, ) from ape.utils import ( @@ -67,13 +45,10 @@ JoinableQueue, abstractmethod, cached_property, - gas_estimation_error_message, raises_not_implemented, - run_until_complete, spawn, - to_int, ) -from ape.utils.misc import DEFAULT_MAX_RETRIES_TX, _create_raises_not_implemented_error +from ape.utils.misc import _create_raises_not_implemented_error class BlockAPI(BaseInterfaceModel): @@ -226,25 +201,28 @@ def chain_id(self) -> int: """ @abstractmethod - def get_balance(self, address: AddressType) -> int: + def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: """ Get the balance of an account. Args: address (``AddressType``): The address of the account. + block_id (:class:`~ape.types.BlockID`): Optionally specify a block + ID. Defaults to using the latest block. Returns: int: The account balance. """ @abstractmethod - def get_code(self, address: AddressType, **kwargs) -> ContractCode: + def get_code(self, address: AddressType, block_id: Optional[BlockID] = None) -> ContractCode: """ Get the bytes a contract. Args: address (``AddressType``): The address of the contract. - **kwargs: Additional, provider-specific kwargs. + block_id (Optional[:class:`~ape.types.BlockID`]): The block ID + for checking a previous account nonce. Returns: :class:`~ape.types.ContractCode`: The contract bytecode. @@ -258,42 +236,54 @@ def network_choice(self) -> str: return f"{self.network.choice}:{self.name}" @raises_not_implemented - def get_storage_at(self, address: AddressType, slot: int) -> bytes: # type: ignore[empty-body] + def get_storage( # type: ignore[empty-body] + self, address: AddressType, slot: int, block_id: Optional[BlockID] = None + ) -> HexBytes: """ Gets the raw value of a storage slot of a contract. Args: - address (str): The address of the contract. + address (AddressType): The address of the contract. slot (int): Storage slot to read the value of. + block_id (Optional[:class:`~ape.types.BlockID`]): The block ID + for checking a previous storage value. Returns: - bytes: The value of the storage slot. + HexBytes: The value of the storage slot. """ @abstractmethod - def get_nonce(self, address: AddressType) -> int: + def get_nonce(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: """ Get the number of times an account has transacted. Args: - address (``AddressType``): The address of the account. + address (AddressType): The address of the account. + block_id (Optional[:class:`~ape.types.BlockID`]): The block ID + for checking a previous account nonce. Returns: int """ @abstractmethod - def estimate_gas_cost(self, txn: TransactionAPI) -> int: + def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional[BlockID] = None) -> int: """ Estimate the cost of gas for a transaction. Args: txn (:class:`~ape.api.transactions.TransactionAPI`): - The transaction to estimate the gas for. + The transaction to estimate the gas for. + block_id (Optional[:class:`~ape.types.BlockID`]): The block ID + to use when estimating the transaction. Useful for checking a + past estimation cost of a transaction. Returns: int: The estimated cost of gas to execute the transaction - reported in the fee-currency's smallest unit, e.g. Wei. + reported in the fee-currency's smallest unit, e.g. Wei. If the + provider's network has been configured with a gas limit override, it + will be returned. If the gas limit configuration is "max" this will + return the block maximum gas limit. """ @property @@ -305,12 +295,11 @@ def gas_price(self) -> int: """ @property + @abstractmethod def max_gas(self) -> int: """ The max gas limit value you can use. """ - # TODO: Make abstract - return 0 @property def config(self) -> PluginConfig: @@ -368,14 +357,25 @@ def get_block(self, block_id: BlockID) -> BlockAPI: """ @abstractmethod - def send_call(self, txn: TransactionAPI, **kwargs) -> bytes: # Return value of function + def send_call( + self, + txn: TransactionAPI, + block_id: Optional[BlockID] = None, + state: Optional[Dict] = None, + **kwargs, + ) -> HexBytes: # Return value of function """ Execute a new transaction call immediately without creating a - transaction on the blockchain. + transaction on the block chain. Args: - txn (:class:`~ape.api.transactions.TransactionAPI`): The transaction - to send as a call. + txn: :class:`~ape.api.transactions.TransactionAPI` + block_id (Optional[:class:`~ape.types.BlockID`]): The block ID + to use to send a call at a historical point of a contract. + checking a past estimation cost of a transaction. + state (Optional[Dict]): Modify the state of the blockchain + prior to sending the call, for testing purposes. + **kwargs: Provider-specific extra kwargs. Returns: str: The result of the transaction call. @@ -586,7 +586,7 @@ def set_storage( # type: ignore[empty-body] Args: address (str): The address of the contract. slot (int): Storage slot to write the value to. - value: (bytes): The value to overwrite the raw storage slot with. + value: (HexBytes): The value to overwrite the raw storage slot with. """ @raises_not_implemented @@ -808,1105 +808,6 @@ def _increment_call_func_coverage_hit_count(self, txn: TransactionAPI): self._test_runner.coverage_tracker.hit_function(contract_src, method) -def _sanitize_web3_url(msg: str) -> str: - if "URI: " not in msg: - return msg - - parts = msg.split("URI: ") - prefix = parts[0].strip() - rest = parts[1].split(" ") - - # * To remove the `,` from the url http://127.0.0.1:8545, - if "," in rest[0]: - rest[0] = rest[0].rstrip(",") - sanitized_url = sanitize_url(rest[0]) - return f"{prefix} URI: {sanitized_url} {' '.join(rest[1:])}" - - -class Web3Provider(ProviderAPI, ABC): - """ - A base provider mixin class that uses the - `web3.py `__ python package. - """ - - _web3: Optional[Web3] = None - _client_version: Optional[str] = None - - def __init__(self, *args, **kwargs): - logger.create_logger("web3.RequestManager", handlers=(_sanitize_web3_url,)) - logger.create_logger("web3.providers.HTTPProvider", handlers=(_sanitize_web3_url,)) - super().__init__(*args, **kwargs) - - @property - def web3(self) -> Web3: - """ - Access to the ``web3`` object as if you did ``Web3(HTTPProvider(uri))``. - """ - - if not self._web3: - raise ProviderNotConnectedError() - - return self._web3 - - @property - def http_uri(self) -> Optional[str]: - if ( - hasattr(self.web3.provider, "endpoint_uri") - and isinstance(self.web3.provider.endpoint_uri, str) - and self.web3.provider.endpoint_uri.startswith("http") - ): - return self.web3.provider.endpoint_uri - - elif hasattr(self, "uri"): - # NOTE: Some providers define this - return self.uri - - return None - - @property - def ws_uri(self) -> Optional[str]: - if ( - hasattr(self.web3.provider, "endpoint_uri") - and isinstance(self.web3.provider.endpoint_uri, str) - and self.web3.provider.endpoint_uri.startswith("ws") - ): - return self.web3.provider.endpoint_uri - - return None - - @property - def client_version(self) -> str: - if not self._web3: - return "" - - # NOTE: Gets reset to `None` on `connect()` and `disconnect()`. - if self._client_version is None: - self._client_version = self.web3.client_version - - return self._client_version - - @property - def base_fee(self) -> int: - latest_block_number = self.get_block("latest").number - if latest_block_number is None: - # Possibly no blocks yet. - logger.debug("Latest block has no number. Using base fee of '0'.") - return 0 - - try: - fee_history = self.web3.eth.fee_history(1, BlockNumber(latest_block_number)) - except ValueError as exc: - # Use the less-accurate approach (OK for testing). - logger.debug( - "Failed using `web3.eth.fee_history` for network " - f"'{self.network_choice}'. Error: {exc}" - ) - return self._get_last_base_fee() - - if len(fee_history["baseFeePerGas"]) < 2: - logger.debug("Not enough fee_history. Defaulting less-accurate approach.") - return self._get_last_base_fee() - - pending_base_fee = fee_history["baseFeePerGas"][1] - if pending_base_fee is None: - # Non-EIP-1559 chains or we time-travelled pre-London fork. - return self._get_last_base_fee() - - return pending_base_fee - - def _get_last_base_fee(self) -> int: - block = self.get_block("latest") - base_fee = getattr(block, "base_fee", None) - if base_fee is not None: - return base_fee - - raise APINotImplementedError("No base fee found in block.") - - @property - def is_connected(self) -> bool: - if self._web3 is None: - return False - - return run_until_complete(self._web3.is_connected()) - - @property - def max_gas(self) -> int: - block = self.web3.eth.get_block("latest") - return block["gasLimit"] - - @cached_property - def supports_tracing(self) -> bool: - try: - # NOTE: Txn hash is purposely not a real hash. - # If we get any exception besides not implemented error, - # then we support tracing on this provider. - self.get_call_tree("__CHECK_IF_SUPPORTS_TRACING__") - except APINotImplementedError: - return False - except Exception: - return True - - return True - - def update_settings(self, new_settings: dict): - self.disconnect() - self.provider_settings.update(new_settings) - self.connect() - - def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: - """ - Estimate the cost of gas for a transaction. - - Args: - txn (:class:`~ape.api.transactions.TransactionAPI`): - The transaction to estimate the gas for. - kwargs: - * ``block_identifier`` (:class:`~ape.types.BlockID`): The block ID - to use when estimating the transaction. Useful for checking a - past estimation cost of a transaction. Also, you can alias - ``block_id``. - * ``state_overrides`` (Dict): Modify the state of the blockchain - prior to estimation. - - Returns: - int: The estimated cost of gas to execute the transaction - reported in the fee-currency's smallest unit, e.g. Wei. If the - provider's network has been configured with a gas limit override, it - will be returned. If the gas limit configuration is "max" this will - return the block maximum gas limit. - """ - - txn_dict = txn.model_dump(mode="json") - - # Force the use of hex values to support a wider range of nodes. - if isinstance(txn_dict.get("type"), int): - txn_dict["type"] = HexBytes(txn_dict["type"]).hex() - - # NOTE: "auto" means to enter this method, so remove it from dict - if "gas" in txn_dict and ( - txn_dict["gas"] == "auto" or isinstance(txn_dict["gas"], AutoGasLimit) - ): - txn_dict.pop("gas") - # Also pop these, they are overridden by "auto" - txn_dict.pop("maxFeePerGas", None) - txn_dict.pop("maxPriorityFeePerGas", None) - - try: - block_id = kwargs.pop("block_identifier", kwargs.pop("block_id", None)) - txn_params = cast(TxParams, txn_dict) - return self.web3.eth.estimate_gas(txn_params, block_identifier=block_id) - except (ValueError, Web3ContractLogicError) as err: - tx_error = self.get_virtual_machine_error(err, txn=txn) - - # If this is the cause of a would-be revert, - # raise ContractLogicError so that we can confirm tx-reverts. - if isinstance(tx_error, ContractLogicError): - raise tx_error from err - - message = gas_estimation_error_message(tx_error) - raise TransactionError( - message, base_err=tx_error, txn=txn, source_traceback=tx_error.source_traceback - ) from err - - @cached_property - def chain_id(self) -> int: - default_chain_id = None - if ( - self.network.name - not in ( - "adhoc", - LOCAL_NETWORK_NAME, - ) - and not self.network.is_fork - ): - # If using a live network, the chain ID is hardcoded. - default_chain_id = self.network.chain_id - - try: - if hasattr(self.web3, "eth"): - return self.web3.eth.chain_id - - except ProviderNotConnectedError: - if default_chain_id is not None: - return default_chain_id - - raise # Original error - - if default_chain_id is not None: - return default_chain_id - - raise ProviderNotConnectedError() - - @property - def gas_price(self) -> int: - price = self.web3.eth.generate_gas_price() or 0 - return to_int(price) - - @property - def priority_fee(self) -> int: - try: - return self.web3.eth.max_priority_fee - except MethodUnavailable as err: - # The user likely should be using a more-catered plugin. - raise APINotImplementedError( - "eth_maxPriorityFeePerGas not supported in this RPC. Please specify manually." - ) from err - - def get_block(self, block_id: BlockID) -> BlockAPI: - if isinstance(block_id, str) and block_id.isnumeric(): - block_id = int(block_id) - - try: - block_data = dict(self.web3.eth.get_block(block_id)) - except Exception as err: - raise BlockNotFoundError(block_id, reason=str(err)) from err - - # Some nodes (like anvil) will not have a base fee if set to 0. - if "baseFeePerGas" in block_data and block_data.get("baseFeePerGas") is None: - block_data["baseFeePerGas"] = 0 - - return self.network.ecosystem.decode_block(block_data) - - def get_nonce(self, address: AddressType, **kwargs) -> int: - """ - Get the number of times an account has transacted. - - Args: - address (AddressType): The address of the account. - kwargs: - * ``block_identifier`` (:class:`~ape.types.BlockID`): The block ID - for checking a previous account nonce. Also, you can use alias - ``block_id``. - - Returns: - int - """ - - block_id = kwargs.pop("block_identifier", kwargs.pop("block_id", None)) - return self.web3.eth.get_transaction_count(address, block_identifier=block_id) - - def get_balance(self, address: AddressType) -> int: - return self.web3.eth.get_balance(address) - - def get_code(self, address: AddressType, **kwargs) -> ContractCode: - """ - Get the bytes a contract. - - Args: - address (``AddressType``): The address of the contract. - kwargs: - * ``block_identifier`` (:class:`~ape.types.BlockID`): The block ID - for checking a previous account nonce. Also, you can use - alias ``block_id``. - - Returns: - :class:`~ape.types.ContractCode`: The contract bytecode. - """ - - block_id = kwargs.pop("block_identifier", kwargs.pop("block_id", None)) - return self.web3.eth.get_code(address, block_identifier=block_id) - - def get_storage_at(self, address: AddressType, slot: int, **kwargs) -> bytes: - """ - Gets the raw value of a storage slot of a contract. - - Args: - address (AddressType): The address of the contract. - slot (int): Storage slot to read the value of. - kwargs: - * ``block_identifier`` (:class:`~ape.types.BlockID`): The block ID - for checking previous contract storage values. Also, you can use - alias ``block_id``. - - Returns: - bytes: The value of the storage slot. - """ - - block_id = kwargs.pop("block_identifier", kwargs.pop("block_id", None)) - try: - return self.web3.eth.get_storage_at( - address, slot, block_identifier=block_id # type: ignore - ) - except ValueError as err: - if "RPC Endpoint has not been implemented" in str(err): - raise APINotImplementedError(str(err)) from err - - raise # Raise original error - - def send_call(self, txn: TransactionAPI, **kwargs) -> bytes: - """ - Execute a new transaction call immediately without creating a - transaction on the block chain. - - Args: - txn: :class:`~ape.api.transactions.TransactionAPI` - kwargs: - * ``block_identifier`` (:class:`~ape.types.BlockID`): The block ID - to use to send a call at a historical point of a contract. Also, - you can us alias ``block_id``. - checking a past estimation cost of a transaction. - * ``state_overrides`` (Dict): Modify the state of the blockchain - prior to sending the call, for testing purposes. - * ``show_trace`` (bool): Set to ``True`` to display the call's - trace. Defaults to ``False``. - * ``show_gas_report (bool): Set to ``True`` to display the call's - gas report. Defaults to ``False``. - * ``skip_trace`` (bool): Set to ``True`` to skip the trace no matter - what. This is useful if you are making a more background contract call - of some sort, such as proxy-checking, and you are running a global - call-tracer such as using the ``--gas`` flag in tests. - - Returns: - str: The result of the transaction call. - """ - skip_trace = kwargs.pop("skip_trace", False) - if skip_trace: - return self._send_call(txn, **kwargs) - - if self._test_runner is not None: - track_gas = self._test_runner.gas_tracker.enabled - track_coverage = self._test_runner.coverage_tracker.enabled - else: - track_gas = False - track_coverage = False - - show_trace = kwargs.pop("show_trace", False) - show_gas = kwargs.pop("show_gas_report", False) - needs_trace = track_gas or track_coverage or show_trace or show_gas - if not needs_trace or not self.provider.supports_tracing or not txn.receiver: - return self._send_call(txn, **kwargs) - - # The user is requesting information related to a call's trace, - # such as gas usage data. - try: - with self.chain_manager.isolate(): - return self._send_call_as_txn( - txn, - track_gas=track_gas, - track_coverage=track_coverage, - show_trace=show_trace, - show_gas=show_gas, - **kwargs, - ) - - except APINotImplementedError: - return self._send_call(txn, **kwargs) - - def _send_call_as_txn( - self, - txn: TransactionAPI, - track_gas: bool = False, - track_coverage: bool = False, - show_trace: bool = False, - show_gas: bool = False, - **kwargs, - ) -> bytes: - account = self.account_manager.test_accounts[0] - receipt = account.call(txn, **kwargs) - if not (call_tree := receipt.call_tree): - return self._send_call(txn, **kwargs) - - # Grab raw returndata before enrichment - returndata = call_tree.outputs - - if (track_gas or track_coverage) and show_gas and not show_trace: - # Optimization to enrich early and in_place=True. - call_tree.enrich() - - if track_gas: - # in_place=False in case show_trace is True - receipt.track_gas() - - if track_coverage: - receipt.track_coverage() - - if show_gas: - # in_place=False in case show_trace is True - self.chain_manager._reports.show_gas(call_tree.enrich(in_place=False)) - - if show_trace: - call_tree = call_tree.enrich(use_symbol_for_tokens=True) - self.chain_manager._reports.show_trace(call_tree) - - return HexBytes(returndata) - - def _send_call(self, txn: TransactionAPI, **kwargs) -> bytes: - arguments = self._prepare_call(txn, **kwargs) - try: - return self._eth_call(arguments) - except TransactionError as err: - if not err.txn: - err.txn = txn - - raise # The tx error - - def _eth_call(self, arguments: List) -> bytes: - # Force the usage of hex-type to support a wider-range of nodes. - txn_dict = copy(arguments[0]) - if isinstance(txn_dict.get("type"), int): - txn_dict["type"] = HexBytes(txn_dict["type"]).hex() - - # Remove unnecessary values to support a wider-range of nodes. - txn_dict.pop("chainId", None) - - arguments[0] = txn_dict - try: - result = self._make_request("eth_call", arguments) - except Exception as err: - receiver = txn_dict["to"] - raise self.get_virtual_machine_error(err, contract_address=receiver) from err - - if "error" in result: - raise ProviderError(result["error"]["message"]) - - return HexBytes(result) - - def _prepare_call(self, txn: TransactionAPI, **kwargs) -> List: - txn_dict = txn.model_dump(mode="json") - fields_to_convert = ("data", "chainId", "value") - for field in fields_to_convert: - value = txn_dict.get(field) - if value is not None and not isinstance(value, str): - txn_dict[field] = to_hex(value) - - # Remove unneeded properties - txn_dict.pop("gas", None) - txn_dict.pop("gasLimit", None) - txn_dict.pop("maxFeePerGas", None) - txn_dict.pop("maxPriorityFeePerGas", None) - - block_identifier = kwargs.pop("block_identifier", kwargs.pop("block_id", "latest")) - if isinstance(block_identifier, int): - block_identifier = to_hex(block_identifier) - - arguments = [txn_dict, block_identifier] - if "state_override" in kwargs: - arguments.append(kwargs["state_override"]) - - return arguments - - def get_receipt( - self, - txn_hash: str, - required_confirmations: int = 0, - timeout: Optional[int] = None, - **kwargs, - ) -> ReceiptAPI: - """ - Get the information about a transaction from a transaction hash. - - Args: - txn_hash (str): The hash of the transaction to retrieve. - required_confirmations (int): The amount of block confirmations - to wait before returning the receipt. Defaults to ``0``. - timeout (Optional[int]): The amount of time to wait for a receipt - before timing out. Defaults ``None``. - - Raises: - :class:`~ape.exceptions.TransactionNotFoundError`: Likely the exception raised - when the transaction receipt is not found (depends on implementation). - - Returns: - :class:`~api.providers.ReceiptAPI`: - The receipt of the transaction with the given hash. - """ - - if required_confirmations < 0: - raise TransactionError("Required confirmations cannot be negative.") - - timeout = ( - timeout if timeout is not None else self.provider.network.transaction_acceptance_timeout - ) - - try: - receipt_data = self.web3.eth.wait_for_transaction_receipt( - HexBytes(txn_hash), timeout=timeout - ) - except TimeExhausted as err: - raise TransactionNotFoundError(txn_hash, error_messsage=str(err)) from err - - network_config: Dict = self.network.config.model_dump(mode="json").get( - self.network.name, {} - ) - max_retries = network_config.get("max_get_transaction_retries", DEFAULT_MAX_RETRIES_TX) - txn = {} - for attempt in range(max_retries): - try: - txn = dict(self.web3.eth.get_transaction(HexStr(txn_hash))) - break - except TransactionNotFound: - if attempt < max_retries - 1: # if this wasn't the last attempt - time.sleep(1) # Wait for 1 second before retrying. - continue # Continue to the next iteration, effectively retrying the operation. - else: # if it was the last attempt - raise # Re-raise the last exception. - - receipt = self.network.ecosystem.decode_receipt( - { - "provider": self, - "required_confirmations": required_confirmations, - **txn, - **receipt_data, - } - ) - return receipt.await_confirmations() - - def get_transactions_by_block(self, block_id: BlockID) -> Iterator[TransactionAPI]: - if isinstance(block_id, str): - block_id = HexStr(block_id) - - if block_id.isnumeric(): - block_id = add_0x_prefix(block_id) - - block = cast(Dict, self.web3.eth.get_block(block_id, full_transactions=True)) - for transaction in block.get("transactions", []): - yield self.network.ecosystem.create_transaction(**transaction) - - def get_transactions_by_account_nonce( - self, - account: AddressType, - start_nonce: int = 0, - stop_nonce: int = -1, - ) -> Iterator[ReceiptAPI]: - if start_nonce > stop_nonce: - raise ValueError("Starting nonce cannot be greater than stop nonce for search") - - if not self.network.is_local and (stop_nonce - start_nonce) > 2: - # NOTE: RPC usage might be acceptable to find 1 or 2 transactions reasonably quickly - logger.warning( - "Performing this action is likely to be very slow and may " - f"use {20 * (stop_nonce - start_nonce)} or more RPC calls. " - "Consider installing an alternative data query provider plugin." - ) - - yield from self._find_txn_by_account_and_nonce( - account, - start_nonce, - stop_nonce, - 0, # first block - self.chain_manager.blocks.head.number or 0, # last block (or 0 if genesis-only chain) - ) - - def _find_txn_by_account_and_nonce( - self, - account: AddressType, - start_nonce: int, - stop_nonce: int, - start_block: int, - stop_block: int, - ) -> Iterator[ReceiptAPI]: - # binary search between `start_block` and `stop_block` to yield txns from account, - # ordered from `start_nonce` to `stop_nonce` - - if start_block == stop_block: - # Honed in on one block where there's a delta in nonce, so must be the right block - for txn in self.get_transactions_by_block(stop_block): - assert isinstance(txn.nonce, int) # NOTE: just satisfying mypy here - if txn.sender == account and txn.nonce >= start_nonce: - yield self.get_receipt(txn.txn_hash.hex()) - - # Nothing else to search for - - else: - # Break up into smaller chunks - # NOTE: biased to `stop_block` - block_number = start_block + (stop_block - start_block) // 2 + 1 - txn_count_prev_to_block = self.web3.eth.get_transaction_count(account, block_number - 1) - - if start_nonce < txn_count_prev_to_block: - yield from self._find_txn_by_account_and_nonce( - account, - start_nonce, - min(txn_count_prev_to_block - 1, stop_nonce), # NOTE: In case >1 txn in block - start_block, - block_number - 1, - ) - - if txn_count_prev_to_block <= stop_nonce: - yield from self._find_txn_by_account_and_nonce( - account, - max(start_nonce, txn_count_prev_to_block), # NOTE: In case >1 txn in block - stop_nonce, - block_number, - stop_block, - ) - - def poll_blocks( - self, - stop_block: Optional[int] = None, - required_confirmations: Optional[int] = None, - new_block_timeout: Optional[int] = None, - ) -> Iterator[BlockAPI]: - # Wait half the time as the block time - # to get data faster. - block_time = self.network.block_time - wait_time = block_time / 2 - - # The timeout occurs when there is no chain activity - # after a certain time. - timeout = ( - (10.0 if self.network.is_dev else 50 * block_time) - if new_block_timeout is None - else new_block_timeout - ) - - # Only yield confirmed blocks. - if required_confirmations is None: - required_confirmations = self.network.required_confirmations - - @dataclass - class YieldAction: - hash: bytes - number: int - time: float - - # Pretend we _did_ yield the last confirmed item, for logic's sake. - fake_last_block = self.get_block(self.web3.eth.block_number - required_confirmations) - last_num = fake_last_block.number or 0 - last_hash = fake_last_block.hash or HexBytes(0) - last = YieldAction(number=last_num, hash=last_hash, time=time.time()) - - # A helper method for various points of ensuring we didn't timeout. - def assert_chain_activity(): - time_waiting = time.time() - last.time - if time_waiting > timeout: - raise ProviderError("Timed out waiting for next block.") - - # Begin the daemon. - while True: - # The next block we want is simply 1 after the last. - next_block = last.number + 1 - - head = self.get_block("latest") - - try: - if head.number is None or head.hash is None: - raise ProviderError("Head block has no number or hash.") - # Use an "adjused" head, based on the required confirmations. - adjusted_head = self.get_block(head.number - required_confirmations) - if adjusted_head.number is None or adjusted_head.hash is None: - raise ProviderError("Adjusted head block has no number or hash.") - except Exception: - # TODO: I did encounter this sometimes in a re-org, needs better handling - # and maybe bubbling up the block number/hash exceptions above. - assert_chain_activity() - continue - - if adjusted_head.number == last.number and adjusted_head.hash == last.hash: - # The chain has not moved! Verify we have activity. - assert_chain_activity() - time.sleep(wait_time) - continue - - elif adjusted_head.number < last.number or ( - adjusted_head.number == last.number and adjusted_head.hash != last.hash - ): - # Re-org detected! Error and catch up the chain. - logger.error( - "Chain has reorganized since returning the last block. " - "Try adjusting the required network confirmations." - ) - # Catch up the chain by setting the "next" to this tiny head. - next_block = adjusted_head.number - - # NOTE: Drop down to code outside of switch-of-ifs - - elif adjusted_head.number < next_block: - # Wait for the next block. - # But first, let's make sure the chain is still active. - assert_chain_activity() - time.sleep(wait_time) - continue - - # NOTE: Should only get here if yielding blocks! - # Either because it is finally time or because a re-org allows us. - for block_idx in range(next_block, adjusted_head.number + 1): - block = self.get_block(block_idx) - if block.number is None or block.hash is None: - raise ProviderError("Block has no number or hash.") - yield block - - # This is the point at which the daemon will end, - # provider the user passes in a `stop_block` arg. - if stop_block is not None and block.number >= stop_block: - return - - # Set the last action, used for checking timeouts and re-orgs. - last = YieldAction( - number=block.number, hash=block.hash, time=time.time() - ) # type: ignore - - def poll_logs( - self, - stop_block: Optional[int] = None, - address: Optional[AddressType] = None, - topics: Optional[List[Union[str, List[str]]]] = None, - required_confirmations: Optional[int] = None, - new_block_timeout: Optional[int] = None, - events: List[EventABI] = [], - ) -> Iterator[ContractLog]: - if required_confirmations is None: - required_confirmations = self.network.required_confirmations - - if stop_block is not None: - if stop_block <= (self.provider.get_block("latest").number or 0): - raise ValueError("'stop' argument must be in the future.") - - for block in self.poll_blocks(stop_block, required_confirmations, new_block_timeout): - if block.number is None: - raise ValueError("Block number cannot be None") - - log_params: Dict[str, Any] = { - "start_block": block.number, - "stop_block": block.number, - "events": events, - } - if address is not None: - log_params["addresses"] = [address] - if topics is not None: - log_params["topics"] = topics - - log_filter = LogFilter(**log_params) - yield from self.get_contract_logs(log_filter) - - def block_ranges(self, start=0, stop=None, page=None): - if stop is None: - stop = self.chain_manager.blocks.height - if page is None: - page = self.block_page_size - - for start_block in range(start, stop + 1, page): - stop_block = min(stop, start_block + page - 1) - yield start_block, stop_block - - def get_contract_creation_receipts( - self, - address: AddressType, - start_block: int = 0, - stop_block: Optional[int] = None, - contract_code: Optional[HexBytes] = None, - ) -> Iterator[ReceiptAPI]: - if stop_block is None: - stop_block = self.chain_manager.blocks.height - - if contract_code is None: - contract_code = HexBytes(self.get_code(address)) - - mid_block = (stop_block - start_block) // 2 + start_block - # NOTE: biased towards mid_block == start_block - - if start_block == mid_block: - for tx in self.chain_manager.blocks[mid_block].transactions: - if (receipt := tx.receipt) and receipt.contract_address == address: - yield receipt - - if mid_block + 1 <= stop_block: - yield from self.get_contract_creation_receipts( - address, - start_block=mid_block + 1, - stop_block=stop_block, - contract_code=contract_code, - ) - - # TODO: Handle when code is nonzero but doesn't match - # TODO: Handle when code is empty after it's not (re-init) - elif HexBytes(self.get_code(address, block_id=mid_block)) == contract_code: - # If the code exists, we need to look backwards. - yield from self.get_contract_creation_receipts( - address, - start_block=start_block, - stop_block=mid_block, - contract_code=contract_code, - ) - - elif mid_block + 1 <= stop_block: - # The code does not exist yet, we need to look ahead. - yield from self.get_contract_creation_receipts( - address, - start_block=mid_block + 1, - stop_block=stop_block, - contract_code=contract_code, - ) - - def get_contract_logs(self, log_filter: LogFilter) -> Iterator[ContractLog]: - height = self.chain_manager.blocks.height - start_block = log_filter.start_block - stop_block_arg = log_filter.stop_block if log_filter.stop_block is not None else height - stop_block = min(stop_block_arg, height) - block_ranges = self.block_ranges(start_block, stop_block, self.block_page_size) - - def fetch_log_page(block_range): - start, stop = block_range - page_filter = log_filter.model_copy(update=dict(start_block=start, stop_block=stop)) - # eth-tester expects a different format, let web3 handle the conversions for it - raw = "EthereumTester" not in self.client_version - logs = self._get_logs(page_filter.model_dump(mode="json"), raw) - return self.network.ecosystem.decode_logs(logs, *log_filter.events) - - with ThreadPoolExecutor(self.concurrency) as pool: - for page in pool.map(fetch_log_page, block_ranges): - yield from page - - def _get_logs(self, filter_params, raw=True) -> List[Dict]: - if not raw: - return [vars(d) for d in self.web3.eth.get_logs(filter_params)] - - return self._make_request("eth_getLogs", [filter_params]) - - def prepare_transaction(self, txn: TransactionAPI) -> TransactionAPI: - # NOTE: Use "expected value" for Chain ID, so if it doesn't match actual, we raise - txn.chain_id = self.network.chain_id - - from ape_ethereum.transactions import StaticFeeTransaction, TransactionType - - txn_type = TransactionType(txn.type) - if ( - txn_type == TransactionType.STATIC - and isinstance(txn, StaticFeeTransaction) - and txn.gas_price is None - ): - txn.gas_price = self.gas_price - elif txn_type == TransactionType.DYNAMIC: - if txn.max_priority_fee is None: - txn.max_priority_fee = self.priority_fee - - if txn.max_fee is None: - multiplier = self.network.base_fee_multiplier - txn.max_fee = int(self.base_fee * multiplier + txn.max_priority_fee) - # else: Assume user specified the correct amount or txn will fail and waste gas - - gas_limit = self.network.gas_limit if txn.gas_limit is None else txn.gas_limit - if gas_limit in (None, "auto") or isinstance(gas_limit, AutoGasLimit): - multiplier = ( - gas_limit.multiplier - if isinstance(gas_limit, AutoGasLimit) - else self.network.auto_gas_multiplier - ) - if multiplier != 1.0: - gas = min(int(self.estimate_gas_cost(txn) * multiplier), self.max_gas) - else: - gas = self.estimate_gas_cost(txn) - - txn.gas_limit = gas - - elif gas_limit == "max": - txn.gas_limit = self.max_gas - - elif gas_limit is not None and isinstance(gas_limit, int): - txn.gas_limit = gas_limit - - if txn.required_confirmations is None: - txn.required_confirmations = self.network.required_confirmations - elif not isinstance(txn.required_confirmations, int) or txn.required_confirmations < 0: - raise TransactionError("'required_confirmations' must be a positive integer.") - - return txn - - def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: - try: - if txn.signature or not txn.sender: - txn_hash = self.web3.eth.send_raw_transaction(txn.serialize_transaction()) - else: - if txn.sender not in self.web3.eth.accounts: - self.chain_manager.provider.unlock_account(txn.sender) - - txn_data = cast(TxParams, txn.model_dump(mode="json")) - txn_hash = self.web3.eth.send_transaction(cast(TxParams, txn_data)) - - except (ValueError, Web3ContractLogicError) as err: - vm_err = self.get_virtual_machine_error(err, txn=txn) - raise vm_err from err - - receipt = self.get_receipt( - txn_hash.hex(), - required_confirmations=( - txn.required_confirmations - if txn.required_confirmations is not None - else self.network.required_confirmations - ), - ) - - # NOTE: Ensure to cache even the failed receipts. - self.chain_manager.history.append(receipt) - - if receipt.failed: - txn_dict = receipt.transaction.model_dump(mode="json") - txn_params = cast(TxParams, txn_dict) - - # Replay txn to get revert reason - try: - self.web3.eth.call(txn_params) - except Exception as err: - vm_err = self.get_virtual_machine_error(err, txn=txn) - vm_err.txn = txn - raise vm_err from err - - logger.info(f"Confirmed {receipt.txn_hash} (total fees paid = {receipt.total_fees_paid})") - return receipt - - def _create_call_tree_node( - self, evm_call: EvmCallTreeNode, txn_hash: Optional[str] = None - ) -> CallTreeNode: - address = evm_call.address - try: - contract_id = str(self.provider.network.ecosystem.decode_address(address)) - except ValueError: - # Use raw value since it is not a real address. - contract_id = address.hex() - - call_type_str: str = getattr(evm_call.call_type, "value", f"{evm_call.call_type}") - input_data = evm_call.calldata if "CREATE" in call_type_str else evm_call.calldata[4:].hex() - return CallTreeNode( - calls=[self._create_call_tree_node(x, txn_hash=txn_hash) for x in evm_call.calls], - call_type=call_type_str, - contract_id=contract_id, - failed=evm_call.failed, - gas_cost=evm_call.gas_cost, - inputs=input_data, - method_id=evm_call.calldata[:4].hex(), - outputs=evm_call.returndata.hex(), - raw=evm_call.model_dump(by_alias=True, mode="json"), - txn_hash=txn_hash, - ) - - def _create_trace_frame(self, evm_frame: EvmTraceFrame) -> TraceFrame: - address_bytes = evm_frame.address - try: - address = ( - self.network.ecosystem.decode_address(address_bytes.hex()) - if address_bytes - else None - ) - except ValueError: - # Might not be a real address. - address = cast(AddressType, address_bytes.hex()) if address_bytes else None - - return TraceFrame( - pc=evm_frame.pc, - op=evm_frame.op, - gas=evm_frame.gas, - gas_cost=evm_frame.gas_cost, - depth=evm_frame.depth, - contract_address=address, - raw=evm_frame.model_dump(by_alias=True, mode="json"), - ) - - def _make_request(self, endpoint: str, parameters: Optional[List] = None) -> Any: - parameters = parameters or [] - coroutine = self.web3.provider.make_request(RPCEndpoint(endpoint), parameters) - result = run_until_complete(coroutine) - - if "error" in result: - error = result["error"] - message = ( - error["message"] if isinstance(error, dict) and "message" in error else str(error) - ) - raise ProviderError(message) - - elif "result" in result: - return result.get("result", {}) - - return result - - def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMachineError: - """ - Get a virtual machine error from an error returned from your RPC. - If from a contract revert / assert statement, you will be given a - special :class:`~ape.exceptions.ContractLogicError` that can be - checked in ``ape.reverts()`` tests. - - **NOTE**: The default implementation is based on ``geth`` output. - ``ProviderAPI`` implementations override when needed. - - Args: - exception (Exception): The error returned from your RPC client. - - Returns: - :class:`~ape.exceptions.VirtualMachineError`: An error representing what - went wrong in the call. - """ - - txn = kwargs.get("txn") - if isinstance(exception, Web3ContractLogicError): - # This happens from `assert` or `require` statements. - return self._handle_execution_reverted(exception, **kwargs) - - if not len(exception.args): - return VirtualMachineError(base_err=exception, **kwargs) - - err_data = exception.args[0] if (hasattr(exception, "args") and exception.args) else None - if isinstance(err_data, str) and "execution reverted" in err_data: - return self._handle_execution_reverted(exception, **kwargs) - - if not isinstance(err_data, dict): - return VirtualMachineError(base_err=exception, **kwargs) - - if not (err_msg := err_data.get("message")): - return VirtualMachineError(base_err=exception, **kwargs) - - if txn is not None and "nonce too low" in str(err_msg): - txn = cast(TransactionAPI, txn) - new_err_msg = f"Nonce '{txn.nonce}' is too low" - return VirtualMachineError( - new_err_msg, base_err=exception, code=err_data.get("code"), **kwargs - ) - - elif "out of gas" in str(err_msg): - return OutOfGasError(code=err_data.get("code"), base_err=exception, **kwargs) - - return VirtualMachineError(str(err_msg), code=err_data.get("code"), **kwargs) - - def _handle_execution_reverted( - self, - exception: Union[Exception, str], - txn: Optional[TransactionAPI] = None, - trace: Optional[Iterator[TraceFrame]] = None, - contract_address: Optional[AddressType] = None, - source_traceback: Optional[SourceTraceback] = None, - ) -> ContractLogicError: - message = str(exception).split(":")[-1].strip() - params: Dict = { - "trace": trace, - "contract_address": contract_address, - "source_traceback": source_traceback, - } - no_reason = message == "execution reverted" - - if isinstance(exception, Web3ContractLogicError) and no_reason: - # Check for custom exception data and use that as the message instead. - # This allows compiler exception enrichment to function. - err_trace = None - try: - if trace: - trace, err_trace = tee(trace) - elif txn: - err_trace = self.provider.get_transaction_trace(txn.txn_hash.hex()) - - trace_ls: List[TraceFrame] = list(err_trace) if err_trace else [] - data = trace_ls[-1].raw if len(trace_ls) > 0 else {} - memory = data.get("memory", []) - return_value = "".join([x[2:] for x in memory[4:]]) - if return_value: - message = f"0x{return_value}" - no_reason = False - - except (ApeException, NotImplementedError): - # Either provider does not support or isn't a custom exception. - pass - - result = ( - ContractLogicError(txn=txn, **params) - if no_reason - else ContractLogicError( - base_err=exception if isinstance(exception, Exception) else None, - revert_message=message, - txn=txn, - **params, - ) - ) - return self.compiler_manager.enrich_error(result) - - class UpstreamProvider(ProviderAPI): """ A provider that can also be set as another provider's upstream. diff --git a/src/ape/cli/__init__.py b/src/ape/cli/__init__.py index 8204b3fcf7..c73ab5a10b 100644 --- a/src/ape/cli/__init__.py +++ b/src/ape/cli/__init__.py @@ -9,8 +9,8 @@ NetworkChoice, OutputFormat, PromptChoice, - get_user_selected_account, output_format_choice, + select_account, ) from ape.cli.commands import NetworkBoundCommand from ape.cli.options import ( @@ -25,10 +25,8 @@ verbosity_option, ) from ape.cli.paramtype import AllFilePaths, Path -from ape.cli.utils import Abort __all__ = [ - "Abort", "account_option", "AccountAliasPromptChoice", "Alias", @@ -38,7 +36,7 @@ "contract_file_paths_argument", "contract_option", "existing_alias_argument", - "get_user_selected_account", + "select_account", "incompatible_with", "network_option", "NetworkBoundCommand", diff --git a/src/ape/cli/arguments.py b/src/ape/cli/arguments.py index eba1a22b47..b9405672c6 100644 --- a/src/ape/cli/arguments.py +++ b/src/ape/cli/arguments.py @@ -34,7 +34,7 @@ def existing_alias_argument(account_type: _ACCOUNT_TYPE_FILTER = None, **kwargs) **kwargs: click.argument overrides. """ - type_ = kwargs.pop("type", Alias(account_type=account_type)) + type_ = kwargs.pop("type", Alias(key=account_type)) return click.argument("alias", type=type_, **kwargs) diff --git a/src/ape/cli/choices.py b/src/ape/cli/choices.py index 20f8c749e5..8e330933ea 100644 --- a/src/ape/cli/choices.py +++ b/src/ape/cli/choices.py @@ -17,25 +17,25 @@ ] -def _get_accounts(account_type: _ACCOUNT_TYPE_FILTER) -> List[AccountAPI]: +def _get_accounts(key: _ACCOUNT_TYPE_FILTER) -> List[AccountAPI]: add_test_accounts = False - if account_type is None: + if key is None: account_list = list(accounts) # Include test accounts at end. add_test_accounts = True - elif isinstance(account_type, type): + elif isinstance(key, type): # Filtering by type. - account_list = accounts.get_accounts_by_type(account_type) + account_list = accounts.get_accounts_by_type(key) - elif isinstance(account_type, (list, tuple, set)): + elif isinstance(key, (list, tuple, set)): # Given an account list. - account_list = account_type # type: ignore + account_list = key # type: ignore else: # Filtering by callable. - account_list = [a for a in accounts if account_type(a)] # type: ignore + account_list = [a for a in accounts if key(a)] # type: ignore sorted_accounts = sorted(account_list, key=lambda a: a.alias or "") if add_test_accounts: @@ -54,15 +54,15 @@ class Alias(click.Choice): name = "alias" - def __init__(self, account_type: _ACCOUNT_TYPE_FILTER = None): + def __init__(self, key: _ACCOUNT_TYPE_FILTER = None): # NOTE: we purposely skip the constructor of `Choice` self.case_sensitive = False - self._account_type = account_type + self._key_filter = key self.choices = _LazySequence(self._choices_iterator) @property def _choices_iterator(self) -> Iterator[str]: - for acct in _get_accounts(account_type=self._account_type): + for acct in _get_accounts(key=self._key_filter): if acct.alias is None: continue @@ -124,7 +124,7 @@ def convert( def fail_from_invalid_choice(self, param): return self.fail("Invalid choice.", param=param) - def get_user_selected_choice(self) -> str: + def select(self) -> str: choices = "\n".join(self.choices) choice = click.prompt(f"Select one of the following:\n{choices}").strip() if not choice.isnumeric(): @@ -138,8 +138,8 @@ def get_user_selected_choice(self) -> str: raise IndexError(f"Choice index '{choice_idx}' out of range.") -def get_user_selected_account( - prompt_message: Optional[str] = None, account_type: _ACCOUNT_TYPE_FILTER = None +def select_account( + prompt_message: Optional[str] = None, key: _ACCOUNT_TYPE_FILTER = None ) -> AccountAPI: """ Prompt the user to pick from their accounts and return that account. @@ -149,7 +149,7 @@ def get_user_selected_account( Args: prompt_message (Optional[str]): Customize the prompt message. - account_type (Union[None, Type[AccountAPI], Callable[[AccountAPI], bool]]): + key (Union[None, Type[AccountAPI], Callable[[AccountAPI], bool]]): If given, the user may only select a matching account. You can provide a list of accounts, an account class type, or a callable for filtering the accounts. @@ -158,11 +158,11 @@ def get_user_selected_account( :class:`~ape.api.accounts.AccountAPI` """ - if account_type and isinstance(account_type, type) and not issubclass(account_type, AccountAPI): - raise AccountsError(f"Cannot return accounts with type '{account_type}'.") + if key and isinstance(key, type) and not issubclass(key, AccountAPI): + raise AccountsError(f"Cannot return accounts with type '{key}'.") - prompt = AccountAliasPromptChoice(prompt_message=prompt_message, account_type=account_type) - return prompt.get_user_selected_account() + prompt = AccountAliasPromptChoice(prompt_message=prompt_message, key=key) + return prompt.select_account() class AccountAliasPromptChoice(PromptChoice): @@ -173,12 +173,12 @@ class AccountAliasPromptChoice(PromptChoice): def __init__( self, - account_type: _ACCOUNT_TYPE_FILTER = None, + key: _ACCOUNT_TYPE_FILTER = None, prompt_message: Optional[str] = None, name: str = "account", ): # NOTE: we purposely skip the constructor of `PromptChoice` - self._account_type = account_type + self._key_filter = key self._prompt_message = prompt_message or "Select an account" self.name = name self.choices = _LazySequence(self._choices_iterator) @@ -231,17 +231,12 @@ def print_choices(self): @property def _choices_iterator(self) -> Iterator[str]: - # Yield real accounts. - for account in _get_accounts(account_type=self._account_type): + # NOTE: Includes test accounts unless filtered out by key. + for account in _get_accounts(key=self._key_filter): if account and (alias := account.alias): yield alias - # Yield test accounts. - if self._account_type is None: - for idx, _ in enumerate(accounts.test_accounts): - yield f"TEST::{idx}" - - def get_user_selected_account(self) -> AccountAPI: + def select_account(self) -> AccountAPI: """ Returns the selected account. @@ -249,7 +244,7 @@ def get_user_selected_account(self) -> AccountAPI: :class:`~ape.api.accounts.AccountAPI` """ - if not self.choices: + if not self.choices or len(self.choices) == 0: raise AccountsError("No accounts found.") elif len(self.choices) == 1 and self.choices[0].startswith("TEST::"): return accounts.test_accounts[int(self.choices[0].replace("TEST::", ""))] diff --git a/src/ape/cli/options.py b/src/ape/cli/options.py index 63ea5556b1..efdc3b236a 100644 --- a/src/ape/cli/options.py +++ b/src/ape/cli/options.py @@ -11,8 +11,7 @@ OutputFormat, output_format_choice, ) -from ape.cli.utils import Abort -from ape.exceptions import ContractError +from ape.exceptions import Abort, ProjectError from ape.logging import DEFAULT_LOG_LEVEL, ApeLogger, LogLevel, logger from ape.managers.base import ManagerAccessMixin @@ -178,7 +177,7 @@ def skip_confirmation_option(help=""): def _account_callback(ctx, param, value): if param and not value: - return param.type.get_user_selected_account() + return param.type.select_account() return value @@ -191,7 +190,7 @@ def account_option(account_type: _ACCOUNT_TYPE_FILTER = None): return click.option( "--account", - type=AccountAliasPromptChoice(account_type=account_type), + type=AccountAliasPromptChoice(key=account_type), callback=_account_callback, ) @@ -201,7 +200,7 @@ def _load_contracts(ctx, param, value) -> Optional[Union[ContractType, List[Cont return None if len(project.contracts) == 0: - raise ContractError("Project has no contracts.") + raise ProjectError("Project has no contracts.") # If the user passed in `multiple=True`, then `value` is a list, # and therefore we should also return a list. @@ -209,7 +208,7 @@ def _load_contracts(ctx, param, value) -> Optional[Union[ContractType, List[Cont def get_contract(contract_name: str) -> ContractType: if contract_name not in project.contracts: - raise ContractError(f"No contract named '{value}'") + raise ProjectError(f"No contract named '{value}'") return project.contracts[contract_name] diff --git a/src/ape/cli/utils.py b/src/ape/cli/utils.py index 52c8e4a314..e69de29bb2 100644 --- a/src/ape/cli/utils.py +++ b/src/ape/cli/utils.py @@ -1,5 +0,0 @@ -from ape.exceptions import Abort - -# TODO: Delete as part of 0.7 breaking changes -# Kept here for backwards compatibility. -__all__ = ["Abort"] diff --git a/src/ape/contracts/base.py b/src/ape/contracts/base.py index 892d13f491..3ef081d83c 100644 --- a/src/ape/contracts/base.py +++ b/src/ape/contracts/base.py @@ -18,8 +18,9 @@ ApeAttributeError, ArgumentsLengthError, ChainError, - ContractError, + ContractDataError, ContractLogicError, + ContractNotFoundError, CustomError, MethodNonPayableError, TransactionNotFoundError, @@ -140,7 +141,8 @@ def encode_input(self, *args) -> HexBytes: def decode_input(self, calldata: bytes) -> Tuple[str, Dict[str, Any]]: matching_abis = [] - err = ContractError( + rest_calldata = None + err = ContractDataError( f"Unable to find matching method ABI for calldata '{calldata.hex()}'. " "Try prepending a method ID to the beginning of the calldata." ) @@ -154,7 +156,7 @@ def decode_input(self, calldata: bytes) -> Tuple[str, Dict[str, Any]]: if len(matching_abis) == 1: abi = matching_abis[0] decoded_input = self.provider.network.ecosystem.decode_calldata( - matching_abis[0], HexBytes(rest_calldata) + matching_abis[0], HexBytes(rest_calldata or "") ) return abi.selector, decoded_input @@ -164,7 +166,6 @@ def decode_input(self, calldata: bytes) -> Tuple[str, Dict[str, Any]]: # Brute-force find method ABI valid_results = [] for abi in self.abis: - decoded_calldata = {} try: decoded_calldata = self.provider.network.ecosystem.decode_calldata( abi, HexBytes(calldata) @@ -185,8 +186,11 @@ def decode_input(self, calldata: bytes) -> Tuple[str, Dict[str, Any]]: class ContractCallHandler(ContractMethodHandler): def __call__(self, *args, **kwargs) -> Any: if not self.contract.is_contract: - network = self.provider.network.name - raise _get_non_contract_error(self.contract.address, network) + raise ContractNotFoundError( + self.contract.address, + self.provider.network.explorer is not None, + self.provider.name, + ) selected_abi = _select_method_abi(self.abis, args) arguments = self.conversion_manager.convert_method_args(selected_abi, args) @@ -351,8 +355,11 @@ def __call__(self, *args, **kwargs) -> ReceiptAPI: def _as_transaction(self, *args) -> ContractTransaction: if not self.contract.is_contract: - network = self.provider.network.name - raise _get_non_contract_error(self.contract.address, network) + raise ContractNotFoundError( + self.contract.address, + self.provider.network.explorer is not None, + self.provider.name, + ) selected_abi = _select_method_abi(self.abis, args) @@ -757,7 +764,7 @@ def decode_input(self, calldata: bytes) -> Tuple[str, Dict[str, Any]]: method = None if not method: - raise ContractError( + raise ContractDataError( f"Unable to find method ABI from calldata '{calldata.hex()}'. " "Try prepending the method ID to the beginning of the calldata." ) @@ -811,7 +818,7 @@ def __call__(self, *args, **kwargs) -> ReceiptAPI: if has_value and has_non_payable_fallback and self.contract_type.receive is None: # User is sending a value when the contract doesn't accept it. - raise ContractError( + raise MethodNonPayableError( "Contract's fallback is non-payable and there is no receive ABI. " "Unable to send value." ) @@ -819,7 +826,9 @@ def __call__(self, *args, **kwargs) -> ReceiptAPI: elif has_value and has_data and has_non_payable_fallback: # User is sending both value and data. When sending data, the fallback # is always triggered. Thus, since it is non-payable, it would fail. - raise ContractError("Sending both value= and data= but fallback is non-payable.") + raise MethodNonPayableError( + "Sending both value= and data= but fallback is non-payable." + ) return super().__call__(*args, **kwargs) @@ -827,7 +836,7 @@ def __call__(self, *args, **kwargs) -> ReceiptAPI: def from_receipt(cls, receipt: ReceiptAPI, contract_type: ContractType) -> "ContractInstance": address = receipt.contract_address if not address: - raise ContractError( + raise ChainError( "Receipt missing 'contract_address' field. " "Was this from a deploy transaction (e.g. `project.MyContract.deploy()`)?" ) @@ -1004,7 +1013,8 @@ def get_event_by_signature(self, signature: str) -> ContractEvent: name_from_sig = signature.split("(")[0].strip() options = self._events_.get(name_from_sig, []) - err = ContractError(f"No event found with signature '{signature}'.") + + err = ContractDataError(f"No event found with signature '{signature}'.") if not options: raise err @@ -1028,7 +1038,7 @@ def get_error_by_signature(self, signature: str) -> Type[CustomError]: name_from_sig = signature.split("(")[0].strip() options = self._errors_.get(name_from_sig, []) - err = ContractError(f"No error found with signature '{signature}'.") + err = ContractDataError(f"No error found with signature '{signature}'.") if not options: raise err @@ -1328,7 +1338,7 @@ def deploy(self, *args, publish: bool = False, **kwargs) -> ContractInstance: address = receipt.contract_address if not address: - raise ContractError(f"'{receipt.txn_hash}' did not create a contract.") + raise ChainError(f"'{receipt.txn_hash}' did not create a contract.") styled_address = click.style(receipt.contract_address, bold=True) contract_name = self.contract_type.name or "" @@ -1385,13 +1395,6 @@ def declare(self, *args, **kwargs) -> ReceiptAPI: return receipt -def _get_non_contract_error(address: str, network_name: str) -> ContractError: - return ContractError( - f"Unable to make contract call. " - f"'{address}' is not a contract on network '{network_name}'." - ) - - class ContractNamespace: """ A class that bridges contract containers in a namespace. diff --git a/src/ape/exceptions.py b/src/ape/exceptions.py index dae2726689..711d8456b8 100644 --- a/src/ape/exceptions.py +++ b/src/ape/exceptions.py @@ -68,13 +68,16 @@ class SignatureError(AccountsError): """ -class ContractError(ApeException): +class ContractDataError(ApeException): """ - Raised when issues occur with contracts. + Raised when issues occur with local contract. + **NOTE**: This error has nothing to do with on-chain + contract logic errors; it is more about ABI-related + issues and alike. """ -class ArgumentsLengthError(ContractError): +class ArgumentsLengthError(ContractDataError): """ Raised when calling a contract method with the wrong number of arguments. """ @@ -85,10 +88,6 @@ def __init__( inputs: Union[MethodABI, ConstructorABI, int, List, None] = None, **kwargs, ): - # For backwards compat. # TODO: Remove in 0.7 - if "inputs_length" in kwargs and not inputs: - inputs = kwargs["inputs_length"] - prefix = ( f"The number of the given arguments ({arguments_length}) " f"do not match what is defined in the ABI" @@ -126,7 +125,7 @@ def __init__( super().__init__(f"{prefix}{suffix}") -class DecodingError(ContractError): +class DecodingError(ContractDataError): """ Raised when issues occur while decoding data from a contract call, transaction, or event. @@ -137,13 +136,13 @@ def __init__(self, message: Optional[str] = None): super().__init__(message) -class MethodNonPayableError(ContractError): +class MethodNonPayableError(ContractDataError): """ Raises when sending funds to a non-payable method """ -class TransactionError(ContractError): +class TransactionError(ApeException): """ Raised when issues occur related to transactions. """ diff --git a/src/ape/logging.py b/src/ape/logging.py index 433b488afb..9f269a58ad 100644 --- a/src/ape/logging.py +++ b/src/ape/logging.py @@ -283,8 +283,5 @@ def sanitize_url(url: str) -> str: logger = ApeLogger.create() -# TODO: Can remove this type alias after 0.7 -CliLogger = ApeLogger - __all__ = ["DEFAULT_LOG_LEVEL", "logger", "LogLevel", "ApeLogger"] diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index 701596d9d3..032d959ef7 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -1372,16 +1372,6 @@ def echo(self, *rich_items, file: Optional[IO[str]] = None): console = self._get_console(file=file) console.print(*rich_items) - @property - def track_gas(self) -> bool: - # TODO: Delete in 0.7; call _test_runner directly if needed. - return self._test_runner is not None and self._test_runner.gas_tracker.enabled - - def append_gas(self, *args, **kwargs): - # TODO: Delete in 0.7 and have all plugins call `_test_runner.gas_tracker.append_gas()`. - if self._test_runner: - self._test_runner.gas_tracker.append_gas(*args, **kwargs) - def show_source_traceback( self, traceback: SourceTraceback, file: Optional[IO[str]] = None, failing: bool = True ): diff --git a/src/ape/managers/converters.py b/src/ape/managers/converters.py index b61db73dfa..be1c77c43f 100644 --- a/src/ape/managers/converters.py +++ b/src/ape/managers/converters.py @@ -145,64 +145,6 @@ def convert(self, value: Any) -> AddressType: return to_checksum_address(to_hex(value)) -class ListTupleConverter(ConverterAPI): - """ - A converter that converts all items in a tuple or list recursively. - """ - - def is_convertible(self, value: Any) -> bool: - return isinstance(value, (list, tuple)) - - def convert(self, value: Union[List, Tuple]) -> Union[List, Tuple]: - """ - Convert the items inside the given list or tuple. - - Args: - value (Union[List, Tuple]): The collection to convert. - - Returns: - Union[list, tuple]: Depending on the input - """ - - converted_value: List[Any] = [] - - for v in value: - # Ignore already-primitive values and assume they are already converted - # (since we don't know the type they are supposed to be, only that - # they should be primitive). Also, ignore dicts because they are handled - # at the ecosystem level. - if ( - (isinstance(v, str) and is_hex(v)) - or isinstance(v, (int, bytes)) - or isinstance(v, dict) - ): - converted_value.append(v) - continue - - # Try all of them to see if one converts it over (only use first one) - conversion_found = False - # NOTE: Double loop required because we might not know the exact type of the inner - # items. The UX of having to specify all inner items seemed poor as well. - for typ in self.conversion_manager._converters: - for check_fn, convert_fn in map( - lambda c: (c.is_convertible, c.convert), - self.conversion_manager._converters[typ], - ): - if check_fn(v): - converted_value.append(convert_fn(v)) - conversion_found = True - break - - if conversion_found: - break - - if not conversion_found: - # NOTE: If no conversions found, just insert the original - converted_value.append(v) - - return value.__class__(converted_value) - - class TimestampConverter(ConverterAPI): """ Converts either a string, datetime object, or a timedelta object to a timestamp. @@ -263,8 +205,6 @@ def _converters(self) -> Dict[Type, List[ConverterAPI]]: bytes: [HexConverter()], int: [TimestampConverter(), HexIntConverter(), StringIntConverter()], Decimal: [], - list: [ListTupleConverter()], - tuple: [ListTupleConverter()], bool: [], str: [], } @@ -372,17 +312,17 @@ def convert_method_args( arguments: Sequence[Any], ): input_types = [i.canonical_type for i in abi.inputs] - pre_processed_args = [] + converted_arguments = [] for ipt, argument in zip(input_types, arguments): # Handle primitive-addresses separately since they may not occur # on the tuple-conversion if they are integers or bytes. if str(ipt) == "address": converted_value = self.convert(argument, AddressType) - pre_processed_args.append(converted_value) + converted_arguments.append(converted_value) else: - pre_processed_args.append(argument) + converted_arguments.append(argument) - return self.convert(pre_processed_args, tuple) + return converted_arguments def convert_method_kwargs(self, kwargs) -> Dict: fields = TransactionAPI.model_fields diff --git a/src/ape/managers/project/dependency.py b/src/ape/managers/project/dependency.py index 83b480efee..8a4d6634d4 100644 --- a/src/ape/managers/project/dependency.py +++ b/src/ape/managers/project/dependency.py @@ -191,12 +191,6 @@ class GithubDependency(DependencyAPI): such as ``dapphub/erc20``. """ - # TODO: Remove at >= 0.7 - branch: Optional[str] = None - """ - **DEPRECATED**: Use ``ref:``. - """ - ref: Optional[str] = None """ The branch or tag to use. @@ -205,21 +199,10 @@ class GithubDependency(DependencyAPI): """ @cached_property - def _reference(self) -> Optional[str]: + def version_id(self) -> str: if self.ref: return self.ref - elif self.branch: - logger.warning("'branch:' config is deprecated. Use 'ref:' instead.") - return self.branch - - return None - - @cached_property - def version_id(self) -> str: - if self._reference: - return self._reference - elif self.version and self.version != "latest": return self.version @@ -232,8 +215,8 @@ def uri(self) -> AnyUrl: if self.version: version = f"v{self.version}" if not self.version.startswith("v") else self.version _uri = f"{_uri}/releases/tag/{version}" - elif self._reference: - _uri = f"{_uri}/tree/{self._reference}" + elif self.ref: + _uri = f"{_uri}/tree/{self.ref}" return HttpUrl(_uri) @@ -250,8 +233,8 @@ def extract_manifest(self, use_cache: bool = True) -> PackageManifest: temp_project_path = (Path(temp_dir) / self.name).resolve() temp_project_path.mkdir(exist_ok=True, parents=True) - if self._reference: - github_client.clone_repo(self.github, temp_project_path, branch=self._reference) + if self.ref: + github_client.clone_repo(self.github, temp_project_path, branch=self.ref) else: try: diff --git a/src/ape/managers/project/manager.py b/src/ape/managers/project/manager.py index 47368cda44..651d7e306d 100644 --- a/src/ape/managers/project/manager.py +++ b/src/ape/managers/project/manager.py @@ -1,6 +1,6 @@ import shutil from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Type, Union, cast +from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union, cast from eth_pydantic_types import Bip122Uri, HexStr from ethpm_types import ContractInstance as EthPMContractInstance @@ -573,9 +573,7 @@ def get_contract(self, contract_name: str) -> ContractContainer: raise ProjectError(f"No contract found with name '{contract_name}'.") - def extensions_with_missing_compilers( - self, extensions: Optional[List[str]] = None - ) -> List[str]: + def extensions_with_missing_compilers(self, extensions: Optional[List[str]] = None) -> Set[str]: """ All file extensions in the ``contracts/`` directory (recursively) that do not correspond to a registered compiler. @@ -585,10 +583,10 @@ def extensions_with_missing_compilers( are in this list. Useful for checking against a subset of source files. Returns: - List[str]: A list of file extensions found in the ``contracts/`` directory + Set[str]: A list of file extensions found in the ``contracts/`` directory that do not have associated compilers installed. """ - extensions_found = [] + extensions_found = set() def _append_extensions_in_dir(directory: Path): if not directory.is_dir(): @@ -597,16 +595,19 @@ def _append_extensions_in_dir(directory: Path): for file in directory.iterdir(): if file.is_dir(): _append_extensions_in_dir(file) - elif ( - file.suffix - and file.suffix not in extensions_found - and file.suffix not in self.compiler_manager.registered_compilers - ): - extensions_found.append(file.suffix) + else: + extensions_found.add(file.suffix) _append_extensions_in_dir(self.contracts_folder) + + # Filter out extensions that have associated compilers. + extensions_found = { + x for x in extensions_found if x not in self.compiler_manager.registered_compilers + } + + # Filer by the given extensions. if extensions: - extensions_found = [e for e in extensions_found if e in extensions] + extensions_found = {e for e in extensions_found if e in extensions} return extensions_found diff --git a/src/ape/types/trace.py b/src/ape/types/trace.py index 8fa06f2c77..e248957cf9 100644 --- a/src/ape/types/trace.py +++ b/src/ape/types/trace.py @@ -697,7 +697,6 @@ def add_builtin_jump( self, name: str, _type: str, - compiler_name: str, full_name: Optional[str] = None, source_path: Optional[Path] = None, pcs: Optional[Set[int]] = None, @@ -710,12 +709,10 @@ def add_builtin_jump( Args: name (str): The name of the compiler built-in. _type (str): A str describing the type of check. - compiler_name (str): The name of the compiler. full_name (Optional[str]): A full-name ID. source_path (Optional[Path]): The source file related, if there is one. pcs (Optional[Set[int]]): Program counter values mapping to this check. """ - # TODO: Assess if compiler_name is needed or get rid of in v0.7. pcs = pcs or set() closure = Closure(name=name, full_name=full_name or name) depth = self.last.depth - 1 if self.last else 0 diff --git a/src/ape_compile/__init__.py b/src/ape_compile/__init__.py index 364807706c..fdcf939999 100644 --- a/src/ape_compile/__init__.py +++ b/src/ape_compile/__init__.py @@ -1,16 +1,12 @@ -from typing import Any, List, Optional +from typing import List -from pydantic import Field, field_validator +from pydantic import field_validator from ape import plugins from ape.api import PluginConfig -from ape.logging import logger class Config(PluginConfig): - # TODO: Remove in 0.7 - evm_version: Optional[Any] = Field(None, exclude=True, repr=False) - include_dependencies: bool = False """ Set to ``True`` to compile dependencies during ``ape compile``. @@ -27,16 +23,6 @@ class Config(PluginConfig): Source exclusion globs across all file types. """ - @field_validator("evm_version") - def warn_deprecate(cls, value): - if value: - logger.warning( - "`evm_version` config is deprecated. " - "Please set in respective compiler plugin config." - ) - - return None - @field_validator("exclude", mode="before") def validate_exclude(cls, value): return value or [] diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index 4e26264d5a..2b85779b7d 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -22,13 +22,7 @@ from ape.api import BlockAPI, EcosystemAPI, PluginConfig, ReceiptAPI, TransactionAPI from ape.api.networks import LOCAL_NETWORK_NAME from ape.contracts.base import ContractCall -from ape.exceptions import ( - ApeException, - APINotImplementedError, - ContractError, - ConversionError, - DecodingError, -) +from ape.exceptions import ApeException, APINotImplementedError, ConversionError, DecodingError from ape.types import ( AddressType, AutoGasLimit, @@ -280,7 +274,7 @@ def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfo]: sequence_pattern = r"363d3d373d3d3d363d30545af43d82803e903d91601857fd5bf3" if re.match(sequence_pattern, code): # the implementation is stored in the slot matching proxy address - slot = self.provider.get_storage_at(address, address) + slot = self.provider.get_storage(address, address) target = self.conversion_manager.convert(slot[-20:], AddressType) return ProxyInfo(type=ProxyType.Sequence, target=target) @@ -296,7 +290,7 @@ def str_to_slot(text): for _type, slot in slots.items(): try: # TODO perf: use a batch call here when ape adds support - storage = self.provider.get_storage_at(address, slot) + storage = self.provider.get_storage(address, slot) except APINotImplementedError: continue @@ -316,7 +310,7 @@ def str_to_slot(text): if safe_pattern.hex() in code: try: singleton = ContractCall(MASTER_COPY_ABI, address)(skip_trace=True) - slot_0 = self.provider.get_storage_at(address, 0) + slot_0 = self.provider.get_storage(address, 0) target = self.conversion_manager.convert(slot_0[-20:], AddressType) # NOTE: `target` is set in initialized proxies if target != ZERO_ADDRESS and target == singleton: @@ -843,7 +837,7 @@ def _enrich_address(self, address: AddressType, **kwargs) -> str: try: symbol = contract.symbol(skip_trace=True) - except ContractError: + except ApeException: symbol = None if isinstance(symbol, str): diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py new file mode 100644 index 0000000000..07eb489b69 --- /dev/null +++ b/src/ape_ethereum/provider.py @@ -0,0 +1,1019 @@ +import time +from abc import ABC +from concurrent.futures import ThreadPoolExecutor +from copy import copy +from functools import cached_property +from itertools import tee +from typing import Any, Dict, Iterator, List, Optional, Union, cast + +from eth_pydantic_types import HexBytes +from eth_typing import BlockNumber, HexStr +from eth_utils import add_0x_prefix, to_hex +from ethpm_types import EventABI +from evm_trace import CallTreeNode as EvmCallTreeNode +from evm_trace import TraceFrame as EvmTraceFrame +from pydantic.dataclasses import dataclass +from web3 import Web3 +from web3.exceptions import ContractLogicError as Web3ContractLogicError +from web3.exceptions import MethodUnavailable, TimeExhausted, TransactionNotFound +from web3.types import RPCEndpoint, TxParams + +from ape.api import BlockAPI, ProviderAPI, ReceiptAPI, TransactionAPI +from ape.api.networks import LOCAL_NETWORK_NAME +from ape.exceptions import ( + ApeException, + APINotImplementedError, + BlockNotFoundError, + ContractLogicError, + OutOfGasError, + ProviderError, + ProviderNotConnectedError, + TransactionError, + TransactionNotFoundError, + VirtualMachineError, +) +from ape.logging import logger, sanitize_url +from ape.types import ( + AddressType, + AutoGasLimit, + BlockID, + CallTreeNode, + ContractCode, + ContractLog, + LogFilter, + SourceTraceback, + TraceFrame, +) +from ape.utils import gas_estimation_error_message, run_until_complete, to_int +from ape.utils.misc import DEFAULT_MAX_RETRIES_TX + + +def _sanitize_web3_url(msg: str) -> str: + if "URI: " not in msg: + return msg + + parts = msg.split("URI: ") + prefix = parts[0].strip() + rest = parts[1].split(" ") + + # * To remove the `,` from the url http://127.0.0.1:8545, + if "," in rest[0]: + rest[0] = rest[0].rstrip(",") + sanitized_url = sanitize_url(rest[0]) + return f"{prefix} URI: {sanitized_url} {' '.join(rest[1:])}" + + +class Web3Provider(ProviderAPI, ABC): + """ + A base provider mixin class that uses the + `web3.py `__ python package. + """ + + _web3: Optional[Web3] = None + _client_version: Optional[str] = None + + def __init__(self, *args, **kwargs): + logger.create_logger("web3.RequestManager", handlers=(_sanitize_web3_url,)) + logger.create_logger("web3.providers.HTTPProvider", handlers=(_sanitize_web3_url,)) + super().__init__(*args, **kwargs) + + @property + def web3(self) -> Web3: + """ + Access to the ``web3`` object as if you did ``Web3(HTTPProvider(uri))``. + """ + + if web3 := self._web3: + return web3 + + raise ProviderNotConnectedError() + + @property + def http_uri(self) -> Optional[str]: + if ( + hasattr(self.web3.provider, "endpoint_uri") + and isinstance(self.web3.provider.endpoint_uri, str) + and self.web3.provider.endpoint_uri.startswith("http") + ): + return self.web3.provider.endpoint_uri + + elif uri := getattr(self, "uri", None): + # NOTE: Some providers define this + return uri + + return None + + @property + def ws_uri(self) -> Optional[str]: + if ( + hasattr(self.web3.provider, "endpoint_uri") + and isinstance(self.web3.provider.endpoint_uri, str) + and self.web3.provider.endpoint_uri.startswith("ws") + ): + return self.web3.provider.endpoint_uri + + return None + + @property + def client_version(self) -> str: + if not self._web3: + return "" + + # NOTE: Gets reset to `None` on `connect()` and `disconnect()`. + if self._client_version is None: + self._client_version = self.web3.client_version + + return self._client_version + + @property + def base_fee(self) -> int: + latest_block_number = self.get_block("latest").number + if latest_block_number is None: + # Possibly no blocks yet. + logger.debug("Latest block has no number. Using base fee of '0'.") + return 0 + + try: + fee_history = self.web3.eth.fee_history(1, BlockNumber(latest_block_number)) + except ValueError as exc: + # Use the less-accurate approach (OK for testing). + logger.debug( + "Failed using `web3.eth.fee_history` for network " + f"'{self.network_choice}'. Error: {exc}" + ) + return self._get_last_base_fee() + + if len(fee_history["baseFeePerGas"]) < 2: + logger.debug("Not enough fee_history. Defaulting less-accurate approach.") + return self._get_last_base_fee() + + pending_base_fee = fee_history["baseFeePerGas"][1] + if pending_base_fee is None: + # Non-EIP-1559 chains or we time-travelled pre-London fork. + return self._get_last_base_fee() + + return pending_base_fee + + def _get_last_base_fee(self) -> int: + block = self.get_block("latest") + base_fee = getattr(block, "base_fee", None) + if base_fee is not None: + return base_fee + + raise APINotImplementedError("No base fee found in block.") + + @property + def is_connected(self) -> bool: + if self._web3 is None: + return False + + return run_until_complete(self._web3.is_connected()) + + @property + def max_gas(self) -> int: + block = self.web3.eth.get_block("latest") + return block["gasLimit"] + + @cached_property + def supports_tracing(self) -> bool: + try: + # NOTE: Txn hash is purposely not a real hash. + # If we get any exception besides not implemented error, + # then we support tracing on this provider. + self.get_call_tree("__CHECK_IF_SUPPORTS_TRACING__") + except APINotImplementedError: + return False + except Exception: + return True + + return True + + def update_settings(self, new_settings: dict): + self.disconnect() + self.provider_settings.update(new_settings) + self.connect() + + def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional[BlockID] = None) -> int: + txn_dict = txn.model_dump(by_alias=True, mode="json") + + # Force the use of hex values to support a wider range of nodes. + if isinstance(txn_dict.get("type"), int): + txn_dict["type"] = HexBytes(txn_dict["type"]).hex() + + # NOTE: "auto" means to enter this method, so remove it from dict + if "gas" in txn_dict and ( + txn_dict["gas"] == "auto" or isinstance(txn_dict["gas"], AutoGasLimit) + ): + txn_dict.pop("gas") + # Also pop these, they are overridden by "auto" + txn_dict.pop("maxFeePerGas", None) + txn_dict.pop("maxPriorityFeePerGas", None) + + txn_params = cast(TxParams, txn_dict) + try: + return self.web3.eth.estimate_gas(txn_params, block_identifier=block_id) + except (ValueError, Web3ContractLogicError) as err: + tx_error = self.get_virtual_machine_error(err, txn=txn) + + # If this is the cause of a would-be revert, + # raise ContractLogicError so that we can confirm tx-reverts. + if isinstance(tx_error, ContractLogicError): + raise tx_error from err + + message = gas_estimation_error_message(tx_error) + raise TransactionError( + message, base_err=tx_error, txn=txn, source_traceback=tx_error.source_traceback + ) from err + + @cached_property + def chain_id(self) -> int: + default_chain_id = None + if ( + self.network.name + not in ( + "adhoc", + LOCAL_NETWORK_NAME, + ) + and not self.network.is_fork + ): + # If using a live network, the chain ID is hardcoded. + default_chain_id = self.network.chain_id + + try: + if hasattr(self.web3, "eth"): + return self.web3.eth.chain_id + + except ProviderNotConnectedError: + if default_chain_id is not None: + return default_chain_id + + raise # Original error + + if default_chain_id is not None: + return default_chain_id + + raise ProviderNotConnectedError() + + @property + def gas_price(self) -> int: + price = self.web3.eth.generate_gas_price() or 0 + return to_int(price) + + @property + def priority_fee(self) -> int: + try: + return self.web3.eth.max_priority_fee + except MethodUnavailable as err: + # The user likely should be using a more-catered plugin. + raise APINotImplementedError( + "eth_maxPriorityFeePerGas not supported in this RPC. Please specify manually." + ) from err + + def get_block(self, block_id: BlockID) -> BlockAPI: + if isinstance(block_id, str) and block_id.isnumeric(): + block_id = int(block_id) + + try: + block_data = dict(self.web3.eth.get_block(block_id)) + except Exception as err: + raise BlockNotFoundError(block_id, reason=str(err)) from err + + # Some nodes (like anvil) will not have a base fee if set to 0. + if "baseFeePerGas" in block_data and block_data.get("baseFeePerGas") is None: + block_data["baseFeePerGas"] = 0 + + return self.network.ecosystem.decode_block(block_data) + + def get_nonce(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + return self.web3.eth.get_transaction_count(address, block_identifier=block_id) + + def get_balance(self, address: AddressType, block_id: Optional[BlockID] = None) -> int: + return self.web3.eth.get_balance(address, block_identifier=block_id) + + def get_code(self, address: AddressType, block_id: Optional[BlockID] = None) -> ContractCode: + return self.web3.eth.get_code(address, block_identifier=block_id) + + def get_storage( + self, address: AddressType, slot: int, block_id: Optional[BlockID] = None + ) -> HexBytes: + try: + return HexBytes(self.web3.eth.get_storage_at(address, slot, block_identifier=block_id)) + except ValueError as err: + if "RPC Endpoint has not been implemented" in str(err): + raise APINotImplementedError(str(err)) from err + + raise # Raise original error + + def send_call( + self, + txn: TransactionAPI, + block_id: Optional[BlockID] = None, + state: Optional[Dict] = None, + **kwargs, + ) -> HexBytes: + if kwargs.pop("skip_trace", False): + return self._send_call(txn, **kwargs) + elif self._test_runner is not None: + track_gas = self._test_runner.gas_tracker.enabled + track_coverage = self._test_runner.coverage_tracker.enabled + else: + track_gas = False + track_coverage = False + + show_trace = kwargs.pop("show_trace", False) + show_gas = kwargs.pop("show_gas_report", False) + needs_trace = track_gas or track_coverage or show_trace or show_gas + if not needs_trace or not self.provider.supports_tracing or not txn.receiver: + return self._send_call(txn, **kwargs) + + # The user is requesting information related to a call's trace, + # such as gas usage data. + try: + with self.chain_manager.isolate(): + return self._send_call_as_txn( + txn, + track_gas=track_gas, + track_coverage=track_coverage, + show_trace=show_trace, + show_gas=show_gas, + **kwargs, + ) + + except APINotImplementedError: + return self._send_call(txn, **kwargs) + + def _send_call_as_txn( + self, + txn: TransactionAPI, + track_gas: bool = False, + track_coverage: bool = False, + show_trace: bool = False, + show_gas: bool = False, + **kwargs, + ) -> HexBytes: + account = self.account_manager.test_accounts[0] + receipt = account.call(txn, **kwargs) + if not (call_tree := receipt.call_tree): + return self._send_call(txn, **kwargs) + + # Grab raw returndata before enrichment + returndata = call_tree.outputs + + if (track_gas or track_coverage) and show_gas and not show_trace: + # Optimization to enrich early and in_place=True. + call_tree.enrich() + + if track_gas: + # in_place=False in case show_trace is True + receipt.track_gas() + + if track_coverage: + receipt.track_coverage() + + if show_gas: + # in_place=False in case show_trace is True + self.chain_manager._reports.show_gas(call_tree.enrich(in_place=False)) + + if show_trace: + call_tree = call_tree.enrich(use_symbol_for_tokens=True) + self.chain_manager._reports.show_trace(call_tree) + + return HexBytes(returndata) + + def _send_call(self, txn: TransactionAPI, **kwargs) -> HexBytes: + arguments = self._prepare_call(txn, **kwargs) + try: + return self._eth_call(arguments) + except TransactionError as err: + if not err.txn: + err.txn = txn + + raise # The tx error + + def _eth_call(self, arguments: List) -> HexBytes: + # Force the usage of hex-type to support a wider-range of nodes. + txn_dict = copy(arguments[0]) + if isinstance(txn_dict.get("type"), int): + txn_dict["type"] = HexBytes(txn_dict["type"]).hex() + + # Remove unnecessary values to support a wider-range of nodes. + txn_dict.pop("chainId", None) + + arguments[0] = txn_dict + try: + result = self._make_request("eth_call", arguments) + except Exception as err: + receiver = txn_dict["to"] + raise self.get_virtual_machine_error(err, contract_address=receiver) from err + + if "error" in result: + raise ProviderError(result["error"]["message"]) + + return HexBytes(result) + + def _prepare_call(self, txn: TransactionAPI, **kwargs) -> List: + txn_dict = txn.model_dump(by_alias=True, mode="json") + fields_to_convert = ("data", "chainId", "value") + for field in fields_to_convert: + value = txn_dict.get(field) + if value is not None and not isinstance(value, str): + txn_dict[field] = to_hex(value) + + # Remove unneeded properties + txn_dict.pop("gas", None) + txn_dict.pop("gasLimit", None) + txn_dict.pop("maxFeePerGas", None) + txn_dict.pop("maxPriorityFeePerGas", None) + + # NOTE: Block ID is required so if given None, default to `"latest"`. + block_identifier = kwargs.pop("block_identifier", kwargs.pop("block_id", None)) or "latest" + if isinstance(block_identifier, int): + block_identifier = to_hex(primitive=block_identifier) + + arguments = [txn_dict, block_identifier] + if "state_override" in kwargs: + arguments.append(kwargs["state_override"]) + + return arguments + + def get_receipt( + self, + txn_hash: str, + required_confirmations: int = 0, + timeout: Optional[int] = None, + **kwargs, + ) -> ReceiptAPI: + if required_confirmations < 0: + raise TransactionError("Required confirmations cannot be negative.") + + timeout = ( + timeout if timeout is not None else self.provider.network.transaction_acceptance_timeout + ) + + hex_hash = HexBytes(txn_hash) + try: + receipt_data = self.web3.eth.wait_for_transaction_receipt(hex_hash, timeout=timeout) + except TimeExhausted as err: + raise TransactionNotFoundError(txn_hash, error_messsage=str(err)) from err + + ecosystem_config = self.network.config.model_dump(by_alias=True, mode="json") + network_config: Dict = ecosystem_config.get(self.network.name, {}) + max_retries = network_config.get("max_get_transaction_retries", DEFAULT_MAX_RETRIES_TX) + txn = {} + for attempt in range(max_retries): + try: + txn = dict(self.web3.eth.get_transaction(HexStr(txn_hash))) + break + except TransactionNotFound: + if attempt < max_retries - 1: # if this wasn't the last attempt + time.sleep(1) # Wait for 1 second before retrying. + continue # Continue to the next iteration, effectively retrying the operation. + else: # if it was the last attempt + raise # Re-raise the last exception. + + receipt = self.network.ecosystem.decode_receipt( + { + "provider": self, + "required_confirmations": required_confirmations, + **txn, + **receipt_data, + } + ) + return receipt.await_confirmations() + + def get_transactions_by_block(self, block_id: BlockID) -> Iterator[TransactionAPI]: + if isinstance(block_id, str): + block_id = HexStr(block_id) + + if block_id.isnumeric(): + block_id = add_0x_prefix(block_id) + + block = cast(Dict, self.web3.eth.get_block(block_id, full_transactions=True)) + for transaction in block.get("transactions", []): + yield self.network.ecosystem.create_transaction(**transaction) + + def get_transactions_by_account_nonce( + self, + account: AddressType, + start_nonce: int = 0, + stop_nonce: int = -1, + ) -> Iterator[ReceiptAPI]: + if start_nonce > stop_nonce: + raise ValueError("Starting nonce cannot be greater than stop nonce for search") + + if not self.network.is_local and (stop_nonce - start_nonce) > 2: + # NOTE: RPC usage might be acceptable to find 1 or 2 transactions reasonably quickly + logger.warning( + "Performing this action is likely to be very slow and may " + f"use {20 * (stop_nonce - start_nonce)} or more RPC calls. " + "Consider installing an alternative data query provider plugin." + ) + + yield from self._find_txn_by_account_and_nonce( + account, + start_nonce, + stop_nonce, + 0, # first block + self.chain_manager.blocks.head.number or 0, # last block (or 0 if genesis-only chain) + ) + + def _find_txn_by_account_and_nonce( + self, + account: AddressType, + start_nonce: int, + stop_nonce: int, + start_block: int, + stop_block: int, + ) -> Iterator[ReceiptAPI]: + # binary search between `start_block` and `stop_block` to yield txns from account, + # ordered from `start_nonce` to `stop_nonce` + + if start_block == stop_block: + # Honed in on one block where there's a delta in nonce, so must be the right block + for txn in self.get_transactions_by_block(stop_block): + assert isinstance(txn.nonce, int) # NOTE: just satisfying mypy here + if txn.sender == account and txn.nonce >= start_nonce: + yield self.get_receipt(txn.txn_hash.hex()) + + # Nothing else to search for + + else: + # Break up into smaller chunks + # NOTE: biased to `stop_block` + block_number = start_block + (stop_block - start_block) // 2 + 1 + txn_count_prev_to_block = self.web3.eth.get_transaction_count(account, block_number - 1) + + if start_nonce < txn_count_prev_to_block: + yield from self._find_txn_by_account_and_nonce( + account, + start_nonce, + min(txn_count_prev_to_block - 1, stop_nonce), # NOTE: In case >1 txn in block + start_block, + block_number - 1, + ) + + if txn_count_prev_to_block <= stop_nonce: + yield from self._find_txn_by_account_and_nonce( + account, + max(start_nonce, txn_count_prev_to_block), # NOTE: In case >1 txn in block + stop_nonce, + block_number, + stop_block, + ) + + def poll_blocks( + self, + stop_block: Optional[int] = None, + required_confirmations: Optional[int] = None, + new_block_timeout: Optional[int] = None, + ) -> Iterator[BlockAPI]: + # Wait half the time as the block time + # to get data faster. + block_time = self.network.block_time + wait_time = block_time / 2 + + # The timeout occurs when there is no chain activity + # after a certain time. + timeout = ( + (10.0 if self.network.is_dev else 50 * block_time) + if new_block_timeout is None + else new_block_timeout + ) + + # Only yield confirmed blocks. + if required_confirmations is None: + required_confirmations = self.network.required_confirmations + + @dataclass + class YieldAction: + hash: bytes + number: int + time: float + + # Pretend we _did_ yield the last confirmed item, for logic's sake. + fake_last_block = self.get_block(self.web3.eth.block_number - required_confirmations) + last_num = fake_last_block.number or 0 + last_hash = fake_last_block.hash or HexBytes(0) + last = YieldAction(number=last_num, hash=last_hash, time=time.time()) + + # A helper method for various points of ensuring we didn't timeout. + def assert_chain_activity(): + time_waiting = time.time() - last.time + if time_waiting > timeout: + raise ProviderError("Timed out waiting for next block.") + + # Begin the daemon. + while True: + # The next block we want is simply 1 after the last. + next_block = last.number + 1 + + head = self.get_block("latest") + + try: + if head.number is None or head.hash is None: + raise ProviderError("Head block has no number or hash.") + # Use an "adjused" head, based on the required confirmations. + adjusted_head = self.get_block(head.number - required_confirmations) + if adjusted_head.number is None or adjusted_head.hash is None: + raise ProviderError("Adjusted head block has no number or hash.") + except Exception: + # TODO: I did encounter this sometimes in a re-org, needs better handling + # and maybe bubbling up the block number/hash exceptions above. + assert_chain_activity() + continue + + if adjusted_head.number == last.number and adjusted_head.hash == last.hash: + # The chain has not moved! Verify we have activity. + assert_chain_activity() + time.sleep(wait_time) + continue + + elif adjusted_head.number < last.number or ( + adjusted_head.number == last.number and adjusted_head.hash != last.hash + ): + # Re-org detected! Error and catch up the chain. + logger.error( + "Chain has reorganized since returning the last block. " + "Try adjusting the required network confirmations." + ) + # Catch up the chain by setting the "next" to this tiny head. + next_block = adjusted_head.number + + # NOTE: Drop down to code outside of switch-of-ifs + + elif adjusted_head.number < next_block: + # Wait for the next block. + # But first, let's make sure the chain is still active. + assert_chain_activity() + time.sleep(wait_time) + continue + + # NOTE: Should only get here if yielding blocks! + # Either because it is finally time or because a re-org allows us. + for block_idx in range(next_block, adjusted_head.number + 1): + block = self.get_block(block_idx) + if block.number is None or block.hash is None: + raise ProviderError("Block has no number or hash.") + yield block + + # This is the point at which the daemon will end, + # provider the user passes in a `stop_block` arg. + if stop_block is not None and block.number >= stop_block: + return + + # Set the last action, used for checking timeouts and re-orgs. + last = YieldAction( + number=block.number, hash=block.hash, time=time.time() + ) # type: ignore + + def poll_logs( + self, + stop_block: Optional[int] = None, + address: Optional[AddressType] = None, + topics: Optional[List[Union[str, List[str]]]] = None, + required_confirmations: Optional[int] = None, + new_block_timeout: Optional[int] = None, + events: List[EventABI] = [], + ) -> Iterator[ContractLog]: + if required_confirmations is None: + required_confirmations = self.network.required_confirmations + + if stop_block is not None: + if stop_block <= (self.provider.get_block("latest").number or 0): + raise ValueError("'stop' argument must be in the future.") + + for block in self.poll_blocks(stop_block, required_confirmations, new_block_timeout): + if block.number is None: + raise ValueError("Block number cannot be None") + + log_params: Dict[str, Any] = { + "start_block": block.number, + "stop_block": block.number, + "events": events, + } + if address is not None: + log_params["addresses"] = [address] + if topics is not None: + log_params["topics"] = topics + + log_filter = LogFilter(**log_params) + yield from self.get_contract_logs(log_filter) + + def block_ranges(self, start: int = 0, stop: Optional[int] = None, page: Optional[int] = None): + if stop is None: + stop = self.chain_manager.blocks.height + if page is None: + page = self.block_page_size + + for start_block in range(start, stop + 1, page): + stop_block = min(stop, start_block + page - 1) + yield start_block, stop_block + + def get_contract_creation_receipts( + self, + address: AddressType, + start_block: int = 0, + stop_block: Optional[int] = None, + contract_code: Optional[HexBytes] = None, + ) -> Iterator[ReceiptAPI]: + if stop_block is None: + stop_block = self.chain_manager.blocks.height + + if contract_code is None: + contract_code = HexBytes(self.get_code(address)) + + mid_block = (stop_block - start_block) // 2 + start_block + # NOTE: biased towards mid_block == start_block + + if start_block == mid_block: + for tx in self.chain_manager.blocks[mid_block].transactions: + if (receipt := tx.receipt) and receipt.contract_address == address: + yield receipt + + if mid_block + 1 <= stop_block: + yield from self.get_contract_creation_receipts( + address, + start_block=mid_block + 1, + stop_block=stop_block, + contract_code=contract_code, + ) + + # TODO: Handle when code is nonzero but doesn't match + # TODO: Handle when code is empty after it's not (re-init) + elif HexBytes(self.get_code(address, block_id=mid_block)) == contract_code: + # If the code exists, we need to look backwards. + yield from self.get_contract_creation_receipts( + address, + start_block=start_block, + stop_block=mid_block, + contract_code=contract_code, + ) + + elif mid_block + 1 <= stop_block: + # The code does not exist yet, we need to look ahead. + yield from self.get_contract_creation_receipts( + address, + start_block=mid_block + 1, + stop_block=stop_block, + contract_code=contract_code, + ) + + def get_contract_logs(self, log_filter: LogFilter) -> Iterator[ContractLog]: + height = self.chain_manager.blocks.height + start_block = log_filter.start_block + stop_block_arg = log_filter.stop_block if log_filter.stop_block is not None else height + stop_block = min(stop_block_arg, height) + block_ranges = self.block_ranges(start_block, stop_block, self.block_page_size) + + def fetch_log_page(block_range): + start, stop = block_range + update = {"start_block": start, "stop_block": stop} + page_filter = log_filter.model_copy(update=update) + # eth-tester expects a different format, let web3 handle the conversions for it. + raw = "EthereumTester" not in self.client_version + logs = self._get_logs(page_filter.model_dump(mode="json"), raw) + return self.network.ecosystem.decode_logs(logs, *log_filter.events) + + with ThreadPoolExecutor(self.concurrency) as pool: + for page in pool.map(fetch_log_page, block_ranges): + yield from page + + def _get_logs(self, filter_params, raw=True) -> List[Dict]: + if not raw: + return [vars(d) for d in self.web3.eth.get_logs(filter_params)] + + return self._make_request("eth_getLogs", [filter_params]) + + def prepare_transaction(self, txn: TransactionAPI) -> TransactionAPI: + # NOTE: Use "expected value" for Chain ID, so if it doesn't match actual, we raise + txn.chain_id = self.network.chain_id + + from ape_ethereum.transactions import StaticFeeTransaction, TransactionType + + txn_type = TransactionType(txn.type) + if ( + txn_type == TransactionType.STATIC + and isinstance(txn, StaticFeeTransaction) + and txn.gas_price is None + ): + txn.gas_price = self.gas_price + elif txn_type == TransactionType.DYNAMIC: + if txn.max_priority_fee is None: + txn.max_priority_fee = self.priority_fee + + if txn.max_fee is None: + multiplier = self.network.base_fee_multiplier + txn.max_fee = int(self.base_fee * multiplier + txn.max_priority_fee) + # else: Assume user specified the correct amount or txn will fail and waste gas + + gas_limit = self.network.gas_limit if txn.gas_limit is None else txn.gas_limit + if gas_limit in (None, "auto") or isinstance(gas_limit, AutoGasLimit): + multiplier = ( + gas_limit.multiplier + if isinstance(gas_limit, AutoGasLimit) + else self.network.auto_gas_multiplier + ) + if multiplier != 1.0: + gas = min(int(self.estimate_gas_cost(txn) * multiplier), self.max_gas) + else: + gas = self.estimate_gas_cost(txn) + + txn.gas_limit = gas + + elif gas_limit == "max": + txn.gas_limit = self.max_gas + + elif gas_limit is not None and isinstance(gas_limit, int): + txn.gas_limit = gas_limit + + if txn.required_confirmations is None: + txn.required_confirmations = self.network.required_confirmations + elif not isinstance(txn.required_confirmations, int) or txn.required_confirmations < 0: + raise TransactionError("'required_confirmations' must be a positive integer.") + + return txn + + def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: + try: + if txn.signature or not txn.sender: + txn_hash = self.web3.eth.send_raw_transaction(txn.serialize_transaction()) + else: + if txn.sender not in self.web3.eth.accounts: + self.chain_manager.provider.unlock_account(txn.sender) + + txn_data = cast(TxParams, txn.model_dump(by_alias=True, mode="json")) + txn_hash = self.web3.eth.send_transaction(txn_data) + + except (ValueError, Web3ContractLogicError) as err: + vm_err = self.get_virtual_machine_error(err, txn=txn) + raise vm_err from err + + receipt = self.get_receipt( + txn_hash.hex(), + required_confirmations=( + txn.required_confirmations + if txn.required_confirmations is not None + else self.network.required_confirmations + ), + ) + + # NOTE: Ensure to cache even the failed receipts. + self.chain_manager.history.append(receipt) + + if receipt.failed: + txn_dict = receipt.transaction.model_dump(by_alias=True, mode="json") + txn_params = cast(TxParams, txn_dict) + + # Replay txn to get revert reason + try: + self.web3.eth.call(txn_params) + except Exception as err: + vm_err = self.get_virtual_machine_error(err, txn=txn) + raise vm_err from err + + logger.info(f"Confirmed {receipt.txn_hash} (total fees paid = {receipt.total_fees_paid})") + return receipt + + def _create_call_tree_node( + self, evm_call: EvmCallTreeNode, txn_hash: Optional[str] = None + ) -> CallTreeNode: + address = evm_call.address + try: + contract_id = str(self.provider.network.ecosystem.decode_address(address)) + except ValueError: + # Use raw value since it is not a real address. + contract_id = address.hex() + + call_type = evm_call.call_type.value + return CallTreeNode( + calls=[self._create_call_tree_node(x, txn_hash=txn_hash) for x in evm_call.calls], + call_type=call_type, + contract_id=contract_id, + failed=evm_call.failed, + gas_cost=evm_call.gas_cost, + inputs=evm_call.calldata if "CREATE" in call_type else evm_call.calldata[4:].hex(), + method_id=evm_call.calldata[:4].hex(), + outputs=evm_call.returndata.hex(), + raw=evm_call.model_dump(by_alias=True, mode="json"), + txn_hash=txn_hash, + ) + + def _create_trace_frame(self, evm_frame: EvmTraceFrame) -> TraceFrame: + address_bytes = evm_frame.address + try: + address = ( + self.network.ecosystem.decode_address(address_bytes.hex()) + if address_bytes + else None + ) + except ValueError: + # Might not be a real address. + address = cast(AddressType, address_bytes.hex()) if address_bytes else None + + return TraceFrame( + pc=evm_frame.pc, + op=evm_frame.op, + gas=evm_frame.gas, + gas_cost=evm_frame.gas_cost, + depth=evm_frame.depth, + contract_address=address, + raw=evm_frame.model_dump(by_alias=True, mode="json"), + ) + + def _make_request(self, endpoint: str, parameters: Optional[List] = None) -> Any: + parameters = parameters or [] + coroutine = self.web3.provider.make_request(RPCEndpoint(endpoint), parameters) + result = run_until_complete(coroutine) + + if "error" in result: + error = result["error"] + message = ( + error["message"] if isinstance(error, dict) and "message" in error else str(error) + ) + raise ProviderError(message) + + elif "result" in result: + return result.get("result", {}) + + return result + + def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMachineError: + txn = kwargs.get("txn") + if isinstance(exception, Web3ContractLogicError): + # This happens from `assert` or `require` statements. + return self._handle_execution_reverted(exception, **kwargs) + + if not len(exception.args): + return VirtualMachineError(base_err=exception, **kwargs) + + err_data = exception.args[0] if (hasattr(exception, "args") and exception.args) else None + if isinstance(err_data, str) and "execution reverted" in err_data: + return self._handle_execution_reverted(exception, **kwargs) + + elif not isinstance(err_data, dict): + return VirtualMachineError(base_err=exception, **kwargs) + + elif not (err_msg := err_data.get("message")): + return VirtualMachineError(base_err=exception, **kwargs) + + elif txn is not None and "nonce too low" in str(err_msg): + txn = cast(TransactionAPI, txn) + new_err_msg = f"Nonce '{txn.nonce}' is too low" + return VirtualMachineError( + new_err_msg, base_err=exception, code=err_data.get("code"), **kwargs + ) + + elif "out of gas" in str(err_msg): + return OutOfGasError(code=err_data.get("code"), base_err=exception, **kwargs) + + return VirtualMachineError(str(err_msg), code=(err_data or {}).get("code"), **kwargs) + + def _handle_execution_reverted( + self, + exception: Union[Exception, str], + txn: Optional[TransactionAPI] = None, + trace: Optional[Iterator[TraceFrame]] = None, + contract_address: Optional[AddressType] = None, + source_traceback: Optional[SourceTraceback] = None, + ) -> ContractLogicError: + message = str(exception).split(":")[-1].strip() + params: Dict = { + "trace": trace, + "contract_address": contract_address, + "source_traceback": source_traceback, + } + no_reason = message == "execution reverted" + + if isinstance(exception, Web3ContractLogicError) and no_reason: + # Check for custom exception data and use that as the message instead. + # This allows compiler exception enrichment to function. + err_trace = None + try: + if trace: + trace, err_trace = tee(trace) + elif txn: + err_trace = self.provider.get_transaction_trace(txn.txn_hash.hex()) + + trace_ls: List[TraceFrame] = list(err_trace) if err_trace else [] + data = trace_ls[-1].raw if len(trace_ls) > 0 else {} + memory = data.get("memory", []) + return_value = "".join([x[2:] for x in memory[4:]]) + if return_value: + message = f"0x{return_value}" + no_reason = False + + except (ApeException, NotImplementedError): + # Either provider does not support or isn't a custom exception. + pass + + result = ( + ContractLogicError(txn=txn, **params) + if no_reason + else ContractLogicError( + base_err=exception if isinstance(exception, Exception) else None, + revert_message=message, + txn=txn, + **params, + ) + ) + return self.compiler_manager.enrich_error(result) diff --git a/src/ape_geth/provider.py b/src/ape_geth/provider.py index e80a3fb1c8..d047dba355 100644 --- a/src/ape_geth/provider.py +++ b/src/ape_geth/provider.py @@ -44,7 +44,6 @@ TestProviderAPI, TransactionAPI, UpstreamProvider, - Web3Provider, ) from ape.exceptions import ( ApeException, @@ -53,7 +52,7 @@ ProviderError, ) from ape.logging import LogLevel, logger, sanitize_url -from ape.types import AddressType, CallTreeNode, SnapshotID, SourceTraceback, TraceFrame +from ape.types import AddressType, BlockID, CallTreeNode, SnapshotID, SourceTraceback, TraceFrame from ape.utils import ( DEFAULT_NUMBER_OF_TEST_ACCOUNTS, DEFAULT_TEST_CHAIN_ID, @@ -63,6 +62,7 @@ raises_not_implemented, spawn, ) +from ape_ethereum.provider import Web3Provider DEFAULT_PORT = 8545 DEFAULT_HOSTNAME = "localhost" @@ -580,7 +580,19 @@ def set_timestamp(self, new_timestamp: int): def mine(self, num_blocks: int = 1): pass - def send_call(self, txn: TransactionAPI, **kwargs: Any) -> bytes: + def send_call( + self, + txn: TransactionAPI, + block_id: Optional[BlockID] = None, + state: Optional[Dict] = None, + **kwargs: Any, + ) -> HexBytes: + if block_id is not None: + kwargs["block_identifier"] = block_id + + if state is not None: + kwargs["state_override"] = state + skip_trace = kwargs.pop("skip_trace", False) arguments = self._prepare_call(txn, **kwargs) if skip_trace: @@ -662,7 +674,7 @@ def _trace_call(self, arguments: List[Any]) -> Tuple[Dict, Iterator[EvmTraceFram trace_data = result.get("structLogs", []) return result, create_trace_frames(trace_data) - def _eth_call(self, arguments: List) -> bytes: + def _eth_call(self, arguments: List) -> HexBytes: try: result = self._make_request("eth_call", arguments) except Exception as err: diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index 6cd76af3cb..f803671ed3 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -16,7 +16,7 @@ from web3.providers.eth_tester.defaults import API_ENDPOINTS, static_return from web3.types import TxParams -from ape.api import PluginConfig, ReceiptAPI, TestProviderAPI, TransactionAPI, Web3Provider +from ape.api import PluginConfig, ReceiptAPI, TestProviderAPI, TransactionAPI from ape.exceptions import ( ContractLogicError, ProviderError, @@ -25,8 +25,9 @@ UnknownSnapshotError, VirtualMachineError, ) -from ape.types import SnapshotID +from ape.types import BlockID, SnapshotID from ape.utils import DEFAULT_TEST_CHAIN_ID, gas_estimation_error_message +from ape_ethereum.provider import Web3Provider class EthTesterProviderConfig(PluginConfig): @@ -78,7 +79,9 @@ def update_settings(self, new_settings: Dict): self.disconnect() self.connect() - def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: + def estimate_gas_cost( + self, txn: TransactionAPI, block_id: Optional[BlockID] = None, **kwargs + ) -> int: if isinstance(self.network.gas_limit, int): return self.network.gas_limit @@ -86,7 +89,6 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: block = self.web3.eth.get_block("latest") return block["gasLimit"] - block_id = kwargs.pop("block_identifier", kwargs.pop("block_id", None)) estimate_gas = self.web3.eth.estimate_gas txn_dict = txn.model_dump(mode="json") txn_dict.pop("gas", None) @@ -103,7 +105,7 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: # and then set it back. expected_nonce, actual_nonce = gas_match.groups() txn.nonce = int(expected_nonce) - txn_params: TxParams = cast(TxParams, txn.model_dump(mode="json")) + txn_params: TxParams = cast(TxParams, txn.model_dump(by_alias=True, mode="json")) value = estimate_gas(txn_params, block_identifier=block_id) txn.nonce = int(actual_nonce) return value @@ -148,11 +150,16 @@ def base_fee(self) -> int: """ return self._get_last_base_fee() - def send_call(self, txn: TransactionAPI, **kwargs) -> bytes: + def send_call( + self, + txn: TransactionAPI, + block_id: Optional[BlockID] = None, + state: Optional[Dict] = None, + **kwargs, + ) -> HexBytes: data = txn.model_dump(mode="json", exclude_none=True) - block_id = kwargs.pop("block_identifier", kwargs.pop("block_id", None)) state = kwargs.pop("state_override", None) - call_kwargs = {"block_identifier": block_id, "state_override": state} + call_kwargs: Dict = {"block_identifier": block_id, "state_override": state} # Remove unneeded properties data.pop("gas", None) @@ -170,7 +177,7 @@ def send_call(self, txn: TransactionAPI, **kwargs) -> bytes: raise self.get_virtual_machine_error(err, txn=txn) from err self._increment_call_func_coverage_hit_count(txn) - return result + return HexBytes(result) def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: try: diff --git a/tests/functional/conversion/test_ether.py b/tests/functional/conversion/test_ether.py index 43b2b3a488..5d5cfda3ee 100644 --- a/tests/functional/conversion/test_ether.py +++ b/tests/functional/conversion/test_ether.py @@ -28,8 +28,7 @@ def test_bad_type(convert): convert(value="something", type=float) expected = ( - "Type '' must be one of " - "[AddressType, bytes, int, Decimal, list, tuple, bool, str]." + "Type '' must be one of " "[AddressType, bytes, int, Decimal, bool, str]." ) assert str(err.value) == expected @@ -39,7 +38,3 @@ def test_no_registered_converter(convert): convert(value="something", type=ChecksumAddress) assert str(err.value) == "No conversion registered to handle 'something'." - - -def test_lists(convert): - assert convert(["1 ether"], list) == [int(1e18)] diff --git a/tests/functional/test_cli.py b/tests/functional/test_cli.py index 3437760f4e..90fc5fe161 100644 --- a/tests/functional/test_cli.py +++ b/tests/functional/test_cli.py @@ -10,9 +10,9 @@ account_option, contract_file_paths_argument, existing_alias_argument, - get_user_selected_account, network_option, non_existing_alias_argument, + select_account, verbosity_option, ) from ape.exceptions import AccountsError @@ -103,60 +103,58 @@ def get_expected_account_str(acct): return f"__expected_output__: {acct.address}" -def test_get_user_selected_account_no_accounts_found(no_accounts): +def test_select_account_no_accounts_found(no_accounts): with pytest.raises(AccountsError, match="No accounts found."): - assert not get_user_selected_account() + assert not select_account() -def test_get_user_selected_account_one_account(runner, one_account): +def test_select_account_one_account(runner, one_account): # No input needed when only one account - with runner.isolation("0\n"): - account = get_user_selected_account() - + account = select_account() assert account == one_account -def test_get_user_selected_account_multiple_accounts_requires_input( +def test_select_account_multiple_accounts_requires_input( runner, keyfile_account, second_keyfile_account ): with runner.isolation(input="0\n"): - account = get_user_selected_account() + account = select_account() assert account == keyfile_account -def test_get_user_selected_account_custom_prompt(runner, keyfile_account, second_keyfile_account): +def test_select_account_custom_prompt(runner, keyfile_account, second_keyfile_account): prompt = "THIS_IS_A_CUSTOM_PROMPT" with runner.isolation(input="0\n") as out_streams: - get_user_selected_account(prompt) + select_account(prompt) output = out_streams[0].getvalue().decode() assert prompt in output -def test_get_user_selected_account_specify_type(runner, one_keyfile_account): - account = get_user_selected_account(account_type=type(one_keyfile_account)) +def test_select_account_specify_type(runner, one_keyfile_account): + with runner.isolation(): + account = select_account(key=type(one_keyfile_account)) + assert account == one_keyfile_account -def test_get_user_selected_account_unknown_type(runner, keyfile_account): +def test_select_account_unknown_type(runner, keyfile_account): with pytest.raises(AccountsError) as err: - get_user_selected_account(account_type=str) # type: ignore + select_account(key=str) # type: ignore assert "Cannot return accounts with type ''" in str(err.value) -def test_get_user_selected_account_with_account_list( - runner, keyfile_account, second_keyfile_account -): - account = get_user_selected_account(account_type=[keyfile_account]) +def test_select_account_with_account_list(runner, keyfile_account, second_keyfile_account): + account = select_account(key=[keyfile_account]) assert account == keyfile_account - account = get_user_selected_account(account_type=[second_keyfile_account]) + account = select_account(key=[second_keyfile_account]) assert account == second_keyfile_account with runner.isolation(input="1\n"): - account = get_user_selected_account(account_type=[keyfile_account, second_keyfile_account]) + account = select_account(key=[keyfile_account, second_keyfile_account]) assert account == second_keyfile_account @@ -297,7 +295,7 @@ def test_prompt_choice(runner, opt): """ def choice_callback(ctx, param, value): - return param.type.get_user_selected_choice() + return param.type.select() choice = PromptChoice(["foo", "bar"]) assert hasattr(choice, "name") diff --git a/tests/functional/test_compilers.py b/tests/functional/test_compilers.py index 47b3c003b8..8460e4f933 100644 --- a/tests/functional/test_compilers.py +++ b/tests/functional/test_compilers.py @@ -14,7 +14,7 @@ def test_get_imports(project, compilers): def test_missing_compilers_without_source_files(project): result = project.extensions_with_missing_compilers() - assert result == [] + assert result == set() @skip_if_plugin_installed("vyper", "solidity") diff --git a/tests/functional/test_contract_instance.py b/tests/functional/test_contract_instance.py index 5bc90b429d..366ed54a8d 100644 --- a/tests/functional/test_contract_instance.py +++ b/tests/functional/test_contract_instance.py @@ -13,7 +13,7 @@ APINotImplementedError, ArgumentsLengthError, ChainError, - ContractError, + ContractDataError, ContractLogicError, CustomError, MethodNonPayableError, @@ -105,21 +105,17 @@ def test_call_view_method(owner, contract_instance): assert contract_instance.myNumber() == 2 -def test_call_use_block_identifier(contract_instance, owner, chain): +def test_call_use_block_id(contract_instance, owner, chain): expected = 2 contract_instance.setNumber(expected, sender=owner) block_id = chain.blocks.height # int contract_instance.setNumber(3, sender=owner) # latest - actual = contract_instance.myNumber(block_identifier=block_id) - assert actual == expected - - # Ensure alias "block_id" works actual = contract_instance.myNumber(block_id=block_id) assert actual == expected # Ensure works with hex block_id = to_hex(block_id) - actual = contract_instance.myNumber(block_identifier=block_id) + actual = contract_instance.myNumber(block_id=block_id) assert actual == expected # Ensure alias "block_id" works. @@ -127,7 +123,7 @@ def test_call_use_block_identifier(contract_instance, owner, chain): assert actual == expected # Ensure works keywords like "latest" - actual = contract_instance.myNumber(block_identifier="latest") + actual = contract_instance.myNumber(block_id="latest") assert actual == 3 @@ -184,16 +180,13 @@ def test_revert_custom_exception(not_owner, error_contract): assert custom_err.inputs == {"addr": addr, "counter": 123} # type: ignore -def test_call_using_block_identifier( - vyper_contract_instance, owner, chain, networks_connected_to_tester -): +def test_call_using_block_id(vyper_contract_instance, owner, chain, networks_connected_to_tester): contract = vyper_contract_instance contract.setNumber(1, sender=owner) height = chain.blocks.height contract.setNumber(33, sender=owner) - actual_0 = contract.myNumber(block_identifier=height) - actual_1 = contract.myNumber(block_id=height) - assert actual_0 == actual_1 == 1 + actual = contract.myNumber(block_id=height) + assert actual == 1 def test_repr(vyper_contract_instance): @@ -552,7 +545,7 @@ def test_from_receipt_when_receipt_not_deploy(contract_instance, owner): "Receipt missing 'contract_address' field. " "Was this from a deploy transaction (e.g. `project.MyContract.deploy()`)?" ) - with pytest.raises(ContractError, match=expected_err): + with pytest.raises(ChainError, match=expected_err): ContractInstance.from_receipt(receipt, contract_instance.contract_type) @@ -667,7 +660,7 @@ def test_decode_ambiguous_input(solidity_contract_instance, calldata_with_addres f"Unable to find matching method ABI for calldata '{anonymous_calldata.hex()}'. " "Try prepending a method ID to the beginning of the calldata." ) - with pytest.raises(ContractError, match=expected): + with pytest.raises(ContractDataError, match=expected): method.decode_input(anonymous_calldata) @@ -769,7 +762,7 @@ def test_value_to_non_payable_fallback_and_no_receive( expected = ( r"Contract's fallback is non-payable and there is no receive ABI\. Unable to send value\." ) - with pytest.raises(ContractError, match=expected): + with pytest.raises(MethodNonPayableError, match=expected): contract(sender=owner, value=1) # Show can bypass by using `as_transaction()` and `owner.call()`. @@ -789,7 +782,7 @@ def test_fallback_with_data_and_value_and_receive(solidity_fallback_contract, ow is non-payable. """ expected = "Sending both value= and data= but fallback is non-payable." - with pytest.raises(ContractError, match=expected): + with pytest.raises(MethodNonPayableError, match=expected): solidity_fallback_contract(sender=owner, data="0x123", value=1) # Show can bypass by using `as_transaction()` and `owner.call()`. diff --git a/tests/functional/test_exceptions.py b/tests/functional/test_exceptions.py index 7ee405e341..644467e3bb 100644 --- a/tests/functional/test_exceptions.py +++ b/tests/functional/test_exceptions.py @@ -1,8 +1,7 @@ import re from ape.api import ReceiptAPI -from ape.cli import Abort -from ape.exceptions import TransactionError +from ape.exceptions import Abort, TransactionError from ape_ethereum.transactions import Receipt diff --git a/tests/functional/test_provider.py b/tests/functional/test_provider.py index 45b04f7e8b..4ac944ea4e 100644 --- a/tests/functional/test_provider.py +++ b/tests/functional/test_provider.py @@ -6,10 +6,10 @@ from eth_utils import ValidationError from web3.exceptions import ContractPanicError -from ape.api.providers import _sanitize_web3_url from ape.exceptions import BlockNotFoundError, ContractLogicError, TransactionNotFoundError from ape.types import LogFilter from ape.utils import DEFAULT_TEST_CHAIN_ID +from ape_ethereum.provider import _sanitize_web3_url from ape_ethereum.transactions import TransactionType