From a9b73196cafdae905f3d99013dd555c66594a64e Mon Sep 17 00:00:00 2001 From: antazoey Date: Thu, 18 Apr 2024 12:22:54 -0600 Subject: [PATCH] refactor!: TraceAPI (#1864) --- docs/userguides/testing.md | 25 +- docs/userguides/trace.md | 114 ++++ docs/userguides/transactions.md | 119 +--- setup.py | 1 + src/ape/api/__init__.py | 2 + src/ape/api/compiler.py | 33 +- src/ape/api/networks.py | 12 +- src/ape/api/projects.py | 6 +- src/ape/api/providers.py | 58 +- src/ape/api/trace.py | 60 ++ src/ape/api/transactions.py | 36 +- src/ape/exceptions.py | 32 +- src/ape/managers/chain.py | 66 +- src/ape/pytest/gas.py | 11 +- src/ape/types/__init__.py | 4 +- src/ape/types/trace.py | 230 +------ src/ape/utils/trace.py | 214 +------ src/ape_ethereum/_print.py | 88 ++- src/ape_ethereum/ecosystem.py | 206 +++--- src/ape_ethereum/provider.py | 401 ++++-------- src/ape_ethereum/trace.py | 634 +++++++++++++++++++ src/ape_ethereum/transactions.py | 91 +-- src/ape_node/provider.py | 128 +--- src/ape_test/provider.py | 8 +- tests/conftest.py | 12 + tests/functional/geth/conftest.py | 5 - tests/functional/geth/test_contract.py | 14 +- tests/functional/geth/test_contract_event.py | 10 + tests/functional/geth/test_gas_tracker.py | 23 + tests/functional/geth/test_provider.py | 68 +- tests/functional/geth/test_receipt.py | 54 ++ tests/functional/geth/test_trace.py | 158 ++++- tests/functional/test_gas_tracker.py | 30 + tests/functional/test_network_manager.py | 4 +- tests/functional/test_provider.py | 22 +- tests/functional/test_receipt.py | 50 +- tests/functional/test_trace.py | 96 +++ tests/functional/test_transaction.py | 1 + tests/functional/utils/test_trace.py | 15 - tests/integration/cli/test_test.py | 14 +- 40 files changed, 1767 insertions(+), 1388 deletions(-) create mode 100644 docs/userguides/trace.md create mode 100644 src/ape/api/trace.py create mode 100644 src/ape_ethereum/trace.py create mode 100644 tests/functional/geth/test_contract_event.py create mode 100644 tests/functional/geth/test_gas_tracker.py create mode 100644 tests/functional/geth/test_receipt.py create mode 100644 tests/functional/test_gas_tracker.py create mode 100644 tests/functional/test_trace.py delete mode 100644 tests/functional/utils/test_trace.py diff --git a/docs/userguides/testing.md b/docs/userguides/testing.md index 93aac64122..8533f0af53 100644 --- a/docs/userguides/testing.md +++ b/docs/userguides/testing.md @@ -310,6 +310,8 @@ Similar to `pytest.raises()`, you can use `ape.reverts()` to assert that contrac From our earlier example we can see this in action: ```python +import ape + def test_authorization(my_contract, owner, not_owner): my_contract.set_owner(sender=owner) assert owner == my_contract.owner() @@ -328,6 +330,9 @@ If the message in the `ContractLogicError` raised by the transaction failure is You may also supply an `re.Pattern` object to assert on a message pattern, rather than on an exact match. ```python +import ape +import re + # Matches explicitly "foo" or "bar" with ape.reverts(re.compile(r"^(foo|bar)$")): ... @@ -358,6 +363,8 @@ def check_value(_value: uint256) -> bool: We can explicitly cause a transaction revert and check the failed line by supplying an expected `dev_message`: ```python +import ape + def test_authorization(my_contract, owner): with ape.reverts(dev_message="dev: invalid value"): my_contract.check_value(sender=owner) @@ -376,6 +383,8 @@ Because `dev_message` relies on transaction tracing to function, you must use a You may also supply an `re.Pattern` object to assert on a dev message pattern, rather than on an exact match. ```python +import ape + # Matches explictly "dev: foo" or "dev: bar" with ape.reverts(dev_message=re.compile(r"^dev: (foo|bar)$")): ... @@ -511,12 +520,10 @@ To run an entire test using a specific network / provider combination, use the ` ```python import pytest - @pytest.mark.use_network("fantom:local:test") def test_my_fantom_test(chain): assert chain.provider.network.ecosystem.name == "fantom" - @pytest.mark.use_network("ethereum:local:test") def test_my_ethereum_test(chain): assert chain.provider.network.ecosystem.name == "ethereum" @@ -544,13 +551,11 @@ This is useful if certain fixtures must run in certain networks. ```python import pytest - @pytest.fixture def stark_contract(networks, project): with networks.parse_network_choice("starknet:local"): yield project.MyStarknetContract.deploy() - def test_starknet_thing(stark_contract, stark_account): # Uses the starknet connection via the stark_contract fixture receipt = stark_contract.my_method(sender=stark_account) @@ -565,10 +570,11 @@ Thus, you can enter and exit a provider's context as much as you need in tests. ## Gas Reporting To include a gas report at the end of your tests, you can use the `--gas` flag. -**NOTE**: This feature requires using a provider with tracing support, such as [ape-hardhat](https://github.com/ApeWorX/ape-hardhat). +**NOTE**: This feature works best when using a provider with tracing support, such as [ape-foundry](https://github.com/ApeWorX/ape-foundry). +When not using a provider with adequate tracing support, such as `EthTester`, gas reporting is limited to receipt-level data. ```bash -ape test --network ethereum:local:hardhat --gas +ape test --network ethereum:local:foundry --gas ``` At the end of test suite, you will see tables such as: @@ -583,12 +589,6 @@ At the end of test suite, you will see tables such as: changeOnStatus 2 23827 45739 34783 34783 getSecret 1 24564 24564 24564 24564 - Transferring ETH Gas - - Method Times called Min. Max. Mean Median - ─────────────────────────────────────────────────────── - to:test0 2 2400 9100 5750 5750 - TestContract Gas Method Times called Min. Max. Mean Median @@ -649,6 +649,7 @@ ape test --coverage ``` **NOTE**: Some types of coverage require using a provider that supports transaction tracing, such as `ape-hardhat` or `ape-foundry`. +Without using a provider with adequate tracing support, coverage is limited to receipt-level data. Afterwards, you should see a coverage report looking something like: diff --git a/docs/userguides/trace.md b/docs/userguides/trace.md new file mode 100644 index 0000000000..dd8461243c --- /dev/null +++ b/docs/userguides/trace.md @@ -0,0 +1,114 @@ +# Traces + +A transaction's trace frames are the individual steps the transaction took. +Using traces, Ape is able to offer features like: + +1. Showing a pretty call-tree from a transaction receipt +2. Gas reporting in `ape test` +3. Coverage tools in `ape test` + +Some network providers, such as Alchemy and Foundry, implement `debug_traceTransaction` and Parity's `trace_transaction` affording tracing capabilities in Ape. +**WARN**: Without RPCs for obtaining traces, some features such as gas-reporting and coverage are limited. + +To see a transaction trace, use the [show_trace()](../methoddocs/api.html#ape.api.transactions.ReceiptAPI.show_trace) method on a receipt API object. + +Here is an example using `show_trace()` in Python code to print out a transaction's trace. +**NOTE**: This code runs assuming you are connected to `ethereum:mainnet` using a provider with tracing RPCs. +To learn more about networks in Ape, see the [networks guide](./networks.html). + +```python +from ape import chain + +tx = chain.provider.get_receipt('0xb7d7f1d5ce7743e821d3026647df486f517946ef1342a1ae93c96e4a8016eab7') + +# Show the steps the transaction took. +tx.show_trace() +``` + +You should see a (less-abridged) trace like: + +``` +Call trace for '0xb7d7f1d5ce7743e821d3026647df486f517946ef1342a1ae93c96e4a8016eab7' +tx.origin=0x5668EAd1eDB8E2a4d724C8fb9cB5fFEabEB422dc +DSProxy.execute(_target=LoanShifterTaker, _data=0x35..0000) -> "" [1421947 gas] +└── (delegate) LoanShifterTaker.moveLoan( + _exchangeData=[ + 0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE, + ZERO_ADDRESS, + + ... + # Abridged because is super long # + ... + + + │ └── LendingRateOracle.getMarketBorrowRate(_asset=DAI) -> + │ 35000000000000000000000000 [1164 gas] + ├── DSProxy.authority() -> DSGuard [1291 gas] + ├── DSGuard.forbid(src=LoanShifterReceiver, dst=DSProxy, sig=0x1c..0000) [5253 gas] + └── DefisaverLogger.Log( + _contract=DSProxy, + _caller=tx.origin, + _logName="LoanShifter", + _data=0x00..0000 + ) [6057 gas] +``` + +Similarly, you can use the provider directly to get a trace. +This is useful if you want to interact with the trace or change some parameters for creating the trace. + +```python +from ape import chain + +# Change the `debug_traceTransaction` parameter dictionary +trace = chain.provider.get_transaction_trace( + "0x...", debug_trace_transaction_parameters={"enableMemory": False} +) + +# You can still print the pretty call-trace (as we did in the example above) +print(trace) + +# Interact with low-level logs for deeper analysis. +struct_logs = trace.get_raw_frames() +``` + +## Tracing Calls + +Some network providers trace calls in addition to transactions. +EVM-based providers best achieve this by implementing the `debug_traceCall` RPC. + +If you want to see the trace of call when making the call, use the `show_trace=` flag: + +```python +token.balanceOf(account, show_trace=True) +``` + +**WARN**: If your provider does not properly support call-tracing (e.g. doesn't implement `debug_traceCall`), traces are limited to the top-level call. + +Ape traces calls automatically when using `--gas` or `--coverage` in tests to build reports. +Learn more about testing in Ape in the [testing guide](./testing.html) and in the following sections. + +## Gas Reports + +To view the gas report of a transaction receipt, use the [ReceiptAPI.show_gas_report()](../methoddocs/api.html?highlight=receiptapi#ape.api.transactions.ReceiptAPI.show_gas_report) method: + +```python +from ape import networks + +txn_hash = "0x053cba5c12172654d894f66d5670bab6215517a94189a9ffc09bc40a589ec04d" +receipt = networks.provider.get_receipt(txn_hash) +receipt.show_gas_report() +``` + +It outputs tables of contracts and methods with gas usages that look like this: + +``` + DAI Gas + + Method Times called Min. Max. Mean Median + ──────────────────────────────────────────────────────────────── + balanceOf 4 1302 13028 1302 1302 + allowance 2 1377 1377 1337 1337 +│ approve 1 22414 22414 22414 22414 +│ burn 1 11946 11946 11946 11946 +│ mint 1 25845 25845 25845 25845 +``` diff --git a/docs/userguides/transactions.md b/docs/userguides/transactions.md index e906c7daea..5c431fb3a9 100644 --- a/docs/userguides/transactions.md +++ b/docs/userguides/transactions.md @@ -187,122 +187,9 @@ ethereum: ## Traces -If you are using a provider that is able to fetch transaction traces, such as the [ape-hardhat](https://github.com/ApeWorX/ape-hardhat) provider, you can call the [`ReceiptAPI.show_trace()`](../methoddocs/api.html?highlight=receiptapi#ape.api.transactions.ReceiptAPI.show_trace) method. - -```python -from ape import accounts, project - -owner = accounts.load("acct") -contract = project.Contract.deploy(sender=owner) -receipt = contract.methodWithoutArguments() -receipt.show_trace() -``` - -**NOTE**: If your provider does not support traces, you will see a `NotImplementedError` saying that the method is not supported. - -The trace might look something like: - -```bash -Call trace for '0x43abb1fdadfdae68f84ce8cd2582af6ab02412f686ee2544aa998db662a5ef50' -txn.origin=0x1e59ce931B4CFea3fe4B875411e280e173cB7A9C -ContractA.methodWithoutArguments() -> 0x00..7a9c [469604 gas] -├── SYMBOL.supercluster(x=234444) -> [ -│ [23523523235235, 11111111111, 234444], -│ [ -│ 345345347789999991, -│ 99999998888882, -│ 345457847457457458457457457 -│ ], -│ [234444, 92222229999998888882, 3454], -│ [ -│ 111145345347789999991, -│ 333399998888882, -│ 234545457847457457458457457457 -│ ] -│ ] [461506 gas] -├── SYMBOL.methodB1(lolol="ice-cream", dynamo=345457847457457458457457457) [402067 gas] -│ ├── ContractC.getSomeList() -> [ -│ │ 3425311345134513461345134534531452345, -│ │ 111344445534535353, -│ │ 993453434534534534534977788884443333 -│ │ ] [370103 gas] -│ └── ContractC.methodC1( -│ windows95="simpler", -│ jamaica=345457847457457458457457457, -│ cardinal=ContractA -│ ) [363869 gas] -├── SYMBOL.callMe(blue=tx.origin) -> tx.origin [233432 gas] -├── SYMBOL.methodB2(trombone=tx.origin) [231951 gas] -│ ├── ContractC.paperwork(ContractA) -> ( -│ │ os="simpler", -│ │ country=345457847457457458457457457, -│ │ wings=ContractA -│ │ ) [227360 gas] -│ ├── ContractC.methodC1(windows95="simpler", jamaica=0, cardinal=ContractC) [222263 gas] -│ ├── ContractC.methodC2() [147236 gas] -│ └── ContractC.methodC2() [122016 gas] -├── ContractC.addressToValue(tx.origin) -> 0 [100305 gas] -├── SYMBOL.bandPractice(tx.origin) -> 0 [94270 gas] -├── SYMBOL.methodB1(lolol="lemondrop", dynamo=0) [92321 gas] -│ ├── ContractC.getSomeList() -> [ -│ │ 3425311345134513461345134534531452345, -│ │ 111344445534535353, -│ │ 993453434534534534534977788884443333 -│ │ ] [86501 gas] -│ └── ContractC.methodC1(windows95="simpler", jamaica=0, cardinal=ContractA) [82729 gas] -└── SYMBOL.methodB1(lolol="snitches_get_stiches", dynamo=111) [55252 gas] - ├── ContractC.getSomeList() -> [ - │ 3425311345134513461345134534531452345, - │ 111344445534535353, - │ 993453434534534534534977788884443333 - │ ] [52079 gas] - └── ContractC.methodC1(windows95="simpler", jamaica=111, cardinal=ContractA) [48306 gas] -``` - -Additionally, you can view the traces of other transactions on your network. - -```python -from ape import networks - -txn_hash = "0x053cba5c12172654d894f66d5670bab6215517a94189a9ffc09bc40a589ec04d" -receipt = networks.provider.get_receipt(txn_hash) -receipt.show_trace() -``` - -In Ape, you can also show the trace for a call. -Use the `show_trace=` kwarg on a contract call and Ape will display the trace before returning the data. - -```python -token.balanceOf(account, show_trace=True) -``` - -**NOTE**: This may not work on all providers, but it should work on common ones such as `ape-hardhat` or `ape-node`. - -## Gas Reports - -To view the gas report of a transaction receipt, use the [`ReceiptAPI.show_gas_report()`](../methoddocs/api.html?highlight=receiptapi#ape.api.transactions.ReceiptAPI.show_gas_report) method: - -```python -from ape import networks - -txn_hash = "0x053cba5c12172654d894f66d5670bab6215517a94189a9ffc09bc40a589ec04d" -receipt = networks.provider.get_receipt(txn_hash) -receipt.show_gas_report() -``` - -It will output tables of contracts and methods with gas usages that look like this: - -```bash - DAI Gas - - Method Times called Min. Max. Mean Median - ──────────────────────────────────────────────────────────────── - balanceOf 4 1302 13028 1302 1302 - allowance 2 1377 1377 1337 1337 -│ approve 1 22414 22414 22414 22414 -│ burn 1 11946 11946 11946 11946 -│ mint 1 25845 25845 25845 25845 -``` +Transaction traces are the steps in the contract the transaction took. +Traces both power a myriad of features in Ape as well are themselves a tool for developers to use to debug transactions. +To learn more about traces, see the [traces userguide](./trace.md). ## Estimate Gas Cost diff --git a/setup.py b/setup.py index ec884b7cdf..8aa7882bb5 100644 --- a/setup.py +++ b/setup.py @@ -120,6 +120,7 @@ "eth-account>=0.11.2,<0.12", "eth-typing>=3.5.2,<4", "eth-utils>=2.3.1,<3", + "hexbytes", # Peer "py-geth>=4.4.0,<5", "trie>=3.0.0,<4", # Peer: stricter pin needed for uv support. "web3[tester]>=6.17.2,<7", diff --git a/src/ape/api/__init__.py b/src/ape/api/__init__.py index caba392f01..8ce7a497fd 100644 --- a/src/ape/api/__init__.py +++ b/src/ape/api/__init__.py @@ -20,6 +20,7 @@ from .projects import DependencyAPI, ProjectAPI from .providers import BlockAPI, ProviderAPI, SubprocessProvider, TestProviderAPI, UpstreamProvider from .query import QueryAPI, QueryType +from .trace import TraceAPI from .transactions import ReceiptAPI, TransactionAPI __all__ = [ @@ -49,6 +50,7 @@ "TestAccountAPI", "TestAccountContainerAPI", "TestProviderAPI", + "TraceAPI", "TransactionAPI", "UpstreamProvider", ] diff --git a/src/ape/api/compiler.py b/src/ape/api/compiler.py index f160343d61..b41d611d15 100644 --- a/src/ape/api/compiler.py +++ b/src/ape/api/compiler.py @@ -1,18 +1,17 @@ from functools import cached_property from pathlib import Path -from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple +from typing import Dict, List, Optional, Sequence, Set from eth_pydantic_types import HexBytes from ethpm_types import ContractType from ethpm_types.source import Content, ContractSource -from evm_trace.geth import TraceFrame as EvmTraceFrame -from evm_trace.geth import create_call_node_data from packaging.version import Version from ape.api.config import PluginConfig +from ape.api.trace import TraceAPI from ape.exceptions import APINotImplementedError, ContractLogicError from ape.types.coverage import ContractSourceCoverage -from ape.types.trace import SourceTraceback, TraceFrame +from ape.types.trace import SourceTraceback from ape.utils import ( BaseInterfaceModel, abstractmethod, @@ -205,7 +204,7 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: @raises_not_implemented def trace_source( # type: ignore[empty-body] - self, contract_type: ContractType, trace: Iterator[TraceFrame], calldata: HexBytes + self, contract_type: ContractType, trace: TraceAPI, calldata: HexBytes ) -> SourceTraceback: """ Get a source-traceback for the given contract type. @@ -214,8 +213,8 @@ def trace_source( # type: ignore[empty-body] Args: contract_type (``ContractType``): A contract type that was created by this compiler. - trace (Iterator[:class:`~ape.types.trace.TraceFrame`]): The resulting frames from - executing a function defined in the given contract type. + trace (:class:`~ape.api.trace.TraceAPI`]): The resulting trace from executing a + function defined in the given contract type. calldata (``HexBytes``): Calldata passed to the top-level call. Returns: @@ -238,26 +237,6 @@ def flatten_contract(self, path: Path, **kwargs) -> Content: # type: ignore[emp ``ethpm_types.source.Content``: The flattened contract content. """ - def _create_contract_from_call( - self, frame: TraceFrame - ) -> Tuple[Optional[ContractSource], HexBytes]: - evm_frame = EvmTraceFrame(**frame.raw) - data = create_call_node_data(evm_frame) - calldata = data.get("calldata", HexBytes("")) - if not (address := (data.get("address", frame.contract_address) or None)): - return None, calldata - - try: - address = self.provider.network.ecosystem.decode_address(address) - except Exception: - return None, calldata - - if address not in self.chain_manager.contracts: - return None, calldata - - called_contract = self.chain_manager.contracts[address] - return self.project_manager._create_contract_source(called_contract), calldata - @raises_not_implemented def init_coverage_profile( self, source_coverage: ContractSourceCoverage, contract_source: ContractSource diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index 4d04437795..bd26d8fd45 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -35,7 +35,7 @@ SignatureError, ) from ape.logging import logger -from ape.types import AddressType, AutoGasLimit, CallTreeNode, ContractLog, GasLimit, RawAddress +from ape.types import AddressType, AutoGasLimit, ContractLog, GasLimit, RawAddress from ape.utils import ( DEFAULT_TRANSACTION_ACCEPTANCE_TIMEOUT, BaseInterfaceModel, @@ -53,6 +53,7 @@ if TYPE_CHECKING: from .explorers import ExplorerAPI from .providers import BlockAPI, ProviderAPI, UpstreamProvider + from .trace import TraceAPI from .transactions import ReceiptAPI, TransactionAPI @@ -584,18 +585,19 @@ def get_method_selector(self, abi: MethodABI) -> HexBytes: return HexBytes(keccak(text=abi.selector)[:4]) - def enrich_calltree(self, call: CallTreeNode, **kwargs) -> CallTreeNode: + def enrich_trace(self, trace: "TraceAPI", **kwargs) -> "TraceAPI": """ Enhance the data in the call tree using information about the ecosystem. Args: - call (:class:`~ape.types.trace.CallTreeNode`): The call tree node to enrich. - kwargs: Additional kwargs to help with enrichment. + call (:class:`~ape.api.trace.TraceAPI`): The trace to enrich. + kwargs: Additional kwargs to control enrichment, defined at the + plugin level. Returns: :class:`~ape.types.trace.CallTreeNode` """ - return call + return trace @raises_not_implemented def get_python_types( # type: ignore[empty-body] diff --git a/src/ape/api/projects.py b/src/ape/api/projects.py index 1731e72c05..04d724461c 100644 --- a/src/ape/api/projects.py +++ b/src/ape/api/projects.py @@ -143,8 +143,10 @@ def contracts(self) -> Dict[str, ContractType]: contract_type.name = contract_name if contract_type.name is None else contract_type.name contracts[contract_type.name] = contract_type - self._contracts = contracts - return self._contracts + if contracts: + self._contracts = contracts + + return self._contracts or {} @property def _cache_folder(self) -> Path: diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index c2b8465ef5..1e8dcfe4c2 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -11,7 +11,7 @@ from pathlib import Path from signal import SIGINT, SIGTERM, signal from subprocess import DEVNULL, PIPE, Popen -from typing import Any, Dict, Iterator, List, Optional, Union, cast +from typing import Any, Dict, Iterable, Iterator, List, Optional, Union, cast from eth_pydantic_types import HexBytes from ethpm_types.abi import EventABI @@ -20,6 +20,7 @@ from ape.api.config import PluginConfig from ape.api.networks import NetworkAPI from ape.api.query import BlockTransactionQuery +from ape.api.trace import TraceAPI from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.exceptions import ( APINotImplementedError, @@ -30,16 +31,7 @@ VirtualMachineError, ) from ape.logging import LogLevel, logger -from ape.types import ( - AddressType, - BlockID, - CallTreeNode, - ContractCode, - ContractLog, - LogFilter, - SnapshotID, - TraceFrame, -) +from ape.types import AddressType, BlockID, ContractCode, ContractLog, LogFilter, SnapshotID from ape.utils import BaseInterfaceModel, JoinableQueue, abstractmethod, cached_property, spawn from ape.utils.misc import ( EMPTY_BYTES32, @@ -317,6 +309,30 @@ def network_choice(self) -> str: return f"{self.network.choice}:{self.name}" + @abstractmethod + def make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any: + """ + Make a raw RPC request to the provider. + Advanced featues such as tracing may utilize this to by-pass unnecessary + class-serializations. + """ + + @raises_not_implemented + def stream_request( # type: ignore[empty-body] + self, method: str, params: Iterable, iter_path: str = "result.item" + ) -> Iterator[Any]: + """ + Stream a request, great for large requests like events or traces. + + Args: + method (str): The RPC method to call. + params (Iterable): Parameters for the method.s + iter_path (str): The response dict-path to the items. + + Returns: + An iterator of items. + """ + def get_storage_at(self, *args, **kwargs) -> HexBytes: warnings.warn( "'provider.get_storage_at()' is deprecated. Use 'provider.get_storage()'.", @@ -694,15 +710,16 @@ def unlock_account(self, address: AddressType) -> bool: # type: ignore[empty-bo @raises_not_implemented def get_transaction_trace( # type: ignore[empty-body] self, txn_hash: Union[HexBytes, str] - ) -> Iterator[TraceFrame]: + ) -> TraceAPI: """ Provide a detailed description of opcodes. Args: - txn_hash (str): The hash of a transaction to trace. + transaction_hash (Union[HexBytes, str]): The hash of a transaction + to trace. Returns: - Iterator(:class:`~ape.type.trace.TraceFrame`): Transaction execution trace. + :class:`~ape.api.trace.TraceAPI`: A transaction trace. """ @raises_not_implemented @@ -775,19 +792,6 @@ def poll_logs( # type: ignore[empty-body] Iterator[:class:`~ape.types.ContractLog`] """ - @raises_not_implemented - def get_call_tree(self, txn_hash: str) -> CallTreeNode: # type: ignore[empty-body] - """ - Create a tree structure of calls for a transaction. - - Args: - txn_hash (str): The hash of a transaction to trace. - - Returns: - :class:`~ape.types.trace.CallTreeNode`: Transaction execution - call-tree objects. - """ - def prepare_transaction(self, txn: TransactionAPI) -> TransactionAPI: """ Set default values on the transaction. diff --git a/src/ape/api/trace.py b/src/ape/api/trace.py new file mode 100644 index 0000000000..858635fa4c --- /dev/null +++ b/src/ape/api/trace.py @@ -0,0 +1,60 @@ +import sys +from abc import abstractmethod +from typing import IO, Any, Dict, List, Optional, Sequence + +from ape.types import ContractFunctionPath +from ape.types.trace import GasReport +from ape.utils.basemodel import BaseInterfaceModel + + +class TraceAPI(BaseInterfaceModel): + """ + The class returned from + :meth:`~ape.api.providers.ProviderAPI.get_transaction_trace`. + """ + + @abstractmethod + def show(self, verbose: bool = False, file: IO[str] = sys.stdout): + """ + Show the enriched trace. + """ + + @abstractmethod + def get_gas_report( + self, exclude: Optional[Sequence["ContractFunctionPath"]] = None + ) -> GasReport: + """ + Get the gas report. + """ + + @abstractmethod + def show_gas_report(self, verbose: bool = False, file: IO[str] = sys.stdout): + """ + Show the gas report. + """ + + @property + @abstractmethod + def return_value(self) -> Any: + """ + The return value deduced from the trace. + """ + + @property + @abstractmethod + def revert_message(self) -> Optional[str]: + """ + The revert message deduced from the trace. + """ + + @abstractmethod + def get_raw_frames(self) -> List[Dict]: + """ + Get raw trace frames for deeper analysis. + """ + + @abstractmethod + def get_raw_calltree(self) -> Dict: + """ + Get a raw calltree for deeper analysis. + """ diff --git a/src/ape/api/transactions.py b/src/ape/api/transactions.py index 2c227d49b0..90a759c94a 100644 --- a/src/ape/api/transactions.py +++ b/src/ape/api/transactions.py @@ -24,7 +24,6 @@ AutoGasLimit, ContractLogContainer, SourceTraceback, - TraceFrame, TransactionSignature, ) from ape.utils import ( @@ -39,6 +38,7 @@ if TYPE_CHECKING: from ape.api.providers import BlockAPI + from ape.api.trace import TraceAPI from ape.contracts import ContractEvent @@ -157,7 +157,7 @@ def receipt(self) -> Optional["ReceiptAPI"]: return None @property - def trace(self) -> Iterator[TraceFrame]: + def trace(self) -> "TraceAPI": """ The transaction trace. Only works if this transaction was published and you are using a provider that support tracing. @@ -306,10 +306,6 @@ def _validate_transaction(cls, value): def validate_txn_hash(cls, value): return HexBytes(value).hex() - @property - def call_tree(self) -> Optional[Any]: - return None - @cached_property def debug_logs_typed(self) -> List[Tuple[Any]]: """Return any debug log data outputted by the transaction.""" @@ -358,11 +354,11 @@ def ran_out_of_gas(self) -> bool: """ @property - def trace(self) -> Iterator[TraceFrame]: + def trace(self) -> "TraceAPI": """ - The trace of the transaction, if available from your provider. + The :class:`~ape.api.trace.TraceAPI` of the transaction. """ - return self.provider.get_transaction_trace(txn_hash=self.txn_hash) + return self.provider.get_transaction_trace(self.txn_hash) @property def _explorer(self) -> Optional[ExplorerAPI]: @@ -497,21 +493,11 @@ def return_value(self) -> Any: since this is not available from the receipt object. """ - if not (call_tree := self.call_tree) or not (method_abi := self.method_called): - return None + if trace := self.trace: + ret_val = trace.return_value + return ret_val[0] if isinstance(ret_val, tuple) and len(ret_val) == 1 else ret_val - if isinstance(call_tree.outputs, (str, HexBytes, int)): - output = self.provider.network.ecosystem.decode_returndata( - method_abi, HexBytes(call_tree.outputs) - ) - else: - # Already enriched. - output = call_tree.outputs - - if isinstance(output, tuple) and len(output) < 2: - output = output[0] if len(output) == 1 else None - - return output + return None @property @raises_not_implemented @@ -558,9 +544,9 @@ def track_gas(self): if not address or not self._test_runner: return - if self.provider.supports_tracing and (call_tree := self.call_tree): + if self.provider.supports_tracing and (trace := self.trace): tracker = self._test_runner.gas_tracker - tracker.append_gas(call_tree.enrich(in_line=False), address) + tracker.append_gas(trace, address) elif ( (contract_type := self.chain_manager.contracts.get(address)) diff --git a/src/ape/exceptions.py b/src/ape/exceptions.py index 54fe24d53e..8a642668ad 100644 --- a/src/ape/exceptions.py +++ b/src/ape/exceptions.py @@ -7,18 +7,7 @@ from inspect import getframeinfo, stack from pathlib import Path from types import CodeType, TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Collection, - Dict, - Iterable, - Iterator, - List, - Optional, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, List, Optional, Union, cast import click from eth_typing import Hash32 @@ -32,8 +21,9 @@ if TYPE_CHECKING: from ape.api.networks import NetworkAPI from ape.api.providers import SubprocessProvider + from ape.api.trace import TraceAPI from ape.api.transactions import ReceiptAPI, TransactionAPI - from ape.types import AddressType, BlockID, SnapshotID, SourceTraceback, TraceFrame + from ape.types import AddressType, BlockID, SnapshotID, SourceTraceback FailedTxn = Union["TransactionAPI", "ReceiptAPI"] @@ -184,7 +174,7 @@ def __init__( base_err: Optional[Exception] = None, code: Optional[int] = None, txn: Optional[FailedTxn] = None, - trace: Optional[Iterator["TraceFrame"]] = None, + trace: Optional["TraceAPI"] = None, contract_address: Optional["AddressType"] = None, source_traceback: Optional["SourceTraceback"] = None, ): @@ -262,7 +252,7 @@ def __init__( self, revert_message: Optional[str] = None, txn: Optional[FailedTxn] = None, - trace: Optional[Iterator["TraceFrame"]] = None, + trace: Optional["TraceAPI"] = None, contract_address: Optional["AddressType"] = None, source_traceback: Optional["SourceTraceback"] = None, base_err: Optional[Exception] = None, @@ -497,9 +487,13 @@ class TransactionNotFoundError(ProviderError): Raised when unable to find a transaction. """ - def __init__(self, txn_hash: str, error_messsage: Optional[str] = None): - message = f"Transaction '{txn_hash}' not found." - suffix = f" Error: {error_messsage}" if error_messsage else "" + def __init__(self, transaction_hash: Optional[str] = None, error_message: Optional[str] = None): + message = ( + f"Transaction '{transaction_hash}' not found." + if transaction_hash + else "Transaction not found" + ) + suffix = f" Error: {error_message}" if error_message else "" super().__init__(f"{message}{suffix}") @@ -764,7 +758,7 @@ def __init__( abi: ErrorABI, inputs: Dict[str, Any], txn: Optional[FailedTxn] = None, - trace: Optional[Iterator["TraceFrame"]] = None, + trace: Optional["TraceAPI"] = None, contract_address: Optional["AddressType"] = None, base_err: Optional[Exception] = None, source_traceback: Optional["SourceTraceback"] = None, diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index 97e9898520..8e15dac9af 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -3,13 +3,16 @@ from contextlib import contextmanager from functools import partial from pathlib import Path +from statistics import mean, median from typing import IO, Collection, Dict, Iterator, List, Optional, Set, Type, Union, cast import pandas as pd from eth_pydantic_types import HexBytes from ethpm_types import ABI, ContractType from rich import get_console +from rich.box import SIMPLE from rich.console import Console as RichConsole +from rich.table import Table from ape.api import BlockAPI, ReceiptAPI from ape.api.address import BaseAddress @@ -30,14 +33,16 @@ CustomError, ProviderNotConnectedError, QueryEngineError, + TransactionNotFoundError, UnknownSnapshotError, ) from ape.logging import logger from ape.managers.base import BaseManager -from ape.types import AddressType, BlockID, CallTreeNode, SnapshotID, SourceTraceback +from ape.types import AddressType, BlockID, GasReport, SnapshotID, SourceTraceback from ape.utils import ( BaseInterfaceModel, - TraceStyles, + is_evm_precompile, + is_zero_hex, log_instead_of_fail, nonreentrant, singledispatchmethod, @@ -1363,29 +1368,44 @@ class ReportManager(BaseManager): rich_console_map: Dict[str, RichConsole] = {} - def show_trace( - self, - call_tree: CallTreeNode, - sender: Optional[AddressType] = None, - transaction_hash: Optional[str] = None, - revert_message: Optional[str] = None, - file: Optional[IO[str]] = None, - verbose: bool = False, - ): - root = call_tree.as_rich_tree(verbose=verbose) - console = self._get_console(file) + def show_gas(self, report: GasReport, file: Optional[IO[str]] = None): + tables: List[Table] = [] + + for contract_id, method_calls in report.items(): + title = f"{contract_id} Gas" + table = Table(title=title, box=SIMPLE) + table.add_column("Method") + table.add_column("Times called", justify="right") + table.add_column("Min.", justify="right") + table.add_column("Max.", justify="right") + table.add_column("Mean", justify="right") + table.add_column("Median", justify="right") + has_at_least_1_row = False + + for method_call, gases in sorted(method_calls.items()): + if not gases: + continue - if transaction_hash: - console.print(f"Call trace for [bold blue]'{transaction_hash}'[/]") - if revert_message: - console.print(f"[bold red]{revert_message}[/]") - if sender: - console.print(f"tx.origin=[{TraceStyles.CONTRACTS}]{sender}[/]") + if not method_call or is_zero_hex(method_call) or is_evm_precompile(method_call): + continue + + elif method_call == "__new__": + # Looks better in the gas report. + method_call = "__init__" + + has_at_least_1_row = True + table.add_row( + method_call, + f"{len(gases)}", + f"{min(gases)}", + f"{max(gases)}", + f"{int(round(mean(gases)))}", + f"{int(round(median(gases)))}", + ) - console.print(root) + if has_at_least_1_row: + tables.append(table) - def show_gas(self, call_tree: CallTreeNode, file: Optional[IO[str]] = None): - tables = call_tree.as_gas_tables() self.echo(*tables, file=file) def echo(self, *rich_items, file: Optional[IO[str]] = None): @@ -1667,6 +1687,6 @@ def get_receipt(self, transaction_hash: str) -> ReceiptAPI: """ receipt = self.chain_manager.history[transaction_hash] if not isinstance(receipt, ReceiptAPI): - raise ChainError(f"No receipt found with hash '{transaction_hash}'.") + raise TransactionNotFoundError(transaction_hash=transaction_hash) return receipt diff --git a/src/ape/pytest/gas.py b/src/ape/pytest/gas.py index 0ff24ea5e8..5b187498d4 100644 --- a/src/ape/pytest/gas.py +++ b/src/ape/pytest/gas.py @@ -4,8 +4,9 @@ from ethpm_types.source import ContractSource from evm_trace.gas import merge_reports +from ape.api import TraceAPI from ape.pytest.config import ConfigWrapper -from ape.types import AddressType, CallTreeNode, ContractFunctionPath, GasReport +from ape.types import AddressType, ContractFunctionPath, GasReport from ape.utils import parse_gas_table from ape.utils.basemodel import ManagerAccessMixin from ape.utils.trace import _exclude_gas @@ -37,17 +38,13 @@ def show_session_gas(self) -> bool: self.chain_manager._reports.echo(*tables) return True - def append_gas( - self, - call_tree: CallTreeNode, - contract_address: AddressType, - ): + def append_gas(self, trace: TraceAPI, contract_address: AddressType): contract_type = self.chain_manager.contracts.get(contract_address) if not contract_type: # Skip unknown contracts. return - report = call_tree.get_gas_report(exclude=self.gas_exclusions) + report = trace.get_gas_report(exclude=self.gas_exclusions) self._merge(report) def append_toplevel_gas(self, contract: ContractSource, method: MethodABI, gas_cost: int): diff --git a/src/ape/types/__init__.py b/src/ape/types/__init__.py index 2eed9fab73..6e891e2bc1 100644 --- a/src/ape/types/__init__.py +++ b/src/ape/types/__init__.py @@ -44,7 +44,7 @@ CoverageStatement, ) from ape.types.signatures import MessageSignature, SignableMessage, TransactionSignature -from ape.types.trace import CallTreeNode, ControlFlow, GasReport, SourceTraceback, TraceFrame +from ape.types.trace import ControlFlow, GasReport, SourceTraceback from ape.utils import ( BaseInterfaceModel, ExtraAttributesMixin, @@ -488,7 +488,6 @@ def generator(self) -> Iterator: "AddressType", "BlockID", "Bytecode", - "CallTreeNode", "Checksum", "Closure", "Compiler", @@ -510,6 +509,5 @@ def generator(self) -> Iterator: "SnapshotID", "Source", "SourceTraceback", - "TraceFrame", "TransactionSignature", ] diff --git a/src/ape/types/trace.py b/src/ape/types/trace.py index 68cb3367d8..139cf3a48c 100644 --- a/src/ape/types/trace.py +++ b/src/ape/types/trace.py @@ -1,23 +1,17 @@ -from itertools import chain, tee +from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Union +from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Set, Union from eth_pydantic_types import HexBytes from ethpm_types import ASTNode, BaseModel, ContractType from ethpm_types.ast import SourceLocation from ethpm_types.source import Closure, Content, Function, SourceStatement, Statement -from evm_trace.gas import merge_reports -from pydantic import Field, RootModel -from rich.table import Table -from rich.tree import Tree +from pydantic import RootModel -from ape.types.address import AddressType -from ape.utils.basemodel import BaseInterfaceModel -from ape.utils.misc import is_evm_precompile, is_zero_hex, log_instead_of_fail -from ape.utils.trace import _exclude_gas, parse_as_str, parse_gas_table, parse_rich_tree +from ape.utils.misc import log_instead_of_fail if TYPE_CHECKING: - from ape.types import ContractFunctionPath + from ape.api.trace import TraceAPI GasReport = Dict[str, Dict[str, List[int]]] @@ -26,204 +20,6 @@ """ -class CallTreeNode(BaseInterfaceModel): - contract_id: str - """ - The identifier representing the contract in this node. - A non-enriched identifier is an address; a more enriched - identifier is a token symbol or contract type name. - """ - - method_id: Optional[str] = None - """ - The identifier representing the method in this node. - A non-enriched identifier is a method selector. - An enriched identifier is method signature. - """ - - txn_hash: Optional[str] = None - """ - The transaction hash, if known and/or exists. - """ - - failed: bool = False - """ - ``True`` where this tree represents a failed call. - """ - - inputs: Optional[Any] = None - """ - The inputs to the call. - Non-enriched inputs are raw bytes or values. - Enriched inputs are decoded. - """ - - outputs: Optional[Any] = None - """ - The output to the call. - Non-enriched inputs are raw bytes or values. - Enriched outputs are decoded. - """ - - value: Optional[int] = None - """ - The value sent with the call, if applicable. - """ - - gas_cost: Optional[int] = None - """ - The gas cost of the call, if known. - """ - - call_type: Optional[str] = None - """ - A str indicating what type of call it is. - See ``evm_trace.enums.CallType`` for EVM examples. - """ - - calls: List["CallTreeNode"] = [] - """ - The list of subcalls made by this call. - """ - - raw: Dict = Field({}, exclude=True, repr=False) - """ - The raw tree, as a dictionary, associated with the call. - """ - - @log_instead_of_fail(default="") - def __repr__(self) -> str: - return parse_as_str(self) - - def __str__(self) -> str: - return parse_as_str(self) - - def _repr_pretty_(self, *args, **kwargs): - enriched_tree = self.enrich(use_symbol_for_tokens=True) - self.chain_manager._reports.show_trace(enriched_tree) - - def enrich(self, **kwargs) -> "CallTreeNode": - """ - Enrich the properties on this call tree using data from contracts - and using information about the ecosystem. - - Args: - **kwargs: Key-word arguments to pass to - :meth:`~ape.api.networks.EcosystemAPI.enrich_calltree`, such as - ``use_symbol_for_tokens``. - - Returns: - :class:`~ape.types.trace.CallTreeNode`: This call tree node with - its properties enriched. - """ - - return self.provider.network.ecosystem.enrich_calltree(self, **kwargs) - - def add(self, sub_call: "CallTreeNode"): - """ - Add a sub call to this node. This implies this call called the sub-call. - - Args: - sub_call (:class:`~ape.types.trace.CallTreeNode`): The sub-call to add. - """ - - self.calls.append(sub_call) - - def as_rich_tree(self, verbose: bool = False) -> Tree: - """ - Return this object as a ``rich.tree.Tree`` for pretty-printing. - - Returns: - ``Tree`` - """ - - return parse_rich_tree(self, verbose=verbose) - - def as_gas_tables(self, exclude: Optional[List["ContractFunctionPath"]] = None) -> List[Table]: - """ - Return this object as list of rich gas tables for pretty printing. - - Args: - exclude (Optional[List[:class:`~ape.types.ContractFunctionPath`]]): - A list of contract / method combinations to exclude from the gas - tables. - - Returns: - List[``rich.table.Table``] - """ - - report = self.get_gas_report(exclude=exclude) - return parse_gas_table(report) - - def get_gas_report(self, exclude: Optional[List["ContractFunctionPath"]] = None) -> "GasReport": - """ - Get a unified gas-report of all the calls made in this tree. - - Args: - exclude (Optional[List[:class:`~ape.types.ContractFunctionPath`]]): - A list of contract / method combinations to exclude from the gas - tables. - - Returns: - :class:`~ape.types.trace.GasReport` - """ - - exclusions = exclude or [] - if ( - not self.contract_id - or not self.method_id - or _exclude_gas(exclusions, self.contract_id, self.method_id) - ): - return merge_reports(*(c.get_gas_report(exclude) for c in self.calls)) - - elif not is_zero_hex(self.method_id) and not is_evm_precompile(self.method_id): - reports = [ - *[c.get_gas_report(exclude) for c in self.calls], - { - self.contract_id: { - self.method_id: [self.gas_cost] if self.gas_cost is not None else [] - } - }, - ] - return merge_reports(*reports) - - return merge_reports(*(c.get_gas_report(exclude) for c in self.calls)) - - -class TraceFrame(BaseInterfaceModel): - """ - A low-level data structure modeling a transaction trace frame - from the Geth RPC ``debug_traceTransaction``. - """ - - pc: int - """Program counter.""" - - op: str - """Opcode.""" - - gas: int - """Remaining gas.""" - - gas_cost: int - """The cost to execute this opcode.""" - - depth: int - """ - The number of external jumps away the initially called contract (starts at 0). - """ - - contract_address: Optional[AddressType] = None - """ - The contract address, if this is a call trace frame. - """ - - raw: Dict = Field({}, exclude=True, repr=False) - """ - The raw trace frame from the provider. - """ - - class ControlFlow(BaseModel): """ A collection of linear source nodes up until a jump. @@ -490,24 +286,18 @@ class SourceTraceback(RootModel[List[ControlFlow]]): """ @classmethod - def create( - cls, - contract_type: ContractType, - trace: Iterator[TraceFrame], - data: Union[HexBytes, str], - ): - trace, second_trace = tee(trace) - if not second_trace or not (accessor := next(second_trace, None)): - return cls.model_validate([]) + def create(cls, contract_type: ContractType, trace: "TraceAPI", data: Union[HexBytes, str]): + # Use the trace as a 'ManagerAccessMixin'. + compilers = trace.compiler_manager if not (source_id := contract_type.source_id): return cls.model_validate([]) ext = f".{source_id.split('.')[-1]}" - if ext not in accessor.compiler_manager.registered_compilers: + if ext not in compilers.registered_compilers: return cls.model_validate([]) - compiler = accessor.compiler_manager.registered_compilers[ext] + compiler = compilers.registered_compilers[ext] try: return compiler.trace_source(contract_type, trace, HexBytes(data)) except NotImplementedError: diff --git a/src/ape/utils/trace.py b/src/ape/utils/trace.py index acfe043d9d..155d21823d 100644 --- a/src/ape/utils/trace.py +++ b/src/ape/utils/trace.py @@ -1,21 +1,15 @@ -import json from fnmatch import fnmatch from statistics import mean, median -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Sequence, Tuple -from eth_pydantic_types import HexBytes -from eth_utils import is_0x_prefixed from rich.box import SIMPLE from rich.table import Table -from rich.tree import Tree -from ape.utils.misc import is_evm_precompile, is_zero_hex +from ape.utils import is_evm_precompile, is_zero_hex if TYPE_CHECKING: - from ape.types import CallTreeNode, ContractFunctionPath, CoverageReport, GasReport + from ape.types import ContractFunctionPath, CoverageReport, GasReport -_WRAP_THRESHOLD = 50 -_INDENT = 2 USER_ASSERT_TAG = "USER_ASSERT" @@ -48,118 +42,6 @@ class TraceStyles: """The gas used of the call.""" -def parse_rich_tree(call: "CallTreeNode", verbose: bool = False) -> Tree: - tree = _create_tree(call, verbose=verbose) - for sub_call in call.calls: - sub_tree = parse_rich_tree(sub_call, verbose=verbose) - tree.add(sub_tree) - - return tree - - -def _create_tree(call: "CallTreeNode", verbose: bool = False) -> Tree: - signature = parse_as_str(call, stylize=True, verbose=verbose) - return Tree(signature) - - -def parse_as_str(call: "CallTreeNode", stylize: bool = False, verbose: bool = False) -> str: - contract = str(call.contract_id) - method = ( - "__new__" - if call.call_type - and "CREATE" in call.call_type - and call.method_id - and is_0x_prefixed(call.method_id) - else str(call.method_id or "") - ) - if "(" in method: - # Only show short name, not ID name - # (it is the full signature when multiple methods have the same name). - method = method.split("(")[0].strip() or method - - if stylize: - contract = f"[{TraceStyles.CONTRACTS}]{contract}[/]" - method = f"[{TraceStyles.METHODS}]{method}[/]" - - call_path = f"{contract}.{method}" - - if call.call_type is not None and call.call_type.upper() == "DELEGATECALL": - delegate = "(delegate)" - if stylize: - delegate = f"[orange]{delegate}[/]" - - call_path = f"{delegate} {call_path}" - - signature = call_path - arguments_str = _get_inputs_str(call.inputs, stylize=stylize) - if call.call_type and "CREATE" in call.call_type and is_0x_prefixed(arguments_str): - # Unenriched CREATE calldata is a massive hex. - arguments_str = "" - - signature = f"{signature}{arguments_str}" - - if ( - call.call_type - and "CREATE" not in call.call_type - and call.outputs not in ((), [], None, {}, "") - ): - if return_str := _get_outputs_str(call.outputs, stylize=stylize): - signature = f"{signature} -> {return_str}" - - if call.value: - value = str(call.value) - if stylize: - value = f"[{TraceStyles.VALUE}]{value}[/]" - - signature += f" {value}" - - if call.gas_cost: - gas_value = f"[{call.gas_cost} gas]" - if stylize: - gas_value = f"[{TraceStyles.GAS_COST}]{gas_value}[/]" - - signature += f" {gas_value}" - - if verbose: - verbose_items = {k: v for k, v in call.raw.items() if type(v) in (int, str, bytes, float)} - extra = json.dumps(verbose_items, indent=2) - signature = f"{signature}\n{extra}" - - return signature - - -def _get_inputs_str(inputs: Any, stylize: bool = False) -> str: - color = TraceStyles.INPUTS if stylize else None - if inputs in ["0x", None, (), [], {}]: - return "()" - - elif isinstance(inputs, dict): - return _dict_to_str(inputs, color=color) - - elif isinstance(inputs, bytes): - return HexBytes(inputs).hex() - - return f"({inputs})" - - -def _get_outputs_str(outputs: Any, stylize: bool = False) -> Optional[str]: - if outputs in ["0x", None, (), [], {}]: - return None - - elif isinstance(outputs, dict): - color = TraceStyles.OUTPUTS if stylize else None - return _dict_to_str(outputs, color=color) - - elif isinstance(outputs, (list, tuple)): - return ( - f"[{TraceStyles.OUTPUTS}]{_list_to_str(outputs)}[/]" - if stylize - else _list_to_str(outputs) - ) - - return f"[{TraceStyles.OUTPUTS}]{outputs}[/]" if stylize else str(outputs) - - def parse_gas_table(report: "GasReport") -> List[Table]: tables: List[Table] = [] @@ -325,96 +207,8 @@ def _parse_verbose_coverage(coverage: "CoverageReport", statement: bool = True) return tables -def _dict_to_str(dictionary: Dict, color: Optional[str] = None) -> str: - length = sum(len(str(v)) for v in [*dictionary.keys(), *dictionary.values()]) - do_wrap = length > _WRAP_THRESHOLD - - index = 0 - end_index = len(dictionary) - 1 - kv_str = "(\n" if do_wrap else "(" - - for key, value in dictionary.items(): - if do_wrap: - kv_str += _INDENT * " " - - if isinstance(value, (list, tuple)): - value = _list_to_str(value, 1 if do_wrap else 0) - - value_str = f"[{color}]{value}[/]" if color is not None else str(value) - kv_str += f"{key}={value_str}" if key and not key.isnumeric() else value_str - if index < end_index: - kv_str += ", " - - if do_wrap: - kv_str += "\n" - - index += 1 - - return f"{kv_str})" - - -def _list_to_str(ls: Union[List, Tuple], depth: int = 0) -> str: - if not isinstance(ls, (list, tuple)) or len(str(ls)) < _WRAP_THRESHOLD: - return str(ls) - - elif ls and isinstance(ls[0], (list, tuple)): - # List of lists - sub_lists = [_list_to_str(i) for i in ls] - - # Use multi-line if exceeds threshold OR any of the sub-lists use multi-line - extra_chars_len = (len(sub_lists) - 1) * 2 - use_multiline = len(str(sub_lists)) + extra_chars_len > _WRAP_THRESHOLD or any( - ["\n" in ls for ls in sub_lists] - ) - - if not use_multiline: - # Happens for lists like '[[0], [1]]' that are short. - return f"[{', '.join(sub_lists)}]" - - value = "[\n" - num_sub_lists = len(sub_lists) - index = 0 - spacing = _INDENT * " " * 2 - for formatted_list in sub_lists: - if "\n" in formatted_list: - # Multi-line sub list. Append 1 more spacing to each line. - indented_item = f"\n{spacing}".join(formatted_list.splitlines()) - value = f"{value}{spacing}{indented_item}" - else: - # Single line sub-list - value = f"{value}{spacing}{formatted_list}" - - if index < num_sub_lists - 1: - value = f"{value}," - - value = f"{value}\n" - index += 1 - - value = f"{value}{_INDENT * ' '}]" - return value - - return _list_to_multiline_str(ls, depth=depth) - - -def _list_to_multiline_str(value: Union[List, Tuple], depth: int = 0) -> str: - spacing = _INDENT * " " - new_val = "[\n" - num_values = len(value) - for idx in range(num_values): - ls_spacing = spacing * (depth + 1) - new_val += f"{ls_spacing}{value[idx]}" - if idx < num_values - 1: - new_val += "," - - new_val += "\n" - - new_val += spacing * depth - new_val += "]" - return new_val - - def _exclude_gas( - exclusions: List["ContractFunctionPath"], contract_id: str, method_id: str + exclusions: Sequence["ContractFunctionPath"], contract_id: str, method_id: str ) -> bool: for exclusion in exclusions: if exclusion.method_name is None and fnmatch(contract_id, exclusion.contract_name): diff --git a/src/ape_ethereum/_print.py b/src/ape_ethereum/_print.py index cafb602409..542d6643ae 100644 --- a/src/ape_ethereum/_print.py +++ b/src/ape_ethereum/_print.py @@ -25,53 +25,43 @@ from eth_typing import ChecksumAddress from eth_utils import decode_hex from ethpm_types import ContractType, MethodABI +from evm_trace import CallTreeNode +from hexbytes import HexBytes from typing_extensions import TypeGuard import ape -from ape.types import CallTreeNode +from ape_ethereum._console_log_abi import CONSOLE_LOG_ABI -from ._console_log_abi import CONSOLE_LOG_ABI - -CONSOLE_CONTRACT_ID = cast(ChecksumAddress, "0x000000000000000000636F6e736F6c652e6c6f67") -VYPER_PRINT_METHOD_ID = "0x23cdd8e8" # log(string,bytes) +CONSOLE_ADDRESS = cast(ChecksumAddress, "0x000000000000000000636F6e736F6c652e6c6f67") +VYPER_PRINT_METHOD_ID = HexBytes("0x23cdd8e8") # log(string,bytes) console_contract = ContractType(abi=CONSOLE_LOG_ABI, contractName="console") -def is_console_log(call: Any) -> TypeGuard[CallTreeNode]: - """Determine if a call is a starndard console.log() call""" +def is_console_log(call: CallTreeNode) -> TypeGuard[CallTreeNode]: + """Determine if a call is a standard console.log() call""" return ( - isinstance(call, CallTreeNode) - and call.contract_id == CONSOLE_CONTRACT_ID - and call.method_id in console_contract.identifier_lookup + call.address == HexBytes(CONSOLE_ADDRESS) + and call.calldata[:4].hex() in console_contract.identifier_lookup ) -def is_vyper_print(call: Any) -> TypeGuard[CallTreeNode]: - """Determine if a call is a starndard Vyper print() call""" - if ( - isinstance(call, CallTreeNode) - and call.contract_id == CONSOLE_CONTRACT_ID - and call.method_id == VYPER_PRINT_METHOD_ID - and isinstance(call.inputs, str) - ): - bcalldata = decode_hex(call.inputs) - schema, _ = decode(["string", "bytes"], bcalldata) - try: - # Now we look at the first arg to try and determine if it's an ABI signature - first_type = schema.strip("()").split(",")[0] - # TODO: Tighten this up. This is not entirely accurate, but should mostly get us there. - if ( - first_type.startswith("uint") - or first_type.startswith("int") - or first_type.startswith("bytes") - or first_type == "string" - ): - return True - except IndexError: - # Empty string as first arg? - pass - return False +def is_vyper_print(call: CallTreeNode) -> TypeGuard[CallTreeNode]: + """Determine if a call is a standard Vyper print() call""" + if call.address != HexBytes(CONSOLE_ADDRESS) or call.calldata[:4] != VYPER_PRINT_METHOD_ID: + return False + + schema, _ = decode(["string", "bytes"], call.calldata[4:]) + types = schema.strip("()").split(",") + + # Now we look at the first arg to try and determine if it's an ABI signature + # TODO: Tighten this up. This is not entirely accurate, but should mostly get us there. + return len(types) > 0 and ( + types[0].startswith("uint") + or types[0].startswith("int") + or types[0].startswith("bytes") + or types[0] == "string" + ) def console_log(method_abi: MethodABI, calldata: str) -> Tuple[Any]: @@ -81,23 +71,23 @@ def console_log(method_abi: MethodABI, calldata: str) -> Tuple[Any]: return tuple(data.values()) -def vyper_print(calldata: str) -> Tuple[Any]: +def vyper_print(calldata: HexBytes) -> Tuple[Any]: """Return logged data for print() calls""" - bcalldata = decode_hex(calldata) - schema, payload = decode(["string", "bytes"], bcalldata) + schema, payload = decode(["string", "bytes"], calldata) data = decode(schema.strip("()").split(","), payload) return tuple(data) -def extract_debug_logs(call_tree: CallTreeNode) -> Iterable[Tuple[Any]]: +def extract_debug_logs(call: CallTreeNode) -> Iterable[Tuple[Any]]: """Filter calls to console.log() and print() from a transactions call tree""" - for call in call_tree.calls: - if is_vyper_print(call) and call.inputs is not None: - yield vyper_print(call.inputs) - elif is_console_log(call) and call.inputs is not None: - assert call.method_id is not None # is_console_log check already checked - method_abi = console_contract.identifier_lookup.get(call.method_id) - if isinstance(method_abi, MethodABI): - yield console_log(method_abi, call.inputs) - elif call.calls is not None: - yield from extract_debug_logs(call) + if is_vyper_print(call) and call.calldata is not None: + yield vyper_print(call.calldata[4:]) + + elif is_console_log(call) and call.calldata is not None: + method_abi = console_contract.identifier_lookup.get(call.calldata[:4].hex()) + if isinstance(method_abi, MethodABI): + yield console_log(method_abi, call.calldata[4:].hex()) + + elif call.calls is not None: + for sub_call in call.calls: + yield from extract_debug_logs(sub_call) diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index f5af78ada1..c43099654a 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -1,5 +1,4 @@ import re -from copy import deepcopy from decimal import Decimal from functools import cached_property from typing import Any, ClassVar, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union, cast @@ -22,7 +21,7 @@ from pydantic import Field, computed_field, field_validator, model_validator from pydantic_settings import SettingsConfigDict -from ape.api import BlockAPI, EcosystemAPI, PluginConfig, ReceiptAPI, TransactionAPI +from ape.api import BlockAPI, EcosystemAPI, PluginConfig, ReceiptAPI, TraceAPI, TransactionAPI from ape.api.networks import LOCAL_NETWORK_NAME from ape.contracts.base import ContractCall from ape.exceptions import ( @@ -37,7 +36,6 @@ from ape.types import ( AddressType, AutoGasLimit, - CallTreeNode, ContractLog, GasLimit, RawAddress, @@ -65,6 +63,7 @@ ProxyInfo, ProxyType, ) +from ape_ethereum.trace import Trace, TransactionTrace from ape_ethereum.transactions import ( AccessListTransaction, BaseTransaction, @@ -660,7 +659,7 @@ def decode_calldata(self, abi: Union[ConstructorABI, MethodABI], calldata: bytes try: raw_input_values = decode(raw_input_types, calldata, strict=False) - except InsufficientDataBytes as err: + except (InsufficientDataBytes, OverflowError, NonEmptyPaddingBytes) as err: raise DecodingError(str(err)) from err input_values = [ @@ -744,7 +743,7 @@ def _enrich_value(self, value: Any, **kwargs) -> Any: elif isinstance(value, str) and is_hex_address(value): address = self.decode_address(value) - return self._enrich_address(address, **kwargs) + return self._enrich_contract_id(address, **kwargs) elif isinstance(value, str): # Surround non-address strings with quotes. @@ -772,6 +771,10 @@ def decode_primitive_value( elif isinstance(output_type, str) and is_array(output_type): sub_type = "[".join(output_type.split("[")[:-1]) + + if not isinstance(value, (list, tuple)): + value = (value,) + return [self.decode_primitive_value(v, sub_type) for v in value] elif isinstance(output_type, tuple): @@ -998,87 +1001,129 @@ def get_abi(_topic: HexStr) -> Optional[LogInputABICollection]: ), ) - def enrich_calltree(self, call: CallTreeNode, **kwargs) -> CallTreeNode: - kwargs["use_symbol_for_tokens"] = kwargs.get("use_symbol_for_tokens", False) - kwargs["in_place"] = kwargs.get("in_place", True) + def enrich_trace(self, trace: TraceAPI, **kwargs) -> TraceAPI: + if not isinstance(trace, Trace): + return trace - if call.txn_hash: - receipt = self.chain_manager.get_receipt(call.txn_hash) - kwargs["sender"] = receipt.sender + elif trace._enriched_calltree is not None: + # Already enriched. + return trace - # Enrich subcalls before any _return_ statement. - enriched_call = call if kwargs["in_place"] else deepcopy(call) - enriched_call.calls = [self.enrich_calltree(c, **kwargs) for c in enriched_call.calls] + if sender := trace.transaction.get("from"): + kwargs["sender"] = sender - not_address_type: bool = not self.conversion_manager.is_type( - enriched_call.contract_id, AddressType - ) - if not_address_type and is_hex_address(enriched_call.contract_id): - enriched_call.contract_id = self.decode_address(enriched_call.contract_id) + # Get the un-enriched calltree. + data = trace.get_calltree().model_dump(mode="json", by_alias=True) + + # Return value was discovered already. + if isinstance(trace, TransactionTrace): + return_value = trace.__dict__.get("return_value") if data.get("depth", 0) == 0 else None + if return_value is not None: + kwargs["return_value"] = return_value - elif not_address_type: + enriched_calltree = self._enrich_calltree(data, **kwargs) + + # Cache the result back on the trace. + trace._enriched_calltree = enriched_calltree + + return trace + + def _enrich_calltree(self, call: Dict, **kwargs) -> Dict: + if "contract_id" in call: # Already enriched. - return enriched_call + return call + + if self._test_runner and self._test_runner.gas_tracker.enabled: + default_symbol_for_tokens = not self._test_runner.gas_tracker.enabled + else: + default_symbol_for_tokens = True + + kwargs["use_symbol_for_tokens"] = kwargs.get( + "use_symbol_for_tokens", default_symbol_for_tokens + ) + call_type = call.get("call_type", "") + is_create = "CREATE" in call_type + + # Enrich sub-calls first. + if subcalls := call.get("calls"): + call["calls"] = [self._enrich_calltree(c, **kwargs) for c in subcalls] + + # Figure out the contract. + address = call.pop("address", "") + try: + call["contract_id"] = address = str(self.decode_address(address)) + except Exception: + # Tx was made with a weird address. + call["contract_id"] = address + + if calldata := call.get("calldata"): + calldata_bytes = HexBytes(calldata) + call["method_id"] = calldata_bytes[:4].hex() + call["calldata"] = calldata if is_create else calldata_bytes[4:].hex() + + else: + call["method_id"] = "0x" - # Collapse pre-compile address calls - address = cast(AddressType, enriched_call.contract_id) - address_int = int(address, 16) - if 1 <= address_int <= 9: - sub_calls = [self.enrich_calltree(c, **kwargs) for c in enriched_call.calls] - if len(sub_calls) == 1: - return sub_calls[0] + try: + address_int = int(address, 16) + except Exception: + pass + else: + # Collapse pre-compile address calls + if 1 <= address_int <= 9: + if len(call.get("calls", [])) == 1: + return call["calls"][0] - intermediary_node = CallTreeNode(contract_id=f"{address_int}") - for sub_tree in sub_calls: - intermediary_node.add(sub_tree) + return {"contract_id": f"{address_int}", "calls": call["calls"]} - return intermediary_node + depth = call.get("depth", 0) + if depth == 0 and address in self.account_manager: + call["contract_id"] = f"__{self.fee_token_symbol}_transfer__" + else: + call["contract_id"] = self._enrich_contract_id(call["contract_id"], **kwargs) if not (contract_type := self.chain_manager.contracts.get(address)): - return enriched_call + # Without a contract, we can enrich no further. + return call - enriched_call.contract_id = self._enrich_address(address, **kwargs) method_abi: Optional[Union[MethodABI, ConstructorABI]] = None - if "CREATE" in (enriched_call.call_type or ""): + if is_create: method_abi = contract_type.constructor name = "__new__" - elif enriched_call.method_id is None: - name = enriched_call.method_id or "0x" - - else: - method_id_bytes = HexBytes(enriched_call.method_id) + elif call["method_id"] != "0x": + method_id_bytes = HexBytes(call["method_id"]) if method_id_bytes in contract_type.methods: method_abi = contract_type.methods[method_id_bytes] assert isinstance(method_abi, MethodABI) # For mypy # Check if method name duplicated. If that is the case, use selector. times = len([x for x in contract_type.methods if x.name == method_abi.name]) - name = ( - method_abi.name if times == 1 else method_abi.selector - ) or enriched_call.method_id - enriched_call = self._enrich_calldata( - enriched_call, method_abi, contract_type, **kwargs - ) + name = (method_abi.name if times == 1 else method_abi.selector) or call["method_id"] + call = self._enrich_calldata(call, method_abi, contract_type, **kwargs) + else: - name = enriched_call.method_id or "0x" + name = call["method_id"] + else: + name = call.get("method_id") or "0x" - enriched_call.method_id = name + call["method_id"] = name if method_abi: - enriched_call = self._enrich_calldata( - enriched_call, method_abi, contract_type, **kwargs - ) + call = self._enrich_calldata(call, method_abi, contract_type, **kwargs) - if isinstance(method_abi, MethodABI): - enriched_call = self._enrich_returndata(enriched_call, method_abi, **kwargs) + if kwargs.get("return_value"): + # Return value was separately enriched. + call["returndata"] = kwargs["return_value"] + elif isinstance(method_abi, MethodABI): + call = self._enrich_returndata(call, method_abi, **kwargs) else: # For constructors, don't include outputs, as it is likely a large amount of bytes. - enriched_call.outputs = None + call["returndata"] = None - return enriched_call + return call - def _enrich_address(self, address: AddressType, **kwargs) -> str: + def _enrich_contract_id(self, address: AddressType, **kwargs) -> str: if address and address == kwargs.get("sender"): return "tx.origin" @@ -1110,22 +1155,16 @@ def _enrich_address(self, address: AddressType, **kwargs) -> str: return str(symbol) name = contract_type.name.strip() if contract_type.name else None - return name or self._get_contract_id_from_address(address) - - def _get_contract_id_from_address(self, address: "AddressType") -> str: - if address in self.account_manager: - return f"Transferring {self.fee_token_symbol}" - - return address + return name or address def _enrich_calldata( self, - call: CallTreeNode, + call: Dict, method_abi: Union[MethodABI, ConstructorABI], contract_type: ContractType, **kwargs, - ) -> CallTreeNode: - calldata = call.inputs + ) -> Dict: + calldata = call["calldata"] if isinstance(calldata, (str, bytes, int)): calldata_arg = HexBytes(calldata) else: @@ -1133,7 +1172,7 @@ def _enrich_calldata( # Mostly for mypy's sake. return call - if call.call_type and "CREATE" in call.call_type: + if call.get("call_type") and "CREATE" in call.get("call_type", ""): # Strip off bytecode bytecode = ( contract_type.deployment_bytecode.to_bytes() @@ -1144,26 +1183,28 @@ def _enrich_calldata( calldata_arg = HexBytes(calldata_arg.split(bytecode)[-1]) try: - call.inputs = self.decode_calldata(method_abi, calldata_arg) + call["calldata"] = self.decode_calldata(method_abi, calldata_arg) except DecodingError: - call.inputs = ["" for _ in method_abi.inputs] + call["calldata"] = ["" for _ in method_abi.inputs] else: - call.inputs = {k: self._enrich_value(v, **kwargs) for k, v in call.inputs.items()} + call["calldata"] = { + k: self._enrich_value(v, **kwargs) for k, v in call["calldata"].items() + } return call - def _enrich_returndata( - self, call: CallTreeNode, method_abi: MethodABI, **kwargs - ) -> CallTreeNode: - if call.call_type and "CREATE" in call.call_type: - call.outputs = "" + def _enrich_returndata(self, call: Dict, method_abi: MethodABI, **kwargs) -> Dict: + if "CREATE" in call.get("call_type", ""): + call["returndata"] = "" return call default_return_value = "" - if (isinstance(call.outputs, str) and is_0x_prefixed(call.outputs)) or isinstance( - call.outputs, (int, bytes) - ): - return_value_bytes = HexBytes(call.outputs) + returndata = call.get("returndata") + + if ( + returndata and isinstance(returndata, str) and is_0x_prefixed(returndata) + ) or isinstance(returndata, (int, bytes)): + return_value_bytes = HexBytes(returndata) else: return_value_bytes = None @@ -1175,7 +1216,7 @@ def _enrich_returndata( try: return_values = ( self.decode_returndata(method_abi, return_value_bytes) - if not call.failed + if not call.get("failed") else None ) except DecodingError: @@ -1183,6 +1224,9 @@ def _enrich_returndata( # Empty result, but it failed decoding because of its length. return_values = ("",) + # Cache un-enriched return_value in trace. + call["unenriched_return_values"] = return_values + values = ( tuple([default_return_value for _ in method_abi.outputs]) if return_values is None @@ -1198,7 +1242,7 @@ def _enrich_returndata( ): output_val = "" - call.outputs = output_val + call["returndata"] = output_val return call def get_python_types(self, abi_type: ABIType) -> Union[Type, Sequence]: diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index d6eb1e7f0f..62417806fc 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -4,11 +4,9 @@ import time from abc import ABC from concurrent.futures import ThreadPoolExecutor -from copy import copy from functools import cached_property, wraps -from itertools import tee from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast +from typing import Any, Dict, Iterable, Iterator, List, Optional, Union, cast import ijson # type: ignore import requests @@ -16,14 +14,6 @@ from eth_typing import BlockNumber, HexStr from eth_utils import add_0x_prefix, is_hex, to_hex from ethpm_types import EventABI -from evm_trace import CallTreeNode as EvmCallTreeNode -from evm_trace import ParityTraceList -from evm_trace import TraceFrame as EvmTraceFrame -from evm_trace import ( - create_trace_frames, - get_calltree_from_geth_call_trace, - get_calltree_from_parity_trace, -) from pydantic.dataclasses import dataclass from requests import HTTPError from web3 import HTTPProvider, IPCProvider, Web3 @@ -41,8 +31,7 @@ from web3.providers.auto import load_provider_from_environment from web3.types import FeeHistory, RPCEndpoint, TxParams -from ape.api import Address, BlockAPI, ProviderAPI, ReceiptAPI, TransactionAPI -from ape.api.networks import LOCAL_NETWORK_NAME +from ape.api import Address, BlockAPI, ProviderAPI, ReceiptAPI, TraceAPI, TransactionAPI from ape.exceptions import ( ApeException, APINotImplementedError, @@ -62,16 +51,15 @@ AddressType, AutoGasLimit, BlockID, - CallTreeNode, ContractCode, ContractLog, LogFilter, SourceTraceback, - TraceFrame, ) from ape.utils import gas_estimation_error_message, to_int from ape.utils.misc import DEFAULT_MAX_RETRIES_TX -from ape_ethereum._print import CONSOLE_CONTRACT_ID, console_contract +from ape_ethereum._print import CONSOLE_ADDRESS, console_contract +from ape_ethereum.trace import CallTrace, TraceApproach, TransactionTrace from ape_ethereum.transactions import AccessList, AccessListTransaction DEFAULT_PORT = 8545 @@ -134,6 +122,14 @@ class Web3Provider(ProviderAPI, ABC): _web3: Optional[Web3] = None _client_version: Optional[str] = None + _call_trace_approach: Optional[TraceApproach] = None + """ + Is ``None`` until known. + NOTE: This gets set in `ape_ethereum.trace.Trace`. + """ + + _supports_debug_trace_call: Optional[bool] = None + def __new__(cls, *args, **kwargs): assert_web3_provider_uri_env_var_not_set() @@ -272,12 +268,12 @@ def max_gas(self) -> int: def supports_tracing(self) -> bool: try: # NOTE: Txn hash is purposely not a real hash. - # If we get any exception besides not implemented error, - # then we support tracing on this provider. - self.get_call_tree("__CHECK_IF_SUPPORTS_TRACING__") - except APINotImplementedError: + self.make_request("debug_traceTransaction", ["__CHECK_IF_SUPPORTS_TRACING__"]) + except NotImplementedError: return False + except Exception: + # We know tracing works because we didn't get a NotImplementedError. return True return True @@ -317,28 +313,11 @@ def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional[BlockID] = N else: tx_to_trace[key] = val - try: - call_trace = self._trace_call([tx_to_trace, "latest"]) - except Exception: - call_trace = None - - traces = None - tb = None - if call_trace and txn_params.get("to"): - traces = (self._create_trace_frame(t) for t in call_trace[1]) - try: - if contract_type := self.chain_manager.contracts.get(txn_params["to"]): - tb = SourceTraceback.create( - contract_type, traces, HexBytes(txn_params["data"]) - ) - except ProviderNotConnectedError: - pass - + trace = CallTrace(tx=txn) tx_error = self.get_virtual_machine_error( err, txn=txn, - trace=traces, - source_traceback=tb, + trace=trace, ) # If this is the cause of a would-be revert, @@ -351,22 +330,10 @@ def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional[BlockID] = N message, base_err=tx_error, txn=txn, source_traceback=tx_error.source_traceback ) from err - def _trace_call(self, arguments: List[Any]) -> Tuple[Dict, Iterator[EvmTraceFrame]]: - result = self._make_request("debug_traceCall", arguments) - trace_data = result.get("structLogs", []) - return result, create_trace_frames(trace_data) - @cached_property def chain_id(self) -> int: default_chain_id = None - if ( - self.network.name - not in ( - "custom", - LOCAL_NETWORK_NAME, - ) - and not self.network.is_fork - ): + if self.network.name != "custom" and not self.network.is_dev: # If using a live network, the chain ID is hardcoded. default_chain_id = self.network.chain_id @@ -435,109 +402,99 @@ def get_storage( raise # Raise original error + def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI: + if "call_trace_approach" not in kwargs: + kwargs["call_trace_approach"] = self._call_trace_approach + + return self._get_transaction_trace(transaction_hash, **kwargs) + def send_call( self, txn: TransactionAPI, block_id: Optional[BlockID] = None, state: Optional[Dict] = None, - **kwargs, + **kwargs: Any, ) -> HexBytes: if block_id is not None: kwargs["block_identifier"] = block_id - if kwargs.pop("skip_trace", False): - return self._send_call(txn, **kwargs) - elif self._test_runner is not None: + + if state is not None: + kwargs["state_override"] = state + + skip_trace = kwargs.pop("skip_trace", False) + arguments = self._prepare_call(txn, **kwargs) + if skip_trace: + return self._eth_call(arguments) + + show_gas = kwargs.pop("show_gas_report", False) + show_trace = kwargs.pop("show_trace", False) + + if self._test_runner is not None: track_gas = self._test_runner.gas_tracker.enabled track_coverage = self._test_runner.coverage_tracker.enabled else: track_gas = False track_coverage = False - show_trace = kwargs.pop("show_trace", False) - show_gas = kwargs.pop("show_gas_report", False) - needs_trace = track_gas or track_coverage or show_trace or show_gas - if not needs_trace or not self.provider.supports_tracing or not txn.receiver: - return self._send_call(txn, **kwargs) + needs_trace = track_gas or track_coverage or show_gas or show_trace + if not needs_trace: + return self._eth_call(arguments) # The user is requesting information related to a call's trace, # such as gas usage data. - try: - with self.chain_manager.isolate(): - return self._send_call_as_txn( - txn, - track_gas=track_gas, - track_coverage=track_coverage, - show_trace=show_trace, - show_gas=show_gas, - **kwargs, - ) - except APINotImplementedError: - return self._send_call(txn, **kwargs) + # When looking at gas, we cannot use token symbols in enrichment. + # Else, the table is difficult to understand. + use_symbol_for_tokens = track_gas or show_gas - def _send_call_as_txn( - self, - txn: TransactionAPI, - track_gas: bool = False, - track_coverage: bool = False, - show_trace: bool = False, - show_gas: bool = False, - **kwargs, - ) -> HexBytes: - account = self.account_manager.test_accounts[0] - receipt = account.call(txn, **kwargs) - if not (call_tree := receipt.call_tree): - return self._send_call(txn, **kwargs) - - # Grab raw returndata before enrichment - returndata = call_tree.outputs - - if (track_gas or track_coverage) and show_gas and not show_trace: - # Optimization to enrich early and in_place=True. - call_tree.enrich() - - if track_gas: - # in_place=False in case show_trace is True - receipt.track_gas() + trace = CallTrace( + tx=arguments[0], + arguments=arguments[1:], + use_symbol_for_tokens=use_symbol_for_tokens, + supports_debug_trace_call=self._supports_debug_trace_call, + ) - if track_coverage: - receipt.track_coverage() + if track_gas and self._test_runner is not None and txn.receiver: + self._test_runner.gas_tracker.append_gas(trace, txn.receiver) + + if track_coverage and self._test_runner is not None and txn.receiver: + contract_type = self.chain_manager.contracts.get(txn.receiver) + if contract_type: + method_id = HexBytes(txn.data) + selector = ( + contract_type.methods[method_id].selector + if method_id in contract_type.methods + else None + ) + source_traceback = SourceTraceback.create(contract_type, trace, method_id) + self._test_runner.coverage_tracker.cover( + source_traceback, function=selector, contract=contract_type.name + ) if show_gas: - # in_place=False in case show_trace is True - self.chain_manager._reports.show_gas(call_tree.enrich(in_place=False)) + trace.show_gas_report() if show_trace: - call_tree = call_tree.enrich(use_symbol_for_tokens=True) - self.chain_manager._reports.show_trace(call_tree) - - return HexBytes(returndata) - - def _send_call(self, txn: TransactionAPI, **kwargs) -> HexBytes: - arguments = self._prepare_call(txn, **kwargs) - try: - return self._eth_call(arguments) - except TransactionError as err: - if not err.txn: - err.txn = txn + trace.show() - raise # The tx error + return HexBytes(trace.return_value) def _eth_call(self, arguments: List) -> HexBytes: - # Force the usage of hex-type to support a wider-range of nodes. - txn_dict = copy(arguments[0]) - if isinstance(txn_dict.get("type"), int): - txn_dict["type"] = HexBytes(txn_dict["type"]).hex() - - # Remove unnecessary values to support a wider-range of nodes. - txn_dict.pop("chainId", None) - - arguments[0] = txn_dict try: - result = self._make_request("eth_call", arguments) + result = self.make_request("eth_call", arguments) except Exception as err: - receiver = txn_dict.get("to") - raise self.get_virtual_machine_error(err, contract_address=receiver) from err + trace = CallTrace(tx=arguments[0], arguments=arguments[1:], use_tokens_for_symbols=True) + contract_address = arguments[0]["to"] + contract_type = self.chain_manager.contracts.get(contract_address) + method_id = arguments[0].get("data", "")[:10] or None + tb = ( + SourceTraceback.create(contract_type, trace, method_id) + if method_id and contract_type + else None + ) + raise self.get_virtual_machine_error( + err, trace=trace, contract_address=contract_address, source_traceback=tb + ) from err if "error" in result: raise ProviderError(result["error"]["message"]) @@ -546,7 +503,9 @@ def _eth_call(self, arguments: List) -> HexBytes: def _prepare_call(self, txn: TransactionAPI, **kwargs) -> List: # NOTE: Using JSON mode since used as request data. - txn_dict = txn.model_dump(by_alias=True, mode="json") + txn_dict = ( + txn.model_dump(by_alias=True, mode="json") if isinstance(txn, TransactionAPI) else txn + ) fields_to_convert = ("data", "chainId", "value") for field in fields_to_convert: value = txn_dict.get(field) @@ -588,7 +547,13 @@ def get_receipt( try: receipt_data = self.web3.eth.wait_for_transaction_receipt(hex_hash, timeout=timeout) except TimeExhausted as err: - raise TransactionNotFoundError(txn_hash, error_messsage=str(err)) from err + msg_str = str(err) + if f"HexBytes('{txn_hash}')" in msg_str: + msg_str = msg_str.replace(f"HexBytes('{txn_hash}')", f"'{txn_hash}'") + + raise TransactionNotFoundError( + transaction_hash=txn_hash, error_message=msg_str + ) from err ecosystem_config = self.network.ecosystem_config.model_dump(by_alias=True) network_config: Dict = ecosystem_config.get(self.network.name, {}) @@ -727,7 +692,7 @@ class YieldAction: fake_last_block = self.get_block(self.web3.eth.block_number - required_confirmations) last_num = fake_last_block.number or 0 last_hash = fake_last_block.hash or HexBytes(0) - last = YieldAction(number=last_num, hash=last_hash, time=time.time()) + last: YieldAction = YieldAction(number=last_num, hash=last_hash, time=time.time()) # A helper method for various points of ensuring we didn't timeout. def assert_chain_activity(): @@ -905,7 +870,7 @@ def fetch_log_page(block_range): # NOTE: Using JSON mode since used as request data. filter_params = page_filter.model_dump(mode="json") - logs = self._make_request("eth_getLogs", [filter_params]) + logs = self.make_request("eth_getLogs", [filter_params]) return self.network.ecosystem.decode_logs(logs, *log_filter.events) with ThreadPoolExecutor(self.concurrency) as pool: @@ -1029,63 +994,17 @@ def _post_send_transaction(self, tx: TransactionAPI, receipt: ReceiptAPI): def _post_connect(self): # Register the console contract for trace enrichment - self.chain_manager.contracts._cache_contract_type(CONSOLE_CONTRACT_ID, console_contract) + self.chain_manager.contracts._cache_contract_type(CONSOLE_ADDRESS, console_contract) - def _create_call_tree_node( - self, evm_call: EvmCallTreeNode, txn_hash: Optional[str] = None - ) -> CallTreeNode: - address = evm_call.address - try: - contract_id = str(self.provider.network.ecosystem.decode_address(address)) - except ValueError: - # Use raw value since it is not a real address. - contract_id = address.hex() - - call_type = evm_call.call_type.value - return CallTreeNode( - calls=[self._create_call_tree_node(x, txn_hash=txn_hash) for x in evm_call.calls], - call_type=call_type, - contract_id=contract_id, - failed=evm_call.failed, - gas_cost=evm_call.gas_cost, - inputs=evm_call.calldata if "CREATE" in call_type else evm_call.calldata[4:].hex(), - method_id=evm_call.calldata[:4].hex(), - outputs=evm_call.returndata.hex(), - raw=evm_call.model_dump(by_alias=True), - txn_hash=txn_hash, - ) - - def _create_trace_frame(self, evm_frame: EvmTraceFrame) -> TraceFrame: - address_bytes = evm_frame.address - try: - address = ( - self.network.ecosystem.decode_address(address_bytes.hex()) - if address_bytes - else None - ) - except ValueError: - # Might not be a real address. - address = cast(AddressType, address_bytes.hex()) if address_bytes else None - - return TraceFrame( - pc=evm_frame.pc, - op=evm_frame.op, - gas=evm_frame.gas, - gas_cost=evm_frame.gas_cost, - depth=evm_frame.depth, - contract_address=address, - raw=evm_frame.model_dump(by_alias=True), - ) - - def _make_request(self, endpoint: str, parameters: Optional[List] = None) -> Any: + def make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any: parameters = parameters or [] try: - result = self.web3.provider.make_request(RPCEndpoint(endpoint), parameters) + result = self.web3.provider.make_request(RPCEndpoint(rpc), parameters) except HTTPError as err: if "method not allowed" in str(err).lower(): raise APINotImplementedError( - f"RPC method '{endpoint}' is not implemented by this node instance." + f"RPC method '{rpc}' is not implemented by this node instance." ) raise ProviderError(str(err)) from err @@ -1103,7 +1022,7 @@ def _make_request(self, endpoint: str, parameters: Optional[List] = None) -> Any or "RPC Endpoint has not been implemented" in message ): raise APINotImplementedError( - f"RPC method '{endpoint}' is not implemented by this node instance." + f"RPC method '{rpc}' is not implemented by this node instance." ) raise ProviderError(message) @@ -1149,7 +1068,7 @@ def create_access_list( if block_id is not None: arguments.append(block_id) - result = self._make_request("eth_createAccessList", arguments) + result = self.make_request("eth_createAccessList", arguments) return [AccessList.model_validate(x) for x in result.get("accessList", [])] def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMachineError: @@ -1187,7 +1106,7 @@ def _handle_execution_reverted( self, exception: Union[Exception, str], txn: Optional[TransactionAPI] = None, - trace: Optional[Iterator[TraceFrame]] = None, + trace: Optional[TraceAPI] = None, contract_address: Optional[AddressType] = None, source_traceback: Optional[SourceTraceback] = None, ) -> ContractLogicError: @@ -1206,36 +1125,22 @@ def _handle_execution_reverted( no_reason = message == "execution reverted" if isinstance(exception, Web3ContractLogicError) and no_reason: - if data is None: - # Check for custom exception data and use that as the message instead. - # This allows compiler exception enrichment to function. - err_trace = None - try: - if trace: - trace, err_trace = tee(trace) - elif txn: - err_trace = self.provider.get_transaction_trace(txn.txn_hash.hex()) - - try: - trace_ls: List[TraceFrame] = list(err_trace) if err_trace else [] - except Exception as err: - logger.error(f"Failed getting traceback: {err}") - trace_ls = [] - - data = trace_ls[-1].raw if len(trace_ls) > 0 else {} - memory = data.get("memory", []) - return_value = "".join([x[2:] for x in memory[4:]]) - if return_value: - message = f"0x{return_value}" - no_reason = False - - except (ApeException, NotImplementedError): - # Either provider does not support or isn't a custom exception. - pass - - elif data != "no data" and is_hex(data): + # Check for custom exception data and use that as the message instead. + # This allows compiler exception enrichment to function. + if data != "no data" and is_hex(data): message = add_0x_prefix(data) + else: + if trace is None and txn is not None: + trace = self.provider.get_transaction_trace(txn.txn_hash.hex()) + + if trace is not None and (revert_message := trace.revert_message): + message = revert_message + no_reason = False + if revert_message := trace.revert_message: + message = revert_message + no_reason = False + result = ( ContractLogicError(txn=txn, **params) if no_reason @@ -1258,6 +1163,9 @@ def _handle_execution_reverted( return enriched + def _get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI: + return TransactionTrace(transaction_hash=transaction_hash, **kwargs) + class EthereumNodeProvider(Web3Provider, ABC): # optimal values for geth @@ -1266,9 +1174,6 @@ class EthereumNodeProvider(Web3Provider, ABC): name: str = "node" - can_use_parity_traces: Optional[bool] = None - """Is ``None`` until known.""" - @property def uri(self) -> str: if "url" in self.provider_settings: @@ -1322,7 +1227,7 @@ def data_dir(self) -> Path: def _ots_api_level(self) -> Optional[int]: # NOTE: Returns None when OTS namespace is not enabled. try: - result = self._make_request("ots_getApiLevel") + result = self.make_request("ots_getApiLevel") except (NotImplementedError, ApeException, ValueError): return None @@ -1391,70 +1296,10 @@ def _complete_connect(self): self.network.verify_chain_id(chain_id) def disconnect(self): - self.can_use_parity_traces = None + self._call_trace_approach = None self._web3 = None self._client_version = None - def get_transaction_trace(self, txn_hash: Union[HexBytes, str]) -> Iterator[TraceFrame]: - if isinstance(txn_hash, HexBytes): - txn_hash_str = str(to_hex(txn_hash)) - else: - txn_hash_str = txn_hash - - frames = self._stream_request( - "debug_traceTransaction", - [txn_hash_str, {"enableMemory": True}], - "result.structLogs.item", - ) - for frame in create_trace_frames(frames): - yield self._create_trace_frame(frame) - - def _get_transaction_trace_using_call_tracer(self, txn_hash: str) -> Dict: - return self._make_request( - "debug_traceTransaction", [txn_hash, {"enableMemory": True, "tracer": "callTracer"}] - ) - - def get_call_tree(self, txn_hash: str) -> CallTreeNode: - if self.can_use_parity_traces is True: - return self._get_parity_call_tree(txn_hash) - - elif self.can_use_parity_traces is False: - return self._get_geth_call_tree(txn_hash) - - elif "erigon" in self.client_version.lower(): - tree = self._get_parity_call_tree(txn_hash) - self.can_use_parity_traces = True - return tree - - try: - # Try the Parity traces first, in case node client supports it. - tree = self._get_parity_call_tree(txn_hash) - except (ValueError, APINotImplementedError, ProviderError): - self.can_use_parity_traces = False - return self._get_geth_call_tree(txn_hash) - except Exception as err: - logger.error(f"Unknown exception while checking for Parity-trace support: {err} ") - self.can_use_parity_traces = False - return self._get_geth_call_tree(txn_hash) - - # Parity style works. - self.can_use_parity_traces = True - return tree - - def _get_parity_call_tree(self, txn_hash: str) -> CallTreeNode: - result = self._make_request("trace_transaction", [txn_hash]) - if not result: - raise ProviderError(f"Failed to get trace for '{txn_hash}'.") - - traces = ParityTraceList.model_validate(result) - evm_call = get_calltree_from_parity_trace(traces) - return self._create_call_tree_node(evm_call, txn_hash=txn_hash) - - def _get_geth_call_tree(self, txn_hash: str) -> CallTreeNode: - calls = self._get_transaction_trace_using_call_tracer(txn_hash) - evm_call = get_calltree_from_geth_call_trace(calls) - return self._create_call_tree_node(evm_call, txn_hash=txn_hash) - def _log_connection(self, client_name: str): msg = f"Connecting to existing {client_name.strip()} node at" suffix = ( @@ -1468,7 +1313,7 @@ def ots_get_contract_creator(self, address: AddressType) -> Optional[Dict]: if self._ots_api_level is None: return None - result = self._make_request("ots_getContractCreator", [address]) + result = self.make_request("ots_getContractCreator", [address]) if result is None: # NOTE: Skip the explorer part of the error message via `has_explorer=True`. raise ContractNotFoundError(address, has_explorer=True, provider_name=self.name) @@ -1482,7 +1327,7 @@ def _get_contract_creation_receipt(self, address: AddressType) -> Optional[Recei return None - def _stream_request(self, method: str, params: List, iter_path="result.item"): + def stream_request(self, method: str, params: Iterable, iter_path: str = "result.item"): payload = {"jsonrpc": "2.0", "id": 1, "method": method, "params": params} results = ijson.sendable_list() coroutine = ijson.items_coro(results, iter_path) diff --git a/src/ape_ethereum/trace.py b/src/ape_ethereum/trace.py new file mode 100644 index 0000000000..c2231fa26a --- /dev/null +++ b/src/ape_ethereum/trace.py @@ -0,0 +1,634 @@ +import json +import sys +from abc import abstractmethod +from enum import Enum +from functools import cached_property +from typing import IO, Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union + +from eth_abi import decode +from eth_utils import is_0x_prefixed, to_hex +from evm_trace import ( + CallTreeNode, + CallType, + ParityTraceList, + TraceFrame, + get_calltree_from_geth_call_trace, + get_calltree_from_geth_trace, + get_calltree_from_parity_trace, +) +from evm_trace.gas import merge_reports +from hexbytes import HexBytes +from rich.tree import Tree + +from ape.api.trace import TraceAPI +from ape.exceptions import ProviderError, TransactionNotFoundError +from ape.logging import logger +from ape.types import ContractFunctionPath, GasReport +from ape.utils import ZERO_ADDRESS, is_evm_precompile, is_zero_hex +from ape.utils.trace import TraceStyles, _exclude_gas +from ape_ethereum._print import extract_debug_logs + +_INDENT = 2 +_WRAP_THRESHOLD = 50 +_REVERT_PREFIX = "0x08c379a00000000000000000000000000000000000000000000000000000000000000020" + + +class TraceApproach(Enum): + """RPC trace_transaction.""" + + """No tracing support; think of EthTester.""" + BASIC = 0 + + """RPC 'trace_transaction'.""" + PARITY = 1 + + """RPC debug_traceTransaction using tracer='callTracer'.""" + GETH_CALL_TRACER = 2 + + """ + RPC debug_traceTransaction using struct-log tracer + and sophisticated parsing from the evm-trace library. + NOT RECOMMENDED. + """ + GETH_STRUCT_LOG_PARSE = 3 + + +class Trace(TraceAPI): + """ + Set to ``True`` to use an ERC-20's SYMBOL as the contract's identifier. + Is ``True`` when showing pretty traces without gas tables. When gas is + involved, Ape must use the ``.name`` as the identifier for all contracts. + """ + + """When None, attempts to deduce.""" + call_trace_approach: Optional[TraceApproach] = None + + _enriched_calltree: Optional[Dict] = None + + def __repr__(self) -> str: + try: + return f"{self}" + except Exception as err: + # Don't let __repr__ fail. + logger.debug(f"Problem transaction trace: {err}") + return "" + + def __str__(self) -> str: + return _call_to_str(self.enriched_calltree) + + def _repr_pretty_(self, *args, **kwargs): + self.show() + + @property + @abstractmethod + def raw_trace_frames(self) -> List[Dict]: + """ + The raw trace frames. + """ + + @property + @abstractmethod + def transaction(self) -> Dict: + """ + The transaction data (obtained differently on + calls versus transactions). + """ + + @abstractmethod + def get_calltree(self) -> CallTreeNode: + """ + Get an un-enriched call-tree node. + """ + + @cached_property + def debug_logs(self) -> Iterable[Tuple[Any]]: + """ + Calls from ``console.log()`` and ``print()`` from a transactions call tree. + """ + return list(extract_debug_logs(self.get_calltree())) + + @property + def enriched_calltree(self) -> Dict: + """ + The fully enriched calltree node. + """ + if self._enriched_calltree is not None: + return self._enriched_calltree + + # Side-effect: sets `_enriched_calltree` if using Ethereum node provider. + self.provider.network.ecosystem.enrich_trace(self) + + if self._enriched_calltree is None: + # If still None (shouldn't be), set to avoid repeated attempts. + self._enriched_calltree = {} + + return self._enriched_calltree + + @cached_property + def return_value(self) -> Any: + calltree = self.enriched_calltree + + # Check if was cached from enrichment. + if "return_value" in self.__dict__: + return self.__dict__["return_value"] + + return calltree.get("unenriched_return_values", calltree.get("returndata")) + + @cached_property + def revert_message(self) -> Optional[str]: + try: + frames = self.raw_trace_frames + except Exception as err: + logger.error(f"Failed getting traceback: {err}") + frames = [] + + data = frames[-1] if len(frames) > 0 else {} + memory = data.get("memory", []) + if ret := "".join([x[2:] for x in memory[4:]]): + return HexBytes(ret).hex() + + return None + + """ API Methods """ + + def show(self, verbose: bool = False, file: IO[str] = sys.stdout): + call = self.enriched_calltree + revert_message = None + + if call.get("failed", False): + default_message = "reverted without message" + returndata = HexBytes(call.get("returndata", b"")) + if to_hex(returndata).startswith(_REVERT_PREFIX): + decoded_result = decode(("string",), returndata[4:]) + if len(decoded_result) == 1: + revert_message = f'reverted with message: "{decoded_result[0]}"' + else: + revert_message = default_message + + elif address := ( + self.transaction.get("to") or self.transaction.get("contract_address") + ): + # Try to enrich revert error using ABI. + if provider := self.network_manager.active_provider: + ecosystem = provider.network.ecosystem + else: + # Default to Ethereum. + ecosystem = self.network_manager.ethereum + + try: + instance = ecosystem.decode_custom_error(returndata, address) + except NotImplementedError: + pass + else: + revert_message = repr(instance) + + else: + revert_message = default_message + + root = self._get_tree(verbose=verbose) + console = self.chain_manager._reports._get_console(file=file) + if txn_hash := getattr(self, "transaction_hash", None): + # Only works on TransactionTrace (not CallTrace). + console.print(f"Call trace for [bold blue]'{txn_hash}'[/]") + + if revert_message: + console.print(f"[bold red]{revert_message}[/]") + + if sender := self.transaction.get("from"): + console.print(f"tx.origin=[{TraceStyles.CONTRACTS}]{sender}[/]") + + console.print(root) + + def get_gas_report(self, exclude: Optional[Sequence[ContractFunctionPath]] = None) -> GasReport: + call = self.enriched_calltree + tx = self.transaction + + # Enrich transfers. + contract_id = call.get("contract_id", "") + is_transfer = contract_id.startswith("__") and contract_id.endswith("transfer__") + if is_transfer and tx.get("to") is not None and tx["to"] in self.account_manager: + receiver_id = self.account_manager[tx["to"]].alias or tx["to"] + call["method_id"] = f"to:{receiver_id}" + + elif is_transfer and (receiver := tx.get("to")): + call["method_id"] = f"to:{receiver}" + + exclusions = exclude or [] + + if ( + not call.get("contract_id") + or not call.get("method_id") + or _exclude_gas(exclusions, call.get("contract_id", ""), call.get("method_id", "")) + ): + return merge_reports(*(c.get_gas_report(exclude) for c in call.get("calls", []))) + + elif not is_zero_hex(call["method_id"]) and not is_evm_precompile(call["method_id"]): + reports = [ + *[c.get_gas_report(exclude) for c in call.get("calls", [])], + { + call["contract_id"]: { + call["method_id"]: ( + [call.get("gas_cost")] if call.get("gas_cost") is not None else [] + ) + } + }, + ] + return merge_reports(*reports) + + return merge_reports(*(c.get_gas_report(exclude) for c in call.get("calls", []))) + + def show_gas_report(self, verbose: bool = False, file: IO[str] = sys.stdout): + gas_report = self.get_gas_report() + self.chain_manager._reports.show_gas(gas_report, file=file) + + def get_raw_frames(self) -> List[Dict]: + return self.raw_trace_frames + + def get_raw_calltree(self) -> Dict: + return self.get_calltree().model_dump(mode="json", by_alias=True) + + """ Shared helpers """ + + def _get_tx_calltree_kwargs(self) -> Dict: + if (receiver := self.transaction.get("to")) and receiver != ZERO_ADDRESS: + call_type = CallType.CALL + else: + call_type = CallType.CREATE + receiver = self.transaction.get("contract_address") + + return { + "address": receiver, + "call_type": call_type, + "calldata": self.transaction.get("data", b""), + "gas_cost": self.transaction.get("gasCost"), + "failed": False, + "value": self.transaction.get("value", 0), + } + + def _debug_trace_transaction_struct_logs_to_call(self) -> CallTreeNode: + init_kwargs = self._get_tx_calltree_kwargs() + return get_calltree_from_geth_trace( + (TraceFrame.model_validate(f) for f in self.raw_trace_frames), **init_kwargs + ) + + def _get_tree(self, verbose: bool = False) -> Tree: + return parse_rich_tree(self.enriched_calltree, verbose=verbose) + + +class TransactionTrace(Trace): + transaction_hash: str + debug_trace_transaction_parameters: Dict = {"enableMemory": True} + + @cached_property + def raw_trace_frames(self) -> List[Dict]: + """ + The raw trace ``"structLogs"`` from ``debug_traceTransaction`` + for deeper investigation. + """ + return list(self._stream_struct_logs()) + + @cached_property + def transaction(self) -> Dict: + receipt = self.chain_manager.get_receipt(self.transaction_hash) + data = receipt.transaction.model_dump(mode="json", by_alias=True) + return {**data, **receipt.model_dump(by_alias=True)} + + def _stream_struct_logs(self) -> Iterator[Dict]: + parameters = self.debug_trace_transaction_parameters + yield from self.provider.stream_request( + "debug_traceTransaction", + [self.transaction_hash, parameters], + iter_path="result.item", + ) + + def get_calltree(self) -> CallTreeNode: + if self.call_trace_approach is TraceApproach.BASIC: + return self._get_basic_calltree() + + elif self.call_trace_approach is TraceApproach.PARITY: + return self._trace_transaction() + + elif self.call_trace_approach is TraceApproach.GETH_CALL_TRACER: + return self._debug_trace_transaction_call_tracer() + + elif self.call_trace_approach is TraceApproach.GETH_STRUCT_LOG_PARSE: + return self._debug_trace_transaction_struct_logs_to_call() + + elif "erigon" in self.provider.client_version.lower(): + # Based on the client version, we know parity works. + call = self._trace_transaction() + self._set_approach(TraceApproach.PARITY) + return call + + return self._discover_calltrace_approach() + + def _discover_calltrace_approach(self) -> CallTreeNode: + # NOTE: This method is only called once, if at all. + # After discovery, short-circuits to the correct approach. + # It tries to create an evm_trace.CallTreeNode using + # all the approaches in order from fastest to slowest. + + TA = TraceApproach + approaches = { + TA.PARITY: self._trace_transaction, + TA.GETH_CALL_TRACER: self._debug_trace_transaction_call_tracer, + TA.GETH_STRUCT_LOG_PARSE: self._debug_trace_transaction_struct_logs_to_call, + TA.BASIC: self._get_basic_calltree, + } + + reason = "" + for approach, fn in approaches.items(): + try: + call = fn() + except Exception as err: + reason = f"{err}" + continue + + self._set_approach(approach) + return call + + # Not sure this would happen, as the basic-approach should + # always work. + raise ProviderError(f"Unable to create CallTreeNode. Reason: {reason}") + + def _debug_trace_transaction(self, parameters: Optional[Dict] = None) -> Dict: + parameters = parameters or self.debug_trace_transaction_parameters + return self.provider.make_request( + "debug_traceTransaction", [self.transaction_hash, parameters] + ) + + def _debug_trace_transaction_call_tracer(self) -> CallTreeNode: + parameters = {**self.debug_trace_transaction_parameters, "tracer": "callTracer"} + data = self._debug_trace_transaction(parameters) + return get_calltree_from_geth_call_trace(data) + + def _trace_transaction(self) -> CallTreeNode: + try: + data = self.provider.make_request("trace_transaction", [self.transaction_hash]) + except ProviderError as err: + if "transaction not found" in str(err).lower(): + raise TransactionNotFoundError(transaction_hash=self.transaction_hash) from err + + raise # The ProviderError as-is + + parity_objects = ParityTraceList.model_validate(data) + return get_calltree_from_parity_trace(parity_objects) + + def _get_basic_calltree(self) -> CallTreeNode: + init_kwargs = self._get_tx_calltree_kwargs() + receipt = self.chain_manager.get_receipt(self.transaction_hash) + init_kwargs["gas_cost"] = receipt.gas_used + + # Figure out the 'returndata' using 'eth_call' RPC. + tx = receipt.transaction.model_copy(update={"nonce": None}) + return_value = self.provider.send_call(tx, block_id=receipt.block_number) + init_kwargs["returndata"] = return_value + + return CallTreeNode(**init_kwargs) + + def _set_approach(self, approach: TraceApproach): + self.call_trace_approach = approach + if hasattr(self.provider, "_call_trace_approach"): + self.provider._call_trace_approach = approach + + +class CallTrace(Trace): + tx: Dict + arguments: List[Any] = [] + + """debug_traceCall must use the struct-log tracer.""" + call_trace_approach: TraceApproach = TraceApproach.GETH_STRUCT_LOG_PARSE + supports_debug_trace_call: Optional[bool] = None + + @property + def raw_trace_frames(self) -> List[Dict]: + return self._traced_call.get("structLogs", []) + + @property + def return_value(self) -> Any: + return self._traced_call.get("returnValue", "") + + @cached_property + def _traced_call(self) -> Dict: + if self.supports_debug_trace_call is True: + return self._debug_trace_call() + elif self.supports_debug_trace_call is False: + return {} + + try: + result = self._debug_trace_call() + except Exception: + self._set_supports_trace_call(False) + return {} + + self._set_supports_trace_call(True) + return result + + @property + def transaction(self) -> Dict: + return self.tx + + def get_calltree(self) -> CallTreeNode: + calltree = self._debug_trace_transaction_struct_logs_to_call() + calltree.gas_cost = self._traced_call.get("gas", calltree.gas_cost) + calltree.failed = self._traced_call.get("failed", calltree.failed) + return calltree + + def _set_supports_trace_call(self, value: bool): + self.supports_debug_trace_call = value + if hasattr(self.provider, "_supports_debug_trace_call"): + self.provider._supports_debug_trace_call = True + + def _debug_trace_call(self): + arguments = [self.transaction, *self.arguments] + + # Block ID is required, at least for regular geth nodes. + if len(arguments) == 1: + arguments.append("latest") + + return self.provider.make_request("debug_traceCall", arguments) + + +def parse_rich_tree(call: Dict, verbose: bool = False) -> Tree: + tree = _create_tree(call, verbose=verbose) + for sub_call in call["calls"]: + sub_tree = parse_rich_tree(sub_call, verbose=verbose) + tree.add(sub_tree) + + return tree + + +def _call_to_str(call: Dict, stylize: bool = False, verbose: bool = False) -> str: + contract = str(call.get("contract_id", "")) + is_create = "CREATE" in call.get("call_type", "") + method = ( + "__new__" + if is_create and call["method_id"] and is_0x_prefixed(call["method_id"]) + else str(call.get("method_id") or "") + ) + if "(" in method: + # Only show short name, not ID name + # (it is the full signature when multiple methods have the same name). + method = method.split("(")[0].strip() or method + + if stylize: + contract = f"[{TraceStyles.CONTRACTS}]{contract}[/]" + method = f"[{TraceStyles.METHODS}]{method}[/]" + + call_path = f"{contract}.{method}" + + if call.get("call_type") is not None and call["call_type"].upper() == "DELEGATECALL": + delegate = "(delegate)" + if stylize: + delegate = f"[orange]{delegate}[/]" + + call_path = f"{delegate} {call_path}" + + arguments_str = _get_inputs_str(call.get("calldata"), stylize=stylize) + if is_create and is_0x_prefixed(arguments_str): + # Un-enriched CREATE calldata is a massive hex. + arguments_str = "" + + signature = f"{call_path}{arguments_str}" + returndata = call.get("returndata", "") + + if not is_create and returndata not in ((), [], None, {}, ""): + if return_str := _get_outputs_str(returndata, stylize=stylize): + signature = f"{signature} -> {return_str}" + + if call.get("value"): + value = str(call["value"]) + if stylize: + value = f"[{TraceStyles.VALUE}]{value}[/]" + + signature += f" {value}" + + if call.get("gas_cost"): + gas_value = f"[{call['gas_cost']} gas]" + if stylize: + gas_value = f"[{TraceStyles.GAS_COST}]{gas_value}[/]" + + signature += f" {gas_value}" + + if verbose: + verbose_items = {k: v for k, v in call.items() if type(v) in (int, str, bytes, float)} + extra = json.dumps(verbose_items, indent=2) + signature = f"{signature}\n{extra}" + + return signature + + +def _create_tree(call: Dict, verbose: bool = False) -> Tree: + signature = _call_to_str(call, stylize=True, verbose=verbose) + return Tree(signature) + + +def _get_inputs_str(inputs: Any, stylize: bool = False) -> str: + color = TraceStyles.INPUTS if stylize else None + if inputs in ["0x", None, (), [], {}]: + return "()" + + elif isinstance(inputs, dict): + return _dict_to_str(inputs, color=color) + + elif isinstance(inputs, bytes): + return HexBytes(inputs).hex() + + return f"({inputs})" + + +def _get_outputs_str(outputs: Any, stylize: bool = False) -> Optional[str]: + if outputs in ["0x", None, (), [], {}]: + return None + + elif isinstance(outputs, dict): + color = TraceStyles.OUTPUTS if stylize else None + return _dict_to_str(outputs, color=color) + + elif isinstance(outputs, (list, tuple)): + return ( + f"[{TraceStyles.OUTPUTS}]{_list_to_str(outputs)}[/]" + if stylize + else _list_to_str(outputs) + ) + + return f"[{TraceStyles.OUTPUTS}]{outputs}[/]" if stylize else str(outputs) + + +def _dict_to_str(dictionary: Dict, color: Optional[str] = None) -> str: + length = sum(len(str(v)) for v in [*dictionary.keys(), *dictionary.values()]) + do_wrap = length > _WRAP_THRESHOLD + + index = 0 + end_index = len(dictionary) - 1 + kv_str = "(\n" if do_wrap else "(" + + for key, value in dictionary.items(): + if do_wrap: + kv_str += _INDENT * " " + + if isinstance(value, (list, tuple)): + value = _list_to_str(value, 1 if do_wrap else 0) + + value_str = f"[{color}]{value}[/]" if color is not None else str(value) + kv_str += f"{key}={value_str}" if key and not key.isnumeric() else value_str + if index < end_index: + kv_str += ", " + + if do_wrap: + kv_str += "\n" + + index += 1 + + return f"{kv_str})" + + +def _list_to_str(ls: Union[List, Tuple], depth: int = 0) -> str: + if not isinstance(ls, (list, tuple)) or len(str(ls)) < _WRAP_THRESHOLD: + return str(ls) + + elif ls and isinstance(ls[0], (list, tuple)): + # List of lists + sub_lists = [_list_to_str(i) for i in ls] + + # Use multi-line if exceeds threshold OR any of the sub-lists use multi-line + extra_chars_len = (len(sub_lists) - 1) * 2 + use_multiline = len(str(sub_lists)) + extra_chars_len > _WRAP_THRESHOLD or any( + ["\n" in ls for ls in sub_lists] + ) + + if not use_multiline: + # Happens for lists like '[[0], [1]]' that are short. + return f"[{', '.join(sub_lists)}]" + + value = "[\n" + num_sub_lists = len(sub_lists) + index = 0 + spacing = _INDENT * " " * 2 + for formatted_list in sub_lists: + if "\n" in formatted_list: + # Multi-line sub list. Append 1 more spacing to each line. + indented_item = f"\n{spacing}".join(formatted_list.splitlines()) + value = f"{value}{spacing}{indented_item}" + else: + # Single line sub-list + value = f"{value}{spacing}{formatted_list}" + + if index < num_sub_lists - 1: + value = f"{value}," + + value = f"{value}\n" + index += 1 + + value = f"{value}]" + return value + + return _list_to_multiline_str(ls, depth=depth) + + +def _list_to_multiline_str(value: Union[List, Tuple], depth: int = 0) -> str: + spacing = _INDENT * " " + ls_spacing = spacing * (depth + 1) + joined = ",\n".join([f"{ls_spacing}{v}" for v in value]) + new_val = f"[\n{joined}\n{spacing * depth}]" + return new_val diff --git a/src/ape_ethereum/transactions.py b/src/ape_ethereum/transactions.py index 15727c5b8d..7da13753a2 100644 --- a/src/ape_ethereum/transactions.py +++ b/src/ape_ethereum/transactions.py @@ -10,18 +10,18 @@ serializable_unsigned_transaction_from_dict, ) from eth_pydantic_types import HexBytes -from eth_utils import decode_hex, encode_hex, keccak, to_hex, to_int +from eth_utils import decode_hex, encode_hex, keccak, to_int from ethpm_types import ContractType from ethpm_types.abi import EventABI, MethodABI from pydantic import BaseModel, Field, field_validator, model_validator from ape.api import ReceiptAPI, TransactionAPI from ape.contracts import ContractEvent -from ape.exceptions import APINotImplementedError, OutOfGasError, SignatureError, TransactionError +from ape.exceptions import OutOfGasError, SignatureError, TransactionError from ape.logging import logger -from ape.types import AddressType, CallTreeNode, ContractLog, ContractLogContainer, SourceTraceback +from ape.types import AddressType, ContractLog, ContractLogContainer, SourceTraceback from ape.utils import ZERO_ADDRESS -from ape_ethereum._print import extract_debug_logs +from ape_ethereum.trace import Trace class TransactionStatusEnum(IntEnum): @@ -204,28 +204,23 @@ def total_fees_paid(self) -> int: def failed(self) -> bool: return self.status != TransactionStatusEnum.NO_ERROR - @cached_property - def call_tree(self) -> Optional[CallTreeNode]: - return self.provider.get_call_tree(self.txn_hash) - @cached_property def debug_logs_typed(self) -> List[Tuple[Any]]: """ Extract messages to console outputted by contracts via print() or console.log() statements """ - try: - self.call_tree - # Some providers do not implement this, so skip - except APINotImplementedError: + trace = self.trace + # Some providers do not implement this, so skip. + except NotImplementedError: logger.debug("Call tree not available, skipping debug log extraction") - return list() + return [] - # If the call tree is not available, no logs are available - if self.call_tree is None: - return list() + # If the trace is not available, no logs are available. + if trace is None or not isinstance(trace, Trace): + return [] - return list(extract_debug_logs(self.call_tree)) + return list(trace.debug_logs) @cached_property def contract_type(self) -> Optional[ContractType]: @@ -266,68 +261,10 @@ def raise_for_status(self): raise TransactionError(f"Transaction '{txn_hash}' failed.", txn=self) def show_trace(self, verbose: bool = False, file: IO[str] = sys.stdout): - if not (call_tree := self.call_tree): - return - - call_tree.enrich(use_symbol_for_tokens=True) - revert_message = None - - if call_tree.failed: - default_message = "reverted without message" - returndata = HexBytes(call_tree.raw["returndata"]) - if to_hex(returndata).startswith( - "0x08c379a00000000000000000000000000000000000000000000000000000000000000020" - ): - # Extra revert-message - decoded_result = decode(("string",), returndata[4:]) - if len(decoded_result) == 1: - revert_message = f'reverted with message: "{decoded_result[0]}"' - else: - revert_message = default_message - - elif address := (self.receiver or self.contract_address): - # Try to enrich revert error using ABI. - if provider := self.network_manager.active_provider: - ecosystem = provider.network.ecosystem - else: - # Default to Ethereum. - ecosystem = self.network_manager.ethereum - - try: - instance = ecosystem.decode_custom_error(returndata, address) - except NotImplementedError: - pass - else: - revert_message = repr(instance) - - self.chain_manager._reports.show_trace( - call_tree, - sender=self.sender, - transaction_hash=self.txn_hash, - revert_message=revert_message, - verbose=verbose, - file=file, - ) + self.trace.show(verbose=verbose, file=file) def show_gas_report(self, file: IO[str] = sys.stdout): - if not (call_tree := self.call_tree): - return - - call_tree.enrich() - - # Enrich transfers. - if ( - call_tree.contract_id.startswith("Transferring ") - and self.receiver is not None - and self.receiver in self.account_manager - ): - receiver_id = self.account_manager[self.receiver].alias or self.receiver - call_tree.method_id = f"to:{receiver_id}" - - elif call_tree.contract_id.startswith("Transferring "): - call_tree.method_id = f"to:{self.receiver}" - - self.chain_manager._reports.show_gas(call_tree, file=file) + self.trace.show_gas_report() def show_source_traceback(self, file: IO[str] = sys.stdout): self.chain_manager._reports.show_source_traceback( diff --git a/src/ape_node/provider.py b/src/ape_node/provider.py index 59a9abfa50..9251dc7b24 100644 --- a/src/ape_node/provider.py +++ b/src/ape_node/provider.py @@ -1,14 +1,12 @@ import atexit import shutil -from itertools import tee from pathlib import Path from subprocess import DEVNULL, PIPE, Popen -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union from eth_pydantic_types import HexBytes from eth_typing import HexStr from eth_utils import add_0x_prefix, to_hex, to_wei -from evm_trace import CallType, get_calltree_from_geth_trace from evmchains import get_random_rpc from geth.accounts import ensure_account_exists # type: ignore from geth.chain import initialize_chain # type: ignore @@ -19,10 +17,9 @@ from web3.middleware import geth_poa_middleware from yarl import URL -from ape.api import PluginConfig, SubprocessProvider, TestProviderAPI, TransactionAPI -from ape.exceptions import ProviderError +from ape.api import PluginConfig, SubprocessProvider, TestProviderAPI from ape.logging import LogLevel, logger -from ape.types import BlockID, CallTreeNode, SnapshotID, SourceTraceback +from ape.types import SnapshotID from ape.utils import ( DEFAULT_NUMBER_OF_TEST_ACCOUNTS, DEFAULT_TEST_CHAIN_ID, @@ -229,7 +226,6 @@ def __init__(self): class GethDev(EthereumNodeProvider, TestProviderAPI, SubprocessProvider): _process: Optional[GethDevProcess] = None name: str = "node" - can_use_parity_traces: Optional[bool] = False @property def process_name(self) -> str: @@ -252,7 +248,7 @@ def __repr__(self) -> str: @property def auto_mine(self) -> bool: - return self._make_request("eth_mining", []) + return self.make_request("eth_mining", []) @auto_mine.setter def auto_mine(self, value): @@ -341,7 +337,7 @@ def restore(self, snapshot_id: SnapshotID): logger.error("Unable to set head to future block.") return - self._make_request("debug_setHead", [block_number_hex_str]) + self.make_request("debug_setHead", [block_number_hex_str]) @raises_not_implemented def set_timestamp(self, new_timestamp: int): @@ -351,120 +347,6 @@ def set_timestamp(self, new_timestamp: int): def mine(self, num_blocks: int = 1): pass - def send_call( - self, - txn: TransactionAPI, - block_id: Optional[BlockID] = None, - state: Optional[Dict] = None, - **kwargs: Any, - ) -> HexBytes: - if block_id is not None: - kwargs["block_identifier"] = block_id - - if state is not None: - kwargs["state_override"] = state - - skip_trace = kwargs.pop("skip_trace", False) - arguments = self._prepare_call(txn, **kwargs) - if skip_trace: - return self._eth_call(arguments) - - show_gas = kwargs.pop("show_gas_report", False) - show_trace = kwargs.pop("show_trace", False) - - if self._test_runner is not None: - track_gas = self._test_runner.gas_tracker.enabled - track_coverage = self._test_runner.coverage_tracker.enabled - else: - track_gas = False - track_coverage = False - - needs_trace = track_gas or track_coverage or show_gas or show_trace - if not needs_trace: - return self._eth_call(arguments) - - # The user is requesting information related to a call's trace, - # such as gas usage data. - - result, trace_frames = self._trace_call(arguments) - trace_frames, frames_copy = tee(trace_frames) - return_value = HexBytes(result["returnValue"]) - root_node_kwargs = { - "gas_cost": result.get("gas", 0), - "address": txn.receiver, - "calldata": txn.data, - "value": txn.value, - "call_type": CallType.CALL, - "failed": False, - "returndata": return_value, - } - - evm_call_tree = get_calltree_from_geth_trace(trace_frames, **root_node_kwargs) - - # NOTE: Don't pass txn_hash here, as it will fail (this is not a real txn). - call_tree = self._create_call_tree_node(evm_call_tree) - - if track_gas and show_gas and not show_trace and call_tree: - # Optimization to enrich early and in_place=True. - call_tree.enrich() - - if track_gas and call_tree and self._test_runner is not None and txn.receiver: - # Gas report being collected, likely for showing a report - # at the end of a test run. - # Use `in_place=False` in case also `show_trace=True` - enriched_call_tree = call_tree.enrich(in_place=False) - self._test_runner.gas_tracker.append_gas(enriched_call_tree, txn.receiver) - - if track_coverage and self._test_runner is not None and txn.receiver: - contract_type = self.chain_manager.contracts.get(txn.receiver) - if contract_type: - traceframes = (self._create_trace_frame(x) for x in frames_copy) - method_id = HexBytes(txn.data) - selector = ( - contract_type.methods[method_id].selector - if method_id in contract_type.methods - else None - ) - source_traceback = SourceTraceback.create(contract_type, traceframes, method_id) - self._test_runner.coverage_tracker.cover( - source_traceback, function=selector, contract=contract_type.name - ) - - if show_gas: - enriched_call_tree = call_tree.enrich(in_place=False) - self.chain_manager._reports.show_gas(enriched_call_tree) - - if show_trace: - call_tree = call_tree.enrich(use_symbol_for_tokens=True) - self.chain_manager._reports.show_trace(call_tree) - - return return_value - - def _eth_call(self, arguments: List) -> HexBytes: - try: - result = self._make_request("eth_call", arguments) - except Exception as err: - trace, trace2 = tee(self._create_trace_frame(x) for x in self._trace_call(arguments)[1]) - contract_address = arguments[0]["to"] - contract_type = self.chain_manager.contracts.get(contract_address) - method_id = arguments[0].get("data", "")[:10] or None - tb = ( - SourceTraceback.create(contract_type, trace, method_id) - if method_id and contract_type - else None - ) - raise self.get_virtual_machine_error( - err, trace=trace2, contract_address=contract_address, source_traceback=tb - ) from err - - if "error" in result: - raise ProviderError(result["error"]["message"]) - - return HexBytes(result) - - def get_call_tree(self, txn_hash: str, **root_node_kwargs) -> CallTreeNode: - return self._get_geth_call_tree(txn_hash, **root_node_kwargs) - def build_command(self) -> List[str]: return self._process.command if self._process else [] diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index bac089acbc..b591c1e2e3 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -55,7 +55,7 @@ def evm_backend(self) -> PyEVMBackend: def tester(self): chain_id = self.settings.chain_id if self._web3 is not None: - connected_chain_id = self._make_request("eth_chainId") + connected_chain_id = self.make_request("eth_chainId") if connected_chain_id == chain_id: # Is already connected and settings have not changed. return @@ -150,10 +150,14 @@ def settings(self) -> EthTesterProviderConfig: {**self.config.provider.model_dump(), **self.provider_settings} ) + @property + def supports_tracing(self) -> bool: + return False + @cached_property def chain_id(self) -> int: try: - result = self._make_request("eth_chainId") + result = self.make_request("eth_chainId") except ProviderNotConnectedError: result = self.settings.chain_id diff --git a/tests/conftest.py b/tests/conftest.py index 67269442bb..686597de0b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,8 @@ from ape.exceptions import APINotImplementedError, UnknownSnapshotError from ape.logging import LogLevel, logger from ape.managers.config import CONFIG_FILE_NAME +from ape.pytest.config import ConfigWrapper +from ape.pytest.gas import GasTracker from ape.types import AddressType from ape.utils import DEFAULT_TEST_CHAIN_ID, ZERO_ADDRESS, create_tempdir from ape.utils.basemodel import only_raise_attribute_error @@ -589,3 +591,13 @@ def custom_network_chain_id_1(): @pytest.fixture def custom_network(ethereum, custom_networks_config): return ethereum.apenet + + +@pytest.fixture +def config_wrapper(mocker): + return ConfigWrapper(mocker.MagicMock()) + + +@pytest.fixture +def gas_tracker(config_wrapper): + return GasTracker(config_wrapper) diff --git a/tests/functional/geth/conftest.py b/tests/functional/geth/conftest.py index afeaf9edf6..5af9c2d742 100644 --- a/tests/functional/geth/conftest.py +++ b/tests/functional/geth/conftest.py @@ -8,11 +8,6 @@ from tests.functional.data.python import TRACE_RESPONSE -@pytest.fixture -def txn_hash(): - return "0x053cba5c12172654d894f66d5670bab6215517a94189a9ffc09bc40a589ec04d" - - @pytest.fixture def parity_trace_response(): return TRACE_RESPONSE diff --git a/tests/functional/geth/test_contract.py b/tests/functional/geth/test_contract.py index 280b60dc4d..bbaf08d50d 100644 --- a/tests/functional/geth/test_contract.py +++ b/tests/functional/geth/test_contract.py @@ -29,7 +29,19 @@ def test_contract_interaction(geth_provider, geth_account, geth_contract, mocker @geth_process_test -def test_revert(accounts, not_owner, geth_contract): +def test_contract_call_show_trace(geth_contract, geth_account): + """ + Show the `show_trace=True` does not corrupt the value. + Note: The provider uses `debug_traceCall` to get the result instead of + `eth_call`. + """ + geth_contract.setNumber(203, sender=geth_account) + actual = geth_contract.myNumber(show_trace=True) + assert actual == 203 + + +@geth_process_test +def test_tx_revert(accounts, not_owner, geth_contract): # 'sender' is not the owner so it will revert (with a message) with pytest.raises(ContractLogicError, match="!authorized") as err: geth_contract.setNumber(5, sender=not_owner) diff --git a/tests/functional/geth/test_contract_event.py b/tests/functional/geth/test_contract_event.py new file mode 100644 index 0000000000..d2f0c312ae --- /dev/null +++ b/tests/functional/geth/test_contract_event.py @@ -0,0 +1,10 @@ +from tests.conftest import geth_process_test + + +@geth_process_test +def test_contract_event(geth_contract, geth_account): + geth_contract.setNumber(101010, sender=geth_account) + actual = geth_contract.NumberChange[-1] + assert actual.event_name == "NumberChange" + assert actual.contract_address == geth_contract.address + assert actual.event_arguments["newNum"] == 101010 diff --git a/tests/functional/geth/test_gas_tracker.py b/tests/functional/geth/test_gas_tracker.py new file mode 100644 index 0000000000..6b4026c61e --- /dev/null +++ b/tests/functional/geth/test_gas_tracker.py @@ -0,0 +1,23 @@ +from tests.conftest import geth_process_test + + +@geth_process_test +def test_append_gas(gas_tracker, geth_account, geth_contract): + tx = geth_contract.setNumber(924, sender=geth_account) + trace = tx.trace + gas_tracker.append_gas(trace, geth_contract.address) + report = gas_tracker.session_gas_report + contract_name = geth_contract.contract_type.name + assert contract_name in report + assert "setNumber" in report[contract_name] + assert tx.gas_used in report[contract_name]["setNumber"] + + +@geth_process_test +def test_append_gas_deploy(gas_tracker, geth_contract): + tx = geth_contract.receipt + trace = tx.trace + gas_tracker.append_gas(trace, geth_contract.address) + report = gas_tracker.session_gas_report + expected = {geth_contract.contract_type.name: {"__new__": [tx.gas_used]}} + assert report == expected diff --git a/tests/functional/geth/test_provider.py b/tests/functional/geth/test_provider.py index 9418bda823..e1afd69998 100644 --- a/tests/functional/geth/test_provider.py +++ b/tests/functional/geth/test_provider.py @@ -4,6 +4,7 @@ import pytest from eth_pydantic_types import HashBytes32 from eth_typing import HexStr +from eth_utils import keccak from evmchains import PUBLIC_CHAIN_META from hexbytes import HexBytes from web3.exceptions import ContractLogicError as Web3ContractLogicError @@ -17,11 +18,13 @@ TransactionError, TransactionNotFoundError, ) +from ape.utils import to_int from ape_ethereum.ecosystem import Block from ape_ethereum.provider import DEFAULT_SETTINGS, EthereumNodeProvider from ape_ethereum.transactions import ( AccessList, AccessListTransaction, + DynamicFeeTransaction, TransactionStatusEnum, TransactionType, ) @@ -103,15 +106,6 @@ def test_repr_on_live_network_and_disconnected(networks): assert actual == expected -@geth_process_test -def test_get_logs(geth_contract, geth_account): - geth_contract.setNumber(101010, sender=geth_account) - actual = geth_contract.NumberChange[-1] - assert actual.event_name == "NumberChange" - assert actual.contract_address == geth_contract.address - assert actual.event_arguments["newNum"] == 101010 - - @geth_process_test def test_chain_id_when_connected(geth_provider): assert geth_provider.chain_id == 1337 @@ -201,10 +195,33 @@ def test_get_block_not_found(geth_provider): @geth_process_test -def test_get_receipt_not_exists_with_timeout(geth_provider, txn_hash): +def test_get_block_pending(geth_provider, geth_account, geth_second_account, accounts): + """ + Pending timestamps can be weird. + This ensures we can check those are various strange states of geth. + """ + actual = geth_provider.get_block("latest") + assert isinstance(actual, Block) + + snap = geth_provider.snapshot() + + # Transact to increase block + geth_account.transfer(geth_second_account, "1 gwei") + actual = geth_provider.get_block("latest") + assert isinstance(actual, Block) + + # Restore state before transaction + geth_provider.restore(snap) + actual = geth_provider.get_block("latest") + assert isinstance(actual, Block) + + +@geth_process_test +def test_get_receipt_not_exists_with_timeout(geth_provider): + txn_hash = "0x0123" expected = ( f"Transaction '{txn_hash}' not found. " - rf"Error: Transaction HexBytes\('{txn_hash}'\) " + rf"Error: Transaction '{txn_hash}' " "is not in the chain after 0 seconds" ) with pytest.raises(TransactionNotFoundError, match=expected): @@ -361,6 +378,23 @@ def test_send_transaction_when_no_error_and_receipt_fails( geth_provider._web3 = start_web3 +@geth_process_test +def test_send_call(geth_provider, ethereum, geth_contract): + txn = DynamicFeeTransaction.model_validate( + { + "chainId": 1337, + "to": geth_contract.address, + "gas": 4716984, + "value": 0, + "data": HexBytes(keccak(text="myNumber()")[:4]), + "type": 2, + "accessList": [], + } + ) + actual = geth_provider.send_call(txn) + assert to_int(actual) == 0 + + @geth_process_test def test_network_choice(geth_provider): actual = geth_provider.network_choice @@ -386,10 +420,10 @@ def test_make_request_not_exists(geth_provider): APINotImplementedError, match="RPC method 'ape_thisDoesNotExist' is not implemented by this node instance.", ): - geth_provider._make_request("ape_thisDoesNotExist") + geth_provider.make_request("ape_thisDoesNotExist") -def test_geth_not_found(): +def test_geth_bin_not_found(): bin_name = "__NOT_A_REAL_EXECUTABLE_HOPEFULLY__" with pytest.raises(NodeSoftwareNotInstalledError): _ = GethDevProcess(Path.cwd(), executable=bin_name) @@ -431,14 +465,14 @@ def test_base_fee_no_history(geth_provider, mocker, ret): @geth_process_test -def test_estimate_gas(geth_contract, geth_provider, geth_account): +def test_estimate_gas_cost(geth_contract, geth_provider, geth_account): txn = geth_contract.setNumber.as_transaction(900, sender=geth_account) estimate = geth_provider.estimate_gas_cost(txn) assert estimate > 0 @geth_process_test -def test_estimate_gas_of_static_fee_txn(geth_contract, geth_provider, geth_account): +def test_estimate_gas_cost_of_static_fee_txn(geth_contract, geth_provider, geth_account): txn = geth_contract.setNumber.as_transaction(900, sender=geth_account, type=0) estimate = geth_provider.estimate_gas_cost(txn) assert estimate > 0 @@ -498,10 +532,10 @@ def hacked_send_call(*args, **kwargs): orig = networks.active_provider networks.active_provider = provider - _ = provider.send_call(tx, block_id=block_id) == HexStr("0x") + _ = provider.send_call(tx, block_id=block_id, skip_trace=True) == HexStr("0x") networks.active_provider = orig # put back ASAP - actual = provider._send_call.call_args[-1]["block_identifier"] + actual = provider._prepare_call.call_args[-1]["block_identifier"] assert actual == block_id diff --git a/tests/functional/geth/test_receipt.py b/tests/functional/geth/test_receipt.py new file mode 100644 index 0000000000..0c267065e0 --- /dev/null +++ b/tests/functional/geth/test_receipt.py @@ -0,0 +1,54 @@ +from ape.api import TraceAPI +from ape.utils import ManagerAccessMixin +from tests.conftest import geth_process_test + + +@geth_process_test +def test_return_value_list(geth_account, geth_contract): + tx = geth_contract.getFilledArray.transact(sender=geth_account) + assert tx.return_value == [1, 2, 3] + + +@geth_process_test +def test_return_value_nested_address_array(geth_account, geth_contract, zero_address): + tx = geth_contract.getNestedAddressArray.transact(sender=geth_account) + expected = [ + [geth_account.address, geth_account.address, geth_account.address], + [zero_address, zero_address, zero_address], + ] + actual = tx.return_value + assert actual == expected + + +@geth_process_test +def test_return_value_nested_struct_in_tuple(geth_account, geth_contract): + tx = geth_contract.getNestedStructWithTuple1.transact(sender=geth_account) + actual = tx.return_value + assert actual[0].t.a == geth_account.address + assert actual[0].foo == 1 + assert actual[1] == 1 + + +@geth_process_test +def test_trace(geth_account, geth_contract): + tx = geth_contract.getNestedStructWithTuple1.transact(sender=geth_account) + assert isinstance(tx.trace, TraceAPI) + + +@geth_process_test +def test_track_gas(mocker, geth_account, geth_contract, gas_tracker): + tx = geth_contract.getNestedStructWithTuple1.transact(sender=geth_account) + mock_test_runner = mocker.MagicMock() + mock_test_runner.gas_tracker = gas_tracker + + ManagerAccessMixin._test_runner = mock_test_runner + + try: + tx.track_gas() + finally: + ManagerAccessMixin._test_runner = None + + report = gas_tracker.session_gas_report or {} + contract_name = geth_contract.contract_type.name + assert contract_name in report + assert "getNestedStructWithTuple1" in report[contract_name] diff --git a/tests/functional/geth/test_trace.py b/tests/functional/geth/test_trace.py index 83fbfd0024..61712464bd 100644 --- a/tests/functional/geth/test_trace.py +++ b/tests/functional/geth/test_trace.py @@ -4,6 +4,7 @@ import pytest from ape.utils import run_in_tempdir +from ape_ethereum.trace import CallTrace, Trace, TraceApproach, TransactionTrace from tests.conftest import geth_process_test LOCAL_TRACE = r""" @@ -23,7 +24,7 @@ │ 333399998888882, │ 234545457847457457458457457457 │ \] -│ \] \[\d+ gas\] +│ \] \[\d+ gas\] ├── SYMBOL\.methodB1\(lolol="ice-cream", dynamo=345457847457457458457457457\) \[\d+ gas\] │ ├── ContractC\.getSomeList\(\) -> \[ │ │ 3425311345134513461345134534531452345, @@ -117,7 +118,7 @@ def assert_rich_output(rich_capture: List[str], expected: str): for actual, expected in zip(actual_lines, expected_lines): fail_message = f"""\n - \tPattern: {expected},\n + \tPattern: {expected}\n \tLine : {actual}\n \n Complete output: @@ -139,37 +140,140 @@ def assert_rich_output(rich_capture: List[str], expected: str): @geth_process_test -def test_get_call_tree(geth_contract, geth_account, geth_provider): +def test_str_and_repr(geth_contract, geth_account, geth_provider): receipt = geth_contract.setNumber(10, sender=geth_account) - result = geth_provider.get_call_tree(receipt.txn_hash) - expected = ( - rf"{geth_contract.address}.0x3fb5c1cb" - r"\(0x000000000000000000000000000000000000000000000000000000000000000a\) \[\d+ gas\]" - ) - actual = repr(result) - assert re.match(expected, actual) + trace = geth_provider.get_transaction_trace(receipt.txn_hash) + expected = rf"{geth_contract.contract_type.name}\.setNumber\(\s*num=\d+\s*\) \[\d+ gas\]" + for actual in (str(trace), repr(trace)): + assert re.match(expected, actual) @geth_process_test -def test_get_call_tree_deploy(geth_contract, geth_provider): +def test_str_and_repr_deploy(geth_contract, geth_provider): receipt = geth_contract.receipt - result = geth_provider.get_call_tree(receipt.txn_hash) - result.enrich() + trace = geth_provider.get_transaction_trace(receipt.txn_hash) + _ = trace.enriched_calltree expected = rf"{geth_contract.contract_type.name}\.__new__\(\s*num=\d+\s*\) \[\d+ gas\]" - actual = repr(result) - assert re.match(expected, actual) + for actual in (str(trace), repr(trace)): + assert re.match(expected, actual), f"Unexpected repr: {actual}" @geth_process_test -def test_get_call_tree_erigon(mock_web3, mock_geth, parity_trace_response, txn_hash): +def test_str_and_repr_erigon( + parity_trace_response, geth_provider, mock_web3, networks, mock_geth, geth_contract +): mock_web3.client_version = "erigon_MOCK" - mock_web3.provider.make_request.return_value = parity_trace_response - result = mock_geth.get_call_tree(txn_hash) - actual = repr(result) - expected = r"0xC17f2C69aE2E66FD87367E3260412EEfF637F70E.0x96d373e5\(\) \[\d+ gas\]" + + def _request(rpc, arguments): + if rpc == "trace_transaction": + return parity_trace_response + + return geth_provider.web3.provider.make_request(rpc, arguments) + + mock_web3.provider.make_request.side_effect = _request + mock_web3.eth = geth_provider.web3.eth + orig_provider = networks.active_provider + networks.active_provider = mock_geth + expected = r"0x[a-fA-F0-9]{40}\.0x[a-fA-F0-9]+\(\) \[\d+ gas\]" + + try: + trace = mock_geth.get_transaction_trace(geth_contract.receipt.txn_hash) + assert isinstance(trace, Trace) + for actual in (str(trace), repr(trace)): + assert re.match(expected, actual), actual + + finally: + networks.active_provider = orig_provider + + +@geth_process_test +def test_str_multiline(geth_contract, geth_account): + tx = geth_contract.getNestedAddressArray.transact(sender=geth_account) + actual = f"{tx.trace}" + expected = r""" +VyperContract\.getNestedAddressArray\(\) -> \[ + \['tx\.origin', 'tx\.origin', 'tx\.origin'\], + \['ZERO_ADDRESS', 'ZERO_ADDRESS', 'ZERO_ADDRESS'\] +\] \[\d+ gas\] +""" + assert re.match(expected.strip(), actual.strip()) + + +@geth_process_test +def test_str_list_of_lists(geth_contract, geth_account): + tx = geth_contract.getNestedArrayMixedDynamic.transact(sender=geth_account) + actual = f"{tx.trace}" + expected = r""" +VyperContract\.getNestedArrayMixedDynamic\(\) -> \[ + \[\[\[0\], \[0, 1\], \[0, 1, 2\]\]\], + \[ + \[\[0\], \[0, 1\], \[0, 1, 2\]\], + \[\[0\], \[0, 1\], \[0, 1, 2\]\] + \], + \[\], + \[\], + \[\] +\] \[\d+ gas\] +""" + assert re.match(expected.strip(), actual.strip()) + + +@geth_process_test +def test_get_gas_report(gas_tracker, geth_account, geth_contract): + tx = geth_contract.setNumber(924, sender=geth_account) + trace = tx.trace + actual = trace.get_gas_report() + contract_name = geth_contract.contract_type.name + expected = {contract_name: {"setNumber": [tx.gas_used]}} + assert actual == expected + + +@geth_process_test +def test_get_gas_report_deploy(gas_tracker, geth_contract): + tx = geth_contract.receipt + trace = tx.trace + actual = trace.get_gas_report() + contract_name = geth_contract.contract_type.name + expected = {contract_name: {"__new__": [tx.gas_used]}} + assert actual == expected + + +@geth_process_test +def test_transaction_trace_create(vyper_contract_instance): + trace = TransactionTrace(transaction_hash=vyper_contract_instance.receipt.txn_hash) + actual = f"{trace}" + expected = r"VyperContract\.__new__\(num=0\) \[\d+ gas\]" assert re.match(expected, actual) +@geth_process_test +def test_get_transaction_trace_erigon_calltree( + parity_trace_response, geth_provider, mock_web3, mocker +): + # hash defined in parity_trace_response + tx_hash = "0x3cef4aaa52b97b6b61aa32b3afcecb0d14f7862ca80fdc76504c37a9374645c4" + default_make_request = geth_provider.web3.provider.make_request + + def hacked_make_request(rpc, arguments): + if rpc == "trace_transaction": + return parity_trace_response + + return default_make_request(rpc, arguments) + + mock_web3.provider.make_request.side_effect = hacked_make_request + original_web3 = geth_provider._web3 + geth_provider._web3 = mock_web3 + trace = geth_provider.get_transaction_trace(tx_hash, call_trace_approach=TraceApproach.PARITY) + trace.__dict__["transaction"] = mocker.MagicMock() # doesn't matter. + result = trace.enriched_calltree + + # Defined in parity_mock_response + assert result["contract_id"] == "0xC17f2C69aE2E66FD87367E3260412EEfF637F70E" + assert result["method_id"] == "0x96d373e5" + + geth_provider._web3 = original_web3 + + @geth_process_test def test_printing_debug_logs_vyper(geth_provider, geth_account, vyper_printing): num = 789 @@ -187,3 +291,17 @@ def test_printing_debug_logs_compat(geth_provider, geth_account, vyper_printing) assert receipt.status assert len(list(receipt.debug_logs_typed)) == 1 assert receipt.debug_logs_typed[0][0] == num + + +def test_call_trace_supports_debug_trace_call(geth_contract, geth_account): + tx = { + "chainId": "0x539", + "to": "0x77c7E3905c21177Be97956c6620567596492C497", + "value": "0x0", + "data": "0x23fd0e40", + "type": 2, + "accessList": [], + } + trace = CallTrace(tx=tx) + _ = trace._traced_call + assert trace.supports_debug_trace_call diff --git a/tests/functional/test_gas_tracker.py b/tests/functional/test_gas_tracker.py new file mode 100644 index 0000000000..09fb43555d --- /dev/null +++ b/tests/functional/test_gas_tracker.py @@ -0,0 +1,30 @@ +def test_append_gas(gas_tracker, owner, vyper_contract_instance): + tx = vyper_contract_instance.setNumber(924, sender=owner) + trace = tx.trace + gas_tracker.append_gas(trace, vyper_contract_instance.address) + report = gas_tracker.session_gas_report + contract_name = vyper_contract_instance.contract_type.name + assert contract_name in report + assert "setNumber" in report[contract_name] + assert tx.gas_used in report[contract_name]["setNumber"] + + +def test_append_gas_deploy(gas_tracker, vyper_contract_instance): + tx = vyper_contract_instance.receipt + trace = tx.trace + gas_tracker.append_gas(trace, vyper_contract_instance.address) + report = gas_tracker.session_gas_report + contract_name = vyper_contract_instance.contract_type.name + assert contract_name in report + assert "__new__" in report[contract_name] + assert tx.gas_used in report[contract_name]["__new__"] + + +def test_append_gas_transfer(gas_tracker, sender, receiver): + tx = sender.transfer(receiver, 0) + trace = tx.trace + gas_tracker.append_gas(trace, receiver.address) + report = gas_tracker.session_gas_report + + # ETH-transfers are not included in the final report. + assert report is None diff --git a/tests/functional/test_network_manager.py b/tests/functional/test_network_manager.py index a8364f71da..40abb83abc 100644 --- a/tests/functional/test_network_manager.py +++ b/tests/functional/test_network_manager.py @@ -222,7 +222,7 @@ def test_parse_network_choice_multiple_contexts( assert ( eth_tester_provider.chain_id == DEFAULT_TEST_CHAIN_ID ), "Test setup failed - expecting to start on default chain ID" - assert eth_tester_provider._make_request("eth_chainId") == DEFAULT_TEST_CHAIN_ID + assert eth_tester_provider.make_request("eth_chainId") == DEFAULT_TEST_CHAIN_ID with first_context: start_count = len(first_context.connected_providers) @@ -234,7 +234,7 @@ def test_parse_network_choice_multiple_contexts( assert len(second_context.connected_providers) == expected_next_count assert eth_tester_provider.chain_id == DEFAULT_TEST_CHAIN_ID - assert eth_tester_provider._make_request("eth_chainId") == DEFAULT_TEST_CHAIN_ID + assert eth_tester_provider.make_request("eth_chainId") == DEFAULT_TEST_CHAIN_ID def test_getattr_ecosystem_with_hyphenated_name(networks, ethereum): diff --git a/tests/functional/test_provider.py b/tests/functional/test_provider.py index 58624c4a02..804fd39c53 100644 --- a/tests/functional/test_provider.py +++ b/tests/functional/test_provider.py @@ -111,7 +111,7 @@ def test_get_receipt_not_exists_with_timeout(eth_tester_provider): unknown_txn = "0x053cba5c12172654d894f66d5670bab6215517a94189a9ffc09bc40a589ec04d" expected = ( f"Transaction '{unknown_txn}' not found. " - rf"Error: Transaction HexBytes\('{unknown_txn}'\) " + rf"Error: Transaction '{unknown_txn}' " "is not in the chain after 0 seconds" ) with pytest.raises(TransactionNotFoundError, match=expected): @@ -353,7 +353,7 @@ def test_make_request_not_exists(eth_tester_provider): APINotImplementedError, match="RPC method 'ape_thisDoesNotExist' is not implemented by this node instance.", ): - eth_tester_provider._make_request("ape_thisDoesNotExist") + eth_tester_provider.make_request("ape_thisDoesNotExist") @pytest.mark.parametrize("msg", ("Method not found", "Method ape_thisDoesNotExist not found")) @@ -364,7 +364,6 @@ def test_make_request_not_exists_dev_nodes(eth_tester_provider, mock_web3, msg): """ real_web3 = eth_tester_provider._web3 mock_web3.eth = real_web3.eth - eth_tester_provider._web3 = mock_web3 def custom_make_request(rpc, params): if rpc == "ape_thisDoesNotExist": @@ -373,11 +372,16 @@ def custom_make_request(rpc, params): return real_web3.provider.make_request(rpc, params) mock_web3.provider.make_request.side_effect = custom_make_request - with pytest.raises( - APINotImplementedError, - match="RPC method 'ape_thisDoesNotExist' is not implemented by this node instance.", - ): - eth_tester_provider._make_request("ape_thisDoesNotExist") + + eth_tester_provider._web3 = mock_web3 + try: + with pytest.raises( + APINotImplementedError, + match="RPC method 'ape_thisDoesNotExist' is not implemented by this node instance.", + ): + eth_tester_provider.make_request("ape_thisDoesNotExist") + finally: + eth_tester_provider._web3 = real_web3 def test_make_request_handles_http_error_method_not_allowed(eth_tester_provider, mock_web3): @@ -399,7 +403,7 @@ def custom_make_request(rpc, params): APINotImplementedError, match="RPC method 'ape_thisDoesNotExist' is not implemented by this node instance.", ): - eth_tester_provider._make_request("ape_thisDoesNotExist") + eth_tester_provider.make_request("ape_thisDoesNotExist") def test_base_fee(eth_tester_provider): diff --git a/tests/functional/test_receipt.py b/tests/functional/test_receipt.py index 7a256cadac..1f762fca93 100644 --- a/tests/functional/test_receipt.py +++ b/tests/functional/test_receipt.py @@ -1,7 +1,9 @@ import pytest +from rich.table import Table +from rich.tree import Tree from ape.api import ReceiptAPI -from ape.exceptions import APINotImplementedError, ContractLogicError, OutOfGasError +from ape.exceptions import ContractLogicError, OutOfGasError from ape.utils import ManagerAccessMixin from ape_ethereum.transactions import DynamicFeeTransaction, Receipt, TransactionStatusEnum @@ -16,22 +18,40 @@ def invoke_receipt(vyper_contract_instance, owner): return vyper_contract_instance.setNumber(1, sender=owner) +@pytest.fixture +def trace_print_capture(mocker, chain): + console_factory = mocker.MagicMock() + capture = mocker.MagicMock() + console_factory.return_value = capture + orig = chain._reports._get_console + chain._reports._get_console = console_factory + try: + yield capture.print + finally: + chain._reports._get_console = orig + + def test_receipt_properties(chain, invoke_receipt): assert invoke_receipt.block_number == chain.blocks.head.number assert invoke_receipt.timestamp == chain.blocks.head.timestamp assert invoke_receipt.datetime == chain.blocks.head.datetime -def test_show_trace(invoke_receipt): - # See trace-supported provider plugin tests for better tests (e.g. ape-hardhat) - with pytest.raises(APINotImplementedError): - invoke_receipt.show_trace() +def test_show_trace(trace_print_capture, invoke_receipt): + invoke_receipt.show_trace() + actual = trace_print_capture.call_args[0][0] + assert isinstance(actual, Tree) + label = f"{actual.label}" + assert "VyperContract" in label + assert "setNumber" in label + assert f"[{invoke_receipt.gas_used} gas]" in label -def test_show_gas_report(invoke_receipt): - # See trace-supported provider plugin tests for better tests (e.g. ape-hardhat) - with pytest.raises(APINotImplementedError): - invoke_receipt.show_gas_report() +def test_show_gas_report(trace_print_capture, invoke_receipt): + invoke_receipt.show_gas_report() + actual = trace_print_capture.call_args[0][0] + assert isinstance(actual, Table) + assert actual.title == "VyperContract Gas" def test_decode_logs_specify_abi(invoke_receipt, vyper_contract_instance): @@ -210,3 +230,15 @@ def test_transaction_validated_from_dict(ethereum, owner, deploy_receipt): assert receipt.transaction.sender == owner.address assert receipt.transaction.value == 123 assert receipt.transaction.data == b"hello" + + +def test_return_value(owner, vyper_contract_instance): + """ + ``.return_value`` still works when using EthTester provider! + It works by using eth_call to get the result rather than + tracing-RPCs. + """ + receipt = vyper_contract_instance.getNestedArrayMixedDynamic.transact(sender=owner) + actual = receipt.return_value + assert len(actual) == 5 + assert actual[1][1] == [[0], [0, 1], [0, 1, 2]] diff --git a/tests/functional/test_trace.py b/tests/functional/test_trace.py new file mode 100644 index 0000000000..d6f95c3719 --- /dev/null +++ b/tests/functional/test_trace.py @@ -0,0 +1,96 @@ +import re + +from evm_trace import CallTreeNode, CallType + +from ape_ethereum.trace import CallTrace, TransactionTrace, parse_rich_tree + + +def test_parse_rich_tree(vyper_contract_instance): + """ + Show that when full selector is set as the method ID, + the tree-output only shows the short method name. + """ + contract_id = vyper_contract_instance.contract_type.name + method_id = vyper_contract_instance.contract_type.methods["setAddress"].selector + call = CallTreeNode(address=vyper_contract_instance.address, call_type=CallType.CALL) + data = { + **call.model_dump(by_alias=True, mode="json"), + "method_id": method_id, + "contract_id": contract_id, + } + actual = parse_rich_tree(data).label + expected = f"[#ff8c00]{contract_id}[/].[bright_green]setAddress[/]()" + assert actual == expected + + +def test_get_gas_report(gas_tracker, owner, vyper_contract_instance): + tx = vyper_contract_instance.setNumber(924, sender=owner) + trace = tx.trace + actual = trace.get_gas_report() + contract_name = vyper_contract_instance.contract_type.name + expected = {contract_name: {"setNumber": [tx.gas_used]}} + assert actual == expected + + +def test_get_gas_report_deploy(gas_tracker, vyper_contract_instance): + tx = vyper_contract_instance.receipt + trace = tx.trace + actual = trace.get_gas_report() + contract_name = vyper_contract_instance.contract_type.name + expected = {contract_name: {"__new__": [tx.gas_used]}} + assert actual == expected + + +def test_get_gas_report_transfer(gas_tracker, sender, receiver): + tx = sender.transfer(receiver, 0) + trace = tx.trace + actual = trace.get_gas_report() + expected = {"__ETH_transfer__": {"to:TEST::2": [tx.gas_used]}} + assert actual == expected + + +def test_transaction_trace_create(vyper_contract_instance): + trace = TransactionTrace(transaction_hash=vyper_contract_instance.receipt.txn_hash) + actual = f"{trace}" + expected = r"VyperContract\.__new__\(num=0\) \[\d+ gas\]" + assert re.match(expected, actual) + + +def test_transaction_trace_multiline(vyper_contract_instance, owner): + tx = vyper_contract_instance.getNestedAddressArray.transact(sender=owner) + actual = f"{tx.trace}" + expected = r""" +VyperContract\.getNestedAddressArray\(\) -> \[ + \['tx\.origin', 'tx\.origin', 'tx\.origin'\], + \['ZERO_ADDRESS', 'ZERO_ADDRESS', 'ZERO_ADDRESS'\] +\] \[\d+ gas\] +""" + assert re.match(expected.strip(), actual.strip()) + + +def test_transaction_trace_list_of_lists(vyper_contract_instance, owner): + tx = vyper_contract_instance.getNestedArrayMixedDynamic.transact(sender=owner) + actual = f"{tx.trace}" + expected = r""" +VyperContract\.getNestedArrayMixedDynamic\(\) -> \[ + \[\[\[0\], \[0, 1\], \[0, 1, 2\]\]\], + \[ + \[\[0\], \[0, 1\], \[0, 1, 2\]\], + \[\[0\], \[0, 1\], \[0, 1, 2\]\] + \], + \[\], + \[\], + \[\] +\] \[\d+ gas\] +""" + assert re.match(expected.strip(), actual.strip()) + + +def test_call_trace_debug_trace_call_not_supported(owner, vyper_contract_instance): + """ + When using EthTester, we can still see the top-level trace of a call. + """ + tx = {"to": vyper_contract_instance.address, "from": owner.address} + trace = CallTrace(tx=tx) + actual = f"{trace}" + assert actual == "VyperContract.0x()" diff --git a/tests/functional/test_transaction.py b/tests/functional/test_transaction.py index 8e9303a337..b01c70172f 100644 --- a/tests/functional/test_transaction.py +++ b/tests/functional/test_transaction.py @@ -248,6 +248,7 @@ def test_txn_hash_when_access_list_is_raw(ethereum, owner): # to this state, but somehow they have. txn.access_list = ACCESS_LIST_HEXBYTES + # Ignore the Pydantic warning from access-list being the wrong type. with warnings.catch_warnings(): warnings.simplefilter("ignore") actual = txn.txn_hash.hex() diff --git a/tests/functional/utils/test_trace.py b/tests/functional/utils/test_trace.py deleted file mode 100644 index 54616e7dca..0000000000 --- a/tests/functional/utils/test_trace.py +++ /dev/null @@ -1,15 +0,0 @@ -from ape.types import CallTreeNode -from ape.utils.trace import parse_rich_tree - - -def test_parse_rich_tree(vyper_contract_instance): - """ - Show that when full selector is set as the method ID, - the tree-output only shows the short method name. - """ - contract_id = vyper_contract_instance.contract_type.name - method_id = vyper_contract_instance.contract_type.methods["setAddress"].selector - call = CallTreeNode(contract_id=contract_id, method_id=method_id) - actual = parse_rich_tree(call).label - expected = f"[#ff8c00]{contract_id}[/].[bright_green]setAddress[/]()" - assert actual == expected diff --git a/tests/integration/cli/test_test.py b/tests/integration/cli/test_test.py index 22a3d8ac9e..43d18d7c67 100644 --- a/tests/integration/cli/test_test.py +++ b/tests/integration/cli/test_test.py @@ -63,7 +63,11 @@ def load_dependencies(project): @pytest.fixture -def setup_pytester(pytester): +def setup_pytester(pytester, owner): + # Mine to a new block so we are not capturing old transactions + # in the tests. + owner.transfer(owner, 0) + def setup(project_name: str): project_path = BASE_PROJECTS_PATH / project_name tests_path = project_path / "tests" @@ -124,7 +128,7 @@ def run_gas_test( gas_header_line_index = index assert gas_header_line_index is not None, "'Gas Profile' not in output." - expected = expected_report.split("\n")[1:] + expected = [x for x in expected_report.rstrip().split("\n")[1:]] start_index = gas_header_line_index + 1 end_index = start_index + len(expected) actual = [x.rstrip() for x in result.outlines[start_index:end_index]] @@ -209,10 +213,12 @@ def test_gas_flag_when_not_supported(setup_pytester, project, pytester, eth_test setup_pytester(project.path.name) path = f"{project.path}/tests/test_contract.py::test_contract_interaction_in_tests" result = pytester.runpytest(path, "--gas") - assert ( + actual = "\n".join(result.outlines) + expected = ( "Provider 'test' does not support transaction tracing. " "The gas profile is limited to receipt-level data." - ) in "\n".join(result.outlines) + ) + assert expected in actual @geth_process_test