@@ -75,10 +75,15 @@ const_machine_hash_view hash_tree::get_sparse_node_hash_view(index_type node_ind
7575
7676void hash_tree::get_pristine_proof (int curr_log2_size, proof_type &proof) const {
7777 const auto log2_target_size = proof.get_log2_target_size ();
78- for (int log2_size = curr_log2_size - 1 ; log2_size >= log2_target_size; --log2_size) {
78+ const auto log2_root_size = proof.get_log2_root_size ();
79+ const auto start_log2_size = std::min (log2_root_size, curr_log2_size) - 1 ;
80+ for (int log2_size = start_log2_size; log2_size >= log2_target_size; --log2_size) {
7981 proof.set_sibling_hash (m_pristine_hashes[log2_size], log2_size);
8082 }
8183 proof.set_target_hash (m_pristine_hashes[log2_target_size]);
84+ if (log2_root_size <= curr_log2_size) {
85+ proof.set_root_hash (m_pristine_hashes[log2_root_size]);
86+ }
8287}
8388
8489static inline uint64_t get_sibling_address (uint64_t address, int log2_size) {
@@ -103,46 +108,65 @@ void hash_tree::get_page_proof(address_range &ar, uint64_t address, proof_type &
103108 update_dirty_page (ar, opt_br->get (), changed);
104109 }
105110 const auto log2_target_size = proof.get_log2_target_size ();
111+ const auto log2_root_size = proof.get_log2_root_size ();
106112 assert (log2_target_size >= HASH_TREE_LOG2_WORD_SIZE && " log2_size is too small" );
107113 const auto &entry = opt_br->get ();
108114 const auto node_offset = address & (HASH_TREE_PAGE_SIZE - 1 );
109- for (int log2_size = HASH_TREE_LOG2_PAGE_SIZE - 1 ; log2_size >= log2_target_size; --log2_size) {
115+ const auto start_log2_size = std::min (log2_root_size, HASH_TREE_LOG2_PAGE_SIZE) - 1 ;
116+ for (int log2_size = start_log2_size; log2_size >= log2_target_size; --log2_size) {
110117 proof.set_sibling_hash (entry.node_hash_view (get_sibling_address (node_offset, log2_size), log2_size), log2_size);
111118 }
112119 proof.set_target_hash (entry.node_hash_view (node_offset, log2_target_size));
120+ if (log2_root_size <= HASH_TREE_LOG2_PAGE_SIZE) {
121+ proof.set_root_hash (entry.node_hash_view (0 , log2_root_size));
122+ }
113123 page_hash_tree_cache::return_entry (*opt_br);
114124}
115125
116126void hash_tree::get_dense_proof (address_range &ar, int ar_log2_size, uint64_t address, proof_type &proof) {
117127 const auto &dht = ar.get_dense_hash_tree ();
118128 const auto log2_target_size = proof.get_log2_target_size ();
129+ const auto log2_root_size = proof.get_log2_root_size ();
119130 const auto dht_end = std::max<int >(HASH_TREE_LOG2_PAGE_SIZE, log2_target_size);
120131 const auto node_offset = address - ar.get_start ();
121- for (int log2_size = ar_log2_size - 1 ; log2_size >= dht_end; --log2_size) {
132+ const auto start_log2_size = std::min (log2_root_size, ar_log2_size) - 1 ;
133+ for (int log2_size = start_log2_size; log2_size >= dht_end; --log2_size) {
122134 const auto sibling_offset = get_sibling_address (node_offset, log2_size);
123135 proof.set_sibling_hash (dht.node_hash_view (sibling_offset, log2_size), log2_size);
124136 }
137+ if (log2_root_size >= HASH_TREE_LOG2_PAGE_SIZE && log2_root_size <= ar_log2_size) {
138+ proof.set_root_hash (dht.node_hash_view (0 , log2_root_size));
139+ }
125140 if (log2_target_size >= HASH_TREE_LOG2_PAGE_SIZE) {
126141 proof.set_target_hash (dht.node_hash_view (node_offset, log2_target_size));
127142 } else {
128143 get_page_proof (ar, address, proof);
129144 }
130145}
131146
132- hash_tree::proof_type hash_tree::get_proof (address_ranges ars, uint64_t address, int log2_size) {
133- if (log2_size < HASH_TREE_LOG2_WORD_SIZE || log2_size > HASH_TREE_LOG2_ROOT_SIZE) {
134- throw std::domain_error{" invalid log2_size" };
147+ hash_tree::proof_type hash_tree::get_proof (address_ranges ars, uint64_t address, int log2_target_size,
148+ int log2_root_size) {
149+ if (log2_root_size < HASH_TREE_LOG2_WORD_SIZE) {
150+ throw std::domain_error{" log2_root_size is too small" };
135151 }
136- if (log2_size == HASH_TREE_LOG2_ROOT_SIZE) {
152+ if (log2_root_size > HASH_TREE_LOG2_ROOT_SIZE) {
153+ throw std::domain_error{" log2_root_size is too large" };
154+ }
155+ if (log2_target_size < HASH_TREE_LOG2_WORD_SIZE) {
156+ throw std::domain_error{" log2_target_size is too small" };
157+ }
158+ if (log2_target_size > log2_root_size) {
159+ throw std::domain_error{" log2_target_size is larger than log2_root_size" };
160+ }
161+ if (log2_target_size == HASH_TREE_LOG2_ROOT_SIZE) {
137162 if (address != 0 ) {
138- throw std::domain_error{" address not aligned to log2_size " };
163+ throw std::domain_error{" address not aligned to log2_target_size " };
139164 }
140- } else if (((address >> log2_size ) << log2_size ) != address) {
141- throw std::domain_error{" address not aligned to log2_size " };
165+ } else if (((address >> log2_target_size ) << log2_target_size ) != address) {
166+ throw std::domain_error{" address not aligned to log2_target_size " };
142167 }
143- proof_type proof{HASH_TREE_LOG2_ROOT_SIZE, log2_size };
168+ proof_type proof{log2_root_size, log2_target_size };
144169 proof.set_target_address (address);
145- proof.set_root_hash (get_root_hash ());
146170 index_type node_index = 1 ;
147171 int curr_log2_size = HASH_TREE_LOG2_ROOT_SIZE;
148172 for (;;) {
@@ -152,9 +176,13 @@ hash_tree::proof_type hash_tree::get_proof(address_ranges ars, uint64_t address,
152176 break ;
153177 }
154178 const auto &node = m_sparse_nodes[node_index];
179+ // found node corresponding to root along the way
180+ if (curr_log2_size == proof.get_log2_root_size ()) {
181+ proof.set_root_hash (node.hash );
182+ }
155183 assert (std::cmp_equal (node.log2_size , curr_log2_size) && " incorrect node log2_size" );
156- // hit sparse tree node
157- if (curr_log2_size == log2_size ) {
184+ // hit target at a sparse tree node
185+ if (curr_log2_size == proof. get_log2_target_size () ) {
158186 proof.set_target_hash (node.hash );
159187 break ;
160188 }
@@ -169,10 +197,14 @@ hash_tree::proof_type hash_tree::get_proof(address_ranges ars, uint64_t address,
169197 // go down left or right on sparse tree depending on address
170198 --curr_log2_size;
171199 if ((address & (UINT64_C (1 ) << curr_log2_size)) == 0 ) {
172- proof.set_sibling_hash (get_sparse_node_hash_view (node.right , curr_log2_size), curr_log2_size);
200+ if (curr_log2_size < log2_root_size && curr_log2_size >= log2_target_size) {
201+ proof.set_sibling_hash (get_sparse_node_hash_view (node.right , curr_log2_size), curr_log2_size);
202+ }
173203 node_index = node.left ;
174204 } else {
175- proof.set_sibling_hash (get_sparse_node_hash_view (node.left , curr_log2_size), curr_log2_size);
205+ if (curr_log2_size < log2_root_size && curr_log2_size >= log2_target_size) {
206+ proof.set_sibling_hash (get_sparse_node_hash_view (node.left , curr_log2_size), curr_log2_size);
207+ }
176208 node_index = node.right ;
177209 }
178210 }
@@ -1061,6 +1093,10 @@ hash_tree::nodes_type hash_tree::create_nodes(const_address_ranges ars) {
10611093}
10621094
10631095// LCOV_EXCL_START
1096+ #if defined(__GNUC__) && __GNUC__ >= 13
1097+ #pragma GCC diagnostic push
1098+ #pragma GCC diagnostic ignored "-Wdangling-reference"
1099+ #endif
10641100void hash_tree::dump (const_address_ranges ars, std::ostream &out) {
10651101 out << " digraph HashTree {\n " ;
10661102 out << " rankdir=TB;\n " ;
@@ -1130,6 +1166,9 @@ void hash_tree::dump(const_address_ranges ars, std::ostream &out) {
11301166 }
11311167 out << " }\n " ;
11321168}
1169+ #if defined(__GNUC__) && __GNUC__ >= 13
1170+ #pragma GCC diagnostic pop
1171+ #endif
11331172// LCOV_EXCL_STOP
11341173
11351174} // namespace cartesi
0 commit comments