Skip to content

Commit 16e9be2

Browse files
mpolitzervfusco
authored andcommitted
feat(merkle): add builder on merkle trees data structure
1 parent b07d998 commit 16e9be2

File tree

2 files changed

+538
-0
lines changed

2 files changed

+538
-0
lines changed

internal/merkle/builder.go

Lines changed: 395 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
// (c) Cartesi and individual authors (see AUTHORS)
2+
// SPDX-License-Identifier: Apache-2.0 (see LICENSE)
3+
4+
package merkle
5+
6+
import (
7+
"fmt"
8+
"math/big"
9+
"slices"
10+
11+
"github.com/ethereum/go-ethereum/common"
12+
"github.com/ethereum/go-ethereum/crypto"
13+
)
14+
15+
// MerkleProof: dave/common-rs/merkle/src/tree.rs
16+
type Proof struct {
17+
Pos *big.Int
18+
Node common.Hash
19+
Siblings []common.Hash
20+
}
21+
22+
func Leaf(node common.Hash, pos *big.Int) *Proof {
23+
return &Proof{
24+
Node: node,
25+
Pos: pos,
26+
Siblings: nil,
27+
}
28+
}
29+
30+
func (proof *Proof) BuildRoot() common.Hash {
31+
zero := big.NewInt(0)
32+
two := big.NewInt(2)
33+
rootHash := proof.Node
34+
35+
for i, s := range proof.Siblings {
36+
37+
// ((pos >> i) % 2) == 0
38+
if new(big.Int).Rem(new(big.Int).Rsh(proof.Pos, uint(i)), two).Cmp(zero) == 0 {
39+
rootHash = crypto.Keccak256Hash(rootHash[:], s[:])
40+
} else {
41+
rootHash = crypto.Keccak256Hash(s[:], rootHash[:])
42+
}
43+
}
44+
return rootHash
45+
}
46+
47+
func (proof *Proof) VerifyRoot(other common.Hash) bool {
48+
return proof.BuildRoot() == other
49+
}
50+
51+
func (proof *Proof) PushHash(h common.Hash) {
52+
proof.Siblings = append(proof.Siblings, h)
53+
}
54+
55+
////////////////////////////////////////////////////////////////////////////////
56+
57+
// MerkleTree: dave/common-rs/merkle/src/tree.rs
58+
type Tree struct {
59+
RootHash common.Hash
60+
Height uint32
61+
Subtrees *InnerNode
62+
}
63+
64+
// InnerNode: dave/common-rs/merkle/src/tree.rs
65+
// Emulate the rust enum type with a struct containing both {Pair, Iterated}.
66+
type InnerNode struct {
67+
// Pair
68+
LHS, RHS *Tree
69+
70+
// Iterated
71+
Child *Tree
72+
}
73+
74+
func (inner *InnerNode) Valid() bool {
75+
isPair := (inner.LHS != nil && inner.RHS != nil)
76+
isIterated := inner.Child != nil
77+
return (isPair || isIterated) && !(isPair && isIterated) // xor
78+
}
79+
80+
func (inner *InnerNode) Children() (*Tree, *Tree) {
81+
if !inner.Valid() {
82+
panic(fmt.Sprintf("invalid InnerNode state: %v\n", inner))
83+
}
84+
85+
if inner.Child != nil {
86+
return inner.Child, inner.Child
87+
} else {
88+
return inner.LHS, inner.RHS
89+
}
90+
}
91+
92+
func TreeLeaf(hash common.Hash) *Tree {
93+
return &Tree{
94+
Height: 0,
95+
RootHash: hash,
96+
Subtrees: nil,
97+
}
98+
}
99+
100+
func (tree *Tree) GetRootHash() common.Hash {
101+
return tree.RootHash
102+
}
103+
104+
func (tree *Tree) FindChildByHash(hash common.Hash) *InnerNode {
105+
if inner := tree.Subtrees; inner != nil {
106+
if !inner.Valid() {
107+
panic(fmt.Sprintf("invalid InnerNode state: %v\n", inner))
108+
}
109+
110+
if inner.Child != nil {
111+
child := inner.Child.FindChildByHash(hash)
112+
if child != nil {
113+
return child
114+
}
115+
} else {
116+
lhs := inner.LHS.FindChildByHash(hash)
117+
if lhs != nil {
118+
return lhs
119+
}
120+
121+
rhs := inner.LHS.FindChildByHash(hash)
122+
if rhs != nil {
123+
return rhs
124+
}
125+
}
126+
}
127+
return nil // not found
128+
}
129+
130+
func (tree *Tree) Join(other *Tree) *Tree {
131+
return &Tree{
132+
RootHash: crypto.Keccak256Hash(tree.RootHash[:], other.RootHash[:]),
133+
Height: tree.Height + 1,
134+
Subtrees: &InnerNode{
135+
LHS: tree,
136+
RHS: other,
137+
},
138+
}
139+
}
140+
141+
func (tree *Tree) Iterated(rep uint64) *Tree {
142+
root := tree
143+
for range rep {
144+
root = &Tree{
145+
RootHash: crypto.Keccak256Hash(root.RootHash[:], root.RootHash[:]),
146+
Height: tree.Height + 1,
147+
Subtrees: &InnerNode{
148+
Child: tree,
149+
},
150+
}
151+
}
152+
return root
153+
}
154+
155+
func (tree *Tree) ProveLeaf(index *big.Int) *Proof {
156+
return tree.ProveLeafRec(index)
157+
}
158+
159+
func (tree *Tree) ProveLast() *Proof {
160+
one := big.NewInt(1)
161+
162+
// index = (1 << height) - 1
163+
index := new(big.Int).Sub(
164+
new(big.Int).Lsh(
165+
one,
166+
uint(tree.Height),
167+
),
168+
one,
169+
)
170+
return tree.ProveLeaf(index)
171+
}
172+
173+
func (tree *Tree) ProveLeafRec(index *big.Int) *Proof {
174+
one := big.NewInt(1)
175+
zero := big.NewInt(0)
176+
numLeafs := new(big.Int).Lsh(one, uint(tree.Height))
177+
if numLeafs.Cmp(index) <= 0 {
178+
panic(fmt.Sprintf("index out of bounds: %v, %v", numLeafs, index))
179+
}
180+
181+
subtree := tree.Subtrees
182+
if subtree == nil {
183+
if index.Cmp(zero) != 0 {
184+
panic(fmt.Sprintf("invalid Tree state: %v", tree))
185+
}
186+
if tree.Height != 0 {
187+
panic(fmt.Sprintf("invalid Tree state: %v", tree))
188+
}
189+
return Leaf(tree.RootHash, index)
190+
}
191+
192+
shiftAmount := uint(tree.Height - 1)
193+
isLeftLeaf := new(big.Int).Rsh(index, shiftAmount).Cmp(zero) == 0
194+
195+
// innerIndex = index & !(1 << shiftAmount)
196+
innerIndex := new(big.Int).And(
197+
index,
198+
new(big.Int).Not(
199+
new(big.Int).Lsh(
200+
one,
201+
shiftAmount,
202+
),
203+
),
204+
)
205+
206+
lhs, rhs := subtree.Children()
207+
if isLeftLeaf {
208+
proof := lhs.ProveLeafRec(innerIndex)
209+
proof.PushHash(rhs.RootHash)
210+
proof.Pos = index
211+
return proof
212+
} else {
213+
proof := rhs.ProveLeafRec(innerIndex)
214+
proof.PushHash(lhs.RootHash)
215+
proof.Pos = index
216+
return proof
217+
}
218+
}
219+
220+
////////////////////////////////////////////////////////////////////////////////
221+
222+
// Node: common-rs/merkle/src/tree_builder.rs
223+
type Node struct {
224+
Tree *Tree
225+
AccumulatedCount *big.Int
226+
}
227+
228+
type Builder struct {
229+
Trees []Node
230+
}
231+
232+
func (b *Builder) Height() (uint32, bool) {
233+
n := len(b.Trees)
234+
if n == 0 {
235+
return 0, false
236+
}
237+
return b.Trees[n-1].Tree.Height, true
238+
}
239+
240+
func (b *Builder) Count() (*big.Int, bool) {
241+
n := len(b.Trees)
242+
if n == 0 {
243+
return nil, false
244+
}
245+
return b.Trees[n-1].AccumulatedCount, true
246+
}
247+
248+
func (b *Builder) CanBuild() bool {
249+
n := len(b.Trees)
250+
if n == 0 {
251+
return false
252+
}
253+
return isPow2(b.Trees[n-1].AccumulatedCount)
254+
}
255+
256+
func (b *Builder) Append(leaf *Tree) {
257+
b.AppendRepeated(leaf, big.NewInt(1))
258+
}
259+
260+
func (b *Builder) AppendRepeatedUint64(leaf *Tree, reps uint64) {
261+
b.AppendRepeated(leaf, new(big.Int).SetUint64(reps))
262+
}
263+
264+
func (b *Builder) AppendRepeated(leaf *Tree, reps *big.Int) {
265+
zero := big.NewInt(0)
266+
if reps.Cmp(zero) <= 0 {
267+
panic("invalid repetitions")
268+
}
269+
270+
accumulatedCount := b.CalculateAccumulatedCount(reps)
271+
if height, ok := b.Height(); ok {
272+
if height != leaf.Height {
273+
panic("mismatched tree size")
274+
}
275+
}
276+
b.Trees = append(b.Trees, Node{
277+
Tree: leaf,
278+
AccumulatedCount: accumulatedCount,
279+
})
280+
}
281+
282+
func (b *Builder) Build() *Tree {
283+
if count, ok := b.Count(); ok {
284+
if !isCountPow2(count) {
285+
panic(fmt.Sprintf("builder has %v leafs, which is not a power of two", count))
286+
}
287+
log2Size := countTrailingZeroes(count)
288+
return buildMerkle(b.Trees, log2Size, big.NewInt(0))
289+
} else {
290+
panic("no leafs in the merkle builder")
291+
}
292+
}
293+
294+
func (b *Builder) CalculateAccumulatedCount(reps *big.Int) *big.Int {
295+
n := len(b.Trees)
296+
if n != 0 {
297+
zero := big.NewInt(0)
298+
if reps.Cmp(zero) == 0 {
299+
panic("merkle builder is full")
300+
}
301+
302+
// TODO: warping version
303+
return new(big.Int).Add(reps, b.Trees[n-1].AccumulatedCount)
304+
} else {
305+
return reps
306+
}
307+
}
308+
309+
func buildMerkle(trees []Node, log2Size uint, stride *big.Int) *Tree {
310+
one := big.NewInt(1)
311+
size := new(big.Int).Lsh(one, log2Size) // TODO: warping version
312+
313+
firstTime := new(big.Int).Add(new(big.Int).Mul(stride, size), one)
314+
lastTime := new(big.Int).Mul(new(big.Int).Add(stride, one), size)
315+
316+
firstCell := findCellContaining(trees, firstTime)
317+
lastCell := findCellContaining(trees, lastTime)
318+
319+
if firstCell == lastCell {
320+
tree := trees[firstCell].Tree
321+
iterated := tree.Iterated(uint64(log2Size))
322+
return iterated
323+
}
324+
325+
left := buildMerkle(trees[firstCell:(lastCell+1)],
326+
log2Size-1,
327+
new(big.Int).Lsh(stride, 1),
328+
)
329+
330+
right := buildMerkle(trees[firstCell:(lastCell+1)],
331+
log2Size-1,
332+
new(big.Int).Add(new(big.Int).Lsh(stride, 1), one),
333+
)
334+
335+
return left.Join(right)
336+
}
337+
338+
func findCellContaining(trees []Node, elem *big.Int) uint {
339+
one := big.NewInt(1)
340+
left := uint(0)
341+
right := uint(len(trees) - 1)
342+
343+
for left < right {
344+
needle := left + (right-left)/2
345+
346+
// TODO: wrapping version
347+
x := new(big.Int).Sub(trees[needle].AccumulatedCount, one)
348+
y := new(big.Int).Sub(elem, one)
349+
if x.Cmp(y) < 0 {
350+
left = needle + 1
351+
} else {
352+
right = needle
353+
}
354+
}
355+
return left
356+
}
357+
358+
////////////////////////////////////////////////////////////////////////////////
359+
360+
func isPow2(x *big.Int) bool {
361+
if x.Sign() <= 0 {
362+
return false
363+
}
364+
365+
// x & (x-1) == 0
366+
zero := big.NewInt(0)
367+
one := big.NewInt(1)
368+
return new(big.Int).And(
369+
x,
370+
new(big.Int).Sub(
371+
x,
372+
one,
373+
),
374+
).Cmp(zero) == 0
375+
}
376+
377+
func isCountPow2(x *big.Int) bool {
378+
return x.Cmp(big.NewInt(0)) == 0 || isPow2(x)
379+
}
380+
381+
func countTrailingZeroes(x *big.Int) uint {
382+
count := uint(0)
383+
384+
// each byte from least to most significant
385+
brk:
386+
for _, b := range slices.Backward(x.Bytes()) {
387+
for i := range 8 {
388+
if b>>i&1 != 0 {
389+
break brk
390+
}
391+
count++
392+
}
393+
}
394+
return count
395+
}

0 commit comments

Comments
 (0)