Skip to content

Commit

Permalink
perf: RPC / provider level speed optimizations (#2193)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Aug 2, 2024
1 parent 26d762a commit b046e1b
Show file tree
Hide file tree
Showing 21 changed files with 241 additions and 200 deletions.
1 change: 0 additions & 1 deletion src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def serialize_transaction(self) -> bytes:
Returns:
bytes
"""

if not self.signature:
raise SignatureError("The transaction is not signed.")

Expand Down
33 changes: 24 additions & 9 deletions src/ape/api/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import IO, TYPE_CHECKING, Any, NoReturn, Optional, Union

from eth_pydantic_types import HexBytes
from eth_utils import is_0x_prefixed, is_hex, to_int
from eth_utils import is_0x_prefixed, is_hex, to_hex, to_int
from ethpm_types.abi import EventABI, MethodABI
from pydantic import ConfigDict, field_validator
from pydantic.fields import Field
Expand Down Expand Up @@ -74,6 +74,20 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._raise_on_revert = raise_on_revert

@field_validator("nonce", mode="before")
@classmethod
def validate_nonce(cls, value):
if value is None or isinstance(value, int):
return value

elif isinstance(value, str) and value.startswith("0x"):
return to_int(hexstr=value)

elif isinstance(value, str):
return int(value)

return to_int(value)

@field_validator("gas_limit", mode="before")
@classmethod
def validate_gas_limit(cls, value):
Expand Down Expand Up @@ -161,12 +175,12 @@ def receipt(self) -> Optional["ReceiptAPI"]:
"""

try:
txn_hash = self.txn_hash.hex()
txn_hash = to_hex(self.txn_hash)
except SignatureError:
return None

try:
return self.provider.get_receipt(txn_hash, required_confirmations=0, timeout=0)
return self.chain_manager.get_receipt(txn_hash)
except (TransactionNotFoundError, ProviderNotConnectedError):
return None

Expand Down Expand Up @@ -355,7 +369,6 @@ def failed(self) -> bool:
Ecosystem plugins override this property when their receipts
are able to be failing.
"""

return False

@property
Expand Down Expand Up @@ -450,6 +463,13 @@ def await_confirmations(self) -> "ReceiptAPI":
Returns:
:class:`~ape.api.ReceiptAPI`: The receipt that is now confirmed.
"""
# perf: avoid *everything* if required_confirmations is 0, as this is likely a
# dev environment or the user doesn't care.
if self.required_confirmations == 0:
# The transaction might not yet be confirmed but
# the user is aware of this. Or, this is a development environment.
return self

try:
self.raise_for_status()
except TransactionError:
Expand All @@ -472,11 +492,6 @@ def await_confirmations(self) -> "ReceiptAPI":
if self.transaction.raise_on_revert:
raise tx_err

if self.required_confirmations == 0:
# The transaction might not yet be confirmed but
# the user is aware of this. Or, this is a development environment.
return self

confirmations_occurred = self._confirmations_occurred
if self.required_confirmations and confirmations_occurred >= self.required_confirmations:
return self
Expand Down
15 changes: 5 additions & 10 deletions src/ape/contracts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,12 +554,7 @@ def __len__(self):
def __call__(self, *args: Any, **kwargs: Any) -> MockContractLog:
# Create a dictionary from the positional arguments
event_args: dict[Any, Any] = dict(zip((ipt.name for ipt in self.abi.inputs), args))

overlapping_keys = set(k for k in event_args.keys() if k is not None) & set(
k for k in kwargs.keys() if k is not None
)

if overlapping_keys:
if overlapping_keys := set(event_args).intersection(kwargs):
raise ValueError(
f"Overlapping keys found in arguments: '{', '.join(overlapping_keys)}'."
)
Expand Down Expand Up @@ -1132,8 +1127,8 @@ def get_event_by_signature(self, signature: str) -> ContractEvent:
:class:`~ape.contracts.base.ContractEvent`
"""

name_from_sig = signature.split("(")[0].strip()
options = self._events_.get(name_from_sig, [])
name_from_sig = signature.partition("(")[0].strip()
options = self._events_.get(name_from_sig.strip(), [])

err = ContractDataError(f"No event found with signature '{signature}'.")
if not options:
Expand All @@ -1157,7 +1152,7 @@ def get_error_by_signature(self, signature: str) -> type[CustomError]:
:class:`~ape.exceptions.CustomError`
"""

name_from_sig = signature.split("(")[0].strip()
name_from_sig = signature.partition("(")[0].strip()
options = self._errors_.get(name_from_sig, [])
err = ContractDataError(f"No error found with signature '{signature}'.")
if not options:
Expand Down Expand Up @@ -1605,7 +1600,7 @@ def _get_name(cc: ContractContainer) -> str:
return contract

elif "." in search_name:
next_node = search_name.split(".")[0]
next_node = search_name.partition(".")[0]
if next_node != item:
continue

Expand Down
1 change: 0 additions & 1 deletion src/ape/managers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def head(self) -> BlockAPI:
"""
The latest block.
"""

return self.provider.get_block("latest")

@property
Expand Down
16 changes: 7 additions & 9 deletions src/ape/pytest/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from functools import cached_property
from typing import Any, Optional, Union

from _pytest.config import Config as PytestConfig

from ape.types import ContractFunctionPath
from ape.utils import ManagerAccessMixin, cached_property
from ape.utils.basemodel import ManagerAccessMixin


def _get_config_exclusions(config) -> list[ContractFunctionPath]:
Expand Down Expand Up @@ -76,15 +77,12 @@ def gas_exclusions(self) -> list[ContractFunctionPath]:
"""
The combination of both CLI values and config values.
"""

cli_value = self.pytest_config.getoption("--gas-exclude")
exclusions: list[ContractFunctionPath] = []
if cli_value:
items = cli_value.split(",")
for item in items:
exclusion = ContractFunctionPath.from_str(item)
exclusions.append(exclusion)

exclusions = (
[ContractFunctionPath.from_str(item) for item in cli_value.split(",")]
if cli_value
else []
)
paths = _get_config_exclusions(self.ape_test_config.gas)
exclusions.extend(paths)
return exclusions
Expand Down
2 changes: 1 addition & 1 deletion src/ape/pytest/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ethpm_types.source import ContractSource

from ape.logging import logger
from ape.managers import ProjectManager
from ape.managers.project import ProjectManager
from ape.pytest.config import ConfigWrapper
from ape.types import (
ContractFunctionPath,
Expand Down
40 changes: 12 additions & 28 deletions src/ape/pytest/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import copy
from collections.abc import Iterator
from fnmatch import fnmatch
from functools import cached_property
from typing import Optional

import pytest

from ape.api import ReceiptAPI, TestAccountAPI
from ape.api.accounts import TestAccountAPI
from ape.api.transactions import ReceiptAPI
from ape.exceptions import BlockNotFoundError, ChainError
from ape.logging import logger
from ape.managers.chain import ChainManager
from ape.managers.networks import NetworkManager
from ape.managers.project import ProjectManager
from ape.pytest.config import ConfigWrapper
from ape.types import SnapshotID
from ape.utils import ManagerAccessMixin, allow_disconnected, cached_property
from ape.utils.basemodel import ManagerAccessMixin
from ape.utils.misc import allow_disconnected


class PytestApeFixtures(ManagerAccessMixin):
Expand All @@ -30,49 +32,45 @@ def __init__(self, config_wrapper: ConfigWrapper, receipt_capture: "ReceiptCaptu

@cached_property
def _track_transactions(self) -> bool:
has_reason = self.config_wrapper.track_gas or self.config_wrapper.track_coverage
return (
self.network_manager.provider is not None and self.provider.is_connected and has_reason
self.network_manager.provider is not None
and self.provider.is_connected
and (self.config_wrapper.track_gas or self.config_wrapper.track_coverage)
)

@pytest.fixture(scope="session")
def accounts(self) -> list[TestAccountAPI]:
"""
A collection of pre-funded accounts.
"""

return self.account_manager.test_accounts

@pytest.fixture(scope="session")
def compilers(self):
"""
Access compiler manager directly.
"""

return self.compiler_manager

@pytest.fixture(scope="session")
def chain(self) -> ChainManager:
"""
Manipulate the blockchain, such as mine or change the pending timestamp.
"""

return self.chain_manager

@pytest.fixture(scope="session")
def networks(self) -> NetworkManager:
"""
Connect to other networks in your tests.
"""

return self.network_manager

@pytest.fixture(scope="session")
def project(self) -> ProjectManager:
"""
Access contract types and dependencies.
"""

return self.local_project

@pytest.fixture(scope="session")
Expand All @@ -88,7 +86,6 @@ def _isolation(self) -> Iterator[None]:
Isolation logic used to implement isolation fixtures for each pytest scope.
When tracing support is available, will also assist in capturing receipts.
"""

try:
snapshot_id = self._snapshot()
except BlockNotFoundError:
Expand Down Expand Up @@ -174,7 +171,7 @@ def capture_range(self, start_block: int, stop_block: int):
txn_hash = txn.txn_hash.hex()
except Exception:
# Might have been from an impersonated account.
# Those txns need to be added separatly, same as tracing calls.
# Those txns need to be added separately, same as tracing calls.
# Likely, it was already accounted before this point.
continue

Expand All @@ -189,14 +186,14 @@ def capture(self, transaction_hash: str):
if not receipt:
return

if not (contract_address := (receipt.receiver or receipt.contract_address)):
elif not (contract_address := (receipt.receiver or receipt.contract_address)):
return

if not (contract_type := self.chain_manager.contracts.get(contract_address)):
elif not (contract_type := self.chain_manager.contracts.get(contract_address)):
# Not an invoke-transaction or a known address
return

if not (source_id := (contract_type.source_id or None)):
elif not (source_id := (contract_type.source_id or None)):
# Not a local or known contract type.
return

Expand Down Expand Up @@ -229,7 +226,6 @@ def _exclude_from_gas_report(
Helper method to determine if a certain contract / method combination should be
excluded from the gas report.
"""

for exclusion in self.config_wrapper.gas_exclusions:
# Default to looking at all contracts
contract_pattern = exclusion.contract_name
Expand All @@ -241,15 +237,3 @@ def _exclude_from_gas_report(
return True

return False


def _build_report(report: dict, contract: str, method: str, usages: list) -> dict:
new_dict = copy.deepcopy(report)
if contract not in new_dict:
new_dict[contract] = {method: usages}
elif method not in new_dict[contract]:
new_dict[contract][method] = usages
else:
new_dict[contract][method].extend(usages)

return new_dict
5 changes: 2 additions & 3 deletions src/ape/pytest/gas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from ethpm_types.source import ContractSource
from evm_trace.gas import merge_reports

from ape.api import TraceAPI
from ape.api.trace import TraceAPI
from ape.pytest.config import ConfigWrapper
from ape.types import AddressType, ContractFunctionPath, GasReport
from ape.utils import parse_gas_table
from ape.utils.basemodel import ManagerAccessMixin
from ape.utils.trace import _exclude_gas
from ape.utils.trace import _exclude_gas, parse_gas_table


class GasTracker(ManagerAccessMixin):
Expand Down
44 changes: 9 additions & 35 deletions src/ape/pytest/plugin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import sys
from pathlib import Path

import pytest

from ape.exceptions import ConfigError
from ape.logging import LogLevel, logger
from ape.pytest.config import ConfigWrapper
from ape.pytest.coverage import CoverageTracker
from ape.pytest.fixtures import PytestApeFixtures, ReceiptCapture
Expand Down Expand Up @@ -64,17 +61,19 @@ def add_option(*names, **kwargs):
def pytest_configure(config):
# Do not include ape internals in tracebacks unless explicitly asked
if not config.getoption("--show-internal"):
path_str = sys.modules["ape"].__file__
if path_str:
base_path = Path(path_str).parent.as_posix()
if path_str := sys.modules["ape"].__file__:
base_path = str(Path(path_str).parent)

def is_module(v):
return getattr(v, "__file__", None) and v.__file__.startswith(base_path)

modules = [v for v in sys.modules.values() if is_module(v)]
for module in modules:
if hasattr(module, "__tracebackhide__"):
setattr(module, "__tracebackhide__", True)
for module in (v for v in sys.modules.values() if is_module(v)):
# NOTE: Using try/except w/ type:ignore (over checking for attr)
# for performance reasons!
try:
module.__tracebackhide__ = True # type: ignore[attr-defined]
except AttributeError:
pass

config_wrapper = ConfigWrapper(config)
receipt_capture = ReceiptCapture(config_wrapper)
Expand All @@ -99,28 +98,3 @@ def is_module(v):
config.addinivalue_line(
"markers", "use_network(choice): Run this test using the given network choice."
)


def pytest_load_initial_conftests(early_config):
"""
Compile contracts before loading ``conftest.py``s.
"""
capture_manager = early_config.pluginmanager.get_plugin("capturemanager")
pm = ManagerAccessMixin.local_project

# Suspend stdout capture to display compilation data
capture_manager.suspend()
try:
pm.load_contracts()
except Exception as err:
logger.log_debug_stack_trace()
message = "Unable to load project. "
if logger.level > LogLevel.DEBUG:
message = f"{message}Use `-v DEBUG` to see more info.\n"

err_type_name = getattr(type(err), "__name__", "Exception")
message = f"{message}Failure reason: ({err_type_name}) {err}"
raise pytest.UsageError(message)

finally:
capture_manager.resume()
Loading

0 comments on commit b046e1b

Please sign in to comment.