From 54c0a5e236553e3ab110f92b8c2dd92bcf55cef1 Mon Sep 17 00:00:00 2001 From: antazoey Date: Fri, 8 Dec 2023 13:00:35 -0600 Subject: [PATCH] feat!: pydantic v2 support [APE-1412] (#55) --- .mdformat.toml | 1 - .pre-commit-config.yaml | 6 +++--- README.md | 4 ++-- evm_trace/__init__.py | 8 +++---- evm_trace/_pydantic_compat.py | 7 ------- evm_trace/base.py | 31 ++++++++++------------------ evm_trace/geth.py | 29 +++++++++++--------------- evm_trace/parity.py | 39 ++++++++++++++--------------------- evm_trace/utils.py | 7 ------- evm_trace/vmtrace.py | 15 +++++--------- pyproject.toml | 3 +++ setup.cfg | 1 + setup.py | 13 +++++++----- tests/conftest.py | 4 ++-- tests/test_gas.py | 2 +- tests/test_geth.py | 4 ++-- tests/test_parity.py | 2 +- 17 files changed, 71 insertions(+), 105 deletions(-) delete mode 100644 .mdformat.toml delete mode 100644 evm_trace/_pydantic_compat.py delete mode 100644 evm_trace/utils.py diff --git a/.mdformat.toml b/.mdformat.toml deleted file mode 100644 index 01b2fb0..0000000 --- a/.mdformat.toml +++ /dev/null @@ -1 +0,0 @@ -number = true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bae32dc..a7a4502 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-yaml @@ -10,7 +10,7 @@ repos: - id: isort - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 23.11.0 hooks: - id: black name: black @@ -21,7 +21,7 @@ repos: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.5.1 + rev: v1.7.1 hooks: - id: mypy additional_dependencies: [types-PyYAML, types-requests, types-setuptools, pydantic] diff --git a/README.md b/README.md index 0cf1ad8..99682d9 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ web3 = Web3(HTTPProvider("https://path.to.my.node")) txn_hash = "0x..." struct_logs = web3.manager.request_blocking("debug_traceTransaction", [txn_hash]).structLogs for item in struct_logs: - frame = TraceFrame.parse_obj(item) + frame = TraceFrame.model_validate(item) ``` If you want to get the call-tree node, you can do: @@ -69,7 +69,7 @@ If you are using a node that supports the `trace_transaction` RPC, you can use ` from evm_trace import CallType, ParityTraceList raw_trace_list = web3.manager.request_blocking("trace_transaction", [txn_hash]) -trace_list = ParityTraceList.parse_obj(raw_trace_list) +trace_list = ParityTraceList.model_validate(raw_trace_list) ``` And to make call-tree nodes, you can do: diff --git a/evm_trace/__init__.py b/evm_trace/__init__.py index 28fb6b7..61001bc 100644 --- a/evm_trace/__init__.py +++ b/evm_trace/__init__.py @@ -1,12 +1,12 @@ -from .base import CallTreeNode -from .enums import CallType -from .geth import ( +from evm_trace.base import CallTreeNode +from evm_trace.enums import CallType +from evm_trace.geth import ( TraceFrame, create_trace_frames, get_calltree_from_geth_call_trace, get_calltree_from_geth_trace, ) -from .parity import ParityTrace, ParityTraceList, get_calltree_from_parity_trace +from evm_trace.parity import ParityTrace, ParityTraceList, get_calltree_from_parity_trace __all__ = [ "CallTreeNode", diff --git a/evm_trace/_pydantic_compat.py b/evm_trace/_pydantic_compat.py deleted file mode 100644 index eda1190..0000000 --- a/evm_trace/_pydantic_compat.py +++ /dev/null @@ -1,7 +0,0 @@ -try: - from pydantic.v1 import Field, ValidationError, validator # type: ignore -except ImportError: - from pydantic import Field, ValidationError, validator # type: ignore - - -__all__ = ["Field", "validator", "ValidationError"] diff --git a/evm_trace/base.py b/evm_trace/base.py index c1765e6..be562dc 100644 --- a/evm_trace/base.py +++ b/evm_trace/base.py @@ -1,25 +1,19 @@ from functools import cached_property, singledispatchmethod from typing import List, Optional -from ethpm_types import BaseModel as _BaseModel -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes +from pydantic import BaseModel as _BaseModel +from pydantic import ConfigDict, field_validator -from evm_trace._pydantic_compat import validator from evm_trace.display import get_tree_display from evm_trace.enums import CallType class BaseModel(_BaseModel): - class Config: - # NOTE: Due to https://github.com/samuelcolvin/pydantic/issues/1241 we have - # to add this cached property workaround in order to avoid this error: - - # TypeError: cannot pickle '_thread.RLock' object - - keep_untouched = (cached_property, singledispatchmethod) - arbitrary_types_allowed = True - underscore_attrs_are_private = True - copy_on_model_validation = "none" + model_config = ConfigDict( + ignored_types=(cached_property, singledispatchmethod), + arbitrary_types_allowed=True, + ) class CallTreeNode(BaseModel): @@ -77,17 +71,14 @@ def __repr__(self) -> str: def __getitem__(self, index: int) -> "CallTreeNode": return self.calls[index] - @validator("calldata", "returndata", "address", pre=True) + @field_validator("calldata", "returndata", "address", mode="before") def validate_bytes(cls, value): return HexBytes(value) if isinstance(value, str) else value - @validator("value", "depth", pre=True) + @field_validator("value", "depth", mode="before") def validate_ints(cls, value): - if not value: - return 0 - - return int(value, 16) if isinstance(value, str) else value + return (int(value, 16) if isinstance(value, str) else value) if value else 0 - @validator("gas_limit", "gas_cost", pre=True) + @field_validator("gas_limit", "gas_cost", mode="before") def validate_optional_ints(cls, value): return int(value, 16) if isinstance(value, str) else value diff --git a/evm_trace/geth.py b/evm_trace/geth.py index 0771445..64b9747 100644 --- a/evm_trace/geth.py +++ b/evm_trace/geth.py @@ -1,20 +1,19 @@ import math from typing import Dict, Iterator, List, Optional +from eth_pydantic_types import HashBytes20, HexBytes from eth_utils import to_int -from ethpm_types import HexBytes +from pydantic import Field, RootModel, field_validator -from evm_trace._pydantic_compat import Field, validator from evm_trace.base import BaseModel, CallTreeNode from evm_trace.enums import CALL_OPCODES, CallType -from evm_trace.utils import to_address -class TraceMemory(BaseModel): - __root__: List[HexBytes] = [] +class TraceMemory(RootModel[List[HexBytes]]): + root: List[HexBytes] = [] def get(self, offset: HexBytes, size: HexBytes): - return extract_memory(offset, size, self.__root__) + return extract_memory(offset, size, self.root) class TraceFrame(BaseModel): @@ -49,15 +48,15 @@ class TraceFrame(BaseModel): storage: Dict[HexBytes, HexBytes] = {} """Contract storage.""" - contract_address: Optional[HexBytes] = None + contract_address: Optional[HashBytes20] = None """The address producing the frame.""" - @validator("pc", "gas", "gas_cost", "depth", pre=True) + @field_validator("pc", "gas", "gas_cost", "depth", mode="before") def validate_ints(cls, value): return int(value, 16) if isinstance(value, str) else value @property - def address(self) -> Optional[HexBytes]: + def address(self) -> Optional[HashBytes20]: """ The address of this CALL frame. Only returns a value if this frame's opcode is a call-based opcode. @@ -66,7 +65,7 @@ def address(self) -> Optional[HexBytes]: if not self.contract_address and ( self.op in CALL_OPCODES and CallType.CREATE.value not in self.op ): - self.contract_address = HexBytes(self.stack[-2][-20:]) + self.contract_address = HashBytes20.__eth_pydantic_validate__(self.stack[-2][-20:]) return self.contract_address @@ -104,7 +103,7 @@ def _get_create_frames(frame: TraceFrame, frames: Iterator[Dict]) -> List[TraceF create_frames = [frame] start_depth = frame.depth for next_frame in frames: - next_frame_obj = TraceFrame.parse_obj(next_frame) + next_frame_obj = TraceFrame.model_validate(next_frame) depth = next_frame_obj.depth if CallType.CREATE.value in next_frame_obj.op: @@ -116,11 +115,7 @@ def _get_create_frames(frame: TraceFrame, frames: Iterator[Dict]) -> List[TraceF # the first frame after the CREATE with an equal depth. if len(next_frame_obj.stack) > 0: raw_addr = HexBytes(next_frame_obj.stack[-1][-40:]) - try: - frame.contract_address = HexBytes(to_address(raw_addr)) - except Exception: - # Potentially, a transaction was made with poor data. - frame.contract_address = raw_addr + frame.contract_address = HashBytes20.__eth_pydantic_validate__(raw_addr) create_frames.append(next_frame_obj) break @@ -279,7 +274,7 @@ def _create_node( node_kwargs["last_create_depth"].pop() for subcall in node_kwargs.get("calls", [])[::-1]: if subcall.call_type in (CallType.CREATE, CallType.CREATE2): - subcall.address = HexBytes(to_address(frame.stack[-1][-40:])) + subcall.address = HashBytes20.__eth_pydantic_validate__(frame.stack[-1][-40:]) if len(frame.stack) >= 5: subcall.calldata = frame.memory.get(frame.stack[-4], frame.stack[-5]) diff --git a/evm_trace/parity.py b/evm_trace/parity.py index 6944f9d..1b384ce 100644 --- a/evm_trace/parity.py +++ b/evm_trace/parity.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Optional, Union, cast -from evm_trace._pydantic_compat import Field, validator +from pydantic import Field, RootModel, field_validator + from evm_trace.base import BaseModel, CallTreeNode from evm_trace.enums import CallType @@ -18,7 +19,7 @@ class CallAction(BaseModel): # only used to recover the specific call type call_type: str = Field(alias="callType", repr=False) - @validator("value", "gas", pre=True) + @field_validator("value", "gas", mode="before") def convert_integer(cls, v): return int(v, 16) @@ -32,7 +33,7 @@ class CreateAction(BaseModel): init: str value: int - @validator("value", "gas", pre=True) + @field_validator("value", "gas", mode="before") def convert_integer(cls, v): return int(v, 16) @@ -41,7 +42,7 @@ class SelfDestructAction(BaseModel): address: str balance: int - @validator("balance", pre=True) + @field_validator("balance", mode="before") def convert_integer(cls, v): return int(v, 16) if isinstance(v, str) else int(v) @@ -59,7 +60,7 @@ class ActionResult(BaseModel): for gas per zero byte and ``16`` gas per non-zero byte. """ - @validator("gas_used", pre=True) + @field_validator("gas_used", mode="before") def convert_integer(cls, v): return int(v, 16) if isinstance(v, str) else int(v) @@ -95,26 +96,19 @@ class ParityTrace(BaseModel): trace_address: List[int] = Field(alias="traceAddress") transaction_hash: str = Field(alias="transactionHash") - @validator("call_type", pre=True) - def convert_call_type(cls, v, values) -> CallType: - if isinstance(values["action"], CallAction): - v = values["action"].call_type - value = v.upper() + @field_validator("call_type", mode="before") + def convert_call_type(cls, value, info) -> CallType: + if isinstance(info.data["action"], CallAction): + value = info.data["action"].call_type + + value = value.upper() if value == "SUICIDE": value = "SELFDESTRUCT" return CallType(value) -class ParityTraceList(BaseModel): - __root__: List[ParityTrace] - - # pydantic models with custom root don't have this by default - def __iter__(self): - return iter(self.__root__) - - def __getitem__(self, item): - return self.__root__[item] +ParityTraceList = RootModel[List[ParityTrace]] def get_calltree_from_parity_trace( @@ -137,7 +131,7 @@ def get_calltree_from_parity_trace( Returns: :class:`~evm_trace.base.CallTreeNode` """ - root = root or traces[0] + root = root or traces.root[0] failed = root.error is not None node_kwargs: Dict[Any, Any] = { "call_type": root.call_type, @@ -187,7 +181,7 @@ def get_calltree_from_parity_trace( address=selfdestruct_action.address, ) - trace_list: List[ParityTrace] = list(traces) + trace_list: List[ParityTrace] = traces.root subtraces: List[ParityTrace] = [ sub for sub in trace_list @@ -196,5 +190,4 @@ def get_calltree_from_parity_trace( ] node_kwargs["calls"] = [get_calltree_from_parity_trace(traces, root=sub) for sub in subtraces] node_kwargs = {**node_kwargs, **root_kwargs} - node = CallTreeNode.parse_obj(node_kwargs) - return node + return CallTreeNode.model_validate(node_kwargs) diff --git a/evm_trace/utils.py b/evm_trace/utils.py deleted file mode 100644 index 9b5b198..0000000 --- a/evm_trace/utils.py +++ /dev/null @@ -1,7 +0,0 @@ -from eth_utils import to_checksum_address -from hexbytes import HexBytes - - -def to_address(value): - # clear the padding and expand to 32 bytes - return to_checksum_address(HexBytes(value)[-20:].rjust(20, b"\x00")) diff --git a/evm_trace/vmtrace.py b/evm_trace/vmtrace.py index 41db066..951a914 100644 --- a/evm_trace/vmtrace.py +++ b/evm_trace/vmtrace.py @@ -4,13 +4,11 @@ from eth.vm.memory import Memory from eth.vm.stack import Stack +from eth_pydantic_types import Address, HexBytes from eth_utils import to_int -from ethpm_types import HexBytes from msgspec import Struct from msgspec.json import Decoder -from evm_trace.utils import to_address - # opcodes grouped by the number of items they pop from the stack # fmt: off POP_OPCODES = { @@ -141,7 +139,7 @@ def to_trace_frames( ) if op.op in ["CALL", "DELEGATECALL", "STATICCALL"]: - call_address = to_address(stack.values[-2][1]) + call_address = Address.__eth_pydantic_validate__(stack.values[-2][1]) if op.ex: if op.ex.mem: @@ -188,9 +186,6 @@ def from_rpc_response(buffer: bytes) -> Union[VMTrace, List[VMTrace]]: """ Decode structured data from a raw `trace_replayTransaction` or `trace_replayBlockTransactions`. """ - resp = Decoder(RPCResponse, dec_hook=dec_hook).decode(buffer) - - if isinstance(resp.result, list): - return [i.vmTrace for i in resp.result] - - return resp.result.vmTrace + response = Decoder(RPCResponse, dec_hook=dec_hook).decode(buffer) + result: Union[List[RPCTraceResult], RPCTraceResult] = response.result + return [i.vmTrace for i in result] if isinstance(result, list) else result.vmTrace diff --git a/pyproject.toml b/pyproject.toml index e4dc1e3..ab24d8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,3 +37,6 @@ force_grid_wrap = 0 include_trailing_comma = true multi_line_output = 3 use_parentheses = true + +[tool.mdformat] +number = true diff --git a/setup.cfg b/setup.cfg index f1bf796..7121939 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,6 +5,7 @@ exclude = .eggs docs build + evm_trace/version.py per-file-ignores = # The traces have to be formatted this way for the tests. tests/expected_traces.py: E501 diff --git a/setup.py b/setup.py index f371e95..b36e55c 100644 --- a/setup.py +++ b/setup.py @@ -4,20 +4,23 @@ extras_require = { "test": [ # `test` GitHub Action jobs uses this "pytest>=6.0", # Core testing package - "pytest-xdist", # multi-process runner + "pytest-xdist", # Multi-process runner "pytest-cov", # Coverage analyzer plugin "hypothesis>=6.2.0,<7.0", # Strategy-based fuzzer "eth-hash[pysha3]", # For eth-utils address checksumming ], "lint": [ - "black>=23.9.1,<24", # Auto-formatter and linter - "mypy>=1.5.1,<2", # Static type analyzer + "black>=23.11.0,<24", # Auto-formatter and linter + "mypy>=1.7.1,<2", # Static type analyzer "types-setuptools", # Needed for mypy type shed "flake8>=6.1.0,<7", # Style linter + "flake8-breakpoint>=1.1.0,<2", # Detect breakpoints left in code + "flake8-print>=5.0.0,<6", # Detect print statements left in code "isort>=5.10.1,<6", # Import sorting linter "mdformat>=0.7.17", # Auto-formatter for markdown "mdformat-gfm>=0.3.5", # Needed for formatting GitHub-flavored markdown "mdformat-frontmatter>=0.4.1", # Needed for frontmatters-style headers in issue templates + "mdformat-pyproject>=0.0.1", # Allows configuring in pyproject.toml ], "release": [ # `release` GitHub Action job uses this "setuptools", # Installation tool @@ -57,11 +60,11 @@ url="https://github.com/ApeWorX/evm-trace", include_package_data=True, install_requires=[ - "pydantic>=1.10.1,<3", + "pydantic>=2.4.2,<3", "py-evm>=0.7.0a3,<0.8", "eth-utils>=2.1,<3", - "ethpm-types>=0.5.0,<0.6", "msgspec>=0.8", + "eth-pydantic-types>=0.1.0a4", ], python_requires=">=3.8,<4", extras_require=extras_require, diff --git a/tests/conftest.py b/tests/conftest.py index 72bb395..9192492 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -72,8 +72,8 @@ def call_tree_data(request): @pytest.fixture def parity_create2_trace_list(): - trace_list = [ParityTrace.parse_obj(x) for x in PARITY_CREATE2_TRACE] - return ParityTraceList(__root__=trace_list) + trace_list = [ParityTrace.model_validate(x) for x in PARITY_CREATE2_TRACE] + return ParityTraceList(root=trace_list) @pytest.fixture diff --git a/tests/test_gas.py b/tests/test_gas.py index b056207..7a15c06 100644 --- a/tests/test_gas.py +++ b/tests/test_gas.py @@ -1,6 +1,6 @@ from typing import List -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes from evm_trace import CallTreeNode from evm_trace.gas import GasReport, get_gas_report, merge_reports diff --git a/tests/test_geth.py b/tests/test_geth.py index 0dfa65f..f4b2edd 100644 --- a/tests/test_geth.py +++ b/tests/test_geth.py @@ -1,9 +1,9 @@ import re import pytest -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes +from pydantic import ValidationError -from evm_trace._pydantic_compat import ValidationError from evm_trace.enums import CallType from evm_trace.geth import ( TraceFrame, diff --git a/tests/test_parity.py b/tests/test_parity.py index 89d993e..952e6df 100644 --- a/tests/test_parity.py +++ b/tests/test_parity.py @@ -34,7 +34,7 @@ def test_parity(name): assert name in EXPECTED_OUTPUT_MAP, f"Missing expected output set for '{name}'." path = DATA_PATH / f"{name}.json" - traces = ParityTraceList.parse_file(path) + traces = ParityTraceList.model_validate_json(path.read_text()) actual = repr(get_calltree_from_parity_trace(traces)) expected = EXPECTED_OUTPUT_MAP[name].strip()