Skip to content

Commit 073427a

Browse files
Improve fast-path batch replay handling and harden pairindex slot allocation.
Track and roll back per-batch checkpoints when replaying a failed fast batch, and switch pairindex slot allocation to unsigned atomics to avoid negative modulo indexing after counter wraparound. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 3755be2 commit 073427a

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

src/c_bindings/fast_wrapper.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ class StreamingOneWesolowskiCallback final : public WesolowskiCallback {
243243
}
244244
SetForm(type, data, &checkpoint);
245245
process_checkpoint(pos, checkpoint, /*record_stats=*/true);
246+
if (iteration >= batch_start_iteration && iteration <= batch_end_iteration) {
247+
current_batch_checkpoints.push_back(BatchCheckpoint{pos, checkpoint});
248+
}
246249
if (stats_enabled) {
247250
checkpoint_event_total_ns += static_cast<uint64_t>(
248251
std::chrono::duration_cast<std::chrono::nanoseconds>(
@@ -258,7 +261,44 @@ class StreamingOneWesolowskiCallback final : public WesolowskiCallback {
258261
}
259262
}
260263

264+
void OnBatchStart(uint64_t base_iteration, uint64_t batch_size) override {
265+
current_batch_checkpoints.clear();
266+
if (batch_size == 0) {
267+
batch_start_iteration = 1;
268+
batch_end_iteration = 0;
269+
return;
270+
}
271+
batch_start_iteration = base_iteration + 1;
272+
if (std::numeric_limits<uint64_t>::max() - base_iteration < batch_size) {
273+
batch_end_iteration = std::numeric_limits<uint64_t>::max();
274+
} else {
275+
batch_end_iteration = base_iteration + batch_size;
276+
}
277+
}
278+
279+
void OnBatchReplay(uint64_t base_iteration, uint64_t batch_size) override {
280+
for (const BatchCheckpoint& entry : current_batch_checkpoints) {
281+
rollback_checkpoint(entry.index, entry.checkpoint);
282+
}
283+
OnBatchStart(base_iteration, batch_size);
284+
}
285+
261286
void process_checkpoint(uint64_t i, const form& checkpoint, bool record_stats) {
287+
apply_checkpoint(i, checkpoint, record_stats);
288+
}
289+
290+
private:
291+
struct BatchCheckpoint {
292+
uint64_t index;
293+
form checkpoint;
294+
};
295+
296+
void rollback_checkpoint(uint64_t i, const form& checkpoint) {
297+
form inverse_checkpoint = checkpoint.inverse();
298+
apply_checkpoint(i, inverse_checkpoint, /*record_stats=*/false);
299+
}
300+
301+
void apply_checkpoint(uint64_t i, const form& checkpoint, bool record_stats) {
262302
const bool do_stats = stats_enabled && record_stats;
263303
auto started_at = std::chrono::steady_clock::time_point{};
264304
if (do_stats) {
@@ -359,7 +399,6 @@ class StreamingOneWesolowskiCallback final : public WesolowskiCallback {
359399
return out;
360400
}
361401

362-
private:
363402
form& bucket(uint32_t j, uint64_t b) {
364403
size_t idx = static_cast<size_t>(j) * (1ULL << k) + static_cast<size_t>(b);
365404
return buckets[idx];
@@ -391,6 +430,9 @@ class StreamingOneWesolowskiCallback final : public WesolowskiCallback {
391430
integer getblock_inv_2k;
392431
integer getblock_r;
393432
integer getblock_tmp;
433+
uint64_t batch_start_iteration = 1;
434+
uint64_t batch_end_iteration = 0;
435+
std::vector<BatchCheckpoint> current_batch_checkpoints;
394436

395437
bool stats_enabled;
396438
uint64_t checkpoint_total_ns = 0;

src/callback.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ class WesolowskiCallback :public INUDUPLListener {
7373
}
7474

7575
virtual void OnIteration(int type, void *data, uint64_t iteration) = 0;
76+
virtual void OnBatchStart(uint64_t base_iteration, uint64_t batch_size) {
77+
(void)base_iteration;
78+
(void)batch_size;
79+
}
80+
virtual void OnBatchReplay(uint64_t base_iteration, uint64_t batch_size) {
81+
(void)base_iteration;
82+
(void)batch_size;
83+
}
7684

7785
std::unique_ptr<form[]> forms;
7886
size_t forms_capacity = 0;

src/vdf.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ bool quiet_mode = false;
9595
// run concurrently in the same process; they must not share a pairindex.
9696
inline int vdf_fast_pairindex() {
9797
#if (defined(ARCH_X86) || defined(ARCH_X64)) && !defined(CHIA_DISABLE_ASM)
98-
constexpr int kSlots = int(sizeof(master_counter) / sizeof(master_counter[0]));
99-
static std::atomic<int> next_slot{0};
100-
thread_local int slot = next_slot.fetch_add(1, std::memory_order_relaxed) % kSlots;
98+
constexpr unsigned int kSlots = unsigned(sizeof(master_counter) / sizeof(master_counter[0]));
99+
static std::atomic<unsigned int> next_slot{0};
100+
thread_local int slot = int(next_slot.fetch_add(1u, std::memory_order_relaxed) % kSlots);
101101
return slot;
102102
#else
103103
return 0;
@@ -201,6 +201,9 @@ void repeated_square(uint64_t iterations, form f, const integer& D, const intege
201201
#endif
202202

203203
uint64 batch_size=c_checkpoint_interval;
204+
if (weso != NULL) {
205+
weso->OnBatchStart(num_iterations, batch_size);
206+
}
204207

205208
#ifdef ENABLE_TRACK_CYCLES
206209
print( "track cycles enabled; results will be wrong" );
@@ -231,6 +234,9 @@ void repeated_square(uint64_t iterations, form f, const integer& D, const intege
231234

232235
if (actual_iterations==~uint64(0)) {
233236
//corruption; f is unchanged. do the entire batch with the slow algorithm
237+
if (weso != NULL) {
238+
weso->OnBatchReplay(num_iterations, batch_size);
239+
}
234240
repeated_square_original(*weso->vdfo, f, D, L, num_iterations, batch_size, weso);
235241
actual_iterations=batch_size;
236242

0 commit comments

Comments
 (0)