Skip to content

Commit

Permalink
fix: persistent fork (#1664)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Sep 15, 2023
1 parent 3633d9e commit 2bda5be
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
34 changes: 26 additions & 8 deletions src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,11 @@ class ProviderContextManager(ManagerAccessMixin):

connected_providers: Dict[str, "ProviderAPI"] = {}
provider_stack: List[str] = []
disconnect_map: Dict[str, bool] = {}

def __init__(self, provider: "ProviderAPI"):
def __init__(self, provider: "ProviderAPI", disconnect_after: bool = False):
self._provider = provider
self._disconnect_after = disconnect_after

@property
def empty(self) -> bool:
Expand All @@ -589,6 +591,7 @@ def push_provider(self):
raise ProviderNotConnectedError()

self.provider_stack.append(provider_id)
self.disconnect_map[provider_id] = self._disconnect_after
if provider_id in self.connected_providers:
# Using already connected instance
if must_connect:
Expand All @@ -609,6 +612,12 @@ def pop_provider(self):

# Clear last provider
exiting_provider_id = self.provider_stack.pop()

# Disconnect the provider in same cases.
if self.disconnect_map[exiting_provider_id]:
if provider := self.network_manager.active_provider:
provider.disconnect()

if not self.provider_stack:
self.network_manager.active_provider = None
return
Expand All @@ -619,8 +628,7 @@ def pop_provider(self):
# Active provider is not changing
return

previous_provider = self.connected_providers[previous_provider_id]
if previous_provider:
if previous_provider := self.connected_providers[previous_provider_id]:
self.network_manager.active_provider = previous_provider

def disconnect_all(self):
Expand Down Expand Up @@ -889,6 +897,7 @@ def use_provider(
self,
provider_name: str,
provider_settings: Optional[Dict] = None,
disconnect_after: bool = False,
) -> ProviderContextManager:
"""
Use and connect to a provider in a temporary context. When entering the context, it calls
Expand All @@ -905,6 +914,9 @@ def use_provider(
Args:
provider_name (str): The name of the provider to use.
disconnect_after (bool): Set to ``True`` to force a disconnect after ending
the context. This defaults to ``False`` so you can re-connect to the
same network, such as in a multi-chain testing scenario.
provider_settings (dict, optional): Settings to apply to the provider.
Defaults to ``None``.
Expand All @@ -913,9 +925,8 @@ def use_provider(
"""

settings = provider_settings or {}
return ProviderContextManager(
provider=self.get_provider(provider_name=provider_name, provider_settings=settings),
)
provider = self.get_provider(provider_name=provider_name, provider_settings=settings)
return ProviderContextManager(provider=provider, disconnect_after=disconnect_after)

@property
def default_provider(self) -> Optional[str]:
Expand Down Expand Up @@ -955,7 +966,9 @@ def set_default_provider(self, provider_name: str):
raise NetworkError(f"Provider '{provider_name}' not found in network '{self.choice}'.")

def use_default_provider(
self, provider_settings: Optional[Dict] = None
self,
provider_settings: Optional[Dict] = None,
disconnect_after: bool = False,
) -> ProviderContextManager:
"""
Temporarily connect and use the default provider. When entering the context, it calls
Expand All @@ -973,13 +986,18 @@ def use_default_provider(
Args:
provider_settings (dict, optional): Settings to override the provider.
disconnect_after (bool): Set to ``True`` to force a disconnect after ending
the context. This defaults to ``False`` so you can re-connect to the
same network, such as in a multi-chain testing scenario.
Returns:
:class:`~ape.api.networks.ProviderContextManager`
"""
if self.default_provider:
settings = provider_settings or {}
return self.use_provider(self.default_provider, provider_settings=settings)
return self.use_provider(
self.default_provider, provider_settings=settings, disconnect_after=disconnect_after
)

raise NetworkError(f"No providers for network '{self.name}'.")

Expand Down
9 changes: 6 additions & 3 deletions src/ape/managers/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ def fork(
provider_settings = provider_settings or {}

if provider_name:
return forked_network.use_provider(provider_name, provider_settings)
return forked_network.use_provider(
provider_name, provider_settings, disconnect_after=True
)

return forked_network.use_default_provider(provider_settings)
return forked_network.use_default_provider(provider_settings, disconnect_after=True)

@property
def ecosystem_names(self) -> Set[str]:
Expand Down Expand Up @@ -441,6 +443,7 @@ def parse_network_choice(
self,
network_choice: Optional[str] = None,
provider_settings: Optional[Dict] = None,
disconnect_after: bool = False,
) -> ProviderContextManager:
"""
Parse a network choice into a context manager for managing a temporary
Expand All @@ -465,7 +468,7 @@ def parse_network_choice(
provider = self.get_provider_from_choice(
network_choice=network_choice, provider_settings=provider_settings
)
return ProviderContextManager(provider=provider)
return ProviderContextManager(provider=provider, disconnect_after=disconnect_after)

@property
def default_ecosystem(self) -> EcosystemAPI:
Expand Down

0 comments on commit 2bda5be

Please sign in to comment.