@@ -199,7 +199,6 @@ template <class mint, internal::is_static_modint_t<mint>* = nullptr>
199
199
std::vector<mint> convolution_fft (std::vector<mint> a, std::vector<mint> b) {
200
200
int n = int (a.size ()), m = int (b.size ());
201
201
int z = (int )internal::bit_ceil ((unsigned int )(n + m - 1 ));
202
- assert (mint::mod () % z == 1 );
203
202
a.resize (z);
204
203
internal::butterfly (a);
205
204
b.resize (z);
@@ -220,6 +219,10 @@ template <class mint, internal::is_static_modint_t<mint>* = nullptr>
220
219
std::vector<mint> convolution (std::vector<mint>&& a, std::vector<mint>&& b) {
221
220
int n = int (a.size ()), m = int (b.size ());
222
221
if (!n || !m) return {};
222
+
223
+ int z = (int )internal::bit_ceil ((unsigned int )(n + m - 1 ));
224
+ assert (mint::mod () % z == 1 );
225
+
223
226
if (std::min (n, m) <= 60 ) return convolution_naive (a, b);
224
227
return internal::convolution_fft (a, b);
225
228
}
@@ -229,6 +232,10 @@ std::vector<mint> convolution(const std::vector<mint>& a,
229
232
const std::vector<mint>& b) {
230
233
int n = int (a.size ()), m = int (b.size ());
231
234
if (!n || !m) return {};
235
+
236
+ int z = (int )internal::bit_ceil ((unsigned int )(n + m - 1 ));
237
+ assert (mint::mod () % z == 1 );
238
+
232
239
if (std::min (n, m) <= 60 ) return convolution_naive (a, b);
233
240
return internal::convolution_fft (a, b);
234
241
}
@@ -241,6 +248,10 @@ std::vector<T> convolution(const std::vector<T>& a, const std::vector<T>& b) {
241
248
if (!n || !m) return {};
242
249
243
250
using mint = static_modint<mod>;
251
+
252
+ int z = (int )internal::bit_ceil ((unsigned int )(n + m - 1 ));
253
+ assert (mint::mod () % z == 1 );
254
+
244
255
std::vector<mint> a2 (n), b2 (m);
245
256
for (int i = 0 ; i < n; i++) {
246
257
a2[i] = mint (a[i]);
@@ -280,7 +291,7 @@ std::vector<long long> convolution_ll(const std::vector<long long>& a,
280
291
static_assert (MOD1 % (1ull << MAX_AB_BIT) == 1 , " MOD1 isn't enough to support an array length of 2^24." );
281
292
static_assert (MOD2 % (1ull << MAX_AB_BIT) == 1 , " MOD2 isn't enough to support an array length of 2^24." );
282
293
static_assert (MOD3 % (1ull << MAX_AB_BIT) == 1 , " MOD3 isn't enough to support an array length of 2^24." );
283
- assert (a. size () + b. size () - 1 <= (1ull << MAX_AB_BIT));
294
+ assert (n + m - 1 <= (1 << MAX_AB_BIT));
284
295
285
296
auto c1 = convolution<MOD1>(a, b);
286
297
auto c2 = convolution<MOD2>(a, b);
0 commit comments