Skip to content

Commit

Permalink
feat: improved contract receipt related logic (#1548)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jul 19, 2023
1 parent 5fb3edf commit 6fa531a
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 21 deletions.
3 changes: 1 addition & 2 deletions src/ape/api/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ def txn_hash(self) -> HexBytes:
@property
def receipt(self) -> Optional["ReceiptAPI"]:
"""
This transaction's associated published receipt,
if it exists.
This transaction's associated published receipt, if it exists.
"""

try:
Expand Down
46 changes: 31 additions & 15 deletions src/ape/contracts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,22 +584,34 @@ def range(
Iterator[:class:`~ape.contracts.base.ContractLog`]
"""

if not hasattr(self.contract, "address"):
return

start_block = None
stop_block = None

if stop is None:
start_block = 0
contract = None
try:
contract = self.chain_manager.contracts.instance_at(self.contract.address)
except Exception:
pass

if contract:
start_block = contract.receipt.block_number
else:
start_block = self.chain_manager.contracts.get_creation_receipt(
self.contract.address
).block_number

stop_block = start_or_stop
elif start_or_stop is not None and stop is not None:
start_block = start_or_stop
stop_block = stop - 1

stop_block = min(stop_block, self.chain_manager.blocks.height)

addresses = set(
([self.contract.address] if hasattr(self.contract, "address") else [])
+ (extra_addresses or [])
)
addresses = set([self.contract.address] + (extra_addresses or []))
contract_event_query = ContractEventQuery(
columns=list(ContractLog.__fields__.keys()),
contract=addresses,
Expand Down Expand Up @@ -822,25 +834,29 @@ def from_receipt(cls, receipt: ReceiptAPI, contract_type: ContractType) -> "Cont
return instance

@property
def receipt(self) -> Optional[ReceiptAPI]:
def receipt(self) -> ReceiptAPI:
"""
The receipt associated with deploying the contract instance,
if it is known and exists.
"""

if not self._cached_receipt and self.txn_hash:
if self._cached_receipt:
return self._cached_receipt

if self.txn_hash:
# Hash is known. Use that to get the receipt.
try:
receipt = self.chain_manager.get_receipt(self.txn_hash)
except (TransactionNotFoundError, ValueError, ChainError):
return None

self._cached_receipt = receipt
return receipt

elif self._cached_receipt:
return self._cached_receipt
pass
else:
self._cached_receipt = receipt
return receipt

return None
# Brute force find the receipt.
receipt = self.chain_manager.contracts.get_creation_receipt(self.address)
self._cached_receipt = receipt
return receipt

def __repr__(self) -> str:
contract_name = self.contract_type.name or "Unnamed contract"
Expand Down
54 changes: 54 additions & 0 deletions src/ape/managers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,51 @@ def _write_deployments_mapping(self, deployments_map: Dict):
with self._deployments_mapping_cache.open("w") as fp:
json.dump(deployments_map, fp, sort_keys=True, indent=2, default=sorted)

def get_creation_receipt(
self, address: AddressType, start_block: int = 0, stop_block: Optional[int] = None
) -> ReceiptAPI:
"""
Get the receipt responsible for the initial creation of the contract.
Args:
address (``AddressType``): The address of the contract.
start_block (int): The block to start looking from.
stop_block (Optional[int]): The block to stop looking at.
Returns:
:class:`~ape.apt.transactions.ReceiptAPI`
"""
if stop_block is None and (stop := self.chain_manager.blocks.head.number):
stop_block = stop
elif stop_block is None:
raise ChainError("Chain missing blocks.")

mid_block = (stop_block - start_block) // 2 + start_block
# NOTE: biased towards mid_block == start_block

if start_block == mid_block:
for tx in self.chain_manager.blocks[mid_block].transactions:
if (receipt := tx.receipt) and receipt.contract_address == address:
return receipt

if mid_block + 1 <= stop_block:
return self.get_creation_receipt(
address, start_block=mid_block + 1, stop_block=stop_block
)
else:
raise ChainError(f"Failed to find a contract-creation receipt for '{address}'.")

elif self.provider.get_code(address, block_id=mid_block):
return self.get_creation_receipt(address, start_block=start_block, stop_block=mid_block)

elif start_block + 1 <= mid_block:
return self.get_creation_receipt(
address, start_block=start_block + 1, stop_block=stop_block
)

else:
raise ChainError(f"Failed to find a contract-creation receipt for '{address}'.")


class ReportManager(BaseManager):
"""
Expand Down Expand Up @@ -1645,4 +1690,13 @@ def set_balance(self, account: Union[BaseAddress, AddressType], amount: Union[in
return self.provider.set_balance(account, amount)

def get_receipt(self, transaction_hash: str) -> ReceiptAPI:
"""
Get a transaction receipt from the chain.
Args:
transaction_hash (str): The hash of the transaction.
Returns:
:class:`~ape.apt.transactions.ReceiptAPI`
"""
return self.chain_manager.history[transaction_hash]
11 changes: 7 additions & 4 deletions src/ape/managers/project/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ape.api import DependencyAPI, ProjectAPI
from ape.api.networks import LOCAL_NETWORK_NAME
from ape.contracts import ContractContainer, ContractInstance, ContractNamespace
from ape.exceptions import ApeAttributeError, APINotImplementedError, ProjectError
from ape.exceptions import ApeAttributeError, APINotImplementedError, ChainError, ProjectError
from ape.logging import logger
from ape.managers.base import BaseManager
from ape.managers.project.types import ApeProject, BrownieProject
Expand Down Expand Up @@ -730,9 +730,12 @@ def track_deployment(self, contract: ContractInstance):
raise ProjectError("Can only publish deployments on a live network.")

contract_name = contract.contract_type.name
receipt = contract.receipt
if not receipt:
raise ProjectError(f"Contract '{contract_name}' transaction receipt is unknown.")
try:
receipt = contract.receipt
except ChainError as err:
raise ProjectError(
f"Contract '{contract_name}' transaction receipt is unknown."
) from err

block_number = receipt.block_number
block_hash_bytes = self.provider.get_block(block_number).hash
Expand Down
10 changes: 10 additions & 0 deletions tests/functional/test_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,3 +626,13 @@ def test_cache_non_checksum_address(chain, vyper_contract_instance):
lowered_address = vyper_contract_instance.address.lower()
chain.contracts[lowered_address] = vyper_contract_instance.contract_type
assert chain.contracts[vyper_contract_instance.address] == vyper_contract_instance.contract_type


def test_get_contract_receipt(chain, vyper_contract_instance):
address = vyper_contract_instance.address
receipt = chain.contracts.get_creation_receipt(address)
assert receipt.contract_address == address

chain.mine()
receipt = chain.contracts.get_creation_receipt(address)
assert receipt.contract_address == address
10 changes: 10 additions & 0 deletions tests/functional/test_contract_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,16 @@ def test_receipt(contract_instance, owner):
assert receipt.sender == owner


def test_receipt_when_needs_brute_force(vyper_contract_instance, owner):
# Force it to use the brute-force approach.
vyper_contract_instance._cached_receipt = None
vyper_contract_instance.txn_hash = None

actual = vyper_contract_instance.receipt.contract_address
expected = vyper_contract_instance.address
assert actual == expected


def test_from_receipt_when_receipt_not_deploy(contract_instance, owner):
receipt = contract_instance.setNumber(555, sender=owner)
expected_err = (
Expand Down
10 changes: 10 additions & 0 deletions tests/functional/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ def test_transaction_contract_event_query(contract_instance, owner, eth_tester_p
assert df_events.event_name[0] == "FooHappened"


def test_transaction_contract_event_query_starts_query_at_deploy_tx(
contract_instance, owner, eth_tester_provider
):
contract_instance.fooAndBar(sender=owner)
time.sleep(0.1)
df_events = contract_instance.FooHappened.query("*")
assert isinstance(df_events, pd.DataFrame)
assert df_events.event_name[0] == "FooHappened"


class Model(BaseInterfaceModel):
number: int
timestamp: int
Expand Down

0 comments on commit 6fa531a

Please sign in to comment.