Skip to content

Commit

Permalink
feat: account type filtering improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Oct 31, 2023
1 parent be57c14 commit 0880b44
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 38 deletions.
28 changes: 28 additions & 0 deletions docs/userguides/clis.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,31 @@ def create_account(alias):
# We know the alias is not yet used in Ape at this point.
click.echo(alias)
```

You can control additional filtering of the accounts by using the `account_type` kwarg.
Use `account_type` to filter the choices by specific types of [AccountAPI](../methoddocs/api.html#ape.api.accounts.AccountAPI), or you can give it a list of already known accounts, or you can provide a callable-filter that takes an account and returns a boolean.

```python
import click
from ape import accounts
from ape.cli import existing_alias_argument, get_user_selected_account
from ape_accounts.accounts import KeyfileAccount

# NOTE: This is just an example and not anything specific or recommended.
APPLICATION_PREFIX = "<FOO_BAR>"

@click.command()
@existing_alias_argument(account_type=KeyfileAccount)
def cli_0(alias):
pass

@click.command()
@existing_alias_argument(account_type=lambda a: a.alias.startswith(APPLICATION_PREFIX))
def cli_1(alias):
pass


# Select from the given accounts directly.
my_accounts = [accounts.load("me"), accounts.load("me2")]
selected_account = get_user_selected_account(account_type=my_accounts)
```
6 changes: 2 additions & 4 deletions src/ape/cli/arguments.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from itertools import chain
from typing import Optional, Type

import click
from eth_utils import is_hex

from ape import accounts, project
from ape.api import AccountAPI
from ape.cli.choices import Alias
from ape.cli.choices import _ACCOUNT_TYPE_FILTER, Alias
from ape.cli.paramtype import AllFilePaths
from ape.exceptions import AccountsError, AliasAlreadyInUseError

Expand All @@ -26,7 +24,7 @@ def _alias_callback(ctx, param, value):
return value


def existing_alias_argument(account_type: Optional[Type[AccountAPI]] = None):
def existing_alias_argument(account_type: _ACCOUNT_TYPE_FILTER = None):
"""
A ``click.argument`` for an existing account alias.
Expand Down
80 changes: 55 additions & 25 deletions src/ape/cli/choices.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from enum import Enum
from functools import lru_cache
from typing import Any, Iterator, List, Optional, Sequence, Type, Union
from typing import Any, Callable, Iterator, List, Optional, Sequence, Type, Union

import click
from click import BadParameter, Choice, Context, Parameter
Expand All @@ -12,14 +12,36 @@
from ape.types import _LazySequence

ADHOC_NETWORK_PATTERN = re.compile(r"\w*:\w*:https?://\w*.*")
_ACCOUNT_TYPE_FILTER = Union[
None, Sequence[AccountAPI], Type[AccountAPI], Callable[[AccountAPI], bool]
]


def _get_account_by_type(account_type: Optional[Type[AccountAPI]] = None) -> List[AccountAPI]:
account_list = (
list(accounts) if not account_type else accounts.get_accounts_by_type(account_type)
)
account_list.sort(key=lambda a: a.alias or "")
return account_list
def _get_accounts(account_type: _ACCOUNT_TYPE_FILTER) -> List[AccountAPI]:
add_test_accounts = False
if account_type is None:
account_list = list(accounts)

# Include test accounts at end.
add_test_accounts = True

elif isinstance(account_type, type):
# Filtering by type.
account_list = accounts.get_accounts_by_type(account_type)

elif isinstance(account_type, (list, tuple, set)):
# Given an account list.
account_list = account_type # type: ignore

else:
# Filtering by callable.
account_list = [a for a in accounts if account_type(a)] # type: ignore

sorted_accounts = sorted(account_list, key=lambda a: a.alias or "")
if add_test_accounts:
sorted_accounts.extend(accounts.test_accounts)

return sorted_accounts


class Alias(click.Choice):
Expand All @@ -32,15 +54,15 @@ class Alias(click.Choice):

name = "alias"

def __init__(self, account_type: Optional[Type[AccountAPI]] = None):
def __init__(self, account_type: _ACCOUNT_TYPE_FILTER = None):
# NOTE: we purposely skip the constructor of `Choice`
self.case_sensitive = False
self._account_type = account_type
self.choices = _LazySequence(self._choices_iterator)

@property
def _choices_iterator(self) -> Iterator[str]:
for acct in _get_account_by_type(self._account_type):
for acct in _get_accounts(account_type=self._account_type):
if acct.alias is None:
continue

Expand Down Expand Up @@ -117,8 +139,7 @@ def get_user_selected_choice(self) -> str:


def get_user_selected_account(
prompt_message: Optional[str] = None,
account_type: Optional[Type[AccountAPI]] = None,
prompt_message: Optional[str] = None, account_type: _ACCOUNT_TYPE_FILTER = None
) -> AccountAPI:
"""
Prompt the user to pick from their accounts and return that account.
Expand All @@ -128,14 +149,16 @@ def get_user_selected_account(
Args:
prompt_message (Optional[str]): Customize the prompt message.
account_type (Optional[Type[:class:`~ape.api.accounts.AccountAPI`]]]):
If given, the user may only select an account of this type.
account_type (Union[None, Type[AccountAPI], Callable[[AccountAPI], bool]]):
If given, the user may only select a matching account. You can provide
a list of accounts, an account class type, or a callable for filtering
the accounts.
Returns:
:class:`~ape.api.accounts.AccountAPI`
"""

if account_type and not issubclass(account_type, AccountAPI):
if account_type and isinstance(account_type, type) and not issubclass(account_type, AccountAPI):
raise AccountsError(f"Cannot return accounts with type '{account_type}'.")

prompt = AccountAliasPromptChoice(prompt_message=prompt_message, account_type=account_type)
Expand All @@ -150,7 +173,7 @@ class AccountAliasPromptChoice(PromptChoice):

def __init__(
self,
account_type: Optional[Type[AccountAPI]] = None,
account_type: _ACCOUNT_TYPE_FILTER = None,
prompt_message: Optional[str] = None,
name: str = "account",
):
Expand All @@ -163,7 +186,15 @@ def __init__(
def convert(
self, value: Any, param: Optional[Parameter], ctx: Optional[Context]
) -> Optional[AccountAPI]:
if isinstance(value, str) and value.startswith("TEST::"):
if value is None:
return None

if isinstance(value, str) and value.isnumeric():
alias = super().convert(value, param, ctx)
else:
alias = value

if isinstance(alias, str) and alias.startswith("TEST::"):
idx_str = value.replace("TEST::", "")
if not idx_str.isnumeric():
self.fail(f"Cannot reference test account by '{value}'.", param=param)
Expand All @@ -174,12 +205,10 @@ def convert(

self.fail(f"Index '{idx_str}' is not valid.", param=param)

if value and value in accounts.aliases:
return accounts.load(value)
elif alias and alias in accounts.aliases:
return accounts.load(alias)

# Prompt the user if they didn't provide a value.
alias = super().convert(value, param, ctx)
return accounts.load(alias) if alias else None
return None

def print_choices(self):
choices = dict(enumerate(self.choices, 0))
Expand All @@ -203,13 +232,14 @@ def print_choices(self):
@property
def _choices_iterator(self) -> Iterator[str]:
# Yield real accounts.
for account in _get_account_by_type(self._account_type):
for account in _get_accounts(account_type=self._account_type):
if account and (alias := account.alias):
yield alias

# Yield test accounts (at the end).
for idx, _ in enumerate(accounts.test_accounts):
yield f"TEST::{idx}"
# Yield test accounts.
if self._account_type is None:
for idx, _ in enumerate(accounts.test_accounts):
yield f"TEST::{idx}"

def get_user_selected_account(self) -> AccountAPI:
"""
Expand Down
5 changes: 3 additions & 2 deletions src/ape/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ape import networks, project
from ape.cli.choices import (
_ACCOUNT_TYPE_FILTER,
AccountAliasPromptChoice,
NetworkChoice,
OutputFormat,
Expand Down Expand Up @@ -182,15 +183,15 @@ def _account_callback(ctx, param, value):
return value


def account_option():
def account_option(account_type: _ACCOUNT_TYPE_FILTER = None):
"""
A CLI option that accepts either the account alias or the account number.
If not given anything, it will prompt the user to select an account.
"""

return click.option(
"--account",
type=AccountAliasPromptChoice(),
type=AccountAliasPromptChoice(account_type=account_type),
callback=_account_callback,
)

Expand Down
2 changes: 2 additions & 0 deletions src/ape/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def generate_dev_accounts(
number_of_accounts (int): Number of accounts. Defaults to ``10``.
hd_path_format (str): Hard Wallets/HD Keys derivation path format.
Defaults to ``"m/44'/60'/0'/{}"``.
start_index (int): The index to start from in the path. Defaults
to 0.
Returns:
List[:class:`~ape.utils.GeneratedDevAccount`]: List of development accounts.
Expand Down
4 changes: 2 additions & 2 deletions src/ape_test/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _dev_accounts(self) -> List[GeneratedDevAccount]:
@property
def aliases(self) -> Iterator[str]:
for index in range(self._num_of_accounts):
yield f"dev_{index}"
yield f"TEST::{index}"

def _is_config_changed(self):
current_mnemonic = self.config["mnemonic"]
Expand Down Expand Up @@ -95,7 +95,7 @@ class TestAccount(TestAccountAPI):

@property
def alias(self) -> str:
return f"dev_{self.index}"
return f"TEST::{self.index}"

@property
def address(self) -> AddressType:
Expand Down
22 changes: 17 additions & 5 deletions tests/functional/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_get_user_selected_account_no_accounts_found(no_accounts):

def test_get_user_selected_account_one_account(runner, one_account):
# No input needed when only one account
with runner.isolation():
with runner.isolation("0\n"):
account = get_user_selected_account()

assert account == one_account
Expand All @@ -133,9 +133,7 @@ def test_get_user_selected_account_custom_prompt(runner, keyfile_account, second


def test_get_user_selected_account_specify_type(runner, one_keyfile_account):
with runner.isolation():
account = get_user_selected_account(account_type=type(one_keyfile_account))

account = get_user_selected_account(account_type=type(one_keyfile_account))
assert account == one_keyfile_account


Expand All @@ -146,6 +144,20 @@ def test_get_user_selected_account_unknown_type(runner, keyfile_account):
assert "Cannot return accounts with type '<class 'str'>'" in str(err.value)


def test_get_user_selected_account_with_account_list(
runner, keyfile_account, second_keyfile_account
):
account = get_user_selected_account(account_type=[keyfile_account])
assert account == keyfile_account

account = get_user_selected_account(account_type=[second_keyfile_account])
assert account == second_keyfile_account

with runner.isolation(input="1\n"):
account = get_user_selected_account(account_type=[keyfile_account, second_keyfile_account])
assert account == second_keyfile_account


def test_network_option_default(runner, network_cmd):
result = runner.invoke(network_cmd)
assert result.exit_code == 0, result.output
Expand Down Expand Up @@ -233,7 +245,7 @@ def test_account_option_uses_single_account_as_default(runner, one_account):
"""

@click.command()
@account_option()
@account_option(account_type=[one_account])
def cmd(account):
_expected = get_expected_account_str(account)
click.echo(_expected)
Expand Down

0 comments on commit 0880b44

Please sign in to comment.