Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,11 @@ Install it via ``pip`` with

FFTW
----
Three different "engines" are provided by the :py:class:`pylops.signalprocessing.FFT` operator:
``engine="numpy"`` (default), ``engine="scipy"`` and ``engine="fftw"``.
Four different "engines" are provided by the :py:class:`pylops.signalprocessing.FFT` operator:
``engine="numpy"`` (default), ``engine="scipy"``, ``engine="fftw"`` and ``engine="mkl_fft"``.

The first two engines are part of the required PyLops dependencies.
The latter implements the well-known `FFTW <http://www.fftw.org>`_
The third implements the well-known `FFTW <http://www.fftw.org>`_
via the Python wrapper :py:class:`pyfftw.FFTW`. While this optimized FFT tends to
outperform the other two in many cases, it is not included by default.
To use this library, install it manually either via ``conda``:
Expand All @@ -381,9 +381,24 @@ or via pip:
FFTW is only available for :py:class:`pylops.signalprocessing.FFT`,
not :py:class:`pylops.signalprocessing.FFT2D` or :py:class:`pylops.signalprocessing.FFTND`.

.. warning::
Intel MKL FFT is not supported.
The fourth implements Intel MKL FFT via the Python interface `mkl_fft <https://github.com/IntelPython/mkl_fft>`_.
This provides access to Intel’s oneMKL Fourier Transform routines, enabling efficient FFT computations with performance
close to native C/Intel® oneMKL

To use this library, you can install it using ``conda``:

.. code-block:: bash

>> conda install --channel https://software.repos.intel.com/python/conda --channel conda-forge mkl_fft

or via pip:

.. code-block:: bash

>> pip install --index-url https://software.repos.intel.com/python/pypi --extra-index-url https://pypi.org/simple mkl_fft

.. note::
`mkl_fft` is not supported on macOS

Numba
-----
Expand Down
2 changes: 2 additions & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
name: pylops
channels:
- https://software.repos.intel.com/python/conda
- defaults
- conda-forge
- numba
Expand All @@ -24,6 +25,7 @@ dependencies:
- autopep8
- isort
- black
- mkl_fft
- pip:
- torch
- devito
Expand Down
24 changes: 24 additions & 0 deletions examples/plot_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,30 @@
axs[1].set_xlim([0, 3 * f0])
plt.tight_layout()

###############################################################################
# PyLops also has a third FFT engine (engine='mkl_fft') that uses the well-known
# `Intel MKL FFT <https://github.com/IntelPython/mkl_fft>`_. This is a Python wrapper around
# the `Intel® oneAPI Math Kernel Library (oneMKL) <https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2025-2/fourier-transform-functions.html>`_
# Fourier Transform functions. It lets PyLops run discrete Fourier transforms faster
# by using Intel’s highly optimized math routines.

FFTop = pylops.signalprocessing.FFT(dims=nt, nfft=nfft, sampling=dt, engine="mkl_fft")
D = FFTop * d

# Inverse for FFT
dinv = FFTop.H * D
dinv = FFTop / D

fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].plot(t, d, "k", lw=2, label="True")
axs[0].plot(t, dinv.real, "--r", lw=2, label="Inverted")
axs[0].legend()
axs[0].set_title("Signal")
axs[1].plot(FFTop.f[: int(FFTop.nfft / 2)], np.abs(D[: int(FFTop.nfft / 2)]), "k", lw=2)
axs[1].set_title("Fourier Transform with MKL FFT")
axs[1].set_xlim([0, 3 * f0])
plt.tight_layout()

###############################################################################
# We can also apply the one dimensional FFT to to a two-dimensional
# signal (along one of the first axis)
Expand Down
117 changes: 113 additions & 4 deletions pylops/signalprocessing/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

pyfftw_message = deps.pyfftw_import("the fft module")
mkl_fft_message = deps.mkl_fft_import("the mkl fft module")

if pyfftw_message is None:
import pyfftw

if mkl_fft_message is None:
import mkl_fft.interfaces.numpy_fft as mkl_backend

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -394,6 +398,94 @@ def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike:
return self._rmatvec(y) / self._scale


class _FFT_mklfft(_BaseFFT):
"""One-dimensional Fast-Fourier Transform using mkl_fft"""

def __init__(
self,
dims: Union[int, InputDimsLike],
axis: int = -1,
nfft: Optional[int] = None,
sampling: float = 1.0,
norm: str = "ortho",
real: bool = False,
ifftshift_before: bool = False,
fftshift_after: bool = False,
dtype: DTypeLike = "complex128",
**kwargs_fft,
) -> None:
super().__init__(
dims=dims,
axis=axis,
nfft=nfft,
sampling=sampling,
norm=norm,
real=real,
ifftshift_before=ifftshift_before,
fftshift_after=fftshift_after,
dtype=dtype,
)
self._kwargs_fft = kwargs_fft
self._norm_kwargs = {"norm": None}
if self.norm is _FFTNorms.ORTHO:
self._norm_kwargs["norm"] = "ortho"
self._scale = np.sqrt(1 / self.nfft)
elif self.norm is _FFTNorms.NONE:
self._scale = self.nfft
elif self.norm is _FFTNorms.ONE_OVER_N:
self._scale = 1.0 / self.nfft

@reshaped
def _matvec(self, x: NDArray) -> NDArray:
if self.ifftshift_before:
x = mkl_backend.ifftshift(x, axes=self.axis)
if not self.clinear:
x = np.real(x)
if self.real:
y = mkl_backend.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
y = np.swapaxes(y, -1, self.axis)
y[..., 1 : 1 + (self.nfft - 1) // 2] *= np.sqrt(2)
y = np.swapaxes(y, self.axis, -1)
else:
y = mkl_backend.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
if self.fftshift_after:
y = mkl_backend.fftshift(y, axes=self.axis)
return y

@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
if self.fftshift_after:
x = mkl_backend.ifftshift(x, axes=self.axis)
if self.real:
x = x.copy()
x = np.swapaxes(x, -1, self.axis)
x[..., 1 : 1 + (self.nfft - 1) // 2] /= np.sqrt(2)
x = np.swapaxes(x, self.axis, -1)
y = mkl_backend.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
else:
y = mkl_backend.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
if self.norm is _FFTNorms.NONE:
y *= self._scale

if self.nfft > self.dims[self.axis]:
y = np.take(y, range(0, self.dims[self.axis]), axis=self.axis)
elif self.nfft < self.dims[self.axis]:
y = np.pad(y, self.ifftpad)

if not self.clinear:
y = np.real(y)
if self.ifftshift_before:
y = mkl_backend.fftshift(y, axes=self.axis)
return y

def __truediv__(self, y):
if self.norm is not _FFTNorms.ORTHO:
return self._rmatvec(y) / self._scale
return self._rmatvec(y)


def FFT(
dims: Union[int, InputDimsLike],
axis: int = -1,
Expand Down Expand Up @@ -481,7 +573,7 @@ def FFT(
frequencies are arranged from zero to largest positive, and then from negative
Nyquist to the frequency bin before zero.
engine : :obj:`str`, optional
Engine used for fft computation (``numpy``, ``fftw``, or ``scipy``). Choose
Engine used for fft computation (``numpy``, ``fftw``, ``scipy`` or ``mkl_fft``). Choose
``numpy`` when working with cupy and jax arrays.

.. note:: Since version 1.17.0, accepts "scipy".
Expand Down Expand Up @@ -534,7 +626,7 @@ def FFT(
- If ``dims`` is provided and ``axis`` is bigger than ``len(dims)``.
- If ``norm`` is not one of "ortho", "none", or "1/n".
NotImplementedError
If ``engine`` is neither ``numpy``, ``fftw``, nor ``scipy``.
If ``engine`` is neither ``numpy``, ``fftw``, ``scipy`` nor ``mkl_fft``.

See Also
--------
Expand Down Expand Up @@ -579,7 +671,24 @@ def FFT(
dtype=dtype,
**kwargs_fft,
)
elif engine == "numpy" or (engine == "fftw" and pyfftw_message is not None):
elif engine == "mkl_fft" and mkl_fft_message is None:
f = _FFT_mklfft(
dims,
axis=axis,
nfft=nfft,
sampling=sampling,
norm=norm,
real=real,
ifftshift_before=ifftshift_before,
fftshift_after=fftshift_after,
dtype=dtype,
**kwargs_fft,
)
elif (
engine == "numpy"
or (engine == "fftw" and pyfftw_message is not None)
or (engine == "mkl_fft" and mkl_fft_message is not None)
):
if engine == "fftw" and pyfftw_message is not None:
logger.warning(pyfftw_message)
f = _FFT_numpy(
Expand Down Expand Up @@ -608,6 +717,6 @@ def FFT(
**kwargs_fft,
)
else:
raise NotImplementedError("engine must be numpy, fftw or scipy")
raise NotImplementedError("engine must be numpy, fftw, scipy or mkl_fft")
f.name = name
return f
Loading