Skip to content

Commit

Permalink
fix: issue where snapshots lingered on disconnect (#2332)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Oct 21, 2024
1 parent 6d2944e commit c695988
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 12 deletions.
19 changes: 11 additions & 8 deletions src/ape/managers/chain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from collections import defaultdict
from collections.abc import Collection, Iterator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
Expand Down Expand Up @@ -1497,7 +1498,7 @@ class ChainManager(BaseManager):
from ape import chain
"""

_snapshots: list[SnapshotID] = []
_snapshots: defaultdict = defaultdict(list) # chain_id -> snapshots
_chain_id_map: dict[str, int] = {}
_block_container_map: dict[int, BlockContainer] = {}
_transaction_history_map: dict[int, TransactionHistory] = {}
Expand Down Expand Up @@ -1602,9 +1603,10 @@ def snapshot(self) -> SnapshotID:
Returns:
:class:`~ape.types.SnapshotID`: The snapshot ID.
"""
chain_id = self.provider.chain_id
snapshot_id = self.provider.snapshot()
if snapshot_id not in self._snapshots:
self._snapshots.append(snapshot_id)
if snapshot_id not in self._snapshots[chain_id]:
self._snapshots[chain_id].append(snapshot_id)

return snapshot_id

Expand All @@ -1623,15 +1625,16 @@ def restore(self, snapshot_id: Optional[SnapshotID] = None):
snapshot_id (Optional[:class:`~ape.types.SnapshotID`]): The snapshot ID. Defaults
to the most recent snapshot ID.
"""
if snapshot_id is None and not self._snapshots:
chain_id = self.provider.chain_id
if snapshot_id is None and not self._snapshots[chain_id]:
raise ChainError("There are no snapshots to revert to.")
elif snapshot_id is None:
snapshot_id = self._snapshots.pop()
elif snapshot_id not in self._snapshots:
snapshot_id = self._snapshots[chain_id].pop()
elif snapshot_id not in self._snapshots[chain_id]:
raise UnknownSnapshotError(snapshot_id)
else:
snapshot_index = self._snapshots.index(snapshot_id)
self._snapshots = self._snapshots[:snapshot_index]
snapshot_index = self._snapshots[chain_id].index(snapshot_id)
self._snapshots[chain_id] = self._snapshots[chain_id][:snapshot_index]

self.provider.restore(snapshot_id)
self.history.revert_to_block(self.blocks.height)
Expand Down
2 changes: 1 addition & 1 deletion src/ape/pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _snapshot(self) -> Optional[SnapshotID]:

@allow_disconnected
def _restore(self, snapshot_id: SnapshotID):
if snapshot_id not in self.chain_manager._snapshots:
if snapshot_id not in self.chain_manager._snapshots[self.provider.chain_id]:
return
try:
self.chain_manager.restore(snapshot_id)
Expand Down
4 changes: 4 additions & 0 deletions src/ape_node/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def connect(self):
if self.is_connected:
self._complete_connect()
else:
# Starting the process.
self.start()

def start(self, timeout: int = 20):
Expand Down Expand Up @@ -390,6 +391,9 @@ def disconnect(self):
# NOTE: Type ignore is wrong; TODO: figure out why.
self.process = None # type: ignore[assignment]

# Clear any snapshots.
self.chain_manager._snapshots[self.chain_id] = []

super().disconnect()

def snapshot(self) -> SnapshotID:
Expand Down
3 changes: 3 additions & 0 deletions src/ape_test/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def disconnect(self):
self._evm_backend = None
self.provider_settings = {}

# Invalidate snapshots.
self.chain_manager._snapshots[self.chain_id] = []

def update_settings(self, new_settings: dict):
self.provider_settings = {**self.provider_settings, **new_settings}
self.disconnect()
Expand Down
21 changes: 18 additions & 3 deletions tests/functional/test_chain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from collections import defaultdict
from datetime import datetime, timedelta

import pytest

from ape.exceptions import APINotImplementedError, ChainError
from ape.exceptions import APINotImplementedError, ChainError, UnknownSnapshotError
from ape.managers.chain import AccountHistory
from ape.types import AddressType

Expand Down Expand Up @@ -61,18 +62,32 @@ def test_snapshot_and_restore_unknown_snapshot_id(chain):
# After restoring to the second ID, the third ID is now invalid.
chain.restore(snapshot_id_2)

with pytest.raises(ChainError) as err:
with pytest.raises(UnknownSnapshotError) as err:
chain.restore(snapshot_id_3)

assert "Unknown snapshot ID" in str(err.value)


def test_snapshot_and_restore_no_snapshots(chain):
chain._snapshots = [] # Ensure empty (gets set in test setup)
chain._snapshots = defaultdict(list) # Ensure empty (gets set in test setup)
with pytest.raises(ChainError, match="There are no snapshots to revert to."):
chain.restore()


def test_snapshot_and_restore_switched_chains(networks, chain):
"""
Ensuring things work as expected when we switch chains after snapshotting
and before restoring.
"""
snapshot = chain.snapshot()
# Switch chains.
with networks.ethereum.local.use_provider(
"test", provider_settings={"chain_id": 11191919191991918223773}
):
with pytest.raises(UnknownSnapshotError):
chain.restore(snapshot)


def test_isolate(chain, vyper_contract_instance, owner):
number_at_start = 444
vyper_contract_instance.setNumber(number_at_start, sender=owner)
Expand Down
7 changes: 7 additions & 0 deletions tests/functional/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,10 @@ def test_ipc_per_network(project, key):
# TODO: 0.9 investigate not using random if ipc set.

assert node.ipc_path == Path(ipc)


def test_update_settings_invalidates_snapshots(eth_tester_provider, chain):
snapshot = chain.snapshot()
assert snapshot in chain._snapshots[eth_tester_provider.chain_id]
eth_tester_provider.update_settings({})
assert snapshot not in chain._snapshots[eth_tester_provider.chain_id]

0 comments on commit c695988

Please sign in to comment.