|
17 | 17 | #include <memory.h> |
18 | 18 |
|
19 | 19 | #include <algorithm> |
| 20 | +#include <cstddef> |
20 | 21 | #include <functional> |
21 | 22 | #include <span> |
22 | 23 | #include <utility> |
@@ -1181,17 +1182,31 @@ inline vector<Ciphertext> SealPirServer::ExpandQuery( |
1181 | 1182 | inline void SealPirServer::MultiplyPowerOfX(const Ciphertext &encrypted, |
1182 | 1183 | Ciphertext &destination, |
1183 | 1184 | uint32_t index) const { |
1184 | | - int N = enc_params_->poly_modulus_degree(); |
| 1185 | + size_t N = enc_params_->poly_modulus_degree(); |
1185 | 1186 | size_t coeff_mod_cnt = enc_params_->coeff_modulus().size() - 1; |
1186 | 1187 | size_t encrypted_cnt = encrypted.size(); |
1187 | 1188 |
|
1188 | 1189 | destination = encrypted; |
1189 | | - |
| 1190 | + uint32_t actual_index = index % N; |
| 1191 | + // avoid exception in SEAL debug mode |
1190 | 1192 | for (size_t i = 0; i < encrypted_cnt; ++i) { |
1191 | 1193 | for (size_t j = 0; j < coeff_mod_cnt; ++j) { |
1192 | | - seal::util::negacyclic_shift_poly_coeffmod( |
1193 | | - encrypted.data(i) + (j * N), N, index, |
1194 | | - enc_params_->coeff_modulus()[j], destination.data(i) + (j * N)); |
| 1194 | + if (index >= N) { |
| 1195 | + // negate coefficients when index >= N |
| 1196 | + vector<uint64_t> temp_data(N); |
| 1197 | + for (size_t k = 0; k < N; ++k) { |
| 1198 | + temp_data[k] = enc_params_->coeff_modulus()[j].value() - |
| 1199 | + encrypted.data(i)[j * N + k]; |
| 1200 | + } |
| 1201 | + seal::util::negacyclic_shift_poly_coeffmod( |
| 1202 | + temp_data.data(), N, actual_index, enc_params_->coeff_modulus()[j], |
| 1203 | + destination.data(i) + (j * N)); |
| 1204 | + } else { |
| 1205 | + // direct shift when index < N |
| 1206 | + seal::util::negacyclic_shift_poly_coeffmod( |
| 1207 | + encrypted.data(i) + (j * N), N, actual_index, |
| 1208 | + enc_params_->coeff_modulus()[j], destination.data(i) + (j * N)); |
| 1209 | + } |
1195 | 1210 | } |
1196 | 1211 | } |
1197 | 1212 | } |
|
0 commit comments