diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index 022eeac8a5..93a525bb42 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -1,4 +1,5 @@ import json +from collections import defaultdict from collections.abc import Collection, Iterator from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager @@ -1497,7 +1498,7 @@ class ChainManager(BaseManager): from ape import chain """ - _snapshots: list[SnapshotID] = [] + _snapshots: defaultdict = defaultdict(list) # chain_id -> snapshots _chain_id_map: dict[str, int] = {} _block_container_map: dict[int, BlockContainer] = {} _transaction_history_map: dict[int, TransactionHistory] = {} @@ -1602,9 +1603,10 @@ def snapshot(self) -> SnapshotID: Returns: :class:`~ape.types.SnapshotID`: The snapshot ID. """ + chain_id = self.provider.chain_id snapshot_id = self.provider.snapshot() - if snapshot_id not in self._snapshots: - self._snapshots.append(snapshot_id) + if snapshot_id not in self._snapshots[chain_id]: + self._snapshots[chain_id].append(snapshot_id) return snapshot_id @@ -1623,15 +1625,16 @@ def restore(self, snapshot_id: Optional[SnapshotID] = None): snapshot_id (Optional[:class:`~ape.types.SnapshotID`]): The snapshot ID. Defaults to the most recent snapshot ID. """ - if snapshot_id is None and not self._snapshots: + chain_id = self.provider.chain_id + if snapshot_id is None and not self._snapshots[chain_id]: raise ChainError("There are no snapshots to revert to.") elif snapshot_id is None: - snapshot_id = self._snapshots.pop() - elif snapshot_id not in self._snapshots: + snapshot_id = self._snapshots[chain_id].pop() + elif snapshot_id not in self._snapshots[chain_id]: raise UnknownSnapshotError(snapshot_id) else: - snapshot_index = self._snapshots.index(snapshot_id) - self._snapshots = self._snapshots[:snapshot_index] + snapshot_index = self._snapshots[chain_id].index(snapshot_id) + self._snapshots[chain_id] = self._snapshots[chain_id][:snapshot_index] self.provider.restore(snapshot_id) self.history.revert_to_block(self.blocks.height) diff --git a/src/ape/pytest/fixtures.py b/src/ape/pytest/fixtures.py index 9f4e5587ea..fe3b997542 100644 --- a/src/ape/pytest/fixtures.py +++ b/src/ape/pytest/fixtures.py @@ -136,7 +136,7 @@ def _snapshot(self) -> Optional[SnapshotID]: @allow_disconnected def _restore(self, snapshot_id: SnapshotID): - if snapshot_id not in self.chain_manager._snapshots: + if snapshot_id not in self.chain_manager._snapshots[self.provider.chain_id]: return try: self.chain_manager.restore(snapshot_id) diff --git a/src/ape_node/provider.py b/src/ape_node/provider.py index 4dca610d33..68067359f7 100644 --- a/src/ape_node/provider.py +++ b/src/ape_node/provider.py @@ -332,6 +332,7 @@ def connect(self): if self.is_connected: self._complete_connect() else: + # Starting the process. self.start() def start(self, timeout: int = 20): @@ -390,6 +391,9 @@ def disconnect(self): # NOTE: Type ignore is wrong; TODO: figure out why. self.process = None # type: ignore[assignment] + # Clear any snapshots. + self.chain_manager._snapshots[self.chain_id] = [] + super().disconnect() def snapshot(self) -> SnapshotID: diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index 74282b988f..673f1a38a2 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -111,6 +111,9 @@ def disconnect(self): self._evm_backend = None self.provider_settings = {} + # Invalidate snapshots. + self.chain_manager._snapshots[self.chain_id] = [] + def update_settings(self, new_settings: dict): self.provider_settings = {**self.provider_settings, **new_settings} self.disconnect() diff --git a/tests/functional/test_chain.py b/tests/functional/test_chain.py index a380cc62bb..cae7e2f7d3 100644 --- a/tests/functional/test_chain.py +++ b/tests/functional/test_chain.py @@ -1,8 +1,9 @@ +from collections import defaultdict from datetime import datetime, timedelta import pytest -from ape.exceptions import APINotImplementedError, ChainError +from ape.exceptions import APINotImplementedError, ChainError, UnknownSnapshotError from ape.managers.chain import AccountHistory from ape.types import AddressType @@ -61,18 +62,32 @@ def test_snapshot_and_restore_unknown_snapshot_id(chain): # After restoring to the second ID, the third ID is now invalid. chain.restore(snapshot_id_2) - with pytest.raises(ChainError) as err: + with pytest.raises(UnknownSnapshotError) as err: chain.restore(snapshot_id_3) assert "Unknown snapshot ID" in str(err.value) def test_snapshot_and_restore_no_snapshots(chain): - chain._snapshots = [] # Ensure empty (gets set in test setup) + chain._snapshots = defaultdict(list) # Ensure empty (gets set in test setup) with pytest.raises(ChainError, match="There are no snapshots to revert to."): chain.restore() +def test_snapshot_and_restore_switched_chains(networks, chain): + """ + Ensuring things work as expected when we switch chains after snapshotting + and before restoring. + """ + snapshot = chain.snapshot() + # Switch chains. + with networks.ethereum.local.use_provider( + "test", provider_settings={"chain_id": 11191919191991918223773} + ): + with pytest.raises(UnknownSnapshotError): + chain.restore(snapshot) + + def test_isolate(chain, vyper_contract_instance, owner): number_at_start = 444 vyper_contract_instance.setNumber(number_at_start, sender=owner) diff --git a/tests/functional/test_provider.py b/tests/functional/test_provider.py index ed8f8ebe03..97b15c58ab 100644 --- a/tests/functional/test_provider.py +++ b/tests/functional/test_provider.py @@ -589,3 +589,10 @@ def test_ipc_per_network(project, key): # TODO: 0.9 investigate not using random if ipc set. assert node.ipc_path == Path(ipc) + + +def test_update_settings_invalidates_snapshots(eth_tester_provider, chain): + snapshot = chain.snapshot() + assert snapshot in chain._snapshots[eth_tester_provider.chain_id] + eth_tester_provider.update_settings({}) + assert snapshot not in chain._snapshots[eth_tester_provider.chain_id]