From a4de3c9bda7496390190c122cb874e0208bf371c Mon Sep 17 00:00:00 2001
From: Piotr Roslaniec
Date: Thu, 6 Jul 2023 17:02:43 +0200
Subject: [PATCH] apply pr suggestions
---
.../contracts/coordination/Coordinator.sol | 22 +++++++++----------
tests/test_coordinator.py | 16 +++++++++-----
2 files changed, 21 insertions(+), 17 deletions(-)
diff --git a/contracts/contracts/coordination/Coordinator.sol b/contracts/contracts/coordination/Coordinator.sol
index 1856ceec..895eb636 100644
--- a/contracts/contracts/coordination/Coordinator.sol
+++ b/contracts/contracts/coordination/Coordinator.sol
@@ -30,7 +30,7 @@ contract Coordinator is AccessControlDefaultAdminRules {
event TimeoutChanged(uint32 oldTimeout, uint32 newTimeout);
event MaxDkgSizeChanged(uint16 oldSize, uint16 newSize);
- event ParticipantPublicKeySet(address indexed participant, BLS12381.G1Point publicKey);
+ event ParticipantPublicKeySet(uint32 indexed ritualId, address indexed participant, BLS12381.G2Point publicKey);
enum RitualState {
NON_INITIATED,
@@ -64,8 +64,8 @@ contract Coordinator is AccessControlDefaultAdminRules {
}
struct ParticipantKey {
- uint32 ritualId;
- BLS12381.G1Point publicKey;
+ uint32 lastRitualId;
+ BLS12381.G2Point publicKey;
}
using SafeERC20 for IERC20;
@@ -133,21 +133,21 @@ contract Coordinator is AccessControlDefaultAdminRules {
_setRoleAdmin(INITIATOR_ROLE, bytes32(0));
}
- function setProviderPublicKey(BLS12381.G1Point calldata _publicKey) public {
+ function setProviderPublicKey(BLS12381.G2Point calldata _publicKey) public {
uint32 lastRitualId = uint32(rituals.length);
address provider = application.stakingProviderFromOperator(msg.sender);
ParticipantKey memory newRecord = ParticipantKey(lastRitualId, _publicKey);
keysHistory[provider].push(newRecord);
- emit ParticipantPublicKeySet(provider, _publicKey);
+ emit ParticipantPublicKeySet(lastRitualId, provider, _publicKey);
}
- function getProviderPublicKey(address _address, uint _ritualId) public view returns (BLS12381.G1Point memory) {
- ParticipantKey[] storage participantHistory = keysHistory[_address];
+ function getProviderPublicKey(address _provider, uint _ritualId) external view returns (BLS12381.G2Point memory) {
+ ParticipantKey[] storage participantHistory = keysHistory[_provider];
for (uint i = participantHistory.length - 1; i >= 0; i--) {
- if (participantHistory[i].ritualId <= _ritualId) {
+ if (participantHistory[i].lastRitualId <= _ritualId) {
return participantHistory[i].publicKey;
}
}
@@ -275,7 +275,7 @@ contract Coordinator is AccessControlDefaultAdminRules {
function postAggregation(
uint32 ritualId,
bytes calldata aggregatedTranscript,
- BLS12381.G1Point calldata publicKey,
+ BLS12381.G1Point calldata dkgPublicKey,
bytes calldata decryptionRequestStaticKey
) external {
uint256 initialGasLeft = gasleft();
@@ -316,9 +316,9 @@ contract Coordinator is AccessControlDefaultAdminRules {
if (ritual.aggregatedTranscript.length == 0) {
ritual.aggregatedTranscript = aggregatedTranscript;
- ritual.publicKey = publicKey;
+ ritual.publicKey = dkgPublicKey;
} else if (
- !BLS12381.eqG1Point(ritual.publicKey, publicKey) ||
+ !BLS12381.eqG1Point(ritual.publicKey, dkgPublicKey) ||
keccak256(ritual.aggregatedTranscript) != aggregatedTranscriptDigest
) {
ritual.aggregationMismatch = true;
diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py
index 3fa4adcd..9ce4cbac 100644
--- a/tests/test_coordinator.py
+++ b/tests/test_coordinator.py
@@ -31,6 +31,10 @@ def transcript_size(shares, threshold):
return int(424 + 240 * (shares / 2) + 50 * (threshold))
+def gen_public_key():
+ return (os.urandom(32), os.urandom(32), os.urandom(32))
+
+
@pytest.fixture(scope="module")
def nodes(accounts):
return sorted(accounts[:MAX_DKG_SIZE], key=lambda x: x.address.lower())
@@ -110,7 +114,7 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator):
coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
for node in nodes:
- public_key = (os.urandom(32), os.urandom(16))
+ public_key = gen_public_key()
coordinator.setProviderPublicKey(public_key, sender=node)
with ape.reverts("Providers must be sorted"):
coordinator.initiateRitual(nodes[1:] + [nodes[0]], initiator, DURATION, sender=initiator)
@@ -122,7 +126,7 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator):
def initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes):
for node in nodes:
- public_key = (os.urandom(32), os.urandom(16))
+ public_key = gen_public_key()
coordinator.setProviderPublicKey(public_key, sender=node)
cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
erc20.approve(coordinator.address, cost, sender=initiator)
@@ -143,9 +147,9 @@ def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_mod
assert coordinator.getRitualState(0) == RitualState.AWAITING_TRANSCRIPTS
-def test_test_provider_public_key(coordinator, nodes):
+def test_provider_public_key(coordinator, nodes):
selected_provider = nodes[0]
- public_key = (os.urandom(32), os.urandom(16))
+ public_key = gen_public_key()
tx = coordinator.setProviderPublicKey(public_key, sender=selected_provider)
ritual_id = coordinator.numberOfRituals()
@@ -220,11 +224,11 @@ def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_mo
aggregated = transcript # has the same size as transcript
decryption_request_static_keys = [os.urandom(42) for _ in nodes]
- public_key = (os.urandom(32), os.urandom(16))
+ dkg_public_key = (os.urandom(32), os.urandom(16))
for i, node in enumerate(nodes):
assert coordinator.getRitualState(0) == RitualState.AWAITING_AGGREGATIONS
tx = coordinator.postAggregation(
- 0, aggregated, public_key, decryption_request_static_keys[i], sender=node
+ 0, aggregated, dkg_public_key, decryption_request_static_keys[i], sender=node
)
events = list(coordinator.AggregationPosted.from_receipt(tx))