@@ -243,6 +243,9 @@ class StreamingOneWesolowskiCallback final : public WesolowskiCallback {
243243 }
244244 SetForm (type, data, &checkpoint);
245245 process_checkpoint (pos, checkpoint, /* record_stats=*/ true );
246+ if (iteration >= batch_start_iteration && iteration <= batch_end_iteration) {
247+ current_batch_checkpoints.push_back (BatchCheckpoint{pos, checkpoint});
248+ }
246249 if (stats_enabled) {
247250 checkpoint_event_total_ns += static_cast <uint64_t >(
248251 std::chrono::duration_cast<std::chrono::nanoseconds>(
@@ -258,7 +261,44 @@ class StreamingOneWesolowskiCallback final : public WesolowskiCallback {
258261 }
259262 }
260263
264+ void OnBatchStart (uint64_t base_iteration, uint64_t batch_size) override {
265+ current_batch_checkpoints.clear ();
266+ if (batch_size == 0 ) {
267+ batch_start_iteration = 1 ;
268+ batch_end_iteration = 0 ;
269+ return ;
270+ }
271+ batch_start_iteration = base_iteration + 1 ;
272+ if (std::numeric_limits<uint64_t >::max () - base_iteration < batch_size) {
273+ batch_end_iteration = std::numeric_limits<uint64_t >::max ();
274+ } else {
275+ batch_end_iteration = base_iteration + batch_size;
276+ }
277+ }
278+
279+ void OnBatchReplay (uint64_t base_iteration, uint64_t batch_size) override {
280+ for (const BatchCheckpoint& entry : current_batch_checkpoints) {
281+ rollback_checkpoint (entry.index , entry.checkpoint );
282+ }
283+ OnBatchStart (base_iteration, batch_size);
284+ }
285+
261286 void process_checkpoint (uint64_t i, const form& checkpoint, bool record_stats) {
287+ apply_checkpoint (i, checkpoint, record_stats);
288+ }
289+
290+ private:
291+ struct BatchCheckpoint {
292+ uint64_t index;
293+ form checkpoint;
294+ };
295+
296+ void rollback_checkpoint (uint64_t i, const form& checkpoint) {
297+ form inverse_checkpoint = checkpoint.inverse ();
298+ apply_checkpoint (i, inverse_checkpoint, /* record_stats=*/ false );
299+ }
300+
301+ void apply_checkpoint (uint64_t i, const form& checkpoint, bool record_stats) {
262302 const bool do_stats = stats_enabled && record_stats;
263303 auto started_at = std::chrono::steady_clock::time_point{};
264304 if (do_stats) {
@@ -359,7 +399,6 @@ class StreamingOneWesolowskiCallback final : public WesolowskiCallback {
359399 return out;
360400 }
361401
362- private:
363402 form& bucket (uint32_t j, uint64_t b) {
364403 size_t idx = static_cast <size_t >(j) * (1ULL << k) + static_cast <size_t >(b);
365404 return buckets[idx];
@@ -391,6 +430,9 @@ class StreamingOneWesolowskiCallback final : public WesolowskiCallback {
391430 integer getblock_inv_2k;
392431 integer getblock_r;
393432 integer getblock_tmp;
433+ uint64_t batch_start_iteration = 1 ;
434+ uint64_t batch_end_iteration = 0 ;
435+ std::vector<BatchCheckpoint> current_batch_checkpoints;
394436
395437 bool stats_enabled;
396438 uint64_t checkpoint_total_ns = 0 ;
0 commit comments