Skip to content

Commit f3e7330

Browse files
committed
wip
1 parent c6be96a commit f3e7330

File tree

2 files changed

+285
-0
lines changed

2 files changed

+285
-0
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// SPDX-License-Identifier: BUSL-1.1
2+
pragma solidity ^0.8.27;
3+
4+
import "forge-std/Test.sol";
5+
import "src/contracts/libraries/Merkle.sol";
6+
import "src/test/utils/Murky.sol";
7+
8+
abstract contract MerkleBaseTest is Test, MurkyBase {
9+
bool usingSha;
10+
bytes32[] leaves;
11+
bytes32 root;
12+
bytes[] proofs;
13+
14+
function test_verifyInclusion_ValidProof() public {
15+
assertValidProofs();
16+
}
17+
18+
function test_verifyInclusion_EmptyProofs() public {
19+
proofs = new bytes[](proofs.length);
20+
assertInvalidProofs();
21+
}
22+
23+
function assertValidProofs() internal virtual {
24+
function (bytes memory proof, bytes32 root, bytes32 leaf, uint256 index) returns (bool) verifyInclusion =
25+
usingSha ? Merkle.verifyInclusionSha256 : Merkle.verifyInclusionKeccak;
26+
for (uint i = 0; i < leaves.length; ++i) {
27+
assertTrue(verifyInclusion(proofs[i], root, leaves[i], i), "invalid proof");
28+
}
29+
}
30+
31+
function assertInvalidProofs() internal virtual {
32+
function (bytes memory proof, bytes32 root, bytes32 leaf, uint256 index) returns (bool) verifyInclusion =
33+
usingSha ? Merkle.verifyInclusionSha256 : Merkle.verifyInclusionKeccak;
34+
for (uint i = 0; i < leaves.length; ++i) {
35+
assertFalse(verifyInclusion(proofs[i], root, leaves[i], i), "valid proof");
36+
}
37+
}
38+
39+
function getLeaves(uint numLeaves) internal view virtual returns (bytes32[] memory leaves) {
40+
bytes memory _leavesAsBytes = vm.randomBytes(numLeaves * 32);
41+
/// @solidity memory-safe-assembly
42+
assembly {
43+
leaves := _leavesAsBytes // Typecast bytes -> bytes32[].
44+
mstore(leaves, numLeaves) // Update length n*32 -> n.
45+
}
46+
}
47+
48+
function getProofs(bytes32[] memory leaves) public view virtual returns (bytes[] memory proofs);
49+
}
50+
51+
contract MerkleKeccakTest is MerkleBaseTest, MerkleKeccak {
52+
function setUp() public {
53+
usingSha = false;
54+
leaves = getLeaves(vm.randomBool() ? 9 : 10);
55+
root = Merkle.merkleizeKeccak(leaves);
56+
proofs = getProofs(leaves);
57+
}
58+
59+
function nextPowerOf2(uint v) internal pure returns (uint) {
60+
unchecked {
61+
// Round up to the next power of 2 using the method described here:
62+
// https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
63+
if (v == 0) return 0;
64+
v -= 1;
65+
v |= v >> 1;
66+
v |= v >> 2;
67+
v |= v >> 4;
68+
v |= v >> 8;
69+
v |= v >> 16;
70+
v |= v >> 32;
71+
v |= v >> 64;
72+
v |= v >> 128;
73+
return v + 1;
74+
}
75+
}
76+
77+
function getProofs(bytes32[] memory leaves) public view virtual override returns (bytes[] memory proofs) {
78+
// Merkle.merkleizeKeccak pads to next power of 2, so we need to match that.
79+
uint numLeaves = nextPowerOf2(leaves.length);
80+
bytes32[] memory paddedLeaves = new bytes32[](numLeaves);
81+
for (uint i = 0; i < leaves.length; ++i) {
82+
paddedLeaves[i] = leaves[i];
83+
}
84+
85+
proofs = new bytes[](leaves.length);
86+
for (uint i = 0; i < leaves.length; ++i) {
87+
proofs[i] = abi.encodePacked(getProof(paddedLeaves, i));
88+
}
89+
}
90+
}

src/test/utils/Murky.sol

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
// SPDX-License-Identifier: MIT
2+
pragma solidity ^0.8.4;
3+
4+
/// @notice Modified from https://github.com/dmfxyz/murky/commit/991e371eb1dfa9f86701869eb08ec4e98c3cc0b0.
5+
abstract contract MurkyBase {
6+
function hashLeafPairs(bytes32 left, bytes32 right) public view virtual returns (bytes32 _hash);
7+
8+
function verifyProof(bytes32 root, bytes32[] memory proof, bytes32 valueToProve) external view virtual returns (bool) {
9+
// proof length must be less than max array size
10+
bytes32 rollingHash = valueToProve;
11+
uint length = proof.length;
12+
unchecked {
13+
for (uint i = 0; i < length; ++i) {
14+
rollingHash = hashLeafPairs(rollingHash, proof[i]);
15+
}
16+
}
17+
return root == rollingHash;
18+
}
19+
20+
/**
21+
*
22+
* PROOF GENERATION *
23+
*
24+
*/
25+
function getRoot(bytes32[] memory data) public view virtual returns (bytes32) {
26+
require(data.length > 1, "won't generate root for single leaf");
27+
while (data.length > 1) data = hashLevel(data);
28+
return data[0];
29+
}
30+
31+
function getProof(bytes32[] memory data, uint node) public view virtual returns (bytes32[] memory) {
32+
require(data.length > 1, "won't generate proof for single leaf");
33+
// The size of the proof is equal to the ceiling of log2(numLeaves)
34+
bytes32[] memory result = new bytes32[](log2ceilBitMagic(data.length));
35+
uint pos = 0;
36+
37+
// Two overflow risks: node, pos
38+
// node: max array size is 2**256-1. Largest index in the array will be 1 less than that. Also,
39+
// for dynamic arrays, size is limited to 2**64-1
40+
// pos: pos is bounded by log2(data.length), which should be less than type(uint256).max
41+
while (data.length > 1) {
42+
unchecked {
43+
if (node & 0x1 == 1) result[pos] = data[node - 1];
44+
else if (node + 1 == data.length) result[pos] = bytes32(0);
45+
else result[pos] = data[node + 1];
46+
++pos;
47+
node /= 2;
48+
}
49+
data = hashLevel(data);
50+
}
51+
return result;
52+
}
53+
54+
///@dev function is private to prevent unsafe data from being passed
55+
function hashLevel(bytes32[] memory data) private view returns (bytes32[] memory) {
56+
bytes32[] memory result;
57+
58+
// Function is private, and all internal callers check that data.length >=2.
59+
// Underflow is not possible as lowest possible value for data/result index is 1
60+
// overflow should be safe as length is / 2 always.
61+
unchecked {
62+
uint length = data.length;
63+
if (length & 0x1 == 1) {
64+
result = new bytes32[](length / 2 + 1);
65+
result[result.length - 1] = hashLeafPairs(data[length - 1], bytes32(0));
66+
} else {
67+
result = new bytes32[](length / 2);
68+
}
69+
// pos is upper bounded by data.length / 2, so safe even if array is at max size
70+
uint pos = 0;
71+
for (uint i = 0; i < length - 1; i += 2) {
72+
result[pos] = hashLeafPairs(data[i], data[i + 1]);
73+
++pos;
74+
}
75+
}
76+
return result;
77+
}
78+
79+
/**
80+
*
81+
* MATH "LIBRARY" *
82+
*
83+
*/
84+
85+
/// @dev Note that x is assumed > 0
86+
function log2ceil(uint x) public view returns (uint) {
87+
uint ceil = 0;
88+
uint pOf2;
89+
// If x is a power of 2, then this function will return a ceiling
90+
// that is 1 greater than the actual ceiling. So we need to check if
91+
// x is a power of 2, and subtract one from ceil if so.
92+
assembly {
93+
// we check by seeing if x == (~x + 1) & x. This applies a mask
94+
// to find the lowest set bit of x and then checks it for equality
95+
// with x. If they are equal, then x is a power of 2.
96+
97+
/* Example
98+
x has single bit set
99+
x := 0000_1000
100+
(~x + 1) = (1111_0111) + 1 = 1111_1000
101+
(1111_1000 & 0000_1000) = 0000_1000 == x
102+
103+
x has multiple bits set
104+
x := 1001_0010
105+
(~x + 1) = (0110_1101 + 1) = 0110_1110
106+
(0110_1110 & x) = 0000_0010 != x
107+
*/
108+
109+
// we do some assembly magic to treat the bool as an integer later on
110+
pOf2 := eq(and(add(not(x), 1), x), x)
111+
}
112+
113+
// if x == type(uint256).max, than ceil is capped at 256
114+
// if x == 0, then pO2 == 0, so ceil won't underflow
115+
unchecked {
116+
while (x > 0) {
117+
x >>= 1;
118+
ceil++;
119+
}
120+
ceil -= pOf2; // see above
121+
}
122+
return ceil;
123+
}
124+
125+
/// Original bitmagic adapted from https://github.com/paulrberg/prb-math/blob/main/contracts/PRBMath.sol
126+
/// @dev Note that x assumed > 1
127+
function log2ceilBitMagic(uint x) public view returns (uint) {
128+
if (x <= 1) return 0;
129+
uint msb = 0;
130+
uint _x = x;
131+
if (x >= 2 ** 128) {
132+
x >>= 128;
133+
msb += 128;
134+
}
135+
if (x >= 2 ** 64) {
136+
x >>= 64;
137+
msb += 64;
138+
}
139+
if (x >= 2 ** 32) {
140+
x >>= 32;
141+
msb += 32;
142+
}
143+
if (x >= 2 ** 16) {
144+
x >>= 16;
145+
msb += 16;
146+
}
147+
if (x >= 2 ** 8) {
148+
x >>= 8;
149+
msb += 8;
150+
}
151+
if (x >= 2 ** 4) {
152+
x >>= 4;
153+
msb += 4;
154+
}
155+
if (x >= 2 ** 2) {
156+
x >>= 2;
157+
msb += 2;
158+
}
159+
if (x >= 2 ** 1) msb += 1;
160+
161+
uint lsb = (~_x + 1) & _x;
162+
if ((lsb == _x) && (msb > 0)) return msb;
163+
else return msb + 1;
164+
}
165+
}
166+
167+
contract MerkleKeccak is MurkyBase {
168+
function hashLeafPairs(bytes32 left, bytes32 right) public view override returns (bytes32 _hash) {
169+
assembly {
170+
mstore(0x0, left)
171+
mstore(0x20, right)
172+
_hash := keccak256(0x0, 0x40)
173+
}
174+
}
175+
}
176+
177+
contract MerkleSha is MurkyBase {
178+
address constant SHA256_PRECOMPILE = 0x0000000000000000000000000000000000000002;
179+
180+
function hashLeafPairs(bytes32 left, bytes32 right) public view override returns (bytes32 _hash) {
181+
assembly {
182+
switch lt(left, right)
183+
case 0 {
184+
mstore(0x0, right)
185+
mstore(0x20, left)
186+
}
187+
default {
188+
mstore(0x0, left)
189+
mstore(0x20, right)
190+
}
191+
_hash := mload(iszero(staticcall(gas(), SHA256_PRECOMPILE, 0x0, 0x40, 0x0, 0x20)))
192+
if iszero(returndatasize()) { invalid() }
193+
}
194+
}
195+
}

0 commit comments

Comments
 (0)