Skip to content

Commit d921f55

Browse files
committed
fix a bug for repeated axes in N-D c2c FFT
1 parent 230d8c1 commit d921f55

File tree

7 files changed

+105
-165
lines changed

7 files changed

+105
-165
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616

1717
### Fixed
1818
* Fixed a bug for N-D FFTs when both `s` and `out` are given [gh-185](https://github.com/IntelPython/mkl_fft/pull/185)
19+
* Fixed a bug for a case when a repeated indices is passed for axes keyword in N-dimensional FFT [gh-215](https://github.com/IntelPython/mkl_fft/pull/215)
1920

2021
## [2.0.0] - 2025-06-03
2122

mkl_fft/_fft_utils.py

Lines changed: 40 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,19 @@ def _check_norm(norm):
4444
)
4545

4646

47-
def _check_shapes_for_direct(xs, shape, axes):
47+
def _check_shapes_for_direct(s, shape, axes):
4848
if len(axes) > 7: # Intel MKL supports up to 7D
4949
return False
50-
if not (len(xs) == len(shape)):
51-
# full-dimensional transform
50+
if len(s) != len(shape):
51+
# not a full-dimensional transform
5252
return False
53-
if not (len(set(axes)) == len(axes)):
53+
if len(set(axes)) != len(axes):
5454
# repeated axes
5555
return False
56-
for xsi, ai in zip(xs, axes):
57-
try:
58-
sh_ai = shape[ai]
59-
except IndexError:
60-
raise ValueError("Invalid axis (%d) specified" % ai)
61-
62-
if not (xsi == sh_ai):
63-
return False
56+
new_shape = tuple(shape[ax] for ax in axes)
57+
if tuple(s) != new_shape:
58+
# trimming or padding is needed
59+
return False
6460
return True
6561

6662

@@ -78,30 +74,6 @@ def _compute_fwd_scale(norm, n, shape):
7874
return np.sqrt(fsc)
7975

8076

81-
def _cook_nd_args(a, s=None, axes=None, invreal=False):
82-
if s is None:
83-
shapeless = True
84-
if axes is None:
85-
s = list(a.shape)
86-
else:
87-
try:
88-
s = [a.shape[i] for i in axes]
89-
except IndexError:
90-
# fake s designed to trip the ValueError further down
91-
s = range(len(axes) + 1)
92-
pass
93-
else:
94-
shapeless = False
95-
s = list(s)
96-
if axes is None:
97-
axes = list(range(-len(s), 0))
98-
if len(s) != len(axes):
99-
raise ValueError("Shape and axes have different lengths.")
100-
if invreal and shapeless:
101-
s[-1] = (a.shape[axes[-1]] - 1) * 2
102-
return s, axes
103-
104-
10577
# copied from scipy.fft module
10678
# https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py
10779
def _datacopied(arr, original):
@@ -129,89 +101,7 @@ def _flat_to_multi(ind, shape):
129101
return m_ind
130102

131103

132-
# copied from scipy.fftpack.helper
133-
def _init_nd_shape_and_axes(x, shape, axes):
134-
"""Handle shape and axes arguments for n-dimensional transforms.
135-
Returns the shape and axes in a standard form, taking into account negative
136-
values and checking for various potential errors.
137-
Parameters
138-
----------
139-
x : array_like
140-
The input array.
141-
shape : int or array_like of ints or None
142-
The shape of the result. If both `shape` and `axes` (see below) are
143-
None, `shape` is ``x.shape``; if `shape` is None but `axes` is
144-
not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``.
145-
If `shape` is -1, the size of the corresponding dimension of `x` is
146-
used.
147-
axes : int or array_like of ints or None
148-
Axes along which the calculation is computed.
149-
The default is over all axes.
150-
Negative indices are automatically converted to their positive
151-
counterpart.
152-
Returns
153-
-------
154-
shape : array
155-
The shape of the result. It is a 1D integer array.
156-
axes : array
157-
The shape of the result. It is a 1D integer array.
158-
"""
159-
x = np.asarray(x)
160-
noshape = shape is None
161-
noaxes = axes is None
162-
163-
if noaxes:
164-
axes = np.arange(x.ndim, dtype=np.intc)
165-
else:
166-
axes = np.atleast_1d(axes)
167-
168-
if axes.size == 0:
169-
axes = axes.astype(np.intc)
170-
171-
if not axes.ndim == 1:
172-
raise ValueError("when given, axes values must be a scalar or vector")
173-
if not np.issubdtype(axes.dtype, np.integer):
174-
raise ValueError("when given, axes values must be integers")
175-
176-
axes = np.where(axes < 0, axes + x.ndim, axes)
177-
178-
if axes.size != 0 and (axes.max() >= x.ndim or axes.min() < 0):
179-
raise ValueError("axes exceeds dimensionality of input")
180-
if axes.size != 0 and np.unique(axes).shape != axes.shape:
181-
raise ValueError("all axes must be unique")
182-
183-
if not noshape:
184-
shape = np.atleast_1d(shape)
185-
elif np.isscalar(x):
186-
shape = np.array([], dtype=np.intc)
187-
elif noaxes:
188-
shape = np.array(x.shape, dtype=np.intc)
189-
else:
190-
shape = np.take(x.shape, axes)
191-
192-
if shape.size == 0:
193-
shape = shape.astype(np.intc)
194-
195-
if shape.ndim != 1:
196-
raise ValueError("when given, shape values must be a scalar or vector")
197-
if not np.issubdtype(shape.dtype, np.integer):
198-
raise ValueError("when given, shape values must be integers")
199-
if axes.shape != shape.shape:
200-
raise ValueError(
201-
"when given, axes and shape arguments have to be of the same length"
202-
)
203-
204-
shape = np.where(shape == -1, np.array(x.shape)[axes], shape)
205-
if shape.size != 0 and (shape < 1).any():
206-
raise ValueError(f"invalid number of data points ({shape}) specified")
207-
208-
return shape, axes
209-
210-
211104
def _iter_complementary(x, axes, func, kwargs, result):
212-
if axes is None:
213-
# s and axes are None, direct N-D FFT
214-
return func(x, **kwargs, out=result)
215105
x_shape = x.shape
216106
nd = x.ndim
217107
r = list(range(nd))
@@ -260,9 +150,6 @@ def _iter_fftnd(
260150
direction=+1,
261151
scale_function=lambda ind: 1.0,
262152
):
263-
a = np.asarray(a)
264-
s, axes = _init_nd_shape_and_axes(a, s, axes)
265-
266153
# Combine the two, but in reverse, to end with the first axis given.
267154
axes_and_s = list(zip(axes, s))[::-1]
268155
# We try to use in-place calculations where possible, which is
@@ -309,13 +196,14 @@ def _output_dtype(dt):
309196
def _pad_array(arr, s, axes):
310197
"""Pads array arr with zeros to attain shape s associated with axes"""
311198
arr_shape = arr.shape
199+
new_shape = tuple(arr_shape[ax] for ax in axes)
200+
if tuple(s) == new_shape:
201+
return arr
202+
312203
no_padding = True
313204
pad_widths = [(0, 0)] * len(arr_shape)
314205
for si, ai in zip(s, axes):
315-
try:
316-
shp_i = arr_shape[ai]
317-
except IndexError:
318-
raise ValueError(f"Invalid axis {ai} specified")
206+
shp_i = arr_shape[ai]
319207
if si > shp_i:
320208
no_padding = False
321209
pad_widths[ai] = (0, si - shp_i)
@@ -345,14 +233,14 @@ def _trim_array(arr, s, axes):
345233
"""
346234

347235
arr_shape = arr.shape
236+
new_shape = tuple(arr_shape[ax] for ax in axes)
237+
if tuple(s) == new_shape:
238+
return arr
239+
348240
no_trim = True
349241
ind = [slice(None, None, None)] * len(arr_shape)
350242
for si, ai in zip(s, axes):
351-
try:
352-
shp_i = arr_shape[ai]
353-
except IndexError:
354-
raise ValueError(f"Invalid axis {ai} specified")
355-
if si < shp_i:
243+
if si < arr_shape[ai]:
356244
no_trim = False
357245
ind[ai] = slice(None, si, None)
358246
if no_trim:
@@ -383,16 +271,11 @@ def _c2c_fftnd_impl(
383271
if direction not in [-1, +1]:
384272
raise ValueError("Direction of FFT should +1 or -1")
385273

274+
x = np.asarray(x)
386275
valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64]
387276
# _direct_fftnd requires complex type, and full-dimensional transform
388-
if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1:
389-
_direct = s is None and axes is None
390-
if _direct:
391-
_direct = x.ndim <= 7 # Intel MKL only supports FFT up to 7D
392-
if not _direct:
393-
xs, xa = _cook_nd_args(x, s, axes)
394-
if _check_shapes_for_direct(xs, x.shape, xa):
395-
_direct = True
277+
if x.size != 0 and x.ndim > 1:
278+
_direct = _check_shapes_for_direct(s, x.shape, axes)
396279
_direct = _direct and x.dtype in valid_dtypes
397280
else:
398281
_direct = False
@@ -405,14 +288,23 @@ def _c2c_fftnd_impl(
405288
out=out,
406289
)
407290
else:
408-
if s is None and x.dtype in valid_dtypes:
409-
x = np.asarray(x)
291+
new_shape = tuple(x.shape[ax] for ax in axes)
292+
if (
293+
tuple(s) == new_shape
294+
and x.dtype in valid_dtypes
295+
and len(set(axes)) == len(axes)
296+
):
410297
if out is None:
411298
res = np.empty_like(x, dtype=_output_dtype(x.dtype))
412299
else:
413300
_validate_out_array(out, x, _output_dtype(x.dtype))
414301
res = out
415302

303+
# MKL is capable of doing batch N-D FFT, it is not required to
304+
# manually loop over the batches as done in _iter_complementary and
305+
# it is the reason for bad performance mentioned in the gh-issue-#67
306+
# TODO: implement a batch N-D FFT using MKL
307+
# _iter_complementary performs batches of N-D FFT
416308
return _iter_complementary(
417309
x,
418310
axes,
@@ -434,14 +326,9 @@ def _c2c_fftnd_impl(
434326

435327
def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
436328
a = np.asarray(x)
437-
no_trim = (s is None) and (axes is None)
438-
s, axes = _cook_nd_args(a, s, axes)
439-
axes = [ax + a.ndim if ax < 0 else ax for ax in axes]
440329
la = axes[-1]
441-
442330
# trim array, so that rfft avoids doing unnecessary computations
443-
if not no_trim:
444-
a = _trim_array(a, s, axes)
331+
a = _trim_array(a, s, axes)
445332

446333
# last axis is not included since we calculate r2c FFT separately
447334
# and not in the loop
@@ -453,13 +340,11 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
453340
a = _r2c_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=res)
454341
res = a
455342
if len(s) > 1:
456-
457343
len_axes = len(axes)
458344
if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2:
459-
if not no_trim:
460-
ss = list(s)
461-
ss[-1] = a.shape[la]
462-
a = _pad_array(a, tuple(ss), axes)
345+
ss = list(s)
346+
ss[-1] = a.shape[la]
347+
a = _pad_array(a, tuple(ss), axes)
463348
# a series of ND c2c FFTs along last axis
464349
ss, aa = _remove_axis(s, axes, -1)
465350
ind = [slice(None, None, 1)] * len(s)
@@ -494,17 +379,12 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
494379

495380
def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
496381
a = np.asarray(x)
497-
no_trim = (s is None) and (axes is None)
498-
s, axes = _cook_nd_args(a, s, axes, invreal=True)
499-
axes = [ax + a.ndim if ax < 0 else ax for ax in axes]
500382
la = axes[-1]
501-
if not no_trim:
502-
a = _trim_array(a, s, axes)
503383
if len(s) > 1:
504384
len_axes = len(axes)
505385
if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2:
506-
if not no_trim:
507-
a = _pad_array(a, s, axes)
386+
a = _trim_array(a, s, axes)
387+
a = _pad_array(a, s, axes)
508388
# a series of ND c2c FFTs along last axis
509389
# due to need to write into a, we must copy
510390
a = a if _datacopied(a, x) else a.copy()
@@ -521,8 +401,8 @@ def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
521401
tind = tuple(ind)
522402
a_inp = a[tind]
523403
# out has real dtype and cannot be used in intermediate steps
524-
# ss and aa are reversed since np.irfftn uses forward order but
525-
# np.ifftn uses reverse order see numpy-gh-28950
404+
# ss and aa are reversed since np.fft.irfftn uses forward order
405+
# but np.fft.ifftn uses reverse order see numpy-gh-28950
526406
_ = _c2c_fftnd_impl(
527407
a_inp, s=ss[::-1], axes=aa[::-1], out=a_inp, direction=-1
528408
)

mkl_fft/_mkl_fft.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27+
import numpy as np
28+
2729
from ._fft_utils import (
2830
_c2c_fftnd_impl,
2931
_c2r_fftnd_impl,
@@ -50,6 +52,36 @@
5052
]
5153

5254

55+
# copied with modifications from:
56+
# https://github.com/numpy/numpy/blob/main/numpy/fft/_pocketfft.py
57+
def _cook_nd_args(a, s=None, axes=None, invreal=False):
58+
if s is None:
59+
shapeless = True
60+
if axes is None:
61+
s = list(a.shape)
62+
else:
63+
s = np.take(a.shape, axes)
64+
else:
65+
shapeless = False
66+
s = list(s)
67+
if axes is None:
68+
if not shapeless:
69+
raise ValueError("If s is not None, axes must not be None either.")
70+
axes = list(range(-len(s), 0))
71+
if len(s) != len(axes):
72+
raise ValueError("Shape and axes have different lengths.")
73+
if invreal and shapeless:
74+
s[-1] = (a.shape[axes[-1]] - 1) * 2
75+
if None in s:
76+
raise ValueError("s must contain only int.")
77+
# use the whole input array along axis `i` if `s[i] == -1`
78+
s = [a.shape[_a] if _s == -1 else _s for _s, _a in zip(s, axes)]
79+
80+
# make axes positive
81+
axes = [ax + a.ndim if ax < 0 else ax for ax in axes]
82+
return s, axes
83+
84+
5385
def fft(x, n=None, axis=-1, norm=None, out=None):
5486
fsc = _compute_fwd_scale(norm, n, x.shape[axis])
5587
return _c2c_fft1d_impl(x, n=n, axis=axis, out=out, direction=+1, fsc=fsc)
@@ -70,11 +102,13 @@ def ifft2(x, s=None, axes=(-2, -1), norm=None, out=None):
70102

71103
def fftn(x, s=None, axes=None, norm=None, out=None):
72104
fsc = _compute_fwd_scale(norm, s, x.shape)
105+
s, axes = _cook_nd_args(x, s, axes)
73106
return _c2c_fftnd_impl(x, s=s, axes=axes, out=out, direction=+1, fsc=fsc)
74107

75108

76109
def ifftn(x, s=None, axes=None, norm=None, out=None):
77110
fsc = _compute_fwd_scale(norm, s, x.shape)
111+
s, axes = _cook_nd_args(x, s, axes)
78112
return _c2c_fftnd_impl(x, s=s, axes=axes, out=out, direction=-1, fsc=fsc)
79113

80114

@@ -98,9 +132,11 @@ def irfft2(x, s=None, axes=(-2, -1), norm=None, out=None):
98132

99133
def rfftn(x, s=None, axes=None, norm=None, out=None):
100134
fsc = _compute_fwd_scale(norm, s, x.shape)
135+
s, axes = _cook_nd_args(x, s, axes)
101136
return _r2c_fftnd_impl(x, s=s, axes=axes, out=out, fsc=fsc)
102137

103138

104139
def irfftn(x, s=None, axes=None, norm=None, out=None):
105140
fsc = _compute_fwd_scale(norm, s, x.shape)
141+
s, axes = _cook_nd_args(x, s, axes, invreal=True)
106142
return _c2r_fftnd_impl(x, s=s, axes=axes, out=out, fsc=fsc)

0 commit comments

Comments
 (0)