diff --git a/contracts/contracts/coordination/Coordinator.sol b/contracts/contracts/coordination/Coordinator.sol index b7b423db..525dce9a 100644 --- a/contracts/contracts/coordination/Coordinator.sol +++ b/contracts/contracts/coordination/Coordinator.sol @@ -70,6 +70,8 @@ contract Coordinator is AccessControlDefaultAdminRules { bool public isInitiationPublic; IFeeModel feeModel; // TODO: Consider making feeModel specific to each ritual IReimbursementPool reimbursementPool; + uint256 public totalPendingFees; + mapping(uint256 => uint256) public pendingFees; constructor( IAccessControlApplication _stakes, @@ -162,8 +164,6 @@ contract Coordinator is AccessControlDefaultAdminRules { require(2 <= length && length <= maxDkgSize, "Invalid number of nodes"); require(duration > 0, "Invalid ritual duration"); // TODO: We probably want to restrict it more - processRitualPayment(providers, duration); - uint32 id = uint32(rituals.length); Ritual storage ritual = rituals.push(); ritual.initiator = msg.sender; @@ -186,6 +186,8 @@ contract Coordinator is AccessControlDefaultAdminRules { newParticipant.provider = current; previous = current; } + + processRitualPayment(id, providers, duration); // TODO: Include cohort fingerprint in StartRitual event? emit StartRitual(id, ritual.authority, providers); @@ -286,6 +288,7 @@ contract Coordinator is AccessControlDefaultAdminRules { if(!ritual.aggregationMismatch){ ritual.totalAggregations++; if (ritual.totalAggregations == ritual.dkgSize){ + processPendingFee(ritualId); emit EndRitual({ ritualId: ritualId, successful: true @@ -293,6 +296,7 @@ contract Coordinator is AccessControlDefaultAdminRules { // TODO: Consider including public key in event } } + processReimbursement(initialGasLeft); } @@ -318,14 +322,45 @@ contract Coordinator is AccessControlDefaultAdminRules { return getParticipantFromProvider(rituals[ritualID], provider); } - function processRitualPayment(address[] calldata providers, uint32 duration) internal { + function processRitualPayment(uint256 ritualID, address[] calldata providers, uint32 duration) internal { uint256 ritualCost = feeModel.getRitualInitiationCost(providers, duration); if (ritualCost > 0){ + totalPendingFees += ritualCost; + assert(pendingFees[ritualID] == 0); // TODO: This is an invariant, not sure if actually needed + pendingFees[ritualID] += ritualCost; IERC20 currency = IERC20(feeModel.currency()); currency.transferFrom(msg.sender, address(this), ritualCost); // TODO: Define methods to manage these funds } } + + function processPendingFee(uint256 ritualID) public { + Ritual storage ritual = rituals[ritualID]; + RitualState state = getRitualState(ritual); + require( + state == RitualState.TIMEOUT || + state == RitualState.INVALID || + state == RitualState.FINALIZED, + "Ritual is not ended" + ); + uint256 pending = pendingFees[ritualID]; + require(pending > 0, "No pending fees for this ritual"); + + // Finalize fees for this ritual + totalPendingFees -= pending; + delete pendingFees[ritualID]; + // Transfer fees back to initiator if failed + if(state == RitualState.TIMEOUT || state == RitualState.INVALID){ + // Amount to refund depends on how much work nodes did for the ritual. + // TODO: Validate if this is enough to remove griefing attacks + uint256 executedTransactions = ritual.totalTranscripts + ritual.totalAggregations; + uint256 expectedTransactions = 2 * ritual.dkgSize; + uint256 consumedFee = pending * executedTransactions / expectedTransactions; + uint256 refundableFee = pending - consumedFee; + IERC20 currency = IERC20(feeModel.currency()); + currency.transferFrom(address(this), ritual.initiator, refundableFee); + } + } function processReimbursement(uint256 initialGasLeft) internal { if(address(reimbursementPool) != address(0)){ // TODO: Consider defining a method