diff --git a/contracts/contracts/coordination/Coordinator.sol b/contracts/contracts/coordination/Coordinator.sol index 125a559f5..e58337a71 100644 --- a/contracts/contracts/coordination/Coordinator.sol +++ b/contracts/contracts/coordination/Coordinator.sol @@ -93,7 +93,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable ITACoChildApplication public immutable application; uint96 private immutable minAuthorization; // TODO use child app for checking eligibility - Ritual[] public rituals; + Ritual[] internal ritualsStub; // former rituals, "internal" for testing only uint32 public timeout; uint16 public maxDkgSize; bool private stub1; // former isInitiationPublic @@ -106,10 +106,13 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable mapping(address => ParticipantKey[]) internal participantKeysHistory; mapping(bytes32 => uint32) internal ritualPublicKeyRegistry; mapping(IFeeModel => bool) public feeModelsRegistry; + + mapping(uint256 index => Ritual ritual) internal _rituals; + uint256 public numberOfRituals; // Note: Adjust the __preSentinelGap size if more contract variables are added // Storage area for sentinel values - uint256[19] internal __preSentinelGap; + uint256[17] internal __preSentinelGap; Participant internal __sentinelParticipant; uint256[20] internal __postSentinelGap; @@ -128,27 +131,50 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable __AccessControlDefaultAdminRules_init(0, _admin); } + /// @dev use `upgradeAndCall` for upgrading together with re-initialization + function initializeNumberOfRituals() external reinitializer(2) { + if (numberOfRituals == 0) { + numberOfRituals = ritualsStub.length; + } + } + + function rituals(uint32 ritualId) public view returns (Ritual memory) { + return storageRitual(ritualId); + } + + // for backward compatibility + function storageRitual(uint32 ritualId) internal view returns (Ritual storage) { + if (ritualId < ritualsStub.length) { + return ritualsStub[ritualId]; + } + require(ritualId < numberOfRituals, "Ritual id out of bounds"); + return _rituals[ritualId]; + } + function getInitiator(uint32 ritualId) external view returns (address) { - return rituals[ritualId].initiator; + return rituals(ritualId).initiator; } function getTimestamps( uint32 ritualId ) external view returns (uint32 initTimestamp, uint32 endTimestamp) { - initTimestamp = rituals[ritualId].initTimestamp; - endTimestamp = rituals[ritualId].endTimestamp; + Ritual storage ritual = storageRitual(ritualId); + initTimestamp = ritual.initTimestamp; + endTimestamp = ritual.endTimestamp; } function getAccessController(uint32 ritualId) external view returns (IEncryptionAuthorizer) { - return rituals[ritualId].accessController; + Ritual storage ritual = storageRitual(ritualId); + return ritual.accessController; } function getFeeModel(uint32 ritualId) external view returns (IFeeModel) { - return rituals[ritualId].feeModel; + Ritual storage ritual = storageRitual(ritualId); + return ritual.feeModel; } function getRitualState(uint32 ritualId) external view returns (RitualState) { - return getRitualState(rituals[ritualId]); + return getRitualState(storageRitual(ritualId)); } function isRitualActive(Ritual storage ritual) internal view returns (bool) { @@ -156,7 +182,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable } function isRitualActive(uint32 ritualId) external view returns (bool) { - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); return isRitualActive(ritual); } @@ -194,7 +220,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable } function setProviderPublicKey(BLS12381.G2Point calldata publicKey) external { - uint32 lastRitualId = uint32(rituals.length); + uint32 lastRitualId = uint32(numberOfRituals); address stakingProvider = application.operatorToStakingProvider(msg.sender); require(stakingProvider != address(0), "Operator has no bond with staking provider"); @@ -247,7 +273,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable } function transferRitualAuthority(uint32 ritualId, address newAuthority) external { - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); require(isRitualActive(ritual), "Ritual is not active"); address previousAuthority = ritual.authority; require(msg.sender == previousAuthority, "Sender not ritual authority"); @@ -255,12 +281,8 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable emit RitualAuthorityTransferred(ritualId, previousAuthority, newAuthority); } - function numberOfRituals() external view returns (uint256) { - return rituals.length; - } - function getParticipants(uint32 ritualId) external view returns (Participant[] memory) { - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); return ritual.participant; } @@ -283,8 +305,9 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable require(2 <= length && length <= maxDkgSize, "Invalid number of nodes"); require(duration >= 24 hours, "Invalid ritual duration"); // TODO: Define minimum duration #106 - uint32 id = uint32(rituals.length); - Ritual storage ritual = rituals.push(); + uint32 id = uint32(numberOfRituals); + Ritual storage ritual = _rituals[id]; + numberOfRituals += 1; ritual.initiator = msg.sender; ritual.authority = authority; ritual.dkgSize = length; @@ -326,7 +349,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable function postTranscript(uint32 ritualId, bytes calldata transcript) external { uint256 initialGasLeft = gasleft(); - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); require( getRitualState(ritual) == RitualState.DKG_AWAITING_TRANSCRIPTS, "Not waiting for transcripts" @@ -354,7 +377,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable } function getAuthority(uint32 ritualId) external view returns (address) { - return rituals[ritualId].authority; + return rituals(ritualId).authority; } function postAggregation( @@ -365,7 +388,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable ) external { uint256 initialGasLeft = gasleft(); - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); require( getRitualState(ritual) == RitualState.DKG_AWAITING_AGGREGATIONS, "Not waiting for aggregations" @@ -433,7 +456,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable function getPublicKeyFromRitualId( uint32 ritualId ) external view returns (BLS12381.G1Point memory) { - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); RitualState state = getRitualState(ritual); require( state == RitualState.ACTIVE || state == RitualState.EXPIRED, @@ -484,7 +507,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable address provider, bool transcript ) external view returns (Participant memory) { - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); Participant memory participant = getParticipant(ritual, provider); if (!transcript) { participant.transcript = ""; @@ -496,7 +519,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable uint32 ritualId, address provider ) external view returns (Participant memory) { - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); Participant memory participant = getParticipant(ritual, provider); return participant; } @@ -507,7 +530,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable uint256 maxParticipants, bool includeTranscript ) external view returns (Participant[] memory) { - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); uint256 endIndex = ritual.participant.length; require(startIndex >= 0, "Invalid start index"); require(startIndex < endIndex, "Wrong start index"); @@ -529,7 +552,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable } function getProviders(uint32 ritualId) external view returns (address[] memory) { - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); address[] memory providers = new address[](ritual.participant.length); for (uint256 i = 0; i < ritual.participant.length; i++) { providers[i] = ritual.participant[i].provider; @@ -538,7 +561,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable } function isParticipant(uint32 ritualId, address provider) external view returns (bool) { - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); (bool found, ) = findParticipant(ritual, provider); return found; } @@ -549,7 +572,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable bytes memory evidence, bytes memory ciphertextHeader ) external view returns (bool) { - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); require(getRitualState(ritual) == RitualState.ACTIVE, "Ritual not active"); return ritual.accessController.isAuthorized(ritualId, evidence, ciphertextHeader); } @@ -572,7 +595,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable } function extendRitual(uint32 ritualId, uint32 duration) external { - Ritual storage ritual = rituals[ritualId]; + Ritual storage ritual = storageRitual(ritualId); require(msg.sender == ritual.initiator, "Only initiator can extend ritual"); require(getRitualState(ritual) == RitualState.ACTIVE, "Only active ritual can be extended"); ritual.endTimestamp += duration; diff --git a/contracts/test/CoordinatorTestSet.sol b/contracts/test/CoordinatorTestSet.sol index d7664553f..6b777d116 100644 --- a/contracts/test/CoordinatorTestSet.sol +++ b/contracts/test/CoordinatorTestSet.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; import "../threshold/ITACoChildApplication.sol"; +import "../contracts/coordination/Coordinator.sol"; /** * @notice Contract for testing Coordinator contract @@ -33,3 +34,37 @@ contract ChildApplicationForCoordinatorMock is ITACoChildApplication { // solhint-disable-next-line no-empty-blocks function penalize(address _stakingProvider) external {} } + +contract ExtendedCoordinator is Coordinator { + constructor(ITACoChildApplication _application) Coordinator(_application) {} + + function initiateOldRitual( + IFeeModel feeModel, + address[] calldata providers, + address authority, + uint32 duration, + IEncryptionAuthorizer accessController + ) external returns (uint32) { + uint16 length = uint16(providers.length); + + uint32 id = uint32(ritualsStub.length); + Ritual storage ritual = ritualsStub.push(); + ritual.initiator = msg.sender; + ritual.authority = authority; + ritual.dkgSize = length; + ritual.threshold = getThresholdForRitualSize(length); + ritual.initTimestamp = uint32(block.timestamp); + ritual.endTimestamp = ritual.initTimestamp + duration; + ritual.accessController = accessController; + ritual.feeModel = feeModel; + + address previous = address(0); + for (uint256 i = 0; i < length; i++) { + Participant storage newParticipant = ritual.participant.push(); + address current = providers[i]; + newParticipant.provider = current; + previous = current; + } + return id; + } +} diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 22b88b092..0d896319d 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -3,6 +3,7 @@ import ape import pytest +from ape.utils import ZERO_ADDRESS from eth_account import Account from hexbytes import HexBytes from web3 import Web3 @@ -86,9 +87,9 @@ def erc20(project, initiator): @pytest.fixture() -def coordinator(project, deployer, application, initiator, oz_dependency): +def coordinator(project, deployer, application, oz_dependency): admin = deployer - contract = project.Coordinator.deploy( + contract = project.ExtendedCoordinator.deploy( application.address, sender=deployer, ) @@ -100,7 +101,7 @@ def coordinator(project, deployer, application, initiator, oz_dependency): encoded_initializer_function, sender=deployer, ) - proxy_contract = project.Coordinator.at(proxy.address) + proxy_contract = project.ExtendedCoordinator.at(proxy.address) return proxy_contract @@ -219,17 +220,20 @@ def test_initiate_ritual( ritual_struct = coordinator.rituals(ritualID) assert ritual_struct[0] == initiator - init, end = ritual_struct[1], ritual_struct[2] + init, end = ritual_struct[1], ritual_struct["endTimestamp"] assert end - init == DURATION - total_transcripts, total_aggregations = ritual_struct[3], ritual_struct[4] + total_transcripts, total_aggregations = ( + ritual_struct["totalTranscripts"], + ritual_struct["totalAggregations"], + ) assert total_transcripts == total_aggregations == 0 - assert ritual_struct[5] == authority - assert ritual_struct[6] == len(nodes) - assert ritual_struct[7] == 1 + len(nodes) // 2 # threshold - assert not ritual_struct[8] # aggregationMismatch - assert ritual_struct[9] == global_allow_list.address # accessController - assert ritual_struct[10] == (b"\x00" * 32, b"\x00" * 16) # publicKey - assert not ritual_struct[11] # aggregatedTranscript + assert ritual_struct["authority"] == authority + assert ritual_struct["dkgSize"] == len(nodes) + assert ritual_struct["threshold"] == 1 + len(nodes) // 2 # threshold + assert not ritual_struct["aggregationMismatch"] # aggregationMismatch + assert ritual_struct["accessController"] == global_allow_list.address # accessController + assert ritual_struct["publicKey"] == (b"\x00" * 32, b"\x00" * 16) # publicKey + assert not ritual_struct["aggregatedTranscript"] # aggregatedTranscript fee = fee_model.getRitualCost(len(nodes), DURATION) assert erc20.balanceOf(fee_model) == fee @@ -564,3 +568,85 @@ def test_post_aggregation_fails( assert fee_model.totalPendingFees() == 0 assert fee_model.pendingFees(ritualID) == 0 fee_model.withdrawTokens(fee_model_balance_after_refund, sender=deployer) + + +def test_upgrade( + coordinator, nodes, initiator, erc20, fee_model, treasury, deployer, global_allow_list +): + coordinator.initiateOldRitual( + fee_model, nodes, initiator, DURATION, global_allow_list.address, sender=initiator + ) + coordinator.initiateOldRitual( + ZERO_ADDRESS, [nodes[0]], treasury, DURATION // 2, deployer, sender=initiator + ) + assert coordinator.numberOfRituals() == 0 + coordinator.initializeNumberOfRituals(sender=deployer) + assert coordinator.numberOfRituals() == 2 + + initiate_ritual( + coordinator=coordinator, + fee_model=fee_model, + erc20=erc20, + authority=initiator, + nodes=nodes, + allow_logic=global_allow_list, + ) + assert coordinator.numberOfRituals() == 3 + + assert coordinator.getRitualState(0) == RitualState.DKG_AWAITING_TRANSCRIPTS + assert coordinator.getRitualState(1) == RitualState.DKG_AWAITING_TRANSCRIPTS + assert coordinator.getRitualState(2) == RitualState.DKG_AWAITING_TRANSCRIPTS + + ritual_struct = coordinator.rituals(0) + assert ritual_struct["initiator"] == initiator + init, end = ritual_struct["initTimestamp"], ritual_struct["endTimestamp"] + assert end - init == DURATION + total_transcripts, total_aggregations = ( + ritual_struct["totalTranscripts"], + ritual_struct["totalAggregations"], + ) + assert total_transcripts == total_aggregations == 0 + assert ritual_struct["authority"] == initiator + assert ritual_struct["dkgSize"] == len(nodes) + assert ritual_struct["threshold"] == 1 + len(nodes) // 2 + assert not ritual_struct["aggregationMismatch"] + assert ritual_struct["accessController"] == global_allow_list.address + assert ritual_struct["publicKey"] == (b"\x00" * 32, b"\x00" * 16) + assert not ritual_struct["aggregatedTranscript"] + assert ritual_struct["feeModel"] == fee_model.address + + ritual_struct = coordinator.rituals(1) + assert ritual_struct["initiator"] == initiator + init, end = ritual_struct["initTimestamp"], ritual_struct["endTimestamp"] + assert end - init == DURATION // 2 + total_transcripts, total_aggregations = ( + ritual_struct["totalTranscripts"], + ritual_struct["totalAggregations"], + ) + assert total_transcripts == total_aggregations == 0 + assert ritual_struct["authority"] == treasury + assert ritual_struct["dkgSize"] == 1 + assert ritual_struct["threshold"] == 1 # threshold + assert not ritual_struct["aggregationMismatch"] # aggregationMismatch + assert ritual_struct["accessController"] == deployer # accessController + assert ritual_struct["publicKey"] == (b"\x00" * 32, b"\x00" * 16) # publicKey + assert not ritual_struct["aggregatedTranscript"] # aggregatedTranscript + assert ritual_struct["feeModel"] == ZERO_ADDRESS # feeModel + + ritual_struct = coordinator.rituals(2) + assert ritual_struct["initiator"] == initiator + init, end = ritual_struct["initTimestamp"], ritual_struct["endTimestamp"] + assert end - init == DURATION + total_transcripts, total_aggregations = ( + ritual_struct["totalTranscripts"], + ritual_struct["totalAggregations"], + ) + assert total_transcripts == total_aggregations == 0 + assert ritual_struct["authority"] == initiator + assert ritual_struct["dkgSize"] == len(nodes) + assert ritual_struct["threshold"] == 1 + len(nodes) // 2 # threshold + assert not ritual_struct["aggregationMismatch"] # aggregationMismatch + assert ritual_struct["accessController"] == global_allow_list.address # accessController + assert ritual_struct["publicKey"] == (b"\x00" * 32, b"\x00" * 16) # publicKey + assert not ritual_struct["aggregatedTranscript"] # aggregatedTranscript + assert ritual_struct["feeModel"] == fee_model.address # feeModel