diff --git a/ape-config.yaml b/ape-config.yaml index 2a9c00c2..9727fd79 100644 --- a/ape-config.yaml +++ b/ape-config.yaml @@ -52,8 +52,6 @@ deployments: pre_base_penalty: 2 pre_penalty_history_coefficient: 0 pre_percentage_penalty_coefficient: 100000 - pre_min_authorization: 40000000000000000000000 - pre_min_operator_seconds: 86400 # one day in seconds reward_duration: 604800 # one week in seconds deauthorization_duration: 5184000 # 60 days in seconds verify: False diff --git a/contracts/contracts/coordination/Coordinator.sol b/contracts/contracts/coordination/Coordinator.sol index 895eb636..b49a127b 100644 --- a/contracts/contracts/coordination/Coordinator.sol +++ b/contracts/contracts/coordination/Coordinator.sol @@ -9,6 +9,7 @@ import "./IFeeModel.sol"; import "./IReimbursementPool.sol"; import "../lib/BLS12381.sol"; import "../../threshold/IAccessControlApplication.sol"; +import "./IEncryptionAuthorizer.sol"; /** * @title Coordinator @@ -58,6 +59,7 @@ contract Coordinator is AccessControlDefaultAdminRules { address authority; uint16 dkgSize; bool aggregationMismatch; + IEncryptionAuthorizer accessController; BLS12381.G1Point publicKey; bytes aggregatedTranscript; Participant[] participant; @@ -100,11 +102,15 @@ contract Coordinator is AccessControlDefaultAdminRules { feeModel = IFeeModel(_feeModel); } - function getRitualState(uint256 ritualId) external view returns (RitualState){ - // TODO: restrict to ritualID < rituals.length? + function getRitualState(uint32 ritualId) external view returns (RitualState){ + // TODO: restrict to ritualId < rituals.length? return getRitualState(rituals[ritualId]); } + function isRitualFinalized(uint32 ritualId) external view returns (bool){ + return getRitualState(rituals[ritualId]) == RitualState.FINALIZED; + } + function getRitualState(Ritual storage ritual) internal view returns (RitualState){ uint32 t0 = ritual.initTimestamp; uint32 deadline = t0 + timeout; @@ -125,6 +131,7 @@ contract Coordinator is AccessControlDefaultAdminRules { // - No public key // - All transcripts and all aggregations // - Still within the deadline + revert("Invalid ritual state"); } } @@ -175,6 +182,13 @@ contract Coordinator is AccessControlDefaultAdminRules { // TODO: Events } + function setRitualAuthority(uint32 ritualId, address authority) external { + Ritual storage ritual = rituals[ritualId]; + require(getRitualState(ritual) == RitualState.FINALIZED, "Ritual not finalized"); + require(msg.sender == ritual.authority, "Sender not ritual authority"); + ritual.authority = authority; + } + function numberOfRituals() external view returns (uint256) { return rituals.length; } @@ -187,8 +201,12 @@ contract Coordinator is AccessControlDefaultAdminRules { function initiateRitual( address[] calldata providers, address authority, - uint32 duration + uint32 duration, + IEncryptionAuthorizer accessController ) external returns (uint32) { + + require(authority != address(0), "Invalid authority"); + require( isInitiationPublic || hasRole(INITIATOR_ROLE, msg.sender), "Sender can't initiate ritual" @@ -205,6 +223,7 @@ contract Coordinator is AccessControlDefaultAdminRules { ritual.dkgSize = uint16(length); ritual.initTimestamp = uint32(block.timestamp); ritual.endTimestamp = ritual.initTimestamp + duration; + ritual.accessController = accessController; address previous = address(0); for (uint256 i = 0; i < length; i++) { @@ -272,6 +291,10 @@ contract Coordinator is AccessControlDefaultAdminRules { processReimbursement(initialGasLeft); } + function getAuthority(uint32 ritualId) external view returns (address) { + return rituals[ritualId].authority; + } + function postAggregation( uint32 ritualId, bytes calldata aggregatedTranscript, @@ -359,26 +382,26 @@ contract Coordinator is AccessControlDefaultAdminRules { } function getParticipantFromProvider( - uint256 ritualID, + uint32 ritualId, address provider ) external view returns (Participant memory) { - return getParticipantFromProvider(rituals[ritualID], provider); + return getParticipantFromProvider(rituals[ritualId], provider); } - function processRitualPayment(uint256 ritualID, address[] calldata providers, uint32 duration) internal { + function processRitualPayment(uint32 ritualId, address[] calldata providers, uint32 duration) internal { uint256 ritualCost = feeModel.getRitualInitiationCost(providers, duration); if (ritualCost > 0) { totalPendingFees += ritualCost; - assert(pendingFees[ritualID] == 0); // TODO: This is an invariant, not sure if actually needed - pendingFees[ritualID] += ritualCost; + assert(pendingFees[ritualId] == 0); // TODO: This is an invariant, not sure if actually needed + pendingFees[ritualId] += ritualCost; IERC20 currency = IERC20(feeModel.currency()); currency.safeTransferFrom(msg.sender, address(this), ritualCost); // TODO: Define methods to manage these funds } } - function processPendingFee(uint256 ritualID) public { - Ritual storage ritual = rituals[ritualID]; + function processPendingFee(uint32 ritualId) public { + Ritual storage ritual = rituals[ritualId]; RitualState state = getRitualState(ritual); require( state == RitualState.TIMEOUT || @@ -386,12 +409,12 @@ contract Coordinator is AccessControlDefaultAdminRules { state == RitualState.FINALIZED, "Ritual is not ended" ); - uint256 pending = pendingFees[ritualID]; + uint256 pending = pendingFees[ritualId]; require(pending > 0, "No pending fees for this ritual"); // Finalize fees for this ritual totalPendingFees -= pending; - delete pendingFees[ritualID]; + delete pendingFees[ritualId]; // Transfer fees back to initiator if failed if (state == RitualState.TIMEOUT || state == RitualState.INVALID) { // Amount to refund depends on how much work nodes did for the ritual. diff --git a/contracts/contracts/coordination/GlobalAllowList.sol b/contracts/contracts/coordination/GlobalAllowList.sol new file mode 100644 index 00000000..24da89fd --- /dev/null +++ b/contracts/contracts/coordination/GlobalAllowList.sol @@ -0,0 +1,58 @@ +pragma solidity ^0.8.0; +import "@openzeppelin/contracts/access/AccessControlDefaultAdminRules.sol"; +import "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import "./IEncryptionAuthorizer.sol"; +import "./Coordinator.sol"; + + +contract GlobalAllowList is AccessControlDefaultAdminRules, IEncryptionAuthorizer { + using ECDSA for bytes32; + + Coordinator public coordinator; + mapping(uint256 => mapping(address => bool)) public authorizations; + + constructor( + Coordinator _coordinator, + address _admin + ) AccessControlDefaultAdminRules(0, _admin) { + require(address(_coordinator) != address(0), "Coordinator cannot be zero address"); + require(_coordinator.numberOfRituals() >= 0, "Invalid coordinator"); + coordinator = _coordinator; + } + + modifier onlyAuthority(uint32 ritualId) { + require(coordinator.getAuthority(ritualId) == msg.sender, + "Only ritual authority is permitted"); + _; + } + + function setCoordinator(Coordinator _coordinator) public { + require(hasRole(DEFAULT_ADMIN_ROLE, msg.sender), "Only admin can set coordinator"); + coordinator = _coordinator; + } + + function isAuthorized( + uint32 ritualId, + bytes memory evidence, + bytes32 digest + ) public view override returns(bool) { + address recovered_address = digest.toEthSignedMessageHash().recover(evidence); + return authorizations[ritualId][recovered_address]; + } + + function authorize(uint32 ritualId, address[] calldata addresses) public onlyAuthority(ritualId) { + require(coordinator.isRitualFinalized(ritualId), + "Only active rituals can add authorizations"); + for (uint256 i=0; i < addresses.length; i++) { + authorizations[ritualId][addresses[i]] = true; + } + } + + function deauthorize(uint32 ritualId, address[] calldata addresses) public onlyAuthority(ritualId) { + require(coordinator.isRitualFinalized(ritualId), + "Only active rituals can add authorizations"); + for (uint256 i=0; i < addresses.length; i++) { + authorizations[ritualId][addresses[i]] = false; + } + } +} \ No newline at end of file diff --git a/contracts/contracts/coordination/IEncryptionAuthorizer.sol b/contracts/contracts/coordination/IEncryptionAuthorizer.sol new file mode 100644 index 00000000..dd1a036b --- /dev/null +++ b/contracts/contracts/coordination/IEncryptionAuthorizer.sol @@ -0,0 +1,9 @@ +pragma solidity ^0.8.0; + +interface IEncryptionAuthorizer { + function isAuthorized( + uint32 ritualID, + bytes memory evidence, // signature + bytes32 digest // signed message hash + ) external view returns(bool); +} diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 9ce4cbac..6dd9be84 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -3,6 +3,7 @@ import ape import pytest +from eth_account.messages import encode_defunct from web3 import Web3 TIMEOUT = 1000 @@ -93,49 +94,93 @@ def coordinator(project, deployer, stake_info, flat_rate_fee_model, initiator): return contract +@pytest.fixture() +def global_allow_list(project, deployer, coordinator): + contract = project.GlobalAllowList.deploy( + coordinator.address, + deployer, # admin + sender=deployer + ) + return contract + + def test_initial_parameters(coordinator): assert coordinator.maxDkgSize() == MAX_DKG_SIZE assert coordinator.timeout() == TIMEOUT assert coordinator.numberOfRituals() == 0 -def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator): +def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator, global_allow_list): with ape.reverts("Sender can't initiate ritual"): sender = accounts[3] - coordinator.initiateRitual(nodes, sender, DURATION, sender=sender) + coordinator.initiateRitual( + nodes, sender, DURATION, global_allow_list.address, sender=sender + ) with ape.reverts("Invalid number of nodes"): - coordinator.initiateRitual(nodes[:5] * 20, initiator, DURATION, sender=initiator) + coordinator.initiateRitual( + nodes[:5] * 20, + initiator, + DURATION, + global_allow_list.address, + sender=initiator + ) with ape.reverts("Invalid ritual duration"): - coordinator.initiateRitual(nodes, initiator, 0, sender=initiator) + coordinator.initiateRitual(nodes, initiator, 0, global_allow_list.address, sender=initiator) with ape.reverts("Provider has not set their public key"): - coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + coordinator.initiateRitual(nodes, initiator, DURATION, global_allow_list.address, sender=initiator) for node in nodes: public_key = gen_public_key() coordinator.setProviderPublicKey(public_key, sender=node) + with ape.reverts("Providers must be sorted"): - coordinator.initiateRitual(nodes[1:] + [nodes[0]], initiator, DURATION, sender=initiator) + coordinator.initiateRitual( + nodes[1:] + [nodes[0]], + initiator, + DURATION, + global_allow_list.address, + sender=initiator + ) with ape.reverts("ERC20: insufficient allowance"): # Sender didn't approve enough tokens - coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + coordinator.initiateRitual( + nodes, + initiator, + DURATION, + global_allow_list.address, + sender=initiator + ) -def initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes): +def initiate_ritual(coordinator, erc20, fee_model, allow_logic, authority, nodes): for node in nodes: public_key = gen_public_key() coordinator.setProviderPublicKey(public_key, sender=node) - cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) - erc20.approve(coordinator.address, cost, sender=initiator) - tx = coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) - return initiator, tx + cost = fee_model.getRitualInitiationCost(nodes, DURATION) + erc20.approve(coordinator.address, cost, sender=authority) + tx = coordinator.initiateRitual( + nodes, + authority, + DURATION, + allow_logic.address, + sender=authority + ) + return authority, tx -def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model): - authority, tx = initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) +def test_initiate_ritual(coordinator, nodes, initiator, erc20, global_allow_list, flat_rate_fee_model): + authority, tx = initiate_ritual( + coordinator=coordinator, + erc20=erc20, + fee_model=flat_rate_fee_model, + authority=initiator, + nodes=nodes, + allow_logic=global_allow_list + ) events = list(coordinator.StartRitual.from_receipt(tx)) assert len(events) == 1 @@ -161,8 +206,15 @@ def test_provider_public_key(coordinator, nodes): assert coordinator.getProviderPublicKey(selected_provider, ritual_id) == public_key -def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_model): - initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) +def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_model, global_allow_list): + initiate_ritual( + coordinator=coordinator, + erc20=erc20, + fee_model=flat_rate_fee_model, + authority=initiator, + nodes=nodes, + allow_logic=global_allow_list + ) transcript = os.urandom(transcript_size(len(nodes), len(nodes))) for node in nodes: @@ -186,18 +238,33 @@ def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_mod def test_post_transcript_but_not_part_of_ritual( - coordinator, nodes, initiator, erc20, flat_rate_fee_model + coordinator, nodes, initiator, erc20, flat_rate_fee_model, global_allow_list ): - initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) + initiate_ritual( + coordinator=coordinator, + erc20=erc20, + fee_model=flat_rate_fee_model, + authority=initiator, + nodes=nodes, + allow_logic=global_allow_list + ) + transcript = os.urandom(transcript_size(len(nodes), len(nodes))) with ape.reverts("Participant not part of ritual"): coordinator.postTranscript(0, transcript, sender=initiator) def test_post_transcript_but_already_posted_transcript( - coordinator, nodes, initiator, erc20, flat_rate_fee_model + coordinator, nodes, initiator, erc20, flat_rate_fee_model, global_allow_list ): - initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) + initiate_ritual( + coordinator=coordinator, + erc20=erc20, + fee_model=flat_rate_fee_model, + authority=initiator, + nodes=nodes, + allow_logic=global_allow_list + ) transcript = os.urandom(transcript_size(len(nodes), len(nodes))) coordinator.postTranscript(0, transcript, sender=nodes[0]) with ape.reverts("Node already posted transcript"): @@ -205,9 +272,16 @@ def test_post_transcript_but_already_posted_transcript( def test_post_transcript_but_not_waiting_for_transcripts( - coordinator, nodes, initiator, erc20, flat_rate_fee_model + coordinator, nodes, initiator, erc20, flat_rate_fee_model, global_allow_list ): - initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) + initiate_ritual( + coordinator=coordinator, + erc20=erc20, + fee_model=flat_rate_fee_model, + authority=initiator, + nodes=nodes, + allow_logic=global_allow_list + ) transcript = os.urandom(transcript_size(len(nodes), len(nodes))) for node in nodes: coordinator.postTranscript(0, transcript, sender=node) @@ -216,8 +290,15 @@ def test_post_transcript_but_not_waiting_for_transcripts( coordinator.postTranscript(0, transcript, sender=nodes[1]) -def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_model): - initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) +def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_model, global_allow_list): + initiate_ritual( + coordinator=coordinator, + erc20=erc20, + fee_model=flat_rate_fee_model, + authority=initiator, + nodes=nodes, + allow_logic=global_allow_list + ) transcript = os.urandom(transcript_size(len(nodes), len(nodes))) for node in nodes: coordinator.postTranscript(0, transcript, sender=node) @@ -249,3 +330,72 @@ def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_mo event = events[0] assert event["ritualId"] == 0 assert event["successful"] + + +def test_authorize_using_global_allow_list( + coordinator, + nodes, + deployer, + initiator, + erc20, + flat_rate_fee_model, + global_allow_list +): + + initiate_ritual( + coordinator=coordinator, + erc20=erc20, + fee_model=flat_rate_fee_model, + authority=initiator, + nodes=nodes, + allow_logic=global_allow_list + ) + + global_allow_list.setCoordinator(coordinator.address, sender=deployer) + + # This block mocks the signature of a threshold decryption request + w3 = Web3() + data = os.urandom(32) + digest = Web3.keccak(data) + signable_message = encode_defunct(digest) + signed_digest = w3.eth.account.sign_message(signable_message, private_key=deployer.private_key) + signature = signed_digest.signature + + # Not authorized + assert not global_allow_list.isAuthorized(0, bytes(signature), bytes(digest)) + + # Negative test cases for authorization + with ape.reverts("Only ritual authority is permitted"): + global_allow_list.authorize(0, [deployer.address], sender=deployer) + + with ape.reverts("Only active rituals can add authorizations"): + global_allow_list.authorize(0, [deployer.address], sender=initiator) + + # Finalize ritual + transcript = os.urandom(transcript_size(len(nodes), len(nodes))) + for node in nodes: + coordinator.postTranscript(0, transcript, sender=node) + + aggregated = transcript + decryption_request_static_keys = [os.urandom(42) for _ in nodes] + dkg_public_key = (os.urandom(32), os.urandom(16)) + for i, node in enumerate(nodes): + coordinator.postAggregation(0, aggregated, dkg_public_key, decryption_request_static_keys[i], sender=node) + + # Actually authorize + global_allow_list.authorize(0, [deployer.address], sender=initiator) + + # Authorized + assert global_allow_list.isAuthorized(0, bytes(signature), bytes(digest)) + + # Deauthorize + global_allow_list.deauthorize(0, [deployer.address], sender=initiator) + assert not global_allow_list.isAuthorized(0, bytes(signature), bytes(digest)) + + # Reauthorize in batch + addresses_to_authorize = [deployer.address, initiator.address] + global_allow_list.authorize(0, addresses_to_authorize, sender=initiator) + signed_digest = w3.eth.account.sign_message(signable_message, private_key=initiator.private_key) + initiator_signature = signed_digest.signature + assert global_allow_list.isAuthorized(0, bytes(initiator_signature), bytes(digest)) + assert global_allow_list.isAuthorized(0, bytes(signature), bytes(digest))