diff --git a/prover/backend/aggregation/craft.go b/prover/backend/aggregation/craft.go index 13059dbb4..5c6d71e16 100644 --- a/prover/backend/aggregation/craft.go +++ b/prover/backend/aggregation/craft.go @@ -39,9 +39,9 @@ const ( func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) { var ( - l2MessageHashes []string - l2MsgBlockOffsets []bool - cf = &CollectedFields{ + allL2MessageHashes []string + l2MsgBlockOffsets []bool + cf = &CollectedFields{ L2MsgTreeDepth: l2MsgMerkleTreeDepth, ParentAggregationLastBlockTimestamp: uint(req.ParentAggregationLastBlockTimestamp), LastFinalizedL1RollingHash: req.ParentAggregationLastL1RollingHash, @@ -55,9 +55,10 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) { for i, execReqFPath := range req.ExecutionProofs { var ( - po = &execution.Response{} - fpath = path.Join(cfg.Execution.DirTo(), execReqFPath) - f = files.MustRead(fpath) + po = &execution.Response{} + l2MessageHashes []string + fpath = path.Join(cfg.Execution.DirTo(), execReqFPath) + f = files.MustRead(fpath) ) if err := json.NewDecoder(f).Decode(po); err != nil { @@ -112,7 +113,7 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) { finalBlock := &po.BlocksData[len(po.BlocksData)-1] piq, err := public_input.ExecutionSerializable{ L2MsgHashes: l2MessageHashes, - FinalStateRootHash: po.PublicInput.Hex(), // TODO @tabaie make sure this is the right value + FinalStateRootHash: finalBlock.RootHash.Hex(), FinalBlockNumber: uint64(cf.FinalBlockNumber), FinalBlockTimestamp: finalBlock.TimeStamp, FinalRollingHash: cf.L1RollingHash, @@ -123,6 +124,8 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) { } cf.ExecutionPI = append(cf.ExecutionPI, piq) } + + allL2MessageHashes = append(allL2MessageHashes, l2MessageHashes...) } cf.DecompressionPI = make([]blobsubmission.Response, 0, len(req.DecompressionProofs)) @@ -166,7 +169,7 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) { } cf.L2MessagingBlocksOffsets = utils.HexEncodeToString(PackOffsets(l2MsgBlockOffsets)) - cf.L2MsgRootHashes = PackInMiniTrees(l2MessageHashes) + cf.L2MsgRootHashes = PackInMiniTrees(allL2MessageHashes) return cf, nil diff --git a/prover/circuits/execution/circuit.go b/prover/circuits/execution/circuit.go index 4412872a2..9ff547cd5 100644 --- a/prover/circuits/execution/circuit.go +++ b/prover/circuits/execution/circuit.go @@ -61,9 +61,13 @@ func assign( funcInputs FunctionalPublicInput, ) CircuitExecution { wizardVerifier := wizard.GetWizardVerifierCircuitAssignment(comp, proof) + fpiSnark, err := funcInputs.ToSnarkType() + if err != nil { + panic(err) // TODO error handling + } return CircuitExecution{ WizardVerifier: *wizardVerifier, - FuncInputs: funcInputs.ToSnarkType(), + FuncInputs: fpiSnark, PublicInput: new(big.Int).SetBytes(funcInputs.Sum()), } } diff --git a/prover/circuits/execution/pi.go b/prover/circuits/execution/pi.go index 926d9f333..79b70e631 100644 --- a/prover/circuits/execution/pi.go +++ b/prover/circuits/execution/pi.go @@ -2,6 +2,7 @@ package execution import ( "encoding/binary" + "fmt" "hash" "slices" @@ -147,7 +148,7 @@ func (pi *FunctionalPublicInputSnark) Sum(api frontend.API, hsh gnarkHash.FieldH return hsh.Sum() } -func (pi *FunctionalPublicInput) ToSnarkType() FunctionalPublicInputSnark { +func (pi *FunctionalPublicInput) ToSnarkType() (FunctionalPublicInputSnark, error) { res := FunctionalPublicInputSnark{ FunctionalPublicInputQSnark: FunctionalPublicInputQSnark{ DataChecksum: slices.Clone(pi.DataChecksum[:]), @@ -167,7 +168,12 @@ func (pi *FunctionalPublicInput) ToSnarkType() FunctionalPublicInputSnark { utils.Copy(res.FinalRollingHash[:], pi.FinalRollingHash[:]) utils.Copy(res.InitialRollingHash[:], pi.InitialRollingHash[:]) - return res + var err error + if nbMsg := len(pi.L2MessageHashes); nbMsg > pi.MaxNbL2MessageHashes { + err = fmt.Errorf("has %d L2 message hashes but a maximum of %d is allowed", nbMsg, pi.MaxNbL2MessageHashes) + } + + return res, err } func (pi *FunctionalPublicInput) Sum() []byte { // all mimc; no need to provide a keccak hasher diff --git a/prover/circuits/execution/pi_test.go b/prover/circuits/execution/pi_test.go index 4c7b54ab8..3d4523863 100644 --- a/prover/circuits/execution/pi_test.go +++ b/prover/circuits/execution/pi_test.go @@ -1,6 +1,7 @@ package execution import ( + "github.com/stretchr/testify/require" "testing" "github.com/consensys/gnark/frontend" @@ -35,7 +36,8 @@ func TestPIConsistency(t *testing.T) { pi.InitialStateRootHash[0] &= 0x0f pi.FinalStateRootHash[0] &= 0x0f - snarkPi := pi.ToSnarkType() + snarkPi, err := pi.ToSnarkType() + require.NoError(t, err) piSum := pi.Sum() snarkTestUtils.SnarkFunctionTest(func(api frontend.API) []frontend.Variable { diff --git a/prover/circuits/pi-interconnection/assign.go b/prover/circuits/pi-interconnection/assign.go index 4c1139f99..ece51d38c 100644 --- a/prover/circuits/pi-interconnection/assign.go +++ b/prover/circuits/pi-interconnection/assign.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/base64" "errors" + "fmt" "hash" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -121,7 +122,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { } if prevShnarf = shnarf.Compute(); !bytes.Equal(prevShnarf, shnarfs[i]) { - err = errors.New("shnarf mismatch") + err = fmt.Errorf("shnarf mismatch, i:%d, shnarf: %x, prevShnarf: %x, ", i, shnarfs[i], prevShnarf) return } } @@ -163,7 +164,6 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { } // Aggregation FPI - aggregationFPI, err := public_input.NewAggregationFPI(&r.Aggregation) if err != nil { return @@ -178,6 +178,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { merkleNbLeaves := 1 << config.L2MsgMerkleDepth maxNbL2MessageHashes := config.L2MsgMaxNbMerkle * merkleNbLeaves l2MessageHashes := make([][32]byte, 0, maxNbL2MessageHashes) + // Execution FPI executionFPI := execution.FunctionalPublicInput{ FinalStateRootHash: aggregationFPI.InitialStateRootHash, @@ -213,21 +214,48 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { } a.ExecutionPublicInput[i] = executionFPI.Sum() - a.ExecutionFPIQ[i] = executionFPI.ToSnarkType().FunctionalPublicInputQSnark + if snarkFPI, _err := executionFPI.ToSnarkType(); _err != nil { + err = fmt.Errorf("execution #%d: %w", i, _err) + return + } else { + a.ExecutionFPIQ[i] = snarkFPI.FunctionalPublicInputQSnark + } } // consistency check - if executionFPI.FinalBlockTimestamp != aggregationFPI.FinalBlockTimestamp || - executionFPI.FinalBlockNumber != aggregationFPI.FinalBlockNumber || - executionFPI.FinalRollingHash != aggregationFPI.FinalRollingHash || - executionFPI.FinalRollingHashNumber != aggregationFPI.FinalRollingHashNumber { - err = errors.New("final execution values not matching final aggregation values") + if executionFPI.FinalBlockTimestamp != aggregationFPI.FinalBlockTimestamp { + err = fmt.Errorf("final block timestamps do not match: execution=%x, aggregation=%x", + executionFPI.FinalBlockTimestamp, aggregationFPI.FinalBlockTimestamp) return } + if executionFPI.FinalBlockNumber != aggregationFPI.FinalBlockNumber { + err = fmt.Errorf("final block numbers do not match: execution=%v, aggregation=%x", + executionFPI.FinalBlockNumber, aggregationFPI.FinalBlockNumber) + return + } + if executionFPI.FinalRollingHash != [32]byte{} { + if executionFPI.FinalRollingHash != aggregationFPI.FinalRollingHash { + err = fmt.Errorf("final rolling hashes do not match: execution=%x, aggregation=%x", + executionFPI.FinalRollingHash, aggregationFPI.FinalRollingHash) + return + } + + if executionFPI.FinalRollingHashNumber != aggregationFPI.FinalRollingHashNumber { + err = fmt.Errorf("final rolling hash numbers do not match: execution=%v, aggregation=%v", + executionFPI.FinalRollingHashNumber, aggregationFPI.FinalRollingHashNumber) + return + } + } + if len(l2MessageHashes) > maxNbL2MessageHashes { err = errors.New("too many L2 messages") return } + if minNbRoots := (len(l2MessageHashes) + merkleNbLeaves - 1) / merkleNbLeaves; len(r.Aggregation.L2MsgRootHashes) < minNbRoots { + err = fmt.Errorf("the %d merkle roots provided are too few to accommodate all %d execution messages. A minimum of %d is needed", len(r.Aggregation.L2MsgRootHashes), len(l2MessageHashes), minNbRoots) + return + } + for i := range r.Aggregation.L2MsgRootHashes { var expectedRoot []byte if expectedRoot, err = utils.HexDecodeString(r.Aggregation.L2MsgRootHashes[i]); err != nil { @@ -239,6 +267,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) { return } } + // padding merkle root hashes emptyTree := make([][]byte, config.L2MsgMerkleDepth+1) emptyTree[0] = make([]byte, 64) diff --git a/prover/circuits/pi-interconnection/circuit.go b/prover/circuits/pi-interconnection/circuit.go index 4e87bd29e..28ad9007e 100644 --- a/prover/circuits/pi-interconnection/circuit.go +++ b/prover/circuits/pi-interconnection/circuit.go @@ -172,16 +172,35 @@ func (c *Circuit) Define(api frontend.API) error { } rExecution := internal.NewRange(api, nbExecution, maxNbExecution) + twoPow8 := big.NewInt(256) + hi16B := func(block [32]frontend.Variable) frontend.Variable { + return compress.ReadNum(api, block[:16], twoPow8) + } + lo16B := func(block [32]frontend.Variable) frontend.Variable { + return compress.ReadNum(api, block[16:], twoPow8) + } + + { // if rolling hash values are present in the last execution, they must match those of aggregation + finalRollingHashFromExec := rExecution.LastArray32F(func(i int) [32]frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHash }) + finalRollingHashNumFromExec := rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHashNumber }) + + h, l := hi16B(finalRollingHashFromExec), lo16B(finalRollingHashFromExec) + + finalRollingHashPresent := api.Sub(1, api.Mul(api.IsZero(h), api.IsZero(l))) + + internal.AssertEqualIf(api, finalRollingHashPresent, h, hi16B(c.FinalRollingHash)) + internal.AssertEqualIf(api, finalRollingHashPresent, l, lo16B(c.FinalRollingHash)) + internal.AssertEqualIf(api, finalRollingHashPresent, finalRollingHashNumFromExec, c.FinalRollingHashNumber) + } + pi := public_input.AggregationFPISnark{ - AggregationFPIQSnark: c.AggregationFPIQSnark, - NbL2Messages: merkleLeavesConcat.Length, - L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle), - FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }), - FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }), - FinalRollingHash: rExecution.LastArray32F(func(i int) [32]frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHash }), - FinalRollingHashNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHashNumber }), - FinalShnarf: rDecompression.LastArray32(shnarfs), - L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth, + AggregationFPIQSnark: c.AggregationFPIQSnark, + NbL2Messages: merkleLeavesConcat.Length, + L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle), + FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }), + FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }), + FinalShnarf: rDecompression.LastArray32(shnarfs), + L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth, } for i := range pi.L2MsgMerkleTreeRoots { @@ -190,7 +209,6 @@ func (c *Circuit) Define(api frontend.API) error { // "open" aggregation public input aggregationPIBytes := pi.Sum(api, &hshK) - twoPow8 := big.NewInt(256) api.AssertIsEqual(c.AggregationPublicInput[0], compress.ReadNum(api, aggregationPIBytes[:16], twoPow8)) api.AssertIsEqual(c.AggregationPublicInput[1], compress.ReadNum(api, aggregationPIBytes[16:], twoPow8)) diff --git a/prover/circuits/pi-interconnection/e2e_test.go b/prover/circuits/pi-interconnection/e2e_test.go index 4b0e187e2..763000c28 100644 --- a/prover/circuits/pi-interconnection/e2e_test.go +++ b/prover/circuits/pi-interconnection/e2e_test.go @@ -15,7 +15,7 @@ import ( "github.com/consensys/linea-monorepo/prover/backend/aggregation" "github.com/consensys/linea-monorepo/prover/backend/blobsubmission" "github.com/consensys/linea-monorepo/prover/circuits/internal" - "github.com/consensys/linea-monorepo/prover/circuits/internal/test_utils" + circuittesting "github.com/consensys/linea-monorepo/prover/circuits/internal/test_utils" pi_interconnection "github.com/consensys/linea-monorepo/prover/circuits/pi-interconnection" pitesting "github.com/consensys/linea-monorepo/prover/circuits/pi-interconnection/test_utils" "github.com/consensys/linea-monorepo/prover/config" @@ -101,7 +101,7 @@ func TestTinyTwoBatchBlob(t *testing.T) { blobResp, err := blobsubmission.CraftResponse(&blobReq) assert.NoError(t, err) - merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes)) + merkleRoots := aggregation.PackInMiniTrees(circuittesting.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes)) req := pi_interconnection.Request{ Decompressions: []blobsubmission.Response{*blobResp}, @@ -182,7 +182,7 @@ func TestTwoTwoBatchBlobs(t *testing.T) { blobResp1, err := blobsubmission.CraftResponse(&blobReq1) assert.NoError(t, err) - merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes, execReq[2].L2MsgHashes, execReq[3].L2MsgHashes)) + merkleRoots := aggregation.PackInMiniTrees(circuittesting.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes, execReq[2].L2MsgHashes, execReq[3].L2MsgHashes)) req := pi_interconnection.Request{ Decompressions: []blobsubmission.Response{*blobResp0, *blobResp1}, diff --git a/prover/public-input/aggregation.go b/prover/public-input/aggregation.go index 7b3b8b37e..5fc3b61c4 100644 --- a/prover/public-input/aggregation.go +++ b/prover/public-input/aggregation.go @@ -123,15 +123,15 @@ func (pi *AggregationFPI) ToSnarkType() AggregationFPISnark { InitialRollingHashNumber: pi.InitialRollingHashNumber, InitialStateRootHash: pi.InitialStateRootHash[:], - NbDecompression: pi.NbDecompression, - ChainID: pi.ChainID, - L2MessageServiceAddr: pi.L2MessageServiceAddr[:], + NbDecompression: pi.NbDecompression, + ChainID: pi.ChainID, + L2MessageServiceAddr: pi.L2MessageServiceAddr[:], + FinalRollingHashNumber: pi.FinalRollingHashNumber, }, - L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)), - FinalBlockNumber: pi.FinalBlockNumber, - FinalBlockTimestamp: pi.FinalBlockTimestamp, - FinalRollingHashNumber: pi.FinalRollingHashNumber, - L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth, + L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)), + FinalBlockNumber: pi.FinalBlockNumber, + FinalBlockTimestamp: pi.FinalBlockTimestamp, + L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth, } utils.Copy(s.FinalRollingHash[:], pi.FinalRollingHash[:]) @@ -154,8 +154,12 @@ type AggregationFPIQSnark struct { InitialBlockTimestamp frontend.Variable InitialRollingHash [32]frontend.Variable InitialRollingHashNumber frontend.Variable - ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID - L2MessageServiceAddr frontend.Variable // 20 bytes + // Ideally, FinalRollingHash and FinalRollingHashNumber would be inferred from the executions + // but sometimes executions are missing those values + FinalRollingHash [32]frontend.Variable + FinalRollingHashNumber frontend.Variable + ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID + L2MessageServiceAddr frontend.Variable // 20 bytes } type AggregationFPISnark struct { @@ -163,12 +167,10 @@ type AggregationFPISnark struct { NbL2Messages frontend.Variable // TODO not used in hash. delete if not necessary L2MsgMerkleTreeRoots [][32]frontend.Variable // FinalStateRootHash frontend.Variable redundant: incorporated into final shnarf - FinalBlockNumber frontend.Variable - FinalBlockTimestamp frontend.Variable - FinalRollingHash [32]frontend.Variable - FinalRollingHashNumber frontend.Variable - FinalShnarf [32]frontend.Variable - L2MsgMerkleTreeDepth int + FinalBlockNumber frontend.Variable + FinalBlockTimestamp frontend.Variable + FinalShnarf [32]frontend.Variable + L2MsgMerkleTreeDepth int } // NewAggregationFPI does NOT set all fields, only the ones covered in public_input.Aggregation diff --git a/prover/utils/test_utils/test_utils.go b/prover/utils/test_utils/test_utils.go index fdf66fb5e..dde00a424 100644 --- a/prover/utils/test_utils/test_utils.go +++ b/prover/utils/test_utils/test_utils.go @@ -3,9 +3,12 @@ package test_utils import ( "crypto/rand" "encoding/binary" + "encoding/json" "fmt" + "github.com/stretchr/testify/require" "math" "os" + "testing" ) type FakeTestingT struct{} @@ -37,3 +40,9 @@ func RandIntSliceN(length, n int) []int { } return res } + +func LoadJson(t *testing.T, path string, v any) { + in, err := os.Open(path) + require.NoError(t, err) + require.NoError(t, json.NewDecoder(in).Decode(v)) +}