diff --git a/runtime/Makefile b/runtime/Makefile index fe4480f00d..8b8b0b7dce 100644 --- a/runtime/Makefile +++ b/runtime/Makefile @@ -40,8 +40,8 @@ ifeq ($(ENABLE_ASAN), ON) ASAN_COMMAND = $(ASAN_FLAGS) endif -BUILD_TARGETS := rt_capi rtd_null_qubit rt_rsdecomp -TEST_TARGETS := runner_tests_qir_runtime runner_tests_mbqc_runtime runner_tests_rsdecomp_runtime +BUILD_TARGETS := rt_capi rtd_null_qubit rt_rsdecomp rt_decoder +TEST_TARGETS := runner_tests_qir_runtime runner_tests_mbqc_runtime runner_tests_rsdecomp_runtime runner_tests_decoder_runtime ifeq ($(ENABLE_OPENQASM), ON) BUILD_TARGETS += rtd_openqasm diff --git a/runtime/lib/CMakeLists.txt b/runtime/lib/CMakeLists.txt index 0eae794f8e..bd23437220 100644 --- a/runtime/lib/CMakeLists.txt +++ b/runtime/lib/CMakeLists.txt @@ -45,6 +45,7 @@ add_subdirectory(capi) add_subdirectory(backend) add_subdirectory(registry) add_subdirectory(RSDecompRuntime) +add_subdirectory(Decoder) if(ENABLE_OQD) add_subdirectory(OQDcapi) diff --git a/runtime/lib/Decoder/CMakeLists.txt b/runtime/lib/Decoder/CMakeLists.txt new file mode 100644 index 0000000000..955e50ec4a --- /dev/null +++ b/runtime/lib/Decoder/CMakeLists.txt @@ -0,0 +1,12 @@ +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +add_library(rt_decoder SHARED LUTDecoder.cpp) + +target_include_directories(rt_decoder + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../../include +) + +set_property(TARGET rt_decoder PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/runtime/lib/Decoder/LUTDecoder.cpp b/runtime/lib/Decoder/LUTDecoder.cpp new file mode 100644 index 0000000000..1b72aa5054 --- /dev/null +++ b/runtime/lib/Decoder/LUTDecoder.cpp @@ -0,0 +1,75 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "LUTDecoder.hpp" + +#include +#include +#include +#include + +#include "DataView.hpp" +#include "LUTDecoderUtils.hpp" + +namespace Catalyst::Runtime::QEC { +/** + * @brief A runtime lookup table based decoder. + * + * NOTE: As CAPI does not support setting default values for args, as discussed, we hardcode the + * required args in the beginning of the function body. Those values are specifically for the [[7, + * 1, 3]] Steane code. We expect those values are from args inputs later. + * @param row_idx_tanner Pointer to the row_idx data of a Tanner graph. + * @param col_ptr_tanner Pointer to the col_ptr data of a Tanner graph. + * @param syndrome_results Pointer to the syndrome measurement data. + * @param err_idx Pointer to the error qubit indices data. + */ +void __catalyst__qecp__lut_decoder(MemRefT_int64_1d *row_idx_tanner, + MemRefT_int64_1d *col_ptr_tanner, + MemRefT_int8_1d *current_syndromes, MemRefT_int64_1d *err_idx) +{ + // TODOs: We should expect the following const value from args. + // The default values here only work for the [[7, 1, 3]] Steane code. + const size_t code_size = 7; + const size_t code_distance = 3; + // The following parameter depends on the design choice of tanner graph would + // change. + const size_t aux_col_offset = 7; + + DataView row_idx(row_idx_tanner->data_aligned, row_idx_tanner->offset, + row_idx_tanner->sizes, row_idx_tanner->strides); + DataView col_ptr(col_ptr_tanner->data_aligned, col_ptr_tanner->offset, + col_ptr_tanner->sizes, col_ptr_tanner->strides); + + auto current_lut = + LUTs::getInstance().get_lut(aux_col_offset, code_size, code_distance, row_idx, col_ptr); + + DataView syndromes_res(current_syndromes->data_aligned, current_syndromes->offset, + current_syndromes->sizes, current_syndromes->strides); + + auto syndrome_str = convert_syndrome_res_to_bitstr(syndromes_res); + + std::vector error_indices = current_lut[syndrome_str]; + + // We use `-1` to full fill the err_idx array if the number of + // errors is less than (code_distance - 1)/2 + for (size_t i = 0; i < (code_distance - 1) / 2; i++) { + if (i < error_indices.size()) { + err_idx->data_allocated[i] = error_indices[i]; + } + else { + err_idx->data_allocated[i] = -1; + } + } +} +} // namespace Catalyst::Runtime::QEC diff --git a/runtime/lib/Decoder/LUTDecoder.hpp b/runtime/lib/Decoder/LUTDecoder.hpp new file mode 100644 index 0000000000..a77ce2e3e7 --- /dev/null +++ b/runtime/lib/Decoder/LUTDecoder.hpp @@ -0,0 +1,29 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "Types.h" + +namespace Catalyst::Runtime::QEC { + +extern "C" { + +void __catalyst__qecp__lut_decoder(/*row_idx*/ MemRefT_int64_1d *row_idx_tanner, + /*col_ptr*/ MemRefT_int64_1d *col_ptr_tanner, + /*syndrome*/ MemRefT_int8_1d *syndrome_res, + /*err_idx*/ MemRefT_int64_1d *err_idx); + +} // extern "C" +} // namespace Catalyst::Runtime::QEC diff --git a/runtime/lib/Decoder/LUTDecoderUtils.hpp b/runtime/lib/Decoder/LUTDecoderUtils.hpp new file mode 100644 index 0000000000..2562998a21 --- /dev/null +++ b/runtime/lib/Decoder/LUTDecoderUtils.hpp @@ -0,0 +1,255 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include "DataView.hpp" +#include "Exception.hpp" +#include "Types.h" + +namespace Catalyst::Runtime::QEC { +/** + * @brief Convert a vector of syndrome results to a bit string representation. + * + * @tparam IntegerType + * @param syndrome_res A dataview of syndrome results. + * @return std::string A bit string representation of the given syndrome results. + */ +template +std::string convert_syndrome_res_to_bitstr(DataView &syndrome_res) +{ + std::string syndrom_str; + for (const auto &bit : syndrome_res) { + RT_ASSERT(bit == 0 || bit == 1) + syndrom_str += (bit ? '1' : '0'); + } + + // Return results + return syndrom_str; +} + +/** + * @brief Get a parity check matrix from Tanner graph data. + * + * NOTE: With the current design of Tanner graph operation in MLIR, the first $n$ or $code_size$ + * columns represent the physical data qubits, while the last $n-1$ columns represent auxillary + * qubits (more details here + * https://github.com/PennyLaneAI/catalyst/blob/ab97f982539b31ab802a63020292595476f22d15/mlir/include/QecPhysical/IR/QecPhysicalTypes.td). + * + * @param tanner_row_idx The dataview of row indices of non-zero elements with a length of $nnz$. + * @param tanner_col_ptr The column offsets dataview of a length number of $num_col + 1$ that + * represents the starting position of each column. + * @param aux_cols A vector of column indices for the corresponding type of auxillary qubits. + * + * @return std::pair, std::vector> The corresponding parity check + * matrix in the CSS format. Each column represents an auxillary qubit. + */ + +std::pair, std::vector> +get_parity_check_matrix(DataView &tanner_row_idx, DataView &tanner_col_ptr, + const std::vector &aux_cols) +{ + std::vector row_idx_parity; + std::vector col_ptr_parity{0}; + + for (const auto &col : aux_cols) { + auto offset_start = tanner_col_ptr(col); + auto offset_end = tanner_col_ptr(col + 1); + + for (int i = offset_start; i < offset_end; i++) { + row_idx_parity.push_back(tanner_row_idx(i)); + } + size_t new_offset = col_ptr_parity.back() + offset_end - offset_start; + col_ptr_parity.push_back(new_offset); + } + return {row_idx_parity, col_ptr_parity}; +} + +/** + * @brief Get the bit representation of a syndrome from errors object + * The syndrome $s$ is calculated using a CSC parity check matrix $H$ and the + * error vector $e$ according to the linear relation: + * + * $$s = He \pmod 2$$ + * + * @param row_idx The row_idx vector of $H$. + * @param col_ptr The col_ptr vector of $H$. + * @param num_rows Number of rows of $H$. + * @param num_cols Number of columns of $H$. + * @param err_vec A vector of qubit errors. + * @return std::string The syndrome string corresponds to the err_vec. + */ +std::string get_syndrome_from_errors(const std::vector &row_idx, + const std::vector &col_ptr, const size_t num_rows, + const size_t num_cols, std::vector &err_vec) +{ + + std::vector syndrome_res(num_cols, 0); + + for (size_t col = 0; col < num_cols; col++) { + for (size_t idx = col_ptr[col]; idx < col_ptr[col + 1]; idx++) { + size_t row = row_idx[idx]; + syndrome_res[col] += err_vec[row]; + } + syndrome_res[col] = syndrome_res[col] % 2; + } + DataView syndrome_res_data_view(syndrome_res); + + return convert_syndrome_res_to_bitstr(syndrome_res_data_view); +} + +/** + * @brief Get the error indices of a vector of qubit errors. + * + * @param err_vec A vector of qubit errors. + * @return std::vector Indices of qubit errors. + */ +std::vector get_error_indices(std::vector &err_vec) +{ + std::vector error_indices; + + error_indices.reserve(err_vec.size()); + + for (size_t i = 0; i < err_vec.size(); ++i) { + if (err_vec[i] != 0) { + error_indices.push_back(i); + } + } + + return error_indices; +} + +/** + * @brief Generates a look up table with a CSC parity check matrix $H$ and QEC code information. + * + * NOTE: Note that this function has a combinatorial time complexity of $O(n^k)$, where $n$ + * represents the number of data qubits and $k$ represents the maximum error weight. Consequently, + * it is computationally intractable for large-scale codes. + * + * @param parity_mat_row_idx The row vector of length nnz that contains row indices of the + * corresponding elements. Each column corresponds to an auxillary qubit. + * @param parity_mat_col_ptr The column offsets vector of length number of num_col + 1 that + * represents the starting position of each row. + * @param code_size The number of data qubits in the QEC code. This param is for safe guard only + * purpose. + * @param code_distance The code distance, which represents the number of quantum errors can be + * corrected. + * @return std::unordered_map>& The result lookup table. + */ +std::unordered_map> +generate_lookup_table(const std::vector &parity_mat_row_idx, + const std::vector &parity_mat_col_ptr, const size_t code_size, + const size_t code_distance) +{ + // The key here is the bitstr representation of the syndrome results, e.g., "0101" + // The value is the corresponding indices of qubits to correct, e.g., {0, 2}. + std::unordered_map> lut; + + const size_t nnz = parity_mat_row_idx.size(); + const size_t num_aux_qubits = + parity_mat_col_ptr.size() - 1; // number of cols or number of auxillary qubits + const size_t num_data_qubits = + *std::max_element(parity_mat_row_idx.begin(), parity_mat_row_idx.end()) + + 1; // number of rows or number of data qubits + + RT_ASSERT(num_aux_qubits == (code_size - 1) >> 1); + RT_ASSERT(nnz > 0); + RT_ASSERT(num_data_qubits == code_size); + + // Get number of errors can be detected from code distance + const size_t num_errors = (code_distance - 1) / 2; + + // Traverse all possible quantum error combinations + for (int i = 0; i <= num_errors; i++) { + // create a base error vector + std::vector err_vector(num_data_qubits, 0); + std::fill(err_vector.end() - i, err_vector.end(), 1); + + do { + std::string syndrome_str = + get_syndrome_from_errors(parity_mat_row_idx, parity_mat_col_ptr, num_data_qubits, + num_aux_qubits, err_vector); + std::vector error_indices = get_error_indices(err_vector); + // We assume that 1:1 mapping for the syndrome and err_vector + lut[syndrome_str] = error_indices; + } while (std::next_permutation(err_vector.begin(), err_vector.end())); + } + + return lut; +} + +class LUTs final { + private: + std::unordered_map>> luts_; + + mutable std::mutex mutex_; + + explicit LUTs() = default; + + public: + LUTs(const LUTs &) = delete; + LUTs &operator=(const LUTs &) = delete; + LUTs(LUTs &&) = delete; + LUTs &operator=(LUTs &&) = delete; + + static auto getInstance() -> LUTs & + { + static LUTs instance; + return instance; + } + + /** + * @brief Get a lookup table. + * + * @param aux_col_offset The offset of the first X-check or Z-check column in a Tanner graph. + * @param code_size Number of data qubits in a QEC code. + * @param code_distance Code distance of a QEC code. + * @param row_idx Dataview of the row_idx of a Tanner graph. + * @param col_ptr Dataview of the col_ptr of a Tanner graph. + * @return const std::unordered_map>& The corresponding lookup + * table. + */ + auto get_lut(size_t aux_col_offset, size_t code_size, size_t code_distance, + DataView &row_idx, DataView &col_ptr) + -> const std::unordered_map> & + { + std::lock_guard lock(mutex_); + + auto it = luts_.find(aux_col_offset); + + if (it == luts_.end()) { + std::vector aux_cols((code_size - 1) / 2); + std::iota(aux_cols.begin(), aux_cols.end(), aux_col_offset); + + auto csc_parity_matrix = get_parity_check_matrix(row_idx, col_ptr, aux_cols); + + auto lut = generate_lookup_table(csc_parity_matrix.first, csc_parity_matrix.second, + code_size, code_distance); + + luts_[aux_col_offset] = std::move(lut); + return luts_[aux_col_offset]; + } + + return it->second; + } +}; + +} // namespace Catalyst::Runtime::QEC diff --git a/runtime/tests/CMakeLists.txt b/runtime/tests/CMakeLists.txt index b6323d9007..bd9f9ba4d0 100644 --- a/runtime/tests/CMakeLists.txt +++ b/runtime/tests/CMakeLists.txt @@ -134,6 +134,20 @@ target_link_libraries(runner_tests_mbqc_runtime PRIVATE catch_discover_tests(runner_tests_mbqc_runtime) +# Decoder test suite +add_executable(runner_tests_decoder_runtime) +target_sources(runner_tests_decoder_runtime PRIVATE + Test_DecoderLUTDecoderUtils.cpp + Test_DecoderLUTDecoder.cpp +) +target_link_libraries(runner_tests_decoder_runtime PRIVATE + Catch2WithMain + catalyst_runtime_testing + rt_decoder +) + +catch_discover_tests(runner_tests_decoder_runtime) + # RS Decomposition test suite add_executable(runner_tests_rsdecomp_runtime) target_sources(runner_tests_rsdecomp_runtime PRIVATE diff --git a/runtime/tests/TestUtils.hpp b/runtime/tests/TestUtils.hpp index 171a34a743..e867c8e98d 100644 --- a/runtime/tests/TestUtils.hpp +++ b/runtime/tests/TestUtils.hpp @@ -19,6 +19,7 @@ #pragma once #include +#include #include "ExecutionContext.hpp" #include "QuantumDevice.hpp" @@ -47,3 +48,47 @@ static inline Catalyst::Runtime::QuantumDevice *loadDevice(const std::string &de : nullptr; // LCOV_EXCL_STOP } + +template struct tanner_graph_steane { + /* Tanner graph representation for the [[7, 1, 3]] Steane code + The shape of dense matrix that the [[7, 1, 3]] Steane code is (10, 10). + The first 7 columns represent data qubits, while the last 3 columns + represent auxillary qubits. The full dense matrix is: + | 0 0 0 0 0 0 0 1 0 0| + | 0 0 0 0 0 0 0 1 1 0| + | 0 0 0 0 0 0 0 1 1 1| + | 0 0 0 0 0 0 0 1 0 1| + | 0 0 0 0 0 0 0 0 1 0| + | 0 0 0 0 0 0 0 0 1 1| + | 0 0 0 0 0 0 0 0 0 1| + | 1 1 1 1 0 0 0 0 0 0| + | 0 1 1 0 1 1 0 0 0 0| + | 0 0 1 1 0 1 1 0 0 0| + */ + size_t code_size = 7; + size_t code_distance = 3; + std::vector row_idx = {7, 7, 8, 7, 8, 9, 7, 9, 8, 8, 9, 9, + 0, 1, 2, 3, 1, 2, 4, 5, 2, 3, 5, 6}; + std::vector col_ptr = {0, 1, 3, 6, 8, 9, 11, 12, 16, 20, 24}; + + std::vector row_idx_parity_matrix_transpose = {0, 1, 2, 3, 1, 2, 4, 5, 2, 3, 5, 6}; + std::vector col_ptr_parity_matrix_transpose = {0, 4, 8, 12}; + + std::unordered_map> lookup_table_syndrome_to_error = { + {"000", std::vector({0, 0, 0, 0, 0, 0, 0})}, + {"001", std::vector({0, 0, 0, 0, 0, 0, 1})}, + {"010", std::vector({0, 0, 0, 0, 1, 0, 0})}, + {"011", std::vector({0, 0, 0, 0, 0, 1, 0})}, + {"100", std::vector({1, 0, 0, 0, 0, 0, 0})}, + {"101", std::vector({0, 0, 0, 1, 0, 0, 0})}, + {"110", std::vector({0, 1, 0, 0, 0, 0, 0})}, + {"111", std::vector({0, 0, 1, 0, 0, 0, 0})}, + }; + + std::unordered_map> lookup_table_error_idx_to_syndrome = { + {-1, std::vector({0, 0, 0})}, {6, std::vector({0, 0, 1})}, + {4, std::vector({0, 1, 0})}, {5, std::vector({0, 1, 1})}, + {0, std::vector({1, 0, 0})}, {3, std::vector({1, 0, 1})}, + {1, std::vector({1, 1, 0})}, {2, std::vector({1, 1, 1})}, + }; +}; diff --git a/runtime/tests/Test_DecoderLUTDecoder.cpp b/runtime/tests/Test_DecoderLUTDecoder.cpp new file mode 100644 index 0000000000..24b86bf801 --- /dev/null +++ b/runtime/tests/Test_DecoderLUTDecoder.cpp @@ -0,0 +1,69 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "catch2/catch_test_macros.hpp" +#include "catch2/matchers/catch_matchers_string.hpp" + +#include "LUTDecoder.hpp" +#include "TestUtils.hpp" +#include "Types.h" + +using namespace Catalyst::Runtime::QEC; + +TEST_CASE("Test C-API Wrapper (Memref Interface)", "[LUTDecoder][lut_decoder]") +{ + tanner_graph_steane tanner_graph; + + std::vector row_idx_tanner = tanner_graph.row_idx; + std::vector col_ptr_tanner = tanner_graph.col_ptr; + std::vector err_idx = std::vector((tanner_graph.code_distance - 1) / 2, -1); + + int64_t *buffer_row_idx_tanner_memref = row_idx_tanner.data(); + int64_t *buffer_col_ptr_tanner_memref = col_ptr_tanner.data(); + int64_t *buffer_err_idx_memref = err_idx.data(); + + MemRefT_int64_1d row_idx_tanner_memref = {buffer_row_idx_tanner_memref, + buffer_row_idx_tanner_memref, + 0, + {row_idx_tanner.size()}, + {1}}; + MemRefT_int64_1d col_ptr_tanner_memref = {buffer_col_ptr_tanner_memref, + buffer_col_ptr_tanner_memref, + 0, + {col_ptr_tanner.size()}, + {1}}; + MemRefT_int64_1d err_idx_memref = { + buffer_err_idx_memref, buffer_err_idx_memref, 0, {err_idx.size()}, {1} + }; + + for (auto it = tanner_graph.lookup_table_error_idx_to_syndrome.begin(); + it != tanner_graph.lookup_table_error_idx_to_syndrome.end(); ++it) { + int64_t expected_res = it->first; + + auto syndrome_res = it->second; + + int8_t *buffer_syndrome_res_memref = syndrome_res.data(); + MemRefT_int8_1d syndrome_res_memref = { + buffer_syndrome_res_memref, buffer_syndrome_res_memref, 0, {syndrome_res.size()}, {1}}; + + __catalyst__qecp__lut_decoder(&row_idx_tanner_memref, &col_ptr_tanner_memref, + &syndrome_res_memref, &err_idx_memref); + + REQUIRE(err_idx_memref.data_allocated[0] == expected_res); + } +} diff --git a/runtime/tests/Test_DecoderLUTDecoderUtils.cpp b/runtime/tests/Test_DecoderLUTDecoderUtils.cpp new file mode 100644 index 0000000000..a037f7a405 --- /dev/null +++ b/runtime/tests/Test_DecoderLUTDecoderUtils.cpp @@ -0,0 +1,113 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "catch2/catch_test_macros.hpp" +#include "catch2/matchers/catch_matchers_string.hpp" + +#include "DataView.hpp" +#include "LUTDecoderUtils.hpp" +#include "TestUtils.hpp" + +using namespace Catalyst::Runtime::QEC; + +TEST_CASE("Test convert_sydrome_res_to_bitstr", "[LUTDecoderUtils::convert_syndrome_res_to_bitstr]") +{ + std::vector bad_syndrome_inputs = {1, 2, 3}; + DataView bad_syndrome_inputs_dv(bad_syndrome_inputs); + REQUIRE_THROWS_WITH(convert_syndrome_res_to_bitstr(bad_syndrome_inputs_dv), + Catch::Matchers::ContainsSubstring("Assertion: bit == 0 || bit == 1")); + + std::vector syndromes_size_t = {0, 1, 0}; + std::vector syndromes_int8_t = {0, 1, 0}; + DataView syndromes_size_t_dv(syndromes_size_t); + DataView syndromes_int8_t_dv(syndromes_int8_t); + std::string expected_syndrome_str = "010"; + std::string syndrome_str_size_t = convert_syndrome_res_to_bitstr(syndromes_size_t_dv); + std::string syndrome_str_int8_t = convert_syndrome_res_to_bitstr(syndromes_int8_t_dv); + + REQUIRE(syndrome_str_size_t == expected_syndrome_str); + REQUIRE(syndrome_str_int8_t == expected_syndrome_str); +} + +TEST_CASE("Test get_error_indices", "[LUTDecoderUtils::get_error_indices]") +{ + std::vector error_vector = {0, 1, 0, 1, 0, 0, 0}; + std::vector expected_indices = {1, 3}; + + auto error_indices = get_error_indices(error_vector); + + REQUIRE(error_indices == expected_indices); +} + +TEST_CASE("Test get_parity_check_matrix", "[LUTDecoderUtils::get_parity_check_matrix]") +{ + tanner_graph_steane tanner_graph; + + std::vector aux_cols = {7, 8, 9}; + + DataView row_idx(tanner_graph.row_idx); + DataView col_ptr(tanner_graph.col_ptr); + + auto parity_mat_csc = get_parity_check_matrix(row_idx, col_ptr, aux_cols); + + REQUIRE(parity_mat_csc.first == tanner_graph.row_idx_parity_matrix_transpose); + REQUIRE(parity_mat_csc.second == tanner_graph.col_ptr_parity_matrix_transpose); +} + +TEST_CASE("Test get_syndrome_from_errors", "[LUTDecoderUtils::get_syndrome_from_errors]") +{ + tanner_graph_steane tanner_graph; + + std::vector aux_cols = {7, 8, 9}; + + DataView row_idx(tanner_graph.row_idx); + DataView col_ptr(tanner_graph.col_ptr); + + auto parity_mat_csc = get_parity_check_matrix(row_idx, col_ptr, aux_cols); + + const size_t num_data_qubits = tanner_graph.code_size; + const size_t num_aux_qubits = 3; + for (auto it = tanner_graph.lookup_table_syndrome_to_error.begin(); + it != tanner_graph.lookup_table_syndrome_to_error.end(); ++it) { + auto err_vec = it->second; + std::string expected_str = it->first; + std::string syndrome_bitstr = get_syndrome_from_errors( + parity_mat_csc.first, parity_mat_csc.second, num_data_qubits, num_aux_qubits, err_vec); + REQUIRE(syndrome_bitstr == expected_str); + } +} + +TEST_CASE("Test generate_lookup_table", "[LUTDecoderUtils::generate_lookup_table]") +{ + tanner_graph_steane tanner_graph; + auto lut = generate_lookup_table(tanner_graph.row_idx_parity_matrix_transpose, + tanner_graph.col_ptr_parity_matrix_transpose, + tanner_graph.code_size, tanner_graph.code_distance); + + std::unordered_map> expected_lut = { + {"000", std::vector({})}, {"001", std::vector({6})}, + {"010", std::vector({4})}, {"011", std::vector({5})}, + {"100", std::vector({0})}, {"101", std::vector({3})}, + {"110", std::vector({1})}, {"111", std::vector({2})}, + }; + + for (auto it = expected_lut.begin(); it != expected_lut.end(); ++it) { + auto key = it->first; + REQUIRE(lut[key] == it->second); + } +}