diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index f55d9352..e9711152 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -8,6 +8,9 @@ TRANSCRIPT_SIZE = 500 TIMEOUT = 1000 MAX_DKG_SIZE = 64 +FEE_RATE = 42 +ERC20_SUPPLY = 10**24 +DURATION = 1234 RitualState = IntEnum( "RitualState", @@ -34,8 +37,12 @@ def initiator(accounts): @pytest.fixture(scope="module") -def stake_info(project, accounts, nodes): - deployer = accounts[8] +def deployer(accounts): + return accounts[8] + + +@pytest.fixture() +def stake_info(project, deployer, nodes): contract = project.StakeInfo.deploy([deployer], sender=deployer) for n in nodes: contract.updateOperator(n, n, sender=deployer) @@ -43,9 +50,37 @@ def stake_info(project, accounts, nodes): return contract -@pytest.fixture(scope="module") -def coordinator(project, accounts, stake_info): - return project.Coordinator.deploy(stake_info.address, TIMEOUT, MAX_DKG_SIZE, sender=accounts[8]) +@pytest.fixture() +def erc20(project, initiator): + # Create an ERC20 token (using NuCypherToken because it's easier, but could be any ERC20) + token = project.NuCypherToken.deploy(ERC20_SUPPLY, sender=initiator) + return token + + +@pytest.fixture() +def flat_rate_fee_model(project, deployer, stake_info, erc20): + contract = project.FlatRateFeeModel.deploy( + erc20.address, + FEE_RATE, + stake_info.address, + sender=deployer + ) + return contract + + +@pytest.fixture() +def coordinator(project, deployer, stake_info, flat_rate_fee_model, initiator): + admin = deployer + contract = project.Coordinator.deploy( + stake_info.address, + TIMEOUT, + MAX_DKG_SIZE, + admin, + flat_rate_fee_model.address, + sender=deployer + ) + contract.grantRole(contract.INITIATOR_ROLE(), initiator, sender=admin) + return contract def test_initial_parameters(coordinator): @@ -54,14 +89,29 @@ def test_initial_parameters(coordinator): assert coordinator.numberOfRituals() == 0 -def test_initiate_ritual(coordinator, nodes, initiator): +def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator): + with ape.reverts("Sender can't initiate ritual"): + sender = accounts[3] + coordinator.initiateRitual(nodes, sender, DURATION, sender=sender) + with ape.reverts("Invalid number of nodes"): - coordinator.initiateRitual(nodes[:5] * 20, sender=initiator) + coordinator.initiateRitual(nodes[:5] * 20, initiator, DURATION, sender=initiator) + + with ape.reverts("Invalid ritual duration"): + coordinator.initiateRitual(nodes, initiator, 0, sender=initiator) with ape.reverts("Providers must be sorted"): - coordinator.initiateRitual(nodes[1:] + [nodes[0]], sender=initiator) + coordinator.initiateRitual(nodes[1:] + [nodes[0]], initiator, DURATION, sender=initiator) + + with ape.reverts("ERC20: insufficient allowance"): + # Sender didn't approve enough tokens + coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + - tx = coordinator.initiateRitual(nodes, sender=initiator) +def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model): + cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) + erc20.approve(coordinator.address, cost, sender=initiator) + tx = coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) events = list(coordinator.StartRitual.from_receipt(tx)) assert len(events) == 1 @@ -73,8 +123,10 @@ def test_initiate_ritual(coordinator, nodes, initiator): assert coordinator.getRitualState(0) == RitualState.AWAITING_TRANSCRIPTS -def test_post_transcript(coordinator, nodes, initiator): - coordinator.initiateRitual(nodes, sender=initiator) +def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_model): + cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) + erc20.approve(coordinator.address, cost, sender=initiator) + coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) for node in nodes: assert coordinator.getRitualState(0) == RitualState.AWAITING_TRANSCRIPTS @@ -97,21 +149,27 @@ def test_post_transcript(coordinator, nodes, initiator): assert coordinator.getRitualState(0) == RitualState.AWAITING_AGGREGATIONS -def test_post_transcript_but_not_part_of_ritual(coordinator, nodes, initiator): - coordinator.initiateRitual(nodes, sender=initiator) +def test_post_transcript_but_not_part_of_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model): + cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) + erc20.approve(coordinator.address, cost, sender=initiator) + coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) with ape.reverts("Participant not part of ritual"): coordinator.postTranscript(0, os.urandom(TRANSCRIPT_SIZE), sender=initiator) -def test_post_transcript_but_already_posted_transcript(coordinator, nodes, initiator): - coordinator.initiateRitual(nodes, sender=initiator) +def test_post_transcript_but_already_posted_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_model): + cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) + erc20.approve(coordinator.address, cost, sender=initiator) + coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) coordinator.postTranscript(0, os.urandom(TRANSCRIPT_SIZE), sender=nodes[0]) with ape.reverts("Node already posted transcript"): coordinator.postTranscript(0, os.urandom(TRANSCRIPT_SIZE), sender=nodes[0]) -def test_post_transcript_but_not_waiting_for_transcripts(coordinator, nodes, initiator): - coordinator.initiateRitual(nodes, sender=initiator) +def test_post_transcript_but_not_waiting_for_transcripts(coordinator, nodes, initiator, erc20, flat_rate_fee_model): + cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) + erc20.approve(coordinator.address, cost, sender=initiator) + coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) for node in nodes: transcript = os.urandom(TRANSCRIPT_SIZE) coordinator.postTranscript(0, transcript, sender=node) @@ -120,8 +178,10 @@ def test_post_transcript_but_not_waiting_for_transcripts(coordinator, nodes, ini coordinator.postTranscript(0, os.urandom(TRANSCRIPT_SIZE), sender=nodes[1]) -def test_post_aggregation(coordinator, nodes, initiator): - coordinator.initiateRitual(nodes, sender=initiator) +def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_model): + cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) + erc20.approve(coordinator.address, cost, sender=initiator) + coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) transcript = os.urandom(TRANSCRIPT_SIZE) for node in nodes: coordinator.postTranscript(0, transcript, sender=node) @@ -152,5 +212,4 @@ def test_post_aggregation(coordinator, nodes, initiator): assert len(events) == 1 event = events[0] assert event["ritualId"] == 0 - assert event["initiator"] == initiator.address assert event["successful"]