Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 91 additions & 11 deletions libs/qec/lib/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaqx::tensor<uint8_t> &)
INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaqx::tensor<uint8_t> &,
const cudaqx::heterogeneous_map &)

// Include decoder implementations AFTER registry instantiation
#include "decoders/sliding_window.h"

namespace cudaq::qec {

struct decoder::rt_impl {
Expand Down Expand Up @@ -50,6 +53,18 @@ struct decoder::rt_impl {

/// The id of the decoder (for instrumentation)
uint32_t decoder_id = 0;

bool is_sliding_window = false;

/// The number of syndromes per round. Only used for sliding window decoder.
size_t num_syndromes_per_round = 0;

/// Whether the first round detectors are included. Only used for sliding
/// window decoder.
bool has_first_round_detectors = false;

/// The current round. Only used for sliding window decoder.
uint32_t current_round = 0;
};

void decoder::rt_impl_deleter::operator()(rt_impl *p) const { delete p; }
Expand Down Expand Up @@ -175,20 +190,41 @@ void decoder::set_decoder_id(uint32_t decoder_id) {

uint32_t decoder::get_decoder_id() const { return pimpl->decoder_id; }

void decoder::set_D_sparse(const std::vector<std::vector<uint32_t>> &D_sparse) {
this->D_sparse = D_sparse;
template <typename PimplType>
void set_D_sparse_common(decoder *decoder,
const std::vector<std::vector<uint32_t>> &D_sparse,
PimplType *pimpl) {
auto *sw_decoder = dynamic_cast<sliding_window *>(decoder);

if (sw_decoder != nullptr) {
pimpl->is_sliding_window = true;
pimpl->num_syndromes_per_round = sw_decoder->get_num_syndromes_per_round();
// Check if first row is a first-round detector (single syndrome index)
pimpl->has_first_round_detectors =
(D_sparse.size() > 0 && D_sparse[0].size() == 1);
pimpl->current_round = 0;
pimpl->persistent_detector_buffer.resize(pimpl->num_syndromes_per_round);
pimpl->persistent_soft_detector_buffer.resize(
pimpl->num_syndromes_per_round);

} else {
pimpl->is_sliding_window = false;
}

pimpl->num_msyn_per_decode = calculate_num_msyn_per_decode(D_sparse);
pimpl->msyn_buffer.clear();
pimpl->msyn_buffer.resize(pimpl->num_msyn_per_decode);
pimpl->msyn_buffer_index = 0;
}

void decoder::set_D_sparse(const std::vector<std::vector<uint32_t>> &D_sparse) {
this->D_sparse = D_sparse;
set_D_sparse_common(this, D_sparse, pimpl.get());
}

void decoder::set_D_sparse(const std::vector<int64_t> &D_sparse_vec_in) {
set_sparse_from_vec(D_sparse_vec_in, this->D_sparse);
pimpl->num_msyn_per_decode = calculate_num_msyn_per_decode(D_sparse);
pimpl->msyn_buffer.clear();
pimpl->msyn_buffer.resize(pimpl->num_msyn_per_decode);
pimpl->msyn_buffer_index = 0;
set_D_sparse_common(this, this->D_sparse, pimpl.get());
}

bool decoder::enqueue_syndrome(const uint8_t *syndrome,
Expand All @@ -198,12 +234,23 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
printf("Syndrome buffer overflow. Syndrome will be ignored.\n");
return false;
}

pimpl->current_round++;
bool did_decode = false;
for (std::size_t i = 0; i < syndrome_length; i++) {
pimpl->msyn_buffer[pimpl->msyn_buffer_index] = syndrome[i];
pimpl->msyn_buffer_index++;
}
if (pimpl->msyn_buffer_index == pimpl->msyn_buffer.size()) {

bool should_decode = false;
if (!pimpl->is_sliding_window) {
should_decode = (pimpl->msyn_buffer_index == pimpl->msyn_buffer.size());
} else {
should_decode =
(pimpl->current_round >= 2) ||
(pimpl->current_round == 1 && pimpl->has_first_round_detectors);
}
if (should_decode) {
// These are just for logging. They are initialized in such a way to avoid
// dynamic memory allocation if logging is disabled.
std::vector<uint32_t> log_msyn;
Expand All @@ -226,11 +273,34 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
}

// Decode now.
for (std::size_t i = 0; i < this->D_sparse.size(); i++) {
pimpl->persistent_detector_buffer[i] = 0;
for (auto col : this->D_sparse[i])
pimpl->persistent_detector_buffer[i] ^= pimpl->msyn_buffer[col];
if (!pimpl->is_sliding_window) {
for (std::size_t i = 0; i < this->D_sparse.size(); i++) {
pimpl->persistent_detector_buffer[i] = 0;
for (auto col : this->D_sparse[i])
pimpl->persistent_detector_buffer[i] ^= pimpl->msyn_buffer[col];
}
} else {
// For sliding window decoder, syndrome_length must equal
// num_syndromes_per_round
assert(syndrome_length == pimpl->num_syndromes_per_round);
if (pimpl->current_round == 1 && pimpl->has_first_round_detectors) {
// First round: only compute first-round detectors (direct copy)
for (std::size_t i = 0; i < pimpl->num_syndromes_per_round; i++) {
pimpl->persistent_detector_buffer[i] = pimpl->msyn_buffer[i];
}
} else {
// Buffer is full with 2 rounds: compute timelike detectors (XOR of two
// rounds)
std::size_t index =
(pimpl->current_round - 2) * pimpl->num_syndromes_per_round;
for (std::size_t i = 0; i < pimpl->num_syndromes_per_round; i++) {
pimpl->persistent_detector_buffer[i] =
pimpl->msyn_buffer[index + i] ^
pimpl->msyn_buffer[index + i + pimpl->num_syndromes_per_round];
}
}
}

if (should_log) {
log_msyn.reserve(pimpl->msyn_buffer.size());
for (std::size_t d = 0, D = pimpl->msyn_buffer.size(); d < D; d++) {
Expand All @@ -249,6 +319,14 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
convert_vec_hard_to_soft(pimpl->persistent_detector_buffer,
pimpl->persistent_soft_detector_buffer);
auto decoded_result = decode(pimpl->persistent_soft_detector_buffer);

// If we didn't get a decoded result, just return
if (pimpl->is_sliding_window) {
if (decoded_result.result.size() == 0) {
return false;
}
}

if (should_log) {
log_t2 = std::chrono::high_resolution_clock::now();
for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++)
Expand Down Expand Up @@ -300,6 +378,7 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
did_decode = true;
// Prepare for more data.
pimpl->msyn_buffer_index = 0;
pimpl->current_round = 0;
}
return did_decode;
}
Expand Down Expand Up @@ -348,6 +427,7 @@ std::size_t decoder::get_num_observables() const { return O_sparse.size(); }
void decoder::reset_decoder() {
// Zero out all data that is considered "per-shot" memory.
pimpl->msyn_buffer_index = 0;
pimpl->current_round = 0;
pimpl->msyn_buffer.clear();
pimpl->msyn_buffer.resize(pimpl->num_msyn_per_decode);
pimpl->corrections.clear();
Expand Down
Loading