diff --git a/lib/runtime/storage/trie.go b/lib/runtime/storage/trie.go index ac2767814a..d2ac93363e 100644 --- a/lib/runtime/storage/trie.go +++ b/lib/runtime/storage/trie.go @@ -214,8 +214,7 @@ func (t *TrieState) ClearPrefix(prefix []byte) error { if currentTx := t.getCurrentTransaction(); currentTx != nil { keysOnState := make([]string, 0) - iter := t.state.PrefixedIter(prefix) - for key := iter.NextKey(); bytes.HasPrefix(key, prefix); key = iter.NextKey() { + for key := range t.state.PrefixedKeys(prefix) { keysOnState = append(keysOnState, string(key)) } @@ -235,8 +234,7 @@ func (t *TrieState) ClearPrefixLimit(prefix []byte, limit uint32) ( if currentTx := t.getCurrentTransaction(); currentTx != nil { keysOnState := make([]string, 0) - iter := t.state.PrefixedIter(prefix) - for key := iter.NextKey(); bytes.HasPrefix(key, prefix); key = iter.NextKey() { + for key := range t.state.PrefixedKeys(prefix) { keysOnState = append(keysOnState, string(key)) } @@ -430,8 +428,7 @@ func (t *TrieState) ClearPrefixInChild(keyToChild, prefix []byte) error { } var onStateKeys []string - iter := child.PrefixedIter(prefix) - for key := iter.NextKey(); bytes.HasPrefix(key, prefix); key = iter.NextKey() { + for key := range child.PrefixedKeys(prefix) { onStateKeys = append(onStateKeys, string(key)) } @@ -466,8 +463,7 @@ func (t *TrieState) ClearPrefixInChildWithLimit(keyToChild, prefix []byte, limit } var onStateKeys []string - iter := child.PrefixedIter(prefix) - for key := iter.NextKey(); bytes.HasPrefix(key, prefix); key = iter.NextKey() { + for key := range child.PrefixedKeys(prefix) { onStateKeys = append(onStateKeys, string(key)) } diff --git a/pkg/trie/inmemory/iterator.go b/pkg/trie/inmemory/iterator.go index 284ff2df40..b4e0293dbf 100644 --- a/pkg/trie/inmemory/iterator.go +++ b/pkg/trie/inmemory/iterator.go @@ -6,6 +6,7 @@ package inmemory import ( "bytes" "fmt" + "iter" "github.com/ChainSafe/gossamer/pkg/trie" "github.com/ChainSafe/gossamer/pkg/trie/codec" @@ -87,6 +88,23 @@ func (t *InMemoryTrie) Entries() (keyValueMap map[string][]byte) { return keyValueMap } +func (t *InMemoryTrie) PrefixedKeys(prefix []byte) iter.Seq[[]byte] { + iter := NewInMemoryTrieIterator(WithTrie(t), WithCursorAt(codec.KeyLEToNibbles(prefix))) + + return func(yield func([]byte) bool) { + // Return same prefix as first key if it's present in trie + if t.Get(prefix) != nil && !yield(prefix) { + return + } + + for key := iter.NextKey(); bytes.HasPrefix(key, prefix); key = iter.NextKey() { + if !yield(key) { + return + } + } + } +} + // NextKey returns the next key in the trie in lexicographic order. // It returns nil if no next key is found. func (t *InMemoryTrie) NextKey(keyLE []byte) (nextKeyLE []byte) { diff --git a/pkg/trie/inmemory/iterator_test.go b/pkg/trie/inmemory/iterator_test.go index e5a5937a88..2be4751c34 100644 --- a/pkg/trie/inmemory/iterator_test.go +++ b/pkg/trie/inmemory/iterator_test.go @@ -4,7 +4,6 @@ package inmemory import ( - "bytes" "testing" "github.com/ChainSafe/gossamer/pkg/trie/codec" @@ -42,10 +41,9 @@ func TestInMemoryIteratorGetAllKeysWithPrefix(t *testing.T) { tt.Put([]byte("account_storage:JJK:EEE"), []byte("0x10")) prefix := []byte("account_storage") - iter := tt.PrefixedIter(prefix) keys := make([][]byte, 0) - for key := iter.NextKey(); bytes.HasPrefix(key, prefix); key = iter.NextKey() { + for key := range tt.PrefixedKeys(prefix) { keys = append(keys, key) } @@ -58,3 +56,32 @@ func TestInMemoryIteratorGetAllKeysWithPrefix(t *testing.T) { require.Equal(t, expectedKeys, keys) } + +func TestInMemoryIteratorGetAllKeysWithPrefixIncluded(t *testing.T) { + tt := NewEmptyTrie() + + tt.Put([]byte("services_storage:serviceA:19090"), []byte("0x10")) + tt.Put([]byte("services_storage:serviceB:22222"), []byte("0x10")) + tt.Put([]byte("account_storage"), []byte("0x10")) + tt.Put([]byte("account_storage:ABC:AAA"), []byte("0x10")) + tt.Put([]byte("account_storage:ABC:CCC"), []byte("0x10")) + tt.Put([]byte("account_storage:ABC:DDD"), []byte("0x10")) + tt.Put([]byte("account_storage:JJK:EEE"), []byte("0x10")) + + prefix := []byte("account_storage") + + keys := make([][]byte, 0) + for key := range tt.PrefixedKeys(prefix) { + keys = append(keys, key) + } + + expectedKeys := [][]byte{ + []byte("account_storage"), + []byte("account_storage:ABC:AAA"), + []byte("account_storage:ABC:CCC"), + []byte("account_storage:ABC:DDD"), + []byte("account_storage:JJK:EEE"), + } + + require.Equal(t, expectedKeys, keys) +} diff --git a/pkg/trie/trie.go b/pkg/trie/trie.go index 038880d8f7..223cfae915 100644 --- a/pkg/trie/trie.go +++ b/pkg/trie/trie.go @@ -5,6 +5,7 @@ package trie import ( "fmt" + "iter" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/trie/tracking" @@ -84,6 +85,7 @@ type TrieRead interface { Entries() (keyValueMap map[string][]byte) NextKey(key []byte) []byte GetKeysWithPrefix(prefix []byte) (keysLE [][]byte) + PrefixedKeys(prefix []byte) iter.Seq[[]byte] } type Trie interface {