Skip to content

Commit

Permalink
Fix: test PI Interconnection with Sepolia data (#208)
Browse files Browse the repository at this point in the history
* feat better pi assignment error finding
* fix append to allL2MessageHashes
* feat relax rolling hash verification in circuit
* PR feedback: remove unit test

---------

Co-authored-by: Arya Tabaie <15056835+Tabaie@users.noreply.github.com>
Co-authored-by: AlexandreBelling <alexandrebelling8@gmail.com>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent 8bb5fe7 commit 333bc77
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 49 deletions.
19 changes: 11 additions & 8 deletions prover/backend/aggregation/craft.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ const (
func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {

var (
l2MessageHashes []string
l2MsgBlockOffsets []bool
cf = &CollectedFields{
allL2MessageHashes []string
l2MsgBlockOffsets []bool
cf = &CollectedFields{
L2MsgTreeDepth: l2MsgMerkleTreeDepth,
ParentAggregationLastBlockTimestamp: uint(req.ParentAggregationLastBlockTimestamp),
LastFinalizedL1RollingHash: req.ParentAggregationLastL1RollingHash,
Expand All @@ -55,9 +55,10 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
for i, execReqFPath := range req.ExecutionProofs {

var (
po = &execution.Response{}
fpath = path.Join(cfg.Execution.DirTo(), execReqFPath)
f = files.MustRead(fpath)
po = &execution.Response{}
l2MessageHashes []string
fpath = path.Join(cfg.Execution.DirTo(), execReqFPath)
f = files.MustRead(fpath)
)

if err := json.NewDecoder(f).Decode(po); err != nil {
Expand Down Expand Up @@ -112,7 +113,7 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
finalBlock := &po.BlocksData[len(po.BlocksData)-1]
piq, err := public_input.ExecutionSerializable{
L2MsgHashes: l2MessageHashes,
FinalStateRootHash: po.PublicInput.Hex(), // TODO @tabaie make sure this is the right value
FinalStateRootHash: finalBlock.RootHash.Hex(),
FinalBlockNumber: uint64(cf.FinalBlockNumber),
FinalBlockTimestamp: finalBlock.TimeStamp,
FinalRollingHash: cf.L1RollingHash,
Expand All @@ -123,6 +124,8 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
}
cf.ExecutionPI = append(cf.ExecutionPI, piq)
}

allL2MessageHashes = append(allL2MessageHashes, l2MessageHashes...)
}

cf.DecompressionPI = make([]blobsubmission.Response, 0, len(req.DecompressionProofs))
Expand Down Expand Up @@ -166,7 +169,7 @@ func collectFields(cfg *config.Config, req *Request) (*CollectedFields, error) {
}

cf.L2MessagingBlocksOffsets = utils.HexEncodeToString(PackOffsets(l2MsgBlockOffsets))
cf.L2MsgRootHashes = PackInMiniTrees(l2MessageHashes)
cf.L2MsgRootHashes = PackInMiniTrees(allL2MessageHashes)

return cf, nil

Expand Down
6 changes: 5 additions & 1 deletion prover/circuits/execution/circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,13 @@ func assign(
funcInputs FunctionalPublicInput,
) CircuitExecution {
wizardVerifier := wizard.GetWizardVerifierCircuitAssignment(comp, proof)
fpiSnark, err := funcInputs.ToSnarkType()
if err != nil {
panic(err) // TODO error handling
}
return CircuitExecution{
WizardVerifier: *wizardVerifier,
FuncInputs: funcInputs.ToSnarkType(),
FuncInputs: fpiSnark,
PublicInput: new(big.Int).SetBytes(funcInputs.Sum()),
}
}
Expand Down
10 changes: 8 additions & 2 deletions prover/circuits/execution/pi.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package execution

import (
"encoding/binary"
"fmt"
"hash"
"slices"

Expand Down Expand Up @@ -147,7 +148,7 @@ func (pi *FunctionalPublicInputSnark) Sum(api frontend.API, hsh gnarkHash.FieldH
return hsh.Sum()
}

func (pi *FunctionalPublicInput) ToSnarkType() FunctionalPublicInputSnark {
func (pi *FunctionalPublicInput) ToSnarkType() (FunctionalPublicInputSnark, error) {
res := FunctionalPublicInputSnark{
FunctionalPublicInputQSnark: FunctionalPublicInputQSnark{
DataChecksum: slices.Clone(pi.DataChecksum[:]),
Expand All @@ -167,7 +168,12 @@ func (pi *FunctionalPublicInput) ToSnarkType() FunctionalPublicInputSnark {
utils.Copy(res.FinalRollingHash[:], pi.FinalRollingHash[:])
utils.Copy(res.InitialRollingHash[:], pi.InitialRollingHash[:])

return res
var err error
if nbMsg := len(pi.L2MessageHashes); nbMsg > pi.MaxNbL2MessageHashes {
err = fmt.Errorf("has %d L2 message hashes but a maximum of %d is allowed", nbMsg, pi.MaxNbL2MessageHashes)
}

return res, err
}

func (pi *FunctionalPublicInput) Sum() []byte { // all mimc; no need to provide a keccak hasher
Expand Down
4 changes: 3 additions & 1 deletion prover/circuits/execution/pi_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package execution

import (
"github.com/stretchr/testify/require"
"testing"

"github.com/consensys/gnark/frontend"
Expand Down Expand Up @@ -35,7 +36,8 @@ func TestPIConsistency(t *testing.T) {
pi.InitialStateRootHash[0] &= 0x0f
pi.FinalStateRootHash[0] &= 0x0f

snarkPi := pi.ToSnarkType()
snarkPi, err := pi.ToSnarkType()
require.NoError(t, err)
piSum := pi.Sum()

snarkTestUtils.SnarkFunctionTest(func(api frontend.API) []frontend.Variable {
Expand Down
45 changes: 37 additions & 8 deletions prover/circuits/pi-interconnection/assign.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"hash"

"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
Expand Down Expand Up @@ -121,7 +122,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
}

if prevShnarf = shnarf.Compute(); !bytes.Equal(prevShnarf, shnarfs[i]) {
err = errors.New("shnarf mismatch")
err = fmt.Errorf("shnarf mismatch, i:%d, shnarf: %x, prevShnarf: %x, ", i, shnarfs[i], prevShnarf)
return
}
}
Expand Down Expand Up @@ -163,7 +164,6 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
}

// Aggregation FPI

aggregationFPI, err := public_input.NewAggregationFPI(&r.Aggregation)
if err != nil {
return
Expand All @@ -178,6 +178,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
merkleNbLeaves := 1 << config.L2MsgMerkleDepth
maxNbL2MessageHashes := config.L2MsgMaxNbMerkle * merkleNbLeaves
l2MessageHashes := make([][32]byte, 0, maxNbL2MessageHashes)

// Execution FPI
executionFPI := execution.FunctionalPublicInput{
FinalStateRootHash: aggregationFPI.InitialStateRootHash,
Expand Down Expand Up @@ -213,21 +214,48 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
}

a.ExecutionPublicInput[i] = executionFPI.Sum()
a.ExecutionFPIQ[i] = executionFPI.ToSnarkType().FunctionalPublicInputQSnark
if snarkFPI, _err := executionFPI.ToSnarkType(); _err != nil {
err = fmt.Errorf("execution #%d: %w", i, _err)
return
} else {
a.ExecutionFPIQ[i] = snarkFPI.FunctionalPublicInputQSnark
}
}
// consistency check
if executionFPI.FinalBlockTimestamp != aggregationFPI.FinalBlockTimestamp ||
executionFPI.FinalBlockNumber != aggregationFPI.FinalBlockNumber ||
executionFPI.FinalRollingHash != aggregationFPI.FinalRollingHash ||
executionFPI.FinalRollingHashNumber != aggregationFPI.FinalRollingHashNumber {
err = errors.New("final execution values not matching final aggregation values")
if executionFPI.FinalBlockTimestamp != aggregationFPI.FinalBlockTimestamp {
err = fmt.Errorf("final block timestamps do not match: execution=%x, aggregation=%x",
executionFPI.FinalBlockTimestamp, aggregationFPI.FinalBlockTimestamp)
return
}
if executionFPI.FinalBlockNumber != aggregationFPI.FinalBlockNumber {
err = fmt.Errorf("final block numbers do not match: execution=%v, aggregation=%x",
executionFPI.FinalBlockNumber, aggregationFPI.FinalBlockNumber)
return
}
if executionFPI.FinalRollingHash != [32]byte{} {
if executionFPI.FinalRollingHash != aggregationFPI.FinalRollingHash {
err = fmt.Errorf("final rolling hashes do not match: execution=%x, aggregation=%x",
executionFPI.FinalRollingHash, aggregationFPI.FinalRollingHash)
return
}

if executionFPI.FinalRollingHashNumber != aggregationFPI.FinalRollingHashNumber {
err = fmt.Errorf("final rolling hash numbers do not match: execution=%v, aggregation=%v",
executionFPI.FinalRollingHashNumber, aggregationFPI.FinalRollingHashNumber)
return
}
}

if len(l2MessageHashes) > maxNbL2MessageHashes {
err = errors.New("too many L2 messages")
return
}

if minNbRoots := (len(l2MessageHashes) + merkleNbLeaves - 1) / merkleNbLeaves; len(r.Aggregation.L2MsgRootHashes) < minNbRoots {
err = fmt.Errorf("the %d merkle roots provided are too few to accommodate all %d execution messages. A minimum of %d is needed", len(r.Aggregation.L2MsgRootHashes), len(l2MessageHashes), minNbRoots)
return
}

for i := range r.Aggregation.L2MsgRootHashes {
var expectedRoot []byte
if expectedRoot, err = utils.HexDecodeString(r.Aggregation.L2MsgRootHashes[i]); err != nil {
Expand All @@ -239,6 +267,7 @@ func (c *Compiled) Assign(r Request) (a Circuit, err error) {
return
}
}

// padding merkle root hashes
emptyTree := make([][]byte, config.L2MsgMerkleDepth+1)
emptyTree[0] = make([]byte, 64)
Expand Down
38 changes: 28 additions & 10 deletions prover/circuits/pi-interconnection/circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,35 @@ func (c *Circuit) Define(api frontend.API) error {
}
rExecution := internal.NewRange(api, nbExecution, maxNbExecution)

twoPow8 := big.NewInt(256)
hi16B := func(block [32]frontend.Variable) frontend.Variable {
return compress.ReadNum(api, block[:16], twoPow8)
}
lo16B := func(block [32]frontend.Variable) frontend.Variable {
return compress.ReadNum(api, block[16:], twoPow8)
}

{ // if rolling hash values are present in the last execution, they must match those of aggregation
finalRollingHashFromExec := rExecution.LastArray32F(func(i int) [32]frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHash })
finalRollingHashNumFromExec := rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHashNumber })

h, l := hi16B(finalRollingHashFromExec), lo16B(finalRollingHashFromExec)

finalRollingHashPresent := api.Sub(1, api.Mul(api.IsZero(h), api.IsZero(l)))

internal.AssertEqualIf(api, finalRollingHashPresent, h, hi16B(c.FinalRollingHash))
internal.AssertEqualIf(api, finalRollingHashPresent, l, lo16B(c.FinalRollingHash))
internal.AssertEqualIf(api, finalRollingHashPresent, finalRollingHashNumFromExec, c.FinalRollingHashNumber)
}

pi := public_input.AggregationFPISnark{
AggregationFPIQSnark: c.AggregationFPIQSnark,
NbL2Messages: merkleLeavesConcat.Length,
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle),
FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }),
FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }),
FinalRollingHash: rExecution.LastArray32F(func(i int) [32]frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHash }),
FinalRollingHashNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalRollingHashNumber }),
FinalShnarf: rDecompression.LastArray32(shnarfs),
L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth,
AggregationFPIQSnark: c.AggregationFPIQSnark,
NbL2Messages: merkleLeavesConcat.Length,
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, c.L2MessageMaxNbMerkle),
FinalBlockNumber: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockNumber }),
FinalBlockTimestamp: rExecution.LastF(func(i int) frontend.Variable { return c.ExecutionFPIQ[i].FinalBlockTimestamp }),
FinalShnarf: rDecompression.LastArray32(shnarfs),
L2MsgMerkleTreeDepth: c.L2MessageMerkleDepth,
}

for i := range pi.L2MsgMerkleTreeRoots {
Expand All @@ -190,7 +209,6 @@ func (c *Circuit) Define(api frontend.API) error {

// "open" aggregation public input
aggregationPIBytes := pi.Sum(api, &hshK)
twoPow8 := big.NewInt(256)
api.AssertIsEqual(c.AggregationPublicInput[0], compress.ReadNum(api, aggregationPIBytes[:16], twoPow8))
api.AssertIsEqual(c.AggregationPublicInput[1], compress.ReadNum(api, aggregationPIBytes[16:], twoPow8))

Expand Down
6 changes: 3 additions & 3 deletions prover/circuits/pi-interconnection/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"github.com/consensys/linea-monorepo/prover/backend/aggregation"
"github.com/consensys/linea-monorepo/prover/backend/blobsubmission"
"github.com/consensys/linea-monorepo/prover/circuits/internal"
"github.com/consensys/linea-monorepo/prover/circuits/internal/test_utils"
circuittesting "github.com/consensys/linea-monorepo/prover/circuits/internal/test_utils"
pi_interconnection "github.com/consensys/linea-monorepo/prover/circuits/pi-interconnection"
pitesting "github.com/consensys/linea-monorepo/prover/circuits/pi-interconnection/test_utils"
"github.com/consensys/linea-monorepo/prover/config"
Expand Down Expand Up @@ -101,7 +101,7 @@ func TestTinyTwoBatchBlob(t *testing.T) {
blobResp, err := blobsubmission.CraftResponse(&blobReq)
assert.NoError(t, err)

merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes))
merkleRoots := aggregation.PackInMiniTrees(circuittesting.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes))

req := pi_interconnection.Request{
Decompressions: []blobsubmission.Response{*blobResp},
Expand Down Expand Up @@ -182,7 +182,7 @@ func TestTwoTwoBatchBlobs(t *testing.T) {
blobResp1, err := blobsubmission.CraftResponse(&blobReq1)
assert.NoError(t, err)

merkleRoots := aggregation.PackInMiniTrees(test_utils.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes, execReq[2].L2MsgHashes, execReq[3].L2MsgHashes))
merkleRoots := aggregation.PackInMiniTrees(circuittesting.BlocksToHex(execReq[0].L2MsgHashes, execReq[1].L2MsgHashes, execReq[2].L2MsgHashes, execReq[3].L2MsgHashes))

req := pi_interconnection.Request{
Decompressions: []blobsubmission.Response{*blobResp0, *blobResp1},
Expand Down
34 changes: 18 additions & 16 deletions prover/public-input/aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ func (pi *AggregationFPI) ToSnarkType() AggregationFPISnark {
InitialRollingHashNumber: pi.InitialRollingHashNumber,
InitialStateRootHash: pi.InitialStateRootHash[:],

NbDecompression: pi.NbDecompression,
ChainID: pi.ChainID,
L2MessageServiceAddr: pi.L2MessageServiceAddr[:],
NbDecompression: pi.NbDecompression,
ChainID: pi.ChainID,
L2MessageServiceAddr: pi.L2MessageServiceAddr[:],
FinalRollingHashNumber: pi.FinalRollingHashNumber,
},
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)),
FinalBlockNumber: pi.FinalBlockNumber,
FinalBlockTimestamp: pi.FinalBlockTimestamp,
FinalRollingHashNumber: pi.FinalRollingHashNumber,
L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth,
L2MsgMerkleTreeRoots: make([][32]frontend.Variable, len(pi.L2MsgMerkleTreeRoots)),
FinalBlockNumber: pi.FinalBlockNumber,
FinalBlockTimestamp: pi.FinalBlockTimestamp,
L2MsgMerkleTreeDepth: pi.L2MsgMerkleTreeDepth,
}

utils.Copy(s.FinalRollingHash[:], pi.FinalRollingHash[:])
Expand All @@ -154,21 +154,23 @@ type AggregationFPIQSnark struct {
InitialBlockTimestamp frontend.Variable
InitialRollingHash [32]frontend.Variable
InitialRollingHashNumber frontend.Variable
ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID
L2MessageServiceAddr frontend.Variable // 20 bytes
// Ideally, FinalRollingHash and FinalRollingHashNumber would be inferred from the executions
// but sometimes executions are missing those values
FinalRollingHash [32]frontend.Variable
FinalRollingHashNumber frontend.Variable
ChainID frontend.Variable // for now we're forcing all executions to have the same chain ID
L2MessageServiceAddr frontend.Variable // 20 bytes
}

type AggregationFPISnark struct {
AggregationFPIQSnark
NbL2Messages frontend.Variable // TODO not used in hash. delete if not necessary
L2MsgMerkleTreeRoots [][32]frontend.Variable
// FinalStateRootHash frontend.Variable redundant: incorporated into final shnarf
FinalBlockNumber frontend.Variable
FinalBlockTimestamp frontend.Variable
FinalRollingHash [32]frontend.Variable
FinalRollingHashNumber frontend.Variable
FinalShnarf [32]frontend.Variable
L2MsgMerkleTreeDepth int
FinalBlockNumber frontend.Variable
FinalBlockTimestamp frontend.Variable
FinalShnarf [32]frontend.Variable
L2MsgMerkleTreeDepth int
}

// NewAggregationFPI does NOT set all fields, only the ones covered in public_input.Aggregation
Expand Down
9 changes: 9 additions & 0 deletions prover/utils/test_utils/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package test_utils
import (
"crypto/rand"
"encoding/binary"
"encoding/json"
"fmt"
"github.com/stretchr/testify/require"
"math"
"os"
"testing"
)

type FakeTestingT struct{}
Expand Down Expand Up @@ -37,3 +40,9 @@ func RandIntSliceN(length, n int) []int {
}
return res
}

func LoadJson(t *testing.T, path string, v any) {
in, err := os.Open(path)
require.NoError(t, err)
require.NoError(t, json.NewDecoder(in).Decode(v))
}

0 comments on commit 333bc77

Please sign in to comment.