From fe06651a291d48cd5fe27981602b4bac80b64228 Mon Sep 17 00:00:00 2001 From: KPrasch Date: Fri, 16 Feb 2024 19:46:39 +0100 Subject: [PATCH] formalizes python package; extracts tests from nucypher --- MANIFEST.in | 1 + tests/__init__.py | 0 tests/conftest.py | 2 + tests/package/__init__.py | 0 tests/package/conftest.py | 9 ++ tests/package/constants.py | 17 ++++ tests/package/test_registry_basics.py | 85 ++++++++++++++++ tests/package/test_registry_soures.py | 138 ++++++++++++++++++++++++++ tests/package/test_taco_domains.py | 107 ++++++++++++++++++++ tests/utils.py | 95 ++++++++++++++++++ 10 files changed, 454 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/package/__init__.py create mode 100644 tests/package/conftest.py create mode 100644 tests/package/constants.py create mode 100644 tests/package/test_registry_basics.py create mode 100644 tests/package/test_registry_soures.py create mode 100644 tests/package/test_taco_domains.py create mode 100644 tests/utils.py diff --git a/MANIFEST.in b/MANIFEST.in index e8bb72d8..251bdce7 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,6 +3,7 @@ include README.md include requirements.txt include dev-requirements.txt recursive-include deployment/artifacts *.json +recursive-exclude nucypher_contracts/tests * recursive-exclude tests * recursive-exclude * __pycache__ global-exclude *.py[cod] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index 6b8bb8d4..690dc0c7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ import pytest from ape import convert, project +from nucypher_contracts.domains import TACoDomain, ChainInfo + @pytest.fixture(scope="session") def oz_dependency(): diff --git a/tests/package/__init__.py b/tests/package/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/package/conftest.py b/tests/package/conftest.py new file mode 100644 index 00000000..f1ba7106 --- /dev/null +++ b/tests/package/conftest.py @@ -0,0 +1,9 @@ +import pytest + +from tests.utils import mock_registry_sources + + +@pytest.fixture(scope="module", autouse=True) +def auto_mock_registry_sources(module_mocker): + with mock_registry_sources(mocker=module_mocker): + yield diff --git a/tests/package/constants.py b/tests/package/constants.py new file mode 100644 index 00000000..f2f7ae0f --- /dev/null +++ b/tests/package/constants.py @@ -0,0 +1,17 @@ +from nucypher_contracts.domains import ChainInfo, TACoDomain + +TEMPORARY_DOMAIN_NAME = ":temporary-domain:" + +TESTERCHAIN_CHAIN_ID = 131277322940537 + +TESTERCHAIN_CHAIN_INFO = ChainInfo( + TESTERCHAIN_CHAIN_ID, + "eth-tester" +) + +TEMPORARY_DOMAIN = TACoDomain( + name=TEMPORARY_DOMAIN_NAME, + eth_chain=TESTERCHAIN_CHAIN_INFO, + polygon_chain=TESTERCHAIN_CHAIN_INFO, + condition_chains=(TESTERCHAIN_CHAIN_INFO,), +) diff --git a/tests/package/test_registry_basics.py b/tests/package/test_registry_basics.py new file mode 100644 index 00000000..c19efb9c --- /dev/null +++ b/tests/package/test_registry_basics.py @@ -0,0 +1,85 @@ +import pytest + +from nucypher_contracts.registry import ContractRegistry +from tests.utils import MockRegistrySource +from tests.package.constants import TEMPORARY_DOMAIN, TESTERCHAIN_CHAIN_ID + + +@pytest.fixture(scope="function") +def name(): + return "TestContract" + + +@pytest.fixture(scope="function") +def address(): + return "0xdeadbeef" + + +@pytest.fixture(scope="function") +def abi(): + return ["fake", "data"] + + +@pytest.fixture(scope="function") +def record(name, address, abi): + record_data = {name: {"address": address, "abi": abi}} + return record_data + + +@pytest.fixture(scope="function") +def data(record): + _data = {TESTERCHAIN_CHAIN_ID: record} + return _data + + +@pytest.fixture(scope="function") +def source(data): + source = MockRegistrySource(domain=TEMPORARY_DOMAIN) + source.data = data + return source + + +@pytest.fixture(scope="function") +def registry(record, source): + registry = ContractRegistry(source=source) + return registry + + +def test_registry_id_consistency(registry, source): + new_registry = ContractRegistry(source=source) + new_registry._data = registry._data + assert new_registry.id == registry.id + + +def test_registry_name_search(registry, name, address, abi): + record = registry.search(chain_id=TESTERCHAIN_CHAIN_ID, contract_name=name) + assert len(record) == 4, "Registry record is the wrong length" + assert record.chain_id == TESTERCHAIN_CHAIN_ID + assert record.name == name + assert record.address == address + assert record.abi == abi + + +def test_registry_address_search(registry, name, address, abi): + record = registry.search(chain_id=TESTERCHAIN_CHAIN_ID, contract_address=address) + assert len(record) == 4, "Registry record is the wrong length" + assert record.chain_id == TESTERCHAIN_CHAIN_ID + assert record.name == name + assert record.address == address + assert record.abi == abi + + +def test_local_registry_unknown_contract_name_search(registry): + with pytest.raises(ContractRegistry.UnknownContract): + registry.search( + chain_id=TESTERCHAIN_CHAIN_ID, contract_name="this does not exist" + ) + + +def test_local_contract_registry_ambiguous_search_terms(data, name, record, address): + data[TESTERCHAIN_CHAIN_ID]["fakeContract"] = record[name] + source = MockRegistrySource(domain=TEMPORARY_DOMAIN) + source.data = data + registry = ContractRegistry(source=source) + with pytest.raises(ContractRegistry.AmbiguousSearchTerms): + registry.search(chain_id=TESTERCHAIN_CHAIN_ID, contract_address=address) diff --git a/tests/package/test_registry_soures.py b/tests/package/test_registry_soures.py new file mode 100644 index 00000000..4d9b112e --- /dev/null +++ b/tests/package/test_registry_soures.py @@ -0,0 +1,138 @@ +import json + +import pytest +import requests +from requests import Response + +from nucypher_contracts import domains +from nucypher_contracts.registry import ( + EmbeddedRegistrySource, + GithubRegistrySource, + LocalRegistrySource, + RegistrySource, + RegistrySourceManager, +) +from tests.package.constants import ( + TEMPORARY_DOMAIN, + TEMPORARY_DOMAIN_NAME +) + + +@pytest.fixture(scope="function") +def registry_data(): + _registry_data = { + "2958363635247": { + "TestContract": {"address": "0xdeadbeef", "abi": ["fake", "data"]}, + "AnotherTestContract": {"address": "0xdeadbeef", "abi": ["fake", "data"]}, + }, + "393742274944474": { + "YetAnotherContract": {"address": "0xdeadbeef", "abi": ["fake", "data"]} + }, + } + return _registry_data + + +@pytest.fixture(scope="function") +def mock_200_response(mocker, registry_data): + mock_response = Response() + mock_response.status_code = 200 + mock_response._content = json.dumps(registry_data).encode("utf-8") + mocker.patch.object(requests, "get", return_value=mock_response) + + +@pytest.fixture(scope="function") +def test_registry_filepath(tmpdir, registry_data): + filepath = tmpdir.join("registry.json") + with open(filepath, "w") as f: + json.dump(registry_data, f) + yield filepath + filepath.remove() + + +@pytest.mark.usefixtures("mock_200_response") +def test_github_registry_source(registry_data): + source = GithubRegistrySource(domain=TEMPORARY_DOMAIN) + assert source.domain.name == TEMPORARY_DOMAIN_NAME + assert str(source.domain) == TEMPORARY_DOMAIN_NAME + assert bytes(source.domain) == TEMPORARY_DOMAIN_NAME.encode("utf-8") + data = source.get() + assert data == registry_data + assert source.data == registry_data + assert data == source.data + + +@pytest.mark.parametrize("domain", list(domains.SUPPORTED_DOMAINS.values())) +def test_get_actual_github_registry_file(domain): + source = GithubRegistrySource(domain=domain) + assert str(domain.eth_chain.id) in source.data + assert str(domain.polygon_chain.id) in source.data + + +def test_local_registry_source(registry_data, test_registry_filepath): + source = LocalRegistrySource( + filepath=test_registry_filepath, domain=TEMPORARY_DOMAIN + ) + assert source.domain.name == TEMPORARY_DOMAIN_NAME + assert str(source.domain) == TEMPORARY_DOMAIN_NAME + assert bytes(source.domain) == TEMPORARY_DOMAIN_NAME.encode("utf-8") + data = source.get() + assert data == registry_data + assert source.data == registry_data + assert data == source.data + + +def test_embedded_registry_source(registry_data, test_registry_filepath, mocker): + mocker.patch.object( + EmbeddedRegistrySource, + "get_publication_endpoint", + return_value=test_registry_filepath, + ) + source = EmbeddedRegistrySource(domain=TEMPORARY_DOMAIN) + assert source.domain.name == TEMPORARY_DOMAIN_NAME + assert str(source.domain) == TEMPORARY_DOMAIN_NAME + assert bytes(source.domain) == TEMPORARY_DOMAIN_NAME.encode("utf-8") + data = source.get() + assert data == registry_data + assert source.data == registry_data + assert data == source.data + + +def test_registry_source_manager_fallback( + registry_data, test_registry_filepath, mocker +): + github_source_get = mocker.patch.object( + GithubRegistrySource, "get", side_effect=RegistrySource.Unavailable + ) + mocker.patch.object( + EmbeddedRegistrySource, + "get_publication_endpoint", + return_value=test_registry_filepath, + ) + embedded_source_get = mocker.spy(EmbeddedRegistrySource, "get") + RegistrySourceManager._FALLBACK_CHAIN = ( + GithubRegistrySource, + EmbeddedRegistrySource, + ) + source_manager = RegistrySourceManager(domain=TEMPORARY_DOMAIN) + assert source_manager.domain.name == TEMPORARY_DOMAIN_NAME + assert str(source_manager.domain) == TEMPORARY_DOMAIN_NAME + assert bytes(source_manager.domain) == TEMPORARY_DOMAIN_NAME.encode("utf-8") + + primary_sources = source_manager.get_primary_sources() + assert len(primary_sources) == 1 + assert primary_sources[0] == GithubRegistrySource + + source = source_manager.fetch_latest_publication() + github_source_get.assert_called_once() + embedded_source_get.assert_called_once() + assert source.data == registry_data + assert isinstance(source, EmbeddedRegistrySource) + + mocker.patch.object( + EmbeddedRegistrySource, + "get_publication_endpoint", + side_effect=RegistrySource.Unavailable, + ) + + with pytest.raises(RegistrySourceManager.NoSourcesAvailable): + source_manager.fetch_latest_publication() diff --git a/tests/package/test_taco_domains.py b/tests/package/test_taco_domains.py new file mode 100644 index 00000000..516e27b1 --- /dev/null +++ b/tests/package/test_taco_domains.py @@ -0,0 +1,107 @@ +import pytest + +from nucypher_contracts import domains +from nucypher_contracts.domains import EthChain, PolygonChain + + +@pytest.fixture(scope="module") +def test_registry(module_mocker): + # override fixture which mocks domains.SUPPORTED_DOMAINS + yield + + +@pytest.fixture(scope="module", autouse=True) +def mock_condition_blockchains(module_mocker): + # override fixture which mocks domains.get_domain + yield + + +@pytest.mark.parametrize( + "eth_chain_test", + ( + (EthChain.MAINNET, "mainnet", 1), + (EthChain.SEPOLIA, "sepolia", 11155111), + ), +) +def test_eth_chains(eth_chain_test): + eth_chain, expected_name, expected_id = eth_chain_test + assert eth_chain.name == expected_name + assert eth_chain.id == expected_id + + +@pytest.mark.parametrize( + "poly_chain_test", + ( + (PolygonChain.MAINNET, "polygon", 137), + (PolygonChain.MUMBAI, "mumbai", 80001), + ), +) +def test_polygon_chains(poly_chain_test): + eth_chain, expected_name, expected_id = poly_chain_test + assert eth_chain.name == expected_name + assert eth_chain.id == expected_id + + +@pytest.mark.parametrize( + "taco_domain_test", + ( + ( + domains.MAINNET, + "mainnet", + EthChain.MAINNET, + PolygonChain.MAINNET, + (EthChain.MAINNET, PolygonChain.MAINNET), + ), + ( + domains.LYNX, + "lynx", + EthChain.SEPOLIA, + PolygonChain.MUMBAI, + ( + EthChain.MAINNET, + EthChain.SEPOLIA, + PolygonChain.MUMBAI, + PolygonChain.MAINNET, + ), + ), + ( + domains.TAPIR, + "tapir", + EthChain.SEPOLIA, + PolygonChain.MUMBAI, + (EthChain.SEPOLIA, PolygonChain.MUMBAI), + ), + ), +) +def test_taco_domain_info(taco_domain_test): + ( + domain_info, + expected_name, + expected_eth_chain, + expected_polygon_chain, + expected_condition_chains, + ) = taco_domain_test + assert domain_info.name == expected_name + assert domain_info.eth_chain == expected_eth_chain + assert domain_info.polygon_chain == expected_polygon_chain + assert domain_info.condition_chains == expected_condition_chains + + assert domain_info.is_testnet == (expected_name != "mainnet") + + +@pytest.mark.parametrize( + "domain_name_test", + ( + ("mainnet", domains.MAINNET), + ("lynx", domains.LYNX), + ("tapir", domains.TAPIR), + ), +) +def test_get_domain(domain_name_test): + domain_name, expected_domain_info = domain_name_test + assert domains.get_domain(domain_name) == expected_domain_info + + +def test_get_domain_unrecognized_domain_name(): + with pytest.raises(domains.UnrecognizedTacoDomain): + domains.get_domain("5am_In_Toronto") diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..c2777123 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,95 @@ +from collections import defaultdict +from contextlib import contextmanager +from typing import List + +from ape.contracts import ContractInstance +from eth_utils import to_checksum_address + +from nucypher_contracts.domains import TACoDomain +from nucypher_contracts.registry import ( + RegistryData, + RegistrySource, + RegistrySourceManager, +) +from tests.package.constants import ( + TEMPORARY_DOMAIN, + TEMPORARY_DOMAIN_NAME +) + + +@contextmanager +def mock_registry_sources(mocker, _domains: List[TACoDomain] = None): + if not _domains: + _domains = [TEMPORARY_DOMAIN] + _supported_domains = mocker.patch.dict( + "nucypher_contracts.domains.SUPPORTED_DOMAINS", + {str(domain): domain for domain in _domains}, + ) + mocker.patch.object(MockRegistrySource, "ALLOWED_DOMAINS", list(map(str, _domains))) + mocker.patch.object(RegistrySourceManager, "_FALLBACK_CHAIN", (MockRegistrySource,)) + yield + + +class MockRegistrySource(RegistrySource): + ALLOWED_DOMAINS = [TEMPORARY_DOMAIN_NAME] + + name = "Mock Registry Source" + is_primary = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if str(self.domain) not in self.ALLOWED_DOMAINS: + raise ValueError( + f"Somehow, MockRegistrySource is trying to get a registry for '{self.domain}'. " + f"Only '{','.join(self.ALLOWED_DOMAINS)}' are supported.'" + ) + + @property + def registry_name(self) -> str: + return str(self.domain) + + def get_publication_endpoint(self) -> str: + return f":mock-registry-source:/{self.registry_name}" + + def get(self) -> RegistryData: + self.logger.debug(f"Reading registry at {self.get_publication_endpoint()}") + data = dict() + return data + + +class ApeRegistrySource(RegistrySource): + name = "Ape Registry Source" + is_primary = False + + _DEPLOYMENTS = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if str(self.domain) != TEMPORARY_DOMAIN_NAME: + raise ValueError( + f"Somehow, ApeRegistrySource is trying to get a registry for '{self.domain}'. " + f"Only '{TEMPORARY_DOMAIN_NAME}' is supported.'" + ) + if self._DEPLOYMENTS is None: + raise ValueError( + "ApeRegistrySource has not been initialized with deployments." + ) + + @classmethod + def set_deployments(cls, deployments: List[ContractInstance]): + cls._DEPLOYMENTS = deployments + + def get_publication_endpoint(self) -> str: + return "ape" + + def get(self) -> RegistryData: + data = defaultdict(dict) + for contract_instance in self._DEPLOYMENTS: + entry = { + "address": to_checksum_address(contract_instance.address), + "abi": [abi.dict() for abi in contract_instance.contract_type.abi], + } + chain_id = contract_instance.chain_manager.chain_id + contract_name = contract_instance.contract_type.name + data[chain_id][contract_name] = entry + return data