Skip to content

fix a bug for repeated axes in N-D c2c FFT #215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: revisit_overwrite_x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
* Fixed a bug for N-D FFTs when both `s` and `out` are given [gh-185](https://github.com/IntelPython/mkl_fft/pull/185)
* Fixed a bug for a case when a repeated indices is passed for axes keyword in N-dimensional FFT [gh-215](https://github.com/IntelPython/mkl_fft/pull/215)

## [2.0.0] - 2025-06-03

Expand Down
200 changes: 40 additions & 160 deletions mkl_fft/_fft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,19 @@ def _check_norm(norm):
)


def _check_shapes_for_direct(xs, shape, axes):
def _check_shapes_for_direct(s, shape, axes):
if len(axes) > 7: # Intel MKL supports up to 7D
return False
if not (len(xs) == len(shape)):
# full-dimensional transform
if len(s) != len(shape):
# not a full-dimensional transform
return False
if not (len(set(axes)) == len(axes)):
if len(set(axes)) != len(axes):
# repeated axes
return False
for xsi, ai in zip(xs, axes):
try:
sh_ai = shape[ai]
except IndexError:
raise ValueError("Invalid axis (%d) specified" % ai)

if not (xsi == sh_ai):
return False
new_shape = tuple(shape[ax] for ax in axes)
if tuple(s) != new_shape:
# trimming or padding is needed
return False
return True


Expand All @@ -78,30 +74,6 @@ def _compute_fwd_scale(norm, n, shape):
return np.sqrt(fsc)


def _cook_nd_args(a, s=None, axes=None, invreal=False):
if s is None:
shapeless = True
if axes is None:
s = list(a.shape)
else:
try:
s = [a.shape[i] for i in axes]
except IndexError:
# fake s designed to trip the ValueError further down
s = range(len(axes) + 1)
pass
else:
shapeless = False
s = list(s)
if axes is None:
axes = list(range(-len(s), 0))
if len(s) != len(axes):
raise ValueError("Shape and axes have different lengths.")
if invreal and shapeless:
s[-1] = (a.shape[axes[-1]] - 1) * 2
return s, axes


# copied from scipy.fft module
# https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py
def _datacopied(arr, original):
Expand Down Expand Up @@ -129,89 +101,7 @@ def _flat_to_multi(ind, shape):
return m_ind


# copied from scipy.fftpack.helper
def _init_nd_shape_and_axes(x, shape, axes):
"""Handle shape and axes arguments for n-dimensional transforms.
Returns the shape and axes in a standard form, taking into account negative
values and checking for various potential errors.
Parameters
----------
x : array_like
The input array.
shape : int or array_like of ints or None
The shape of the result. If both `shape` and `axes` (see below) are
None, `shape` is ``x.shape``; if `shape` is None but `axes` is
not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``.
If `shape` is -1, the size of the corresponding dimension of `x` is
used.
axes : int or array_like of ints or None
Axes along which the calculation is computed.
The default is over all axes.
Negative indices are automatically converted to their positive
counterpart.
Returns
-------
shape : array
The shape of the result. It is a 1D integer array.
axes : array
The shape of the result. It is a 1D integer array.
"""
x = np.asarray(x)
noshape = shape is None
noaxes = axes is None

if noaxes:
axes = np.arange(x.ndim, dtype=np.intc)
else:
axes = np.atleast_1d(axes)

if axes.size == 0:
axes = axes.astype(np.intc)

if not axes.ndim == 1:
raise ValueError("when given, axes values must be a scalar or vector")
if not np.issubdtype(axes.dtype, np.integer):
raise ValueError("when given, axes values must be integers")

axes = np.where(axes < 0, axes + x.ndim, axes)

if axes.size != 0 and (axes.max() >= x.ndim or axes.min() < 0):
raise ValueError("axes exceeds dimensionality of input")
if axes.size != 0 and np.unique(axes).shape != axes.shape:
raise ValueError("all axes must be unique")

if not noshape:
shape = np.atleast_1d(shape)
elif np.isscalar(x):
shape = np.array([], dtype=np.intc)
elif noaxes:
shape = np.array(x.shape, dtype=np.intc)
else:
shape = np.take(x.shape, axes)

if shape.size == 0:
shape = shape.astype(np.intc)

if shape.ndim != 1:
raise ValueError("when given, shape values must be a scalar or vector")
if not np.issubdtype(shape.dtype, np.integer):
raise ValueError("when given, shape values must be integers")
if axes.shape != shape.shape:
raise ValueError(
"when given, axes and shape arguments have to be of the same length"
)

shape = np.where(shape == -1, np.array(x.shape)[axes], shape)
if shape.size != 0 and (shape < 1).any():
raise ValueError(f"invalid number of data points ({shape}) specified")

return shape, axes


def _iter_complementary(x, axes, func, kwargs, result):
if axes is None:
# s and axes are None, direct N-D FFT
return func(x, **kwargs, out=result)
x_shape = x.shape
nd = x.ndim
r = list(range(nd))
Expand Down Expand Up @@ -260,9 +150,6 @@ def _iter_fftnd(
direction=+1,
scale_function=lambda ind: 1.0,
):
a = np.asarray(a)
s, axes = _init_nd_shape_and_axes(a, s, axes)

# Combine the two, but in reverse, to end with the first axis given.
axes_and_s = list(zip(axes, s))[::-1]
# We try to use in-place calculations where possible, which is
Expand Down Expand Up @@ -309,13 +196,14 @@ def _output_dtype(dt):
def _pad_array(arr, s, axes):
"""Pads array arr with zeros to attain shape s associated with axes"""
arr_shape = arr.shape
new_shape = tuple(arr_shape[ax] for ax in axes)
if tuple(s) == new_shape:
return arr

no_padding = True
pad_widths = [(0, 0)] * len(arr_shape)
for si, ai in zip(s, axes):
try:
shp_i = arr_shape[ai]
except IndexError:
raise ValueError(f"Invalid axis {ai} specified")
shp_i = arr_shape[ai]
if si > shp_i:
no_padding = False
pad_widths[ai] = (0, si - shp_i)
Expand Down Expand Up @@ -345,14 +233,14 @@ def _trim_array(arr, s, axes):
"""

arr_shape = arr.shape
new_shape = tuple(arr_shape[ax] for ax in axes)
if tuple(s) == new_shape:
return arr

no_trim = True
ind = [slice(None, None, None)] * len(arr_shape)
for si, ai in zip(s, axes):
try:
shp_i = arr_shape[ai]
except IndexError:
raise ValueError(f"Invalid axis {ai} specified")
if si < shp_i:
if si < arr_shape[ai]:
no_trim = False
ind[ai] = slice(None, si, None)
if no_trim:
Expand Down Expand Up @@ -383,16 +271,11 @@ def _c2c_fftnd_impl(
if direction not in [-1, +1]:
raise ValueError("Direction of FFT should +1 or -1")

x = np.asarray(x)
valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64]
# _direct_fftnd requires complex type, and full-dimensional transform
if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1:
_direct = s is None and axes is None
if _direct:
_direct = x.ndim <= 7 # Intel MKL only supports FFT up to 7D
if not _direct:
xs, xa = _cook_nd_args(x, s, axes)
if _check_shapes_for_direct(xs, x.shape, xa):
_direct = True
if x.size != 0 and x.ndim > 1:
_direct = _check_shapes_for_direct(s, x.shape, axes)
_direct = _direct and x.dtype in valid_dtypes
else:
_direct = False
Expand All @@ -405,14 +288,23 @@ def _c2c_fftnd_impl(
out=out,
)
else:
if s is None and x.dtype in valid_dtypes:
x = np.asarray(x)
new_shape = tuple(x.shape[ax] for ax in axes)
if (
tuple(s) == new_shape
and x.dtype in valid_dtypes
and len(set(axes)) == len(axes)
):
if out is None:
res = np.empty_like(x, dtype=_output_dtype(x.dtype))
else:
_validate_out_array(out, x, _output_dtype(x.dtype))
res = out

# MKL is capable of doing batch N-D FFT, it is not required to
# manually loop over the batches as done in _iter_complementary and
# it is the reason for bad performance mentioned in the gh-issue-#67
# TODO: implement a batch N-D FFT using MKL
# _iter_complementary performs batches of N-D FFT
return _iter_complementary(
x,
axes,
Expand All @@ -434,14 +326,9 @@ def _c2c_fftnd_impl(

def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
a = np.asarray(x)
no_trim = (s is None) and (axes is None)
s, axes = _cook_nd_args(a, s, axes)
axes = [ax + a.ndim if ax < 0 else ax for ax in axes]
la = axes[-1]

# trim array, so that rfft avoids doing unnecessary computations
if not no_trim:
a = _trim_array(a, s, axes)
a = _trim_array(a, s, axes)

# last axis is not included since we calculate r2c FFT separately
# and not in the loop
Expand All @@ -453,13 +340,11 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
a = _r2c_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=res)
res = a
if len(s) > 1:

len_axes = len(axes)
if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2:
if not no_trim:
ss = list(s)
ss[-1] = a.shape[la]
a = _pad_array(a, tuple(ss), axes)
ss = list(s)
ss[-1] = a.shape[la]
a = _pad_array(a, tuple(ss), axes)
# a series of ND c2c FFTs along last axis
ss, aa = _remove_axis(s, axes, -1)
ind = [slice(None, None, 1)] * len(s)
Expand Down Expand Up @@ -494,17 +379,12 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):

def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
a = np.asarray(x)
no_trim = (s is None) and (axes is None)
s, axes = _cook_nd_args(a, s, axes, invreal=True)
axes = [ax + a.ndim if ax < 0 else ax for ax in axes]
la = axes[-1]
if not no_trim:
a = _trim_array(a, s, axes)
if len(s) > 1:
len_axes = len(axes)
if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2:
if not no_trim:
a = _pad_array(a, s, axes)
a = _trim_array(a, s, axes)
a = _pad_array(a, s, axes)
# a series of ND c2c FFTs along last axis
# due to need to write into a, we must copy
a = a if _datacopied(a, x) else a.copy()
Expand All @@ -521,8 +401,8 @@ def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
tind = tuple(ind)
a_inp = a[tind]
# out has real dtype and cannot be used in intermediate steps
# ss and aa are reversed since np.irfftn uses forward order but
# np.ifftn uses reverse order see numpy-gh-28950
# ss and aa are reversed since np.fft.irfftn uses forward order
# but np.fft.ifftn uses reverse order see numpy-gh-28950
_ = _c2c_fftnd_impl(
a_inp, s=ss[::-1], axes=aa[::-1], out=a_inp, direction=-1
)
Expand Down
36 changes: 36 additions & 0 deletions mkl_fft/_mkl_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np

from ._fft_utils import (
_c2c_fftnd_impl,
_c2r_fftnd_impl,
Expand All @@ -50,6 +52,36 @@
]


# copied with modifications from:
# https://github.com/numpy/numpy/blob/main/numpy/fft/_pocketfft.py
def _cook_nd_args(a, s=None, axes=None, invreal=False):
if s is None:
shapeless = True
if axes is None:
s = list(a.shape)
else:
s = np.take(a.shape, axes)
else:
shapeless = False
s = list(s)
if axes is None:
if not shapeless:
raise ValueError("If s is not None, axes must not be None either.")
axes = list(range(-len(s), 0))
if len(s) != len(axes):
raise ValueError("Shape and axes have different lengths.")
if invreal and shapeless:
s[-1] = (a.shape[axes[-1]] - 1) * 2
if None in s:
raise ValueError("s must contain only int.")
# use the whole input array along axis `i` if `s[i] == -1`
s = [a.shape[_a] if _s == -1 else _s for _s, _a in zip(s, axes)]

# make axes positive
axes = [ax + a.ndim if ax < 0 else ax for ax in axes]
return s, axes


def fft(x, n=None, axis=-1, norm=None, out=None):
fsc = _compute_fwd_scale(norm, n, x.shape[axis])
return _c2c_fft1d_impl(x, n=n, axis=axis, out=out, direction=+1, fsc=fsc)
Expand All @@ -70,11 +102,13 @@ def ifft2(x, s=None, axes=(-2, -1), norm=None, out=None):

def fftn(x, s=None, axes=None, norm=None, out=None):
fsc = _compute_fwd_scale(norm, s, x.shape)
s, axes = _cook_nd_args(x, s, axes)
return _c2c_fftnd_impl(x, s=s, axes=axes, out=out, direction=+1, fsc=fsc)


def ifftn(x, s=None, axes=None, norm=None, out=None):
fsc = _compute_fwd_scale(norm, s, x.shape)
s, axes = _cook_nd_args(x, s, axes)
return _c2c_fftnd_impl(x, s=s, axes=axes, out=out, direction=-1, fsc=fsc)


Expand All @@ -98,9 +132,11 @@ def irfft2(x, s=None, axes=(-2, -1), norm=None, out=None):

def rfftn(x, s=None, axes=None, norm=None, out=None):
fsc = _compute_fwd_scale(norm, s, x.shape)
s, axes = _cook_nd_args(x, s, axes)
return _r2c_fftnd_impl(x, s=s, axes=axes, out=out, fsc=fsc)


def irfftn(x, s=None, axes=None, norm=None, out=None):
fsc = _compute_fwd_scale(norm, s, x.shape)
s, axes = _cook_nd_args(x, s, axes, invreal=True)
return _c2r_fftnd_impl(x, s=s, axes=axes, out=out, fsc=fsc)
Loading
Loading