Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: issue where snapshots lingered on disconnect #2332

Merged
merged 2 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/ape/managers/chain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from collections import defaultdict
from collections.abc import Collection, Iterator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
Expand Down Expand Up @@ -1497,7 +1498,7 @@ class ChainManager(BaseManager):
from ape import chain
"""

_snapshots: list[SnapshotID] = []
_snapshots: defaultdict = defaultdict(list) # chain_id -> snapshots
_chain_id_map: dict[str, int] = {}
_block_container_map: dict[int, BlockContainer] = {}
_transaction_history_map: dict[int, TransactionHistory] = {}
Expand Down Expand Up @@ -1602,9 +1603,10 @@ def snapshot(self) -> SnapshotID:
Returns:
:class:`~ape.types.SnapshotID`: The snapshot ID.
"""
chain_id = self.provider.chain_id
snapshot_id = self.provider.snapshot()
if snapshot_id not in self._snapshots:
self._snapshots.append(snapshot_id)
if snapshot_id not in self._snapshots[chain_id]:
self._snapshots[chain_id].append(snapshot_id)

return snapshot_id

Expand All @@ -1623,15 +1625,16 @@ def restore(self, snapshot_id: Optional[SnapshotID] = None):
snapshot_id (Optional[:class:`~ape.types.SnapshotID`]): The snapshot ID. Defaults
to the most recent snapshot ID.
"""
if snapshot_id is None and not self._snapshots:
chain_id = self.provider.chain_id
if snapshot_id is None and not self._snapshots[chain_id]:
raise ChainError("There are no snapshots to revert to.")
elif snapshot_id is None:
snapshot_id = self._snapshots.pop()
elif snapshot_id not in self._snapshots:
snapshot_id = self._snapshots[chain_id].pop()
elif snapshot_id not in self._snapshots[chain_id]:
raise UnknownSnapshotError(snapshot_id)
else:
snapshot_index = self._snapshots.index(snapshot_id)
self._snapshots = self._snapshots[:snapshot_index]
snapshot_index = self._snapshots[chain_id].index(snapshot_id)
self._snapshots[chain_id] = self._snapshots[chain_id][:snapshot_index]

self.provider.restore(snapshot_id)
self.history.revert_to_block(self.blocks.height)
Expand Down
2 changes: 1 addition & 1 deletion src/ape/pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _snapshot(self) -> Optional[SnapshotID]:

@allow_disconnected
def _restore(self, snapshot_id: SnapshotID):
if snapshot_id not in self.chain_manager._snapshots:
if snapshot_id not in self.chain_manager._snapshots[self.provider.chain_id]:
return
try:
self.chain_manager.restore(snapshot_id)
Expand Down
4 changes: 4 additions & 0 deletions src/ape_node/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def connect(self):
if self.is_connected:
self._complete_connect()
else:
# Starting the process.
self.start()

def start(self, timeout: int = 20):
Expand Down Expand Up @@ -390,6 +391,9 @@ def disconnect(self):
# NOTE: Type ignore is wrong; TODO: figure out why.
self.process = None # type: ignore[assignment]

# Clear any snapshots.
self.chain_manager._snapshots[self.chain_id] = []

super().disconnect()

def snapshot(self) -> SnapshotID:
Expand Down
3 changes: 3 additions & 0 deletions src/ape_test/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def disconnect(self):
self._evm_backend = None
self.provider_settings = {}

# Invalidate snapshots.
self.chain_manager._snapshots[self.chain_id] = []

def update_settings(self, new_settings: dict):
self.provider_settings = {**self.provider_settings, **new_settings}
self.disconnect()
Expand Down
21 changes: 18 additions & 3 deletions tests/functional/test_chain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from collections import defaultdict
from datetime import datetime, timedelta

import pytest

from ape.exceptions import APINotImplementedError, ChainError
from ape.exceptions import APINotImplementedError, ChainError, UnknownSnapshotError
from ape.managers.chain import AccountHistory
from ape.types import AddressType

Expand Down Expand Up @@ -61,18 +62,32 @@ def test_snapshot_and_restore_unknown_snapshot_id(chain):
# After restoring to the second ID, the third ID is now invalid.
chain.restore(snapshot_id_2)

with pytest.raises(ChainError) as err:
with pytest.raises(UnknownSnapshotError) as err:
chain.restore(snapshot_id_3)

assert "Unknown snapshot ID" in str(err.value)


def test_snapshot_and_restore_no_snapshots(chain):
chain._snapshots = [] # Ensure empty (gets set in test setup)
chain._snapshots = defaultdict(list) # Ensure empty (gets set in test setup)
with pytest.raises(ChainError, match="There are no snapshots to revert to."):
chain.restore()


def test_snapshot_and_restore_switched_chains(networks, chain):
"""
Ensuring things work as expected when we switch chains after snapshotting
and before restoring.
"""
snapshot = chain.snapshot()
# Switch chains.
with networks.ethereum.local.use_provider(
"test", provider_settings={"chain_id": 11191919191991918223773}
):
with pytest.raises(UnknownSnapshotError):
chain.restore(snapshot)


def test_isolate(chain, vyper_contract_instance, owner):
number_at_start = 444
vyper_contract_instance.setNumber(number_at_start, sender=owner)
Expand Down
7 changes: 7 additions & 0 deletions tests/functional/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,10 @@ def test_ipc_per_network(project, key):
# TODO: 0.9 investigate not using random if ipc set.

assert node.ipc_path == Path(ipc)


def test_update_settings_invalidates_snapshots(eth_tester_provider, chain):
snapshot = chain.snapshot()
assert snapshot in chain._snapshots[eth_tester_provider.chain_id]
eth_tester_provider.update_settings({})
assert snapshot not in chain._snapshots[eth_tester_provider.chain_id]
Loading