@@ -44,23 +44,19 @@ def _check_norm(norm):
44
44
)
45
45
46
46
47
- def _check_shapes_for_direct (xs , shape , axes ):
47
+ def _check_shapes_for_direct (s , shape , axes ):
48
48
if len (axes ) > 7 : # Intel MKL supports up to 7D
49
49
return False
50
- if not ( len (xs ) == len (shape ) ):
51
- # full-dimensional transform
50
+ if len (s ) != len (shape ):
51
+ # not a full-dimensional transform
52
52
return False
53
- if not ( len (set (axes )) == len (axes ) ):
53
+ if len (set (axes )) != len (axes ):
54
54
# repeated axes
55
55
return False
56
- for xsi , ai in zip (xs , axes ):
57
- try :
58
- sh_ai = shape [ai ]
59
- except IndexError :
60
- raise ValueError ("Invalid axis (%d) specified" % ai )
61
-
62
- if not (xsi == sh_ai ):
63
- return False
56
+ new_shape = tuple (shape [ax ] for ax in axes )
57
+ if tuple (s ) != new_shape :
58
+ # trimming or padding is needed
59
+ return False
64
60
return True
65
61
66
62
@@ -78,30 +74,6 @@ def _compute_fwd_scale(norm, n, shape):
78
74
return np .sqrt (fsc )
79
75
80
76
81
- def _cook_nd_args (a , s = None , axes = None , invreal = False ):
82
- if s is None :
83
- shapeless = True
84
- if axes is None :
85
- s = list (a .shape )
86
- else :
87
- try :
88
- s = [a .shape [i ] for i in axes ]
89
- except IndexError :
90
- # fake s designed to trip the ValueError further down
91
- s = range (len (axes ) + 1 )
92
- pass
93
- else :
94
- shapeless = False
95
- s = list (s )
96
- if axes is None :
97
- axes = list (range (- len (s ), 0 ))
98
- if len (s ) != len (axes ):
99
- raise ValueError ("Shape and axes have different lengths." )
100
- if invreal and shapeless :
101
- s [- 1 ] = (a .shape [axes [- 1 ]] - 1 ) * 2
102
- return s , axes
103
-
104
-
105
77
# copied from scipy.fft module
106
78
# https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py
107
79
def _datacopied (arr , original ):
@@ -129,89 +101,7 @@ def _flat_to_multi(ind, shape):
129
101
return m_ind
130
102
131
103
132
- # copied from scipy.fftpack.helper
133
- def _init_nd_shape_and_axes (x , shape , axes ):
134
- """Handle shape and axes arguments for n-dimensional transforms.
135
- Returns the shape and axes in a standard form, taking into account negative
136
- values and checking for various potential errors.
137
- Parameters
138
- ----------
139
- x : array_like
140
- The input array.
141
- shape : int or array_like of ints or None
142
- The shape of the result. If both `shape` and `axes` (see below) are
143
- None, `shape` is ``x.shape``; if `shape` is None but `axes` is
144
- not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``.
145
- If `shape` is -1, the size of the corresponding dimension of `x` is
146
- used.
147
- axes : int or array_like of ints or None
148
- Axes along which the calculation is computed.
149
- The default is over all axes.
150
- Negative indices are automatically converted to their positive
151
- counterpart.
152
- Returns
153
- -------
154
- shape : array
155
- The shape of the result. It is a 1D integer array.
156
- axes : array
157
- The shape of the result. It is a 1D integer array.
158
- """
159
- x = np .asarray (x )
160
- noshape = shape is None
161
- noaxes = axes is None
162
-
163
- if noaxes :
164
- axes = np .arange (x .ndim , dtype = np .intc )
165
- else :
166
- axes = np .atleast_1d (axes )
167
-
168
- if axes .size == 0 :
169
- axes = axes .astype (np .intc )
170
-
171
- if not axes .ndim == 1 :
172
- raise ValueError ("when given, axes values must be a scalar or vector" )
173
- if not np .issubdtype (axes .dtype , np .integer ):
174
- raise ValueError ("when given, axes values must be integers" )
175
-
176
- axes = np .where (axes < 0 , axes + x .ndim , axes )
177
-
178
- if axes .size != 0 and (axes .max () >= x .ndim or axes .min () < 0 ):
179
- raise ValueError ("axes exceeds dimensionality of input" )
180
- if axes .size != 0 and np .unique (axes ).shape != axes .shape :
181
- raise ValueError ("all axes must be unique" )
182
-
183
- if not noshape :
184
- shape = np .atleast_1d (shape )
185
- elif np .isscalar (x ):
186
- shape = np .array ([], dtype = np .intc )
187
- elif noaxes :
188
- shape = np .array (x .shape , dtype = np .intc )
189
- else :
190
- shape = np .take (x .shape , axes )
191
-
192
- if shape .size == 0 :
193
- shape = shape .astype (np .intc )
194
-
195
- if shape .ndim != 1 :
196
- raise ValueError ("when given, shape values must be a scalar or vector" )
197
- if not np .issubdtype (shape .dtype , np .integer ):
198
- raise ValueError ("when given, shape values must be integers" )
199
- if axes .shape != shape .shape :
200
- raise ValueError (
201
- "when given, axes and shape arguments have to be of the same length"
202
- )
203
-
204
- shape = np .where (shape == - 1 , np .array (x .shape )[axes ], shape )
205
- if shape .size != 0 and (shape < 1 ).any ():
206
- raise ValueError (f"invalid number of data points ({ shape } ) specified" )
207
-
208
- return shape , axes
209
-
210
-
211
104
def _iter_complementary (x , axes , func , kwargs , result ):
212
- if axes is None :
213
- # s and axes are None, direct N-D FFT
214
- return func (x , ** kwargs , out = result )
215
105
x_shape = x .shape
216
106
nd = x .ndim
217
107
r = list (range (nd ))
@@ -260,9 +150,6 @@ def _iter_fftnd(
260
150
direction = + 1 ,
261
151
scale_function = lambda ind : 1.0 ,
262
152
):
263
- a = np .asarray (a )
264
- s , axes = _init_nd_shape_and_axes (a , s , axes )
265
-
266
153
# Combine the two, but in reverse, to end with the first axis given.
267
154
axes_and_s = list (zip (axes , s ))[::- 1 ]
268
155
# We try to use in-place calculations where possible, which is
@@ -309,13 +196,14 @@ def _output_dtype(dt):
309
196
def _pad_array (arr , s , axes ):
310
197
"""Pads array arr with zeros to attain shape s associated with axes"""
311
198
arr_shape = arr .shape
199
+ new_shape = tuple (arr_shape [ax ] for ax in axes )
200
+ if tuple (s ) == new_shape :
201
+ return arr
202
+
312
203
no_padding = True
313
204
pad_widths = [(0 , 0 )] * len (arr_shape )
314
205
for si , ai in zip (s , axes ):
315
- try :
316
- shp_i = arr_shape [ai ]
317
- except IndexError :
318
- raise ValueError (f"Invalid axis { ai } specified" )
206
+ shp_i = arr_shape [ai ]
319
207
if si > shp_i :
320
208
no_padding = False
321
209
pad_widths [ai ] = (0 , si - shp_i )
@@ -345,14 +233,14 @@ def _trim_array(arr, s, axes):
345
233
"""
346
234
347
235
arr_shape = arr .shape
236
+ new_shape = tuple (arr_shape [ax ] for ax in axes )
237
+ if tuple (s ) == new_shape :
238
+ return arr
239
+
348
240
no_trim = True
349
241
ind = [slice (None , None , None )] * len (arr_shape )
350
242
for si , ai in zip (s , axes ):
351
- try :
352
- shp_i = arr_shape [ai ]
353
- except IndexError :
354
- raise ValueError (f"Invalid axis { ai } specified" )
355
- if si < shp_i :
243
+ if si < arr_shape [ai ]:
356
244
no_trim = False
357
245
ind [ai ] = slice (None , si , None )
358
246
if no_trim :
@@ -383,16 +271,11 @@ def _c2c_fftnd_impl(
383
271
if direction not in [- 1 , + 1 ]:
384
272
raise ValueError ("Direction of FFT should +1 or -1" )
385
273
274
+ x = np .asarray (x )
386
275
valid_dtypes = [np .complex64 , np .complex128 , np .float32 , np .float64 ]
387
276
# _direct_fftnd requires complex type, and full-dimensional transform
388
- if isinstance (x , np .ndarray ) and x .size != 0 and x .ndim > 1 :
389
- _direct = s is None and axes is None
390
- if _direct :
391
- _direct = x .ndim <= 7 # Intel MKL only supports FFT up to 7D
392
- if not _direct :
393
- xs , xa = _cook_nd_args (x , s , axes )
394
- if _check_shapes_for_direct (xs , x .shape , xa ):
395
- _direct = True
277
+ if x .size != 0 and x .ndim > 1 :
278
+ _direct = _check_shapes_for_direct (s , x .shape , axes )
396
279
_direct = _direct and x .dtype in valid_dtypes
397
280
else :
398
281
_direct = False
@@ -405,14 +288,23 @@ def _c2c_fftnd_impl(
405
288
out = out ,
406
289
)
407
290
else :
408
- if s is None and x .dtype in valid_dtypes :
409
- x = np .asarray (x )
291
+ new_shape = tuple (x .shape [ax ] for ax in axes )
292
+ if (
293
+ tuple (s ) == new_shape
294
+ and x .dtype in valid_dtypes
295
+ and len (set (axes )) == len (axes )
296
+ ):
410
297
if out is None :
411
298
res = np .empty_like (x , dtype = _output_dtype (x .dtype ))
412
299
else :
413
300
_validate_out_array (out , x , _output_dtype (x .dtype ))
414
301
res = out
415
302
303
+ # MKL is capable of doing batch N-D FFT, it is not required to
304
+ # manually loop over the batches as done in _iter_complementary and
305
+ # it is the reason for bad performance mentioned in the gh-issue-#67
306
+ # TODO: implement a batch N-D FFT using MKL
307
+ # _iter_complementary performs batches of N-D FFT
416
308
return _iter_complementary (
417
309
x ,
418
310
axes ,
@@ -434,14 +326,9 @@ def _c2c_fftnd_impl(
434
326
435
327
def _r2c_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
436
328
a = np .asarray (x )
437
- no_trim = (s is None ) and (axes is None )
438
- s , axes = _cook_nd_args (a , s , axes )
439
- axes = [ax + a .ndim if ax < 0 else ax for ax in axes ]
440
329
la = axes [- 1 ]
441
-
442
330
# trim array, so that rfft avoids doing unnecessary computations
443
- if not no_trim :
444
- a = _trim_array (a , s , axes )
331
+ a = _trim_array (a , s , axes )
445
332
446
333
# last axis is not included since we calculate r2c FFT separately
447
334
# and not in the loop
@@ -453,13 +340,11 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
453
340
a = _r2c_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = res )
454
341
res = a
455
342
if len (s ) > 1 :
456
-
457
343
len_axes = len (axes )
458
344
if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
459
- if not no_trim :
460
- ss = list (s )
461
- ss [- 1 ] = a .shape [la ]
462
- a = _pad_array (a , tuple (ss ), axes )
345
+ ss = list (s )
346
+ ss [- 1 ] = a .shape [la ]
347
+ a = _pad_array (a , tuple (ss ), axes )
463
348
# a series of ND c2c FFTs along last axis
464
349
ss , aa = _remove_axis (s , axes , - 1 )
465
350
ind = [slice (None , None , 1 )] * len (s )
@@ -494,17 +379,12 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
494
379
495
380
def _c2r_fftnd_impl (x , s = None , axes = None , fsc = 1.0 , out = None ):
496
381
a = np .asarray (x )
497
- no_trim = (s is None ) and (axes is None )
498
- s , axes = _cook_nd_args (a , s , axes , invreal = True )
499
- axes = [ax + a .ndim if ax < 0 else ax for ax in axes ]
500
382
la = axes [- 1 ]
501
- if not no_trim :
502
- a = _trim_array (a , s , axes )
503
383
if len (s ) > 1 :
504
384
len_axes = len (axes )
505
385
if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
506
- if not no_trim :
507
- a = _pad_array (a , s , axes )
386
+ a = _trim_array ( a , s , axes )
387
+ a = _pad_array (a , s , axes )
508
388
# a series of ND c2c FFTs along last axis
509
389
# due to need to write into a, we must copy
510
390
a = a if _datacopied (a , x ) else a .copy ()
@@ -521,8 +401,8 @@ def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
521
401
tind = tuple (ind )
522
402
a_inp = a [tind ]
523
403
# out has real dtype and cannot be used in intermediate steps
524
- # ss and aa are reversed since np.irfftn uses forward order but
525
- # np .ifftn uses reverse order see numpy-gh-28950
404
+ # ss and aa are reversed since np.fft. irfftn uses forward order
405
+ # but np.fft .ifftn uses reverse order see numpy-gh-28950
526
406
_ = _c2c_fftnd_impl (
527
407
a_inp , s = ss [::- 1 ], axes = aa [::- 1 ], out = a_inp , direction = - 1
528
408
)
0 commit comments