|
| 1 | +// (c) Cartesi and individual authors (see AUTHORS) |
| 2 | +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) |
| 3 | + |
| 4 | +package merkle |
| 5 | + |
| 6 | +import ( |
| 7 | + "fmt" |
| 8 | + "math/big" |
| 9 | + "slices" |
| 10 | + |
| 11 | + "github.com/ethereum/go-ethereum/common" |
| 12 | + "github.com/ethereum/go-ethereum/crypto" |
| 13 | +) |
| 14 | + |
| 15 | +// MerkleProof: dave/common-rs/merkle/src/tree.rs |
| 16 | +type Proof struct { |
| 17 | + Pos *big.Int |
| 18 | + Node common.Hash |
| 19 | + Siblings []common.Hash |
| 20 | +} |
| 21 | + |
| 22 | +func Leaf(node common.Hash, pos *big.Int) *Proof { |
| 23 | + return &Proof{ |
| 24 | + Node: node, |
| 25 | + Pos: pos, |
| 26 | + Siblings: nil, |
| 27 | + } |
| 28 | +} |
| 29 | + |
| 30 | +func (proof *Proof) BuildRoot() common.Hash { |
| 31 | + zero := big.NewInt(0) |
| 32 | + two := big.NewInt(2) |
| 33 | + rootHash := proof.Node |
| 34 | + |
| 35 | + for i, s := range proof.Siblings { |
| 36 | + |
| 37 | + // ((pos >> i) % 2) == 0 |
| 38 | + if new(big.Int).Rem(new(big.Int).Rsh(proof.Pos, uint(i)), two).Cmp(zero) == 0 { |
| 39 | + rootHash = crypto.Keccak256Hash(rootHash[:], s[:]) |
| 40 | + } else { |
| 41 | + rootHash = crypto.Keccak256Hash(s[:], rootHash[:]) |
| 42 | + } |
| 43 | + } |
| 44 | + return rootHash |
| 45 | +} |
| 46 | + |
| 47 | +func (proof *Proof) VerifyRoot(other common.Hash) bool { |
| 48 | + return proof.BuildRoot() == other |
| 49 | +} |
| 50 | + |
| 51 | +func (proof *Proof) PushHash(h common.Hash) { |
| 52 | + proof.Siblings = append(proof.Siblings, h) |
| 53 | +} |
| 54 | + |
| 55 | +//////////////////////////////////////////////////////////////////////////////// |
| 56 | + |
| 57 | +// MerkleTree: dave/common-rs/merkle/src/tree.rs |
| 58 | +type Tree struct { |
| 59 | + RootHash common.Hash |
| 60 | + Height uint32 |
| 61 | + Subtrees *InnerNode |
| 62 | +} |
| 63 | + |
| 64 | +// InnerNode: dave/common-rs/merkle/src/tree.rs |
| 65 | +// Emulate the rust enum type with a struct containing both {Pair, Iterated}. |
| 66 | +type InnerNode struct { |
| 67 | + // Pair |
| 68 | + LHS, RHS *Tree |
| 69 | + |
| 70 | + // Iterated |
| 71 | + Child *Tree |
| 72 | +} |
| 73 | + |
| 74 | +func (inner *InnerNode) Valid() bool { |
| 75 | + isPair := (inner.LHS != nil && inner.RHS != nil) |
| 76 | + isIterated := inner.Child != nil |
| 77 | + return (isPair || isIterated) && !(isPair && isIterated) // xor |
| 78 | +} |
| 79 | + |
| 80 | +func (inner *InnerNode) Children() (*Tree, *Tree) { |
| 81 | + if !inner.Valid() { |
| 82 | + panic(fmt.Sprintf("invalid InnerNode state: %v\n", inner)) |
| 83 | + } |
| 84 | + |
| 85 | + if inner.Child != nil { |
| 86 | + return inner.Child, inner.Child |
| 87 | + } else { |
| 88 | + return inner.LHS, inner.RHS |
| 89 | + } |
| 90 | +} |
| 91 | + |
| 92 | +func TreeLeaf(hash common.Hash) *Tree { |
| 93 | + return &Tree{ |
| 94 | + Height: 0, |
| 95 | + RootHash: hash, |
| 96 | + Subtrees: nil, |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +func (tree *Tree) GetRootHash() common.Hash { |
| 101 | + return tree.RootHash |
| 102 | +} |
| 103 | + |
| 104 | +func (tree *Tree) FindChildByHash(hash common.Hash) *InnerNode { |
| 105 | + if inner := tree.Subtrees; inner != nil { |
| 106 | + if !inner.Valid() { |
| 107 | + panic(fmt.Sprintf("invalid InnerNode state: %v\n", inner)) |
| 108 | + } |
| 109 | + |
| 110 | + if inner.Child != nil { |
| 111 | + child := inner.Child.FindChildByHash(hash) |
| 112 | + if child != nil { |
| 113 | + return child |
| 114 | + } |
| 115 | + } else { |
| 116 | + lhs := inner.LHS.FindChildByHash(hash) |
| 117 | + if lhs != nil { |
| 118 | + return lhs |
| 119 | + } |
| 120 | + |
| 121 | + rhs := inner.LHS.FindChildByHash(hash) |
| 122 | + if rhs != nil { |
| 123 | + return rhs |
| 124 | + } |
| 125 | + } |
| 126 | + } |
| 127 | + return nil // not found |
| 128 | +} |
| 129 | + |
| 130 | +func (tree *Tree) Join(other *Tree) *Tree { |
| 131 | + return &Tree{ |
| 132 | + RootHash: crypto.Keccak256Hash(tree.RootHash[:], other.RootHash[:]), |
| 133 | + Height: tree.Height + 1, |
| 134 | + Subtrees: &InnerNode{ |
| 135 | + LHS: tree, |
| 136 | + RHS: other, |
| 137 | + }, |
| 138 | + } |
| 139 | +} |
| 140 | + |
| 141 | +func (tree *Tree) Iterated(rep uint64) *Tree { |
| 142 | + root := tree |
| 143 | + for range rep { |
| 144 | + root = &Tree{ |
| 145 | + RootHash: crypto.Keccak256Hash(root.RootHash[:], root.RootHash[:]), |
| 146 | + Height: tree.Height + 1, |
| 147 | + Subtrees: &InnerNode{ |
| 148 | + Child: tree, |
| 149 | + }, |
| 150 | + } |
| 151 | + } |
| 152 | + return root |
| 153 | +} |
| 154 | + |
| 155 | +func (tree *Tree) ProveLeaf(index *big.Int) *Proof { |
| 156 | + return tree.ProveLeafRec(index) |
| 157 | +} |
| 158 | + |
| 159 | +func (tree *Tree) ProveLast() *Proof { |
| 160 | + one := big.NewInt(1) |
| 161 | + |
| 162 | + // index = (1 << height) - 1 |
| 163 | + index := new(big.Int).Sub( |
| 164 | + new(big.Int).Lsh( |
| 165 | + one, |
| 166 | + uint(tree.Height), |
| 167 | + ), |
| 168 | + one, |
| 169 | + ) |
| 170 | + return tree.ProveLeaf(index) |
| 171 | +} |
| 172 | + |
| 173 | +func (tree *Tree) ProveLeafRec(index *big.Int) *Proof { |
| 174 | + one := big.NewInt(1) |
| 175 | + zero := big.NewInt(0) |
| 176 | + numLeafs := new(big.Int).Lsh(one, uint(tree.Height)) |
| 177 | + if numLeafs.Cmp(index) <= 0 { |
| 178 | + panic(fmt.Sprintf("index out of bounds: %v, %v", numLeafs, index)) |
| 179 | + } |
| 180 | + |
| 181 | + subtree := tree.Subtrees |
| 182 | + if subtree == nil { |
| 183 | + if index.Cmp(zero) != 0 { |
| 184 | + panic(fmt.Sprintf("invalid Tree state: %v", tree)) |
| 185 | + } |
| 186 | + if tree.Height != 0 { |
| 187 | + panic(fmt.Sprintf("invalid Tree state: %v", tree)) |
| 188 | + } |
| 189 | + return Leaf(tree.RootHash, index) |
| 190 | + } |
| 191 | + |
| 192 | + shiftAmount := uint(tree.Height - 1) |
| 193 | + isLeftLeaf := new(big.Int).Rsh(index, shiftAmount).Cmp(zero) == 0 |
| 194 | + |
| 195 | + // innerIndex = index & !(1 << shiftAmount) |
| 196 | + innerIndex := new(big.Int).And( |
| 197 | + index, |
| 198 | + new(big.Int).Not( |
| 199 | + new(big.Int).Lsh( |
| 200 | + one, |
| 201 | + shiftAmount, |
| 202 | + ), |
| 203 | + ), |
| 204 | + ) |
| 205 | + |
| 206 | + lhs, rhs := subtree.Children() |
| 207 | + if isLeftLeaf { |
| 208 | + proof := lhs.ProveLeafRec(innerIndex) |
| 209 | + proof.PushHash(rhs.RootHash) |
| 210 | + proof.Pos = index |
| 211 | + return proof |
| 212 | + } else { |
| 213 | + proof := rhs.ProveLeafRec(innerIndex) |
| 214 | + proof.PushHash(lhs.RootHash) |
| 215 | + proof.Pos = index |
| 216 | + return proof |
| 217 | + } |
| 218 | +} |
| 219 | + |
| 220 | +//////////////////////////////////////////////////////////////////////////////// |
| 221 | + |
| 222 | +// Node: common-rs/merkle/src/tree_builder.rs |
| 223 | +type Node struct { |
| 224 | + Tree *Tree |
| 225 | + AccumulatedCount *big.Int |
| 226 | +} |
| 227 | + |
| 228 | +type Builder struct { |
| 229 | + Trees []Node |
| 230 | +} |
| 231 | + |
| 232 | +func (b *Builder) Height() (uint32, bool) { |
| 233 | + n := len(b.Trees) |
| 234 | + if n == 0 { |
| 235 | + return 0, false |
| 236 | + } |
| 237 | + return b.Trees[n-1].Tree.Height, true |
| 238 | +} |
| 239 | + |
| 240 | +func (b *Builder) Count() (*big.Int, bool) { |
| 241 | + n := len(b.Trees) |
| 242 | + if n == 0 { |
| 243 | + return nil, false |
| 244 | + } |
| 245 | + return b.Trees[n-1].AccumulatedCount, true |
| 246 | +} |
| 247 | + |
| 248 | +func (b *Builder) CanBuild() bool { |
| 249 | + n := len(b.Trees) |
| 250 | + if n == 0 { |
| 251 | + return false |
| 252 | + } |
| 253 | + return isPow2(b.Trees[n-1].AccumulatedCount) |
| 254 | +} |
| 255 | + |
| 256 | +func (b *Builder) Append(leaf *Tree) { |
| 257 | + b.AppendRepeated(leaf, big.NewInt(1)) |
| 258 | +} |
| 259 | + |
| 260 | +func (b *Builder) AppendRepeatedUint64(leaf *Tree, reps uint64) { |
| 261 | + b.AppendRepeated(leaf, new(big.Int).SetUint64(reps)) |
| 262 | +} |
| 263 | + |
| 264 | +func (b *Builder) AppendRepeated(leaf *Tree, reps *big.Int) { |
| 265 | + zero := big.NewInt(0) |
| 266 | + if reps.Cmp(zero) <= 0 { |
| 267 | + panic("invalid repetitions") |
| 268 | + } |
| 269 | + |
| 270 | + accumulatedCount := b.CalculateAccumulatedCount(reps) |
| 271 | + if height, ok := b.Height(); ok { |
| 272 | + if height != leaf.Height { |
| 273 | + panic("mismatched tree size") |
| 274 | + } |
| 275 | + } |
| 276 | + b.Trees = append(b.Trees, Node{ |
| 277 | + Tree: leaf, |
| 278 | + AccumulatedCount: accumulatedCount, |
| 279 | + }) |
| 280 | +} |
| 281 | + |
| 282 | +func (b *Builder) Build() *Tree { |
| 283 | + if count, ok := b.Count(); ok { |
| 284 | + if !isCountPow2(count) { |
| 285 | + panic(fmt.Sprintf("builder has %v leafs, which is not a power of two", count)) |
| 286 | + } |
| 287 | + log2Size := countTrailingZeroes(count) |
| 288 | + return buildMerkle(b.Trees, log2Size, big.NewInt(0)) |
| 289 | + } else { |
| 290 | + panic("no leafs in the merkle builder") |
| 291 | + } |
| 292 | +} |
| 293 | + |
| 294 | +func (b *Builder) CalculateAccumulatedCount(reps *big.Int) *big.Int { |
| 295 | + n := len(b.Trees) |
| 296 | + if n != 0 { |
| 297 | + zero := big.NewInt(0) |
| 298 | + if reps.Cmp(zero) == 0 { |
| 299 | + panic("merkle builder is full") |
| 300 | + } |
| 301 | + |
| 302 | + // TODO: warping version |
| 303 | + return new(big.Int).Add(reps, b.Trees[n-1].AccumulatedCount) |
| 304 | + } else { |
| 305 | + return reps |
| 306 | + } |
| 307 | +} |
| 308 | + |
| 309 | +func buildMerkle(trees []Node, log2Size uint, stride *big.Int) *Tree { |
| 310 | + one := big.NewInt(1) |
| 311 | + size := new(big.Int).Lsh(one, log2Size) // TODO: warping version |
| 312 | + |
| 313 | + firstTime := new(big.Int).Add(new(big.Int).Mul(stride, size), one) |
| 314 | + lastTime := new(big.Int).Mul(new(big.Int).Add(stride, one), size) |
| 315 | + |
| 316 | + firstCell := findCellContaining(trees, firstTime) |
| 317 | + lastCell := findCellContaining(trees, lastTime) |
| 318 | + |
| 319 | + if firstCell == lastCell { |
| 320 | + tree := trees[firstCell].Tree |
| 321 | + iterated := tree.Iterated(uint64(log2Size)) |
| 322 | + return iterated |
| 323 | + } |
| 324 | + |
| 325 | + left := buildMerkle(trees[firstCell:(lastCell+1)], |
| 326 | + log2Size-1, |
| 327 | + new(big.Int).Lsh(stride, 1), |
| 328 | + ) |
| 329 | + |
| 330 | + right := buildMerkle(trees[firstCell:(lastCell+1)], |
| 331 | + log2Size-1, |
| 332 | + new(big.Int).Add(new(big.Int).Lsh(stride, 1), one), |
| 333 | + ) |
| 334 | + |
| 335 | + return left.Join(right) |
| 336 | +} |
| 337 | + |
| 338 | +func findCellContaining(trees []Node, elem *big.Int) uint { |
| 339 | + one := big.NewInt(1) |
| 340 | + left := uint(0) |
| 341 | + right := uint(len(trees) - 1) |
| 342 | + |
| 343 | + for left < right { |
| 344 | + needle := left + (right-left)/2 |
| 345 | + |
| 346 | + // TODO: wrapping version |
| 347 | + x := new(big.Int).Sub(trees[needle].AccumulatedCount, one) |
| 348 | + y := new(big.Int).Sub(elem, one) |
| 349 | + if x.Cmp(y) < 0 { |
| 350 | + left = needle + 1 |
| 351 | + } else { |
| 352 | + right = needle |
| 353 | + } |
| 354 | + } |
| 355 | + return left |
| 356 | +} |
| 357 | + |
| 358 | +//////////////////////////////////////////////////////////////////////////////// |
| 359 | + |
| 360 | +func isPow2(x *big.Int) bool { |
| 361 | + if x.Sign() <= 0 { |
| 362 | + return false |
| 363 | + } |
| 364 | + |
| 365 | + // x & (x-1) == 0 |
| 366 | + zero := big.NewInt(0) |
| 367 | + one := big.NewInt(1) |
| 368 | + return new(big.Int).And( |
| 369 | + x, |
| 370 | + new(big.Int).Sub( |
| 371 | + x, |
| 372 | + one, |
| 373 | + ), |
| 374 | + ).Cmp(zero) == 0 |
| 375 | +} |
| 376 | + |
| 377 | +func isCountPow2(x *big.Int) bool { |
| 378 | + return x.Cmp(big.NewInt(0)) == 0 || isPow2(x) |
| 379 | +} |
| 380 | + |
| 381 | +func countTrailingZeroes(x *big.Int) uint { |
| 382 | + count := uint(0) |
| 383 | + |
| 384 | + // each byte from least to most significant |
| 385 | +brk: |
| 386 | + for _, b := range slices.Backward(x.Bytes()) { |
| 387 | + for i := range 8 { |
| 388 | + if b>>i&1 != 0 { |
| 389 | + break brk |
| 390 | + } |
| 391 | + count++ |
| 392 | + } |
| 393 | + } |
| 394 | + return count |
| 395 | +} |
0 commit comments