Skip to content

Commit

Permalink
feat!: pydantic v2 support [APE-1412] (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Dec 8, 2023
1 parent 9a5feb3 commit 54c0a5e
Show file tree
Hide file tree
Showing 17 changed files with 71 additions and 105 deletions.
1 change: 0 additions & 1 deletion .mdformat.toml

This file was deleted.

6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-yaml

Expand All @@ -10,7 +10,7 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 23.9.1
rev: 23.11.0
hooks:
- id: black
name: black
Expand All @@ -21,7 +21,7 @@ repos:
- id: flake8

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
rev: v1.7.1
hooks:
- id: mypy
additional_dependencies: [types-PyYAML, types-requests, types-setuptools, pydantic]
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ web3 = Web3(HTTPProvider("https://path.to.my.node"))
txn_hash = "0x..."
struct_logs = web3.manager.request_blocking("debug_traceTransaction", [txn_hash]).structLogs
for item in struct_logs:
frame = TraceFrame.parse_obj(item)
frame = TraceFrame.model_validate(item)
```

If you want to get the call-tree node, you can do:
Expand Down Expand Up @@ -69,7 +69,7 @@ If you are using a node that supports the `trace_transaction` RPC, you can use `
from evm_trace import CallType, ParityTraceList

raw_trace_list = web3.manager.request_blocking("trace_transaction", [txn_hash])
trace_list = ParityTraceList.parse_obj(raw_trace_list)
trace_list = ParityTraceList.model_validate(raw_trace_list)
```

And to make call-tree nodes, you can do:
Expand Down
8 changes: 4 additions & 4 deletions evm_trace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .base import CallTreeNode
from .enums import CallType
from .geth import (
from evm_trace.base import CallTreeNode
from evm_trace.enums import CallType
from evm_trace.geth import (
TraceFrame,
create_trace_frames,
get_calltree_from_geth_call_trace,
get_calltree_from_geth_trace,
)
from .parity import ParityTrace, ParityTraceList, get_calltree_from_parity_trace
from evm_trace.parity import ParityTrace, ParityTraceList, get_calltree_from_parity_trace

__all__ = [
"CallTreeNode",
Expand Down
7 changes: 0 additions & 7 deletions evm_trace/_pydantic_compat.py

This file was deleted.

31 changes: 11 additions & 20 deletions evm_trace/base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
from functools import cached_property, singledispatchmethod
from typing import List, Optional

from ethpm_types import BaseModel as _BaseModel
from ethpm_types import HexBytes
from eth_pydantic_types import HexBytes
from pydantic import BaseModel as _BaseModel
from pydantic import ConfigDict, field_validator

from evm_trace._pydantic_compat import validator
from evm_trace.display import get_tree_display
from evm_trace.enums import CallType


class BaseModel(_BaseModel):
class Config:
# NOTE: Due to https://github.com/samuelcolvin/pydantic/issues/1241 we have
# to add this cached property workaround in order to avoid this error:

# TypeError: cannot pickle '_thread.RLock' object

keep_untouched = (cached_property, singledispatchmethod)
arbitrary_types_allowed = True
underscore_attrs_are_private = True
copy_on_model_validation = "none"
model_config = ConfigDict(
ignored_types=(cached_property, singledispatchmethod),
arbitrary_types_allowed=True,
)


class CallTreeNode(BaseModel):
Expand Down Expand Up @@ -77,17 +71,14 @@ def __repr__(self) -> str:
def __getitem__(self, index: int) -> "CallTreeNode":
return self.calls[index]

@validator("calldata", "returndata", "address", pre=True)
@field_validator("calldata", "returndata", "address", mode="before")
def validate_bytes(cls, value):
return HexBytes(value) if isinstance(value, str) else value

@validator("value", "depth", pre=True)
@field_validator("value", "depth", mode="before")
def validate_ints(cls, value):
if not value:
return 0

return int(value, 16) if isinstance(value, str) else value
return (int(value, 16) if isinstance(value, str) else value) if value else 0

@validator("gas_limit", "gas_cost", pre=True)
@field_validator("gas_limit", "gas_cost", mode="before")
def validate_optional_ints(cls, value):
return int(value, 16) if isinstance(value, str) else value
29 changes: 12 additions & 17 deletions evm_trace/geth.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import math
from typing import Dict, Iterator, List, Optional

from eth_pydantic_types import HashBytes20, HexBytes
from eth_utils import to_int
from ethpm_types import HexBytes
from pydantic import Field, RootModel, field_validator

from evm_trace._pydantic_compat import Field, validator
from evm_trace.base import BaseModel, CallTreeNode
from evm_trace.enums import CALL_OPCODES, CallType
from evm_trace.utils import to_address


class TraceMemory(BaseModel):
__root__: List[HexBytes] = []
class TraceMemory(RootModel[List[HexBytes]]):
root: List[HexBytes] = []

def get(self, offset: HexBytes, size: HexBytes):
return extract_memory(offset, size, self.__root__)
return extract_memory(offset, size, self.root)


class TraceFrame(BaseModel):
Expand Down Expand Up @@ -49,15 +48,15 @@ class TraceFrame(BaseModel):
storage: Dict[HexBytes, HexBytes] = {}
"""Contract storage."""

contract_address: Optional[HexBytes] = None
contract_address: Optional[HashBytes20] = None
"""The address producing the frame."""

@validator("pc", "gas", "gas_cost", "depth", pre=True)
@field_validator("pc", "gas", "gas_cost", "depth", mode="before")
def validate_ints(cls, value):
return int(value, 16) if isinstance(value, str) else value

@property
def address(self) -> Optional[HexBytes]:
def address(self) -> Optional[HashBytes20]:
"""
The address of this CALL frame.
Only returns a value if this frame's opcode is a call-based opcode.
Expand All @@ -66,7 +65,7 @@ def address(self) -> Optional[HexBytes]:
if not self.contract_address and (
self.op in CALL_OPCODES and CallType.CREATE.value not in self.op
):
self.contract_address = HexBytes(self.stack[-2][-20:])
self.contract_address = HashBytes20.__eth_pydantic_validate__(self.stack[-2][-20:])

return self.contract_address

Expand Down Expand Up @@ -104,7 +103,7 @@ def _get_create_frames(frame: TraceFrame, frames: Iterator[Dict]) -> List[TraceF
create_frames = [frame]
start_depth = frame.depth
for next_frame in frames:
next_frame_obj = TraceFrame.parse_obj(next_frame)
next_frame_obj = TraceFrame.model_validate(next_frame)
depth = next_frame_obj.depth

if CallType.CREATE.value in next_frame_obj.op:
Expand All @@ -116,11 +115,7 @@ def _get_create_frames(frame: TraceFrame, frames: Iterator[Dict]) -> List[TraceF
# the first frame after the CREATE with an equal depth.
if len(next_frame_obj.stack) > 0:
raw_addr = HexBytes(next_frame_obj.stack[-1][-40:])
try:
frame.contract_address = HexBytes(to_address(raw_addr))
except Exception:
# Potentially, a transaction was made with poor data.
frame.contract_address = raw_addr
frame.contract_address = HashBytes20.__eth_pydantic_validate__(raw_addr)

create_frames.append(next_frame_obj)
break
Expand Down Expand Up @@ -279,7 +274,7 @@ def _create_node(
node_kwargs["last_create_depth"].pop()
for subcall in node_kwargs.get("calls", [])[::-1]:
if subcall.call_type in (CallType.CREATE, CallType.CREATE2):
subcall.address = HexBytes(to_address(frame.stack[-1][-40:]))
subcall.address = HashBytes20.__eth_pydantic_validate__(frame.stack[-1][-40:])
if len(frame.stack) >= 5:
subcall.calldata = frame.memory.get(frame.stack[-4], frame.stack[-5])

Expand Down
39 changes: 16 additions & 23 deletions evm_trace/parity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List, Optional, Union, cast

from evm_trace._pydantic_compat import Field, validator
from pydantic import Field, RootModel, field_validator

from evm_trace.base import BaseModel, CallTreeNode
from evm_trace.enums import CallType

Expand All @@ -18,7 +19,7 @@ class CallAction(BaseModel):
# only used to recover the specific call type
call_type: str = Field(alias="callType", repr=False)

@validator("value", "gas", pre=True)
@field_validator("value", "gas", mode="before")
def convert_integer(cls, v):
return int(v, 16)

Expand All @@ -32,7 +33,7 @@ class CreateAction(BaseModel):
init: str
value: int

@validator("value", "gas", pre=True)
@field_validator("value", "gas", mode="before")
def convert_integer(cls, v):
return int(v, 16)

Expand All @@ -41,7 +42,7 @@ class SelfDestructAction(BaseModel):
address: str
balance: int

@validator("balance", pre=True)
@field_validator("balance", mode="before")
def convert_integer(cls, v):
return int(v, 16) if isinstance(v, str) else int(v)

Expand All @@ -59,7 +60,7 @@ class ActionResult(BaseModel):
for gas per zero byte and ``16`` gas per non-zero byte.
"""

@validator("gas_used", pre=True)
@field_validator("gas_used", mode="before")
def convert_integer(cls, v):
return int(v, 16) if isinstance(v, str) else int(v)

Expand Down Expand Up @@ -95,26 +96,19 @@ class ParityTrace(BaseModel):
trace_address: List[int] = Field(alias="traceAddress")
transaction_hash: str = Field(alias="transactionHash")

@validator("call_type", pre=True)
def convert_call_type(cls, v, values) -> CallType:
if isinstance(values["action"], CallAction):
v = values["action"].call_type
value = v.upper()
@field_validator("call_type", mode="before")
def convert_call_type(cls, value, info) -> CallType:
if isinstance(info.data["action"], CallAction):
value = info.data["action"].call_type

value = value.upper()
if value == "SUICIDE":
value = "SELFDESTRUCT"

return CallType(value)


class ParityTraceList(BaseModel):
__root__: List[ParityTrace]

# pydantic models with custom root don't have this by default
def __iter__(self):
return iter(self.__root__)

def __getitem__(self, item):
return self.__root__[item]
ParityTraceList = RootModel[List[ParityTrace]]


def get_calltree_from_parity_trace(
Expand All @@ -137,7 +131,7 @@ def get_calltree_from_parity_trace(
Returns:
:class:`~evm_trace.base.CallTreeNode`
"""
root = root or traces[0]
root = root or traces.root[0]
failed = root.error is not None
node_kwargs: Dict[Any, Any] = {
"call_type": root.call_type,
Expand Down Expand Up @@ -187,7 +181,7 @@ def get_calltree_from_parity_trace(
address=selfdestruct_action.address,
)

trace_list: List[ParityTrace] = list(traces)
trace_list: List[ParityTrace] = traces.root
subtraces: List[ParityTrace] = [
sub
for sub in trace_list
Expand All @@ -196,5 +190,4 @@ def get_calltree_from_parity_trace(
]
node_kwargs["calls"] = [get_calltree_from_parity_trace(traces, root=sub) for sub in subtraces]
node_kwargs = {**node_kwargs, **root_kwargs}
node = CallTreeNode.parse_obj(node_kwargs)
return node
return CallTreeNode.model_validate(node_kwargs)
7 changes: 0 additions & 7 deletions evm_trace/utils.py

This file was deleted.

15 changes: 5 additions & 10 deletions evm_trace/vmtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

from eth.vm.memory import Memory
from eth.vm.stack import Stack
from eth_pydantic_types import Address, HexBytes
from eth_utils import to_int
from ethpm_types import HexBytes
from msgspec import Struct
from msgspec.json import Decoder

from evm_trace.utils import to_address

# opcodes grouped by the number of items they pop from the stack
# fmt: off
POP_OPCODES = {
Expand Down Expand Up @@ -141,7 +139,7 @@ def to_trace_frames(
)

if op.op in ["CALL", "DELEGATECALL", "STATICCALL"]:
call_address = to_address(stack.values[-2][1])
call_address = Address.__eth_pydantic_validate__(stack.values[-2][1])

if op.ex:
if op.ex.mem:
Expand Down Expand Up @@ -188,9 +186,6 @@ def from_rpc_response(buffer: bytes) -> Union[VMTrace, List[VMTrace]]:
"""
Decode structured data from a raw `trace_replayTransaction` or `trace_replayBlockTransactions`.
"""
resp = Decoder(RPCResponse, dec_hook=dec_hook).decode(buffer)

if isinstance(resp.result, list):
return [i.vmTrace for i in resp.result]

return resp.result.vmTrace
response = Decoder(RPCResponse, dec_hook=dec_hook).decode(buffer)
result: Union[List[RPCTraceResult], RPCTraceResult] = response.result
return [i.vmTrace for i in result] if isinstance(result, list) else result.vmTrace
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ force_grid_wrap = 0
include_trailing_comma = true
multi_line_output = 3
use_parentheses = true

[tool.mdformat]
number = true
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ exclude =
.eggs
docs
build
evm_trace/version.py
per-file-ignores =
# The traces have to be formatted this way for the tests.
tests/expected_traces.py: E501
Loading

0 comments on commit 54c0a5e

Please sign in to comment.