diff --git a/.github/workflows/build-mkl.yaml b/.github/workflows/build-mkl.yaml new file mode 100644 index 00000000..c086d826 --- /dev/null +++ b/.github/workflows/build-mkl.yaml @@ -0,0 +1,41 @@ +name: PyLops Testing with Intel oneAPI Math Kernel Library(oneMKL) + +on: [push, pull_request] + +jobs: + build: + strategy: + matrix: + platform: [ubuntu-latest] + python-version: ["3.10", "3.11", "3.12"] + + runs-on: ${{ matrix.platform }} + defaults: + run: + shell: bash -l {0} + steps: + - uses: actions/checkout@v4 + - name: Get history and tags for SCM versioning to work + run: | + git fetch --prune --unshallow + git fetch --depth=1 origin +refs/tags/*:refs/tags/* + - uses: conda-incubator/setup-miniconda@v3.2.0 + with: + use-mamba: true + channels: https://software.repos.intel.com/python/conda, conda-forge + conda-remove-defaults: true + python-version: ${{ matrix.python-version }} + activate-environment: mkl-test-env + - name: Install dependencies + run: | + conda install -y pyfftw + pip install -r requirements-intel-mkl.txt + pip install -r requirements-dev.txt + pip install -r requirements-torch.txt + - name: Install pylops + run: | + python -m setuptools_scm + pip install . + - name: Tests with pytest + run: | + pytest diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index ac5788ad..c430ae7b 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -25,6 +25,7 @@ jobs: python -m pip install --upgrade pip setuptools pip install flake8 pytest pip install -r requirements-dev.txt + pip install -r requirements-pyfftw.txt pip install -r requirements-torch.txt - name: Install pylops run: | diff --git a/.github/workflows/codacy-coverage-reporter.yaml b/.github/workflows/codacy-coverage-reporter.yaml index 113357ad..239af9f0 100644 --- a/.github/workflows/codacy-coverage-reporter.yaml +++ b/.github/workflows/codacy-coverage-reporter.yaml @@ -27,6 +27,7 @@ jobs: python -m pip install --upgrade pip pip install flake8 pytest pip install -r requirements-dev.txt + pip install -r requirements-pyfftw.txt pip install -r requirements-torch.txt - name: Install pylops run: | diff --git a/Makefile b/Makefile index 06c36ca5..64940b37 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ PIP := $(shell command -v pip3 2> /dev/null || command which pip 2> /dev/null) PYTHON := $(shell command -v python3 2> /dev/null || command which python 2> /dev/null) -.PHONY: install dev-install dev-install_gpu install_conda dev-install_conda dev-install_conda_arm tests tests_cpu_ongpu tests_gpu doc docupdate servedoc lint typeannot coverage +.PHONY: install dev-install dev-install_intel_mkl dev-install_gpu install_conda dev-install_conda dev-install_conda_intel_mkl dev-install_conda_arm tests tests_cpu_ongpu tests_gpu doc docupdate servedoc lint typeannot coverage pipcheck: ifndef PIP @@ -22,6 +22,13 @@ install: dev-install: make pipcheck $(PIP) install -r requirements-dev.txt &&\ + $(PIP) install -r requirements-pyfftw.txt &&\ + $(PIP) install -r requirements-torch.txt && $(PIP) install -e . + +dev-install_intel_mkl: + make pipcheck + $(PIP) install -r requirements-intel-mkl.txt &&\ + $(PIP) install -r requirements-dev.txt &&\ $(PIP) install -r requirements-torch.txt && $(PIP) install -e . dev-install_gpu: @@ -35,6 +42,9 @@ install_conda: dev-install_conda: conda env create -f environment-dev.yml && source ${CONDA_PREFIX}/etc/profile.d/conda.sh && conda activate pylops && pip install -e . +dev-install_conda_intel_mkl: + conda env create -f environment-dev-intel-mkl.yml && source ${CONDA_PREFIX}/etc/profile.d/conda.sh && conda activate pylops && pip install -e . + dev-install_conda_arm: conda env create -f environment-dev-arm.yml && source ${CONDA_PREFIX}/etc/profile.d/conda.sh && conda activate pylops && pip install -e . diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 049f19bf..132eedff 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -62,6 +62,7 @@ jobs: - script: | python -m pip install --upgrade pip setuptools wheel django pip install -r requirements-dev.txt + pip install -r requirements-pyfftw.txt pip install -r requirements-torch.txt pip install . displayName: 'Install prerequisites and library' @@ -92,6 +93,7 @@ jobs: - script: | python -m pip install --upgrade pip setuptools wheel django pip install -r requirements-dev.txt + pip install -r requirements-pyfftw.txt pip install -r requirements-torch.txt pip install . displayName: 'Install prerequisites and library' diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 9f148a12..f4ff87fa 100755 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -209,7 +209,7 @@ library will ensure optimal performance of PyLops when using only *required depe We strongly encourage using the Anaconda Python distribution as NumPy and SciPy will, when available, be automatically linked to `Intel MKL `_, the most performant library for basic linear algebra -operations to date (see `Markus Beuckelmann's benchmarks `_). +operations to date. The PyPI version installed with ``pip``, however, will default to `OpenBLAS `_. For more information, see `NumPy's section on BLAS `_. @@ -224,14 +224,13 @@ run the following commands in a Python interpreter: print(sp.__config__.show()) -Intel also provides `NumPy `__ and `SciPy `__ replacement packages in PyPI ``intel-numpy`` and ``intel-scipy``, respectively, which link to Intel MKL. +Intel also provides `NumPy `__ and `SciPy `__ replacement packages in PyPI, namely ``intel-numpy`` and ``intel-scipy``, which link to Intel MKL. These are an option for an environment without ``conda`` that needs Intel MKL without requiring manual compilation. .. warning:: ``intel-numpy`` and ``intel-scipy`` not only link against Intel MKL, but also substitute NumPy and - SciPy FFTs for `Intel MKL FFT `_. **MKL FFT is not supported - and may break PyLops**. + SciPy FFTs with `Intel MKL FFT `_. Multithreading @@ -297,7 +296,7 @@ of PyLops in such a way that if an *optional* dependency is not present in your a safe fallback to one of the required dependencies will be enforced. When available in your system, we recommend using the Conda package manager and install all the -required and optional dependencies of PyLops at once using the command: +required and some of the optional dependencies of PyLops at once using the command: .. code-block:: bash @@ -305,17 +304,19 @@ required and optional dependencies of PyLops at once using the command: in this case all dependencies will be installed from their Conda distributions. -Alternatively, from version ``1.4.0`` optional dependencies can also be installed as -part of the pip installation via: +Alternatively, from version ``1.4.0`` some of the optional dependencies can also be +installed as part of the pip installation via: .. code-block:: bash >> pip install pylops[advanced] Dependencies are however installed from their PyPI wheels. -An exception is however represented by CuPy. This library is **not** installed + +Finally, note that CuPy and JAX are not **not** installed automatically. Users interested to accelerate their computations with the aid -of GPUs should install it prior to installing PyLops as described in :ref:`OptionalGPU`. +of GPUs should install either or both of them prior to installing PyLops as +described in :ref:`OptionalGPU`. .. note:: @@ -324,11 +325,13 @@ of GPUs should install it prior to installing PyLops as described in :ref:`Optio PyLops via ``make dev-install_conda`` (``conda``) or ``make dev-install`` (``pip``). -In alphabetic order: +More details about the installation process for the different optional dependencies are described +in the following (an asterisc is used to indicate those dependencies that are automatically installed +when installing PyLops from conda-forge or via ``pip install pylops[advanced]``): -dtcwt ------ +dtcwt* +------ `dtcwt `_ is a library used to implement the DT-CWT operators. @@ -356,13 +359,16 @@ Install it via ``pip`` with >> pip install devito -FFTW ----- -Three different "engines" are provided by the :py:class:`pylops.signalprocessing.FFT` operator: -``engine="numpy"`` (default), ``engine="scipy"`` and ``engine="fftw"``. +FFTW* and MKL-FFT +----------------- +Four different "engines" are provided by the :py:class:`pylops.signalprocessing.FFT` operator: +``engine="numpy"`` (default), ``engine="scipy"``, ``engine="fftw"`` and ``engine="mkl_fft"``. +Similarly, the :py:class:`pylops.signalprocessing.FFT2D` and +the :py:class:`pylops.signalprocessing.FFTND` operators come with three "engines", namely +``engine="numpy"`` (default), ``engine="scipy"``, and ``engine="mkl_fft"``. The first two engines are part of the required PyLops dependencies. -The latter implements the well-known `FFTW `_ +The third implements the well-known `FFTW `_ via the Python wrapper :py:class:`pyfftw.FFTW`. While this optimized FFT tends to outperform the other two in many cases, it is not included by default. To use this library, install it manually either via ``conda``: @@ -377,16 +383,52 @@ or via pip: >> pip install pyfftw +The fourth implements **Intel MKL FFT** via the Python interface `mkl_fft `_. +This provides access to Intel’s oneMKL Fourier Transform routines, enabling efficient FFT computations with performance +close to native C/Intel® oneMKL + +To use this library, you can install it using ``conda``: + +.. code-block:: bash + + >> conda install --channel https://software.repos.intel.com/python/conda --channel conda-forge mkl_fft + +or via pip: + +.. code-block:: bash + + >> pip install --index-url https://software.repos.intel.com/python/pypi --extra-index-url https://pypi.org/simple mkl_fft + +Installing ``mkl-fft`` triggers the installation of Intel-optimized versions of `NumPy `__ and +`SciPy `__, which redirects ``numpy.fft`` and ``scipy.fft`` to use MKL FFT routines. +As a result, all FFT operations and computational backends leverage Intel MKL for optimal performance. + +Although the library can run without Intel-optimized NumPy and SciPy, maximum performance is achieved when using NumPy and +SciPy built with Intel’s Math Kernel Library (MKL) alongside Intel Python. + .. note:: - FFTW is only available for :py:class:`pylops.signalprocessing.FFT`, - not :py:class:`pylops.signalprocessing.FFT2D` or :py:class:`pylops.signalprocessing.FFTND`. + `mkl_fft` is not supported on macOS .. warning:: - Intel MKL FFT is not supported. + ``pyFFTW`` may not work correctly with NumPy + MKL. To avoid issues, it is recommended to build ``pyFFTW`` from + source after setting the ``STATIC_FFTW_DIR`` environment variable to the absolute path of the static FFTW + libraries. -Numba ------ + If the following environment variables are set before installing ``pyFFTW``, compatibility problems with MKL + should not occur: + + 1. ``export STATIC_FFTW_DIR=${PREFIX}/lib`` + (where ``${PREFIX}`` is the base of the current Anaconda environment with + the ``fftw`` package installed) + + 2. ``export CFLAGS="$CFLAGS -Wl,-Bsymbolic"`` + + Alternatively, you can install ``pyFFTW`` directly with ``conda``, since the updated recipe is already available + and works without any manual adjustments. + +Numba* +------ Although we always strive to write code for forward and adjoint operators that takes advantage of the perks of NumPy and SciPy (e.g., broadcasting, ufunc), in some case we may end up using for loops that may lead to poor performance. In those cases we may decide to implement alternative (optional) @@ -445,7 +487,6 @@ It can also be checked dynamically with ``numba.config.NUMBA_DEFAULT_NUM_THREADS PyMC and PyTensor ----------------- - `PyTensor `_ is used to allow seamless integration between PyLops and `PyMC `_ operators. Install both of them via ``conda`` with: @@ -464,8 +505,8 @@ or via ``pip`` with OSX users may experience a ``CompileError`` error when using PyTensor. This can be solved by adding ``pytensor.config.gcc__cxxflags = "-Wno-c++11-narrowing"`` after ``import pytensor``. -PyWavelets ----------- +PyWavelets* +----------- `PyWavelets `_ is used to implement the wavelet operators. Install it via ``conda`` with: @@ -480,8 +521,8 @@ or via ``pip`` with >> pip install PyWavelets -scikit-fmm ----------- +scikit-fmm* +----------- `scikit-fmm `_ is a library which implements the fast marching method. It is used in PyLops to compute traveltime tables in the initialization of :py:class:`pylops.waveeqprocessing.Kirchhoff` @@ -499,8 +540,8 @@ or with ``pip`` via >> pip install scikit-fmm -SPGL1 ------ +SPGL1* +------ `SPGL1 `_ is used to solve sparsity-promoting basis pursuit, basis pursuit denoise, and Lasso problems in :py:func:`pylops.optimization.sparsity.SPGL1` solver. diff --git a/environment-dev-intel-mkl.yml b/environment-dev-intel-mkl.yml new file mode 100644 index 00000000..7ce56bd6 --- /dev/null +++ b/environment-dev-intel-mkl.yml @@ -0,0 +1,45 @@ +name: pylops +channels: + - https://software.repos.intel.com/python/conda + - conda-forge + - defaults + - numba +dependencies: + - python>=3.10.0 + - pip + - numpy>=2.0.0 + - scipy>=1.13.0 + - pyfftw + - pywavelets + - sympy + - pymc>=5 + - pytensor + - matplotlib + - ipython + - pytest + - Sphinx + - numpydoc + - numba + - icc_rt + - pre-commit + - autopep8 + - isort + - black + - mkl_fft + - pip: + - torch + - devito + # - dtcwt (until numpy>=2.0.0 is supported) + - scikit-fmm + - spgl1 + - jax + - pytest-runner + - setuptools_scm + - pydata-sphinx-theme + - pooch + - sphinx-gallery + - nbsphinx + - sphinxemoji + - image + - flake8 + - mypy diff --git a/examples/plot_fft.py b/examples/plot_fft.py index a8af9da6..a41e8ffe 100644 --- a/examples/plot_fft.py +++ b/examples/plot_fft.py @@ -6,6 +6,7 @@ and :py:class:`pylops.signalprocessing.FFTND` operators to apply the Fourier Transform to the model and the inverse Fourier Transform to the data. """ + import matplotlib.pyplot as plt import numpy as np @@ -15,7 +16,7 @@ ############################################################################### # Let's start by applying the one dimensional FFT to a one dimensional -# sinusoidal signal :math:`d(t)=sin(2 \pi f_0t)` using a time axis of +# sinusoidal signal :math:`d(t)=\sin(2 \pi f_0t)` using a time axis of # lenght :math:`nt` and sampling :math:`dt` dt = 0.005 nt = 100 @@ -67,7 +68,31 @@ plt.tight_layout() ############################################################################### -# We can also apply the one dimensional FFT to to a two-dimensional +# PyLops also has a third FFT engine (``engine='mkl_fft'``) that uses the well-known +# `Intel MKL FFT `_. This is a Python wrapper around +# the `Intel® oneAPI Math Kernel Library (oneMKL) `_ +# Fourier Transform functions. It lets PyLops run discrete Fourier transforms faster +# by using Intel’s highly optimized math routines. + +FFTop = pylops.signalprocessing.FFT(dims=nt, nfft=nfft, sampling=dt, engine="mkl_fft") +D = FFTop * d + +# Inverse for FFT +dinv = FFTop.H * D +dinv = FFTop / D + +fig, axs = plt.subplots(1, 2, figsize=(10, 4)) +axs[0].plot(t, d, "k", lw=2, label="True") +axs[0].plot(t, dinv.real, "--r", lw=2, label="Inverted") +axs[0].legend() +axs[0].set_title("Signal") +axs[1].plot(FFTop.f[: int(FFTop.nfft / 2)], np.abs(D[: int(FFTop.nfft / 2)]), "k", lw=2) +axs[1].set_title("Fourier Transform with MKL FFT") +axs[1].set_xlim([0, 3 * f0]) +plt.tight_layout() + +############################################################################### +# We can also apply the one dimensional FFT to a two-dimensional # signal (along one of the first axis) dt = 0.005 nt, nx = 100, 20 @@ -100,7 +125,7 @@ fig.tight_layout() ############################################################################### -# We can also apply the two dimensional FFT to to a two-dimensional signal +# We can also apply the two dimensional FFT to a two-dimensional signal dt, dx = 0.005, 5 nt, nx = 100, 201 t = np.arange(nt) * dt @@ -137,7 +162,7 @@ ############################################################################### -# Finally can apply the three dimensional FFT to to a three-dimensional signal +# Finally can apply the three dimensional FFT to a three-dimensional signal dt, dx, dy = 0.005, 5, 3 nt, nx, ny = 30, 21, 11 t = np.arange(nt) * dt @@ -176,3 +201,21 @@ axs[1][1].set_title("Error") axs[1][1].axis("tight") fig.tight_layout() + +############################################################################### +# To conclude, we provide a summary table of the different backends +# supported by :py:class:`pylops.signalprocessing.FFT`, +# :py:class:`pylops.signalprocessing.FFT2D` +# and :py:class:`pylops.signalprocessing.FFTND` operators and +# third-party dependencies are required to be able to use them. +# +# Supported Backends +# ~~~~~~~~~~~~~~~~~~ +# ========== ==================== ========== +# Backend Supported Dimensions Dependency +# ========== ==================== ========== +# Numpy/CuPy 1D, 2D, ND ``numpy`` (included) +# Scipy 1D, 2D, ND ``scipy`` (included) +# FFTW 1D ``pyfftw`` +# MKL 1D, 2D, ND ``mkl_fft``, or ``intel-numpy``/``intel-scipy`` via standard "numpy"/"scipy" engines +# ========== ==================== ========= diff --git a/pylops/signalprocessing/fft.py b/pylops/signalprocessing/fft.py index fe853771..b459519f 100644 --- a/pylops/signalprocessing/fft.py +++ b/pylops/signalprocessing/fft.py @@ -16,10 +16,15 @@ from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray pyfftw_message = deps.pyfftw_import("the fft module") +mkl_fft_message = deps.mkl_fft_import("the fft module") if pyfftw_message is None: import pyfftw +if mkl_fft_message is None: + import mkl_fft.interfaces.scipy_fft as mkl_backend + from mkl_fft.interfaces import _float_utils + logger = logging.getLogger(__name__) @@ -335,7 +340,7 @@ def _matvec(self, x: NDArray) -> NDArray: elif self.doifftpad: x = np.take(x, range(0, self.nfft), axis=self.axis) - # self.fftplan() always uses byte-alligned self.x as input array and + # self.fftplan() always uses byte-aligned self.x as input array and # returns self.y as output array. As such, self.x must be copied so as # not to be overwritten on a subsequent call to _matvec. np.copyto(self.x, x) @@ -357,7 +362,7 @@ def _rmatvec(self, x: NDArray) -> NDArray: if self.fftshift_after: x = np.fft.ifftshift(x, axes=self.axis) - # self.ifftplan() always uses byte-alligned self.y as input array. + # self.ifftplan() always uses byte-aligned self.y as input array. # We copy here so we don't need to copy again in the case of `real=True`, # which only performs operations that preserve byte-allignment. np.copyto(self.y, x) @@ -394,6 +399,98 @@ def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike: return self._rmatvec(y) / self._scale +class _FFT_mklfft(_BaseFFT): + """One-dimensional Fast-Fourier Transform using mkl_fft""" + + def __init__( + self, + dims: Union[int, InputDimsLike], + axis: int = -1, + nfft: Optional[int] = None, + sampling: float = 1.0, + norm: str = "ortho", + real: bool = False, + ifftshift_before: bool = False, + fftshift_after: bool = False, + dtype: DTypeLike = "complex128", + **kwargs_fft, + ) -> None: + super().__init__( + dims=dims, + axis=axis, + nfft=nfft, + sampling=sampling, + norm=norm, + real=real, + ifftshift_before=ifftshift_before, + fftshift_after=fftshift_after, + dtype=dtype, + ) + self._kwargs_fft = kwargs_fft + self._norm_kwargs = {"norm": None} + if self.norm is _FFTNorms.ORTHO: + self._norm_kwargs["norm"] = "ortho" + self._scale = np.sqrt(1 / self.nfft) + elif self.norm is _FFTNorms.NONE: + self._scale = self.nfft + elif self.norm is _FFTNorms.ONE_OVER_N: + self._scale = 1.0 / self.nfft + + @reshaped + def _matvec(self, x: NDArray) -> NDArray: + x = _float_utils._downcast_float128_array(x) + x = _float_utils._upcast_float16_array(x) + if self.ifftshift_before: + x = scipy.fft.ifftshift(x, axes=self.axis) + if not self.clinear: + x = np.real(x) + if self.real: + y = mkl_backend.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) + y = np.swapaxes(y, -1, self.axis) + y[..., 1 : 1 + (self.nfft - 1) // 2] *= np.sqrt(2) + y = np.swapaxes(y, self.axis, -1) + else: + y = mkl_backend.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) + if self.norm is _FFTNorms.ONE_OVER_N: + y *= self._scale + if self.fftshift_after: + y = scipy.fft.fftshift(y, axes=self.axis) + return y + + @reshaped + def _rmatvec(self, x: NDArray) -> NDArray: + x = _float_utils._downcast_float128_array(x) + x = _float_utils._upcast_float16_array(x) + if self.fftshift_after: + x = scipy.fft.ifftshift(x, axes=self.axis) + if self.real: + x = x.copy() + x = np.swapaxes(x, -1, self.axis) + x[..., 1 : 1 + (self.nfft - 1) // 2] /= np.sqrt(2) + x = np.swapaxes(x, self.axis, -1) + y = mkl_backend.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) + else: + y = mkl_backend.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) + if self.norm is _FFTNorms.NONE: + y *= self._scale + + if self.nfft > self.dims[self.axis]: + y = np.take(y, range(0, self.dims[self.axis]), axis=self.axis) + elif self.nfft < self.dims[self.axis]: + y = np.pad(y, self.ifftpad) + + if not self.clinear: + y = np.real(y) + if self.ifftshift_before: + y = scipy.fft.fftshift(y, axes=self.axis) + return y + + def __truediv__(self, y): + if self.norm is not _FFTNorms.ORTHO: + return self._rmatvec(y) / self._scale + return self._rmatvec(y) + + def FFT( dims: Union[int, InputDimsLike], axis: int = -1, @@ -481,7 +578,7 @@ def FFT( frequencies are arranged from zero to largest positive, and then from negative Nyquist to the frequency bin before zero. engine : :obj:`str`, optional - Engine used for fft computation (``numpy``, ``fftw``, or ``scipy``). Choose + Engine used for fft computation (``numpy``, ``fftw``, ``scipy`` or ``mkl_fft``). Choose ``numpy`` when working with cupy and jax arrays. .. note:: Since version 1.17.0, accepts "scipy". @@ -534,7 +631,7 @@ def FFT( - If ``dims`` is provided and ``axis`` is bigger than ``len(dims)``. - If ``norm`` is not one of "ortho", "none", or "1/n". NotImplementedError - If ``engine`` is neither ``numpy``, ``fftw``, nor ``scipy``. + If ``engine`` is neither ``numpy``, ``fftw``, ``scipy`` nor ``mkl_fft``. See Also -------- @@ -579,9 +676,28 @@ def FFT( dtype=dtype, **kwargs_fft, ) - elif engine == "numpy" or (engine == "fftw" and pyfftw_message is not None): + elif engine == "mkl_fft" and mkl_fft_message is None: + f = _FFT_mklfft( + dims, + axis=axis, + nfft=nfft, + sampling=sampling, + norm=norm, + real=real, + ifftshift_before=ifftshift_before, + fftshift_after=fftshift_after, + dtype=dtype, + **kwargs_fft, + ) + elif ( + engine == "numpy" + or (engine == "fftw" and pyfftw_message is not None) + or (engine == "mkl_fft" and mkl_fft_message is not None) + ): if engine == "fftw" and pyfftw_message is not None: logger.warning(pyfftw_message) + if engine == "mkl_fft" and mkl_fft_message is not None: + logger.warning(mkl_fft_message) f = _FFT_numpy( dims, axis=axis, @@ -608,6 +724,6 @@ def FFT( **kwargs_fft, ) else: - raise NotImplementedError("engine must be numpy, fftw or scipy") + raise NotImplementedError("engine must be numpy, scipy, fftw, or mkl_fft") f.name = name return f diff --git a/pylops/signalprocessing/fft2d.py b/pylops/signalprocessing/fft2d.py index 5df5df34..34d01d9a 100644 --- a/pylops/signalprocessing/fft2d.py +++ b/pylops/signalprocessing/fft2d.py @@ -8,10 +8,17 @@ from pylops import LinearOperator from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms +from pylops.utils import deps from pylops.utils.backend import get_array_module from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike +mkl_fft_message = deps.mkl_fft_import("the mkl fft module") + +if mkl_fft_message is None: + import mkl_fft.interfaces.scipy_fft as mkl_backend + from mkl_fft.interfaces import _float_utils + class _FFT2D_numpy(_BaseFFTND): """Two dimensional Fast-Fourier Transform using NumPy""" @@ -235,6 +242,118 @@ def __truediv__(self, y): return self._rmatvec(y) +class _FFT2D_mklfft(_BaseFFTND): + """Two-dimensional Fast-Fourier Transform using mkl_fft""" + + def __init__( + self, + dims: InputDimsLike, + axes: InputDimsLike = (-2, -1), + nffts: Optional[Union[int, InputDimsLike]] = None, + sampling: Union[float, Sequence[float]] = 1.0, + norm: str = "ortho", + real: bool = False, + ifftshift_before: bool = False, + fftshift_after: bool = False, + dtype: DTypeLike = "complex128", + **kwargs_fft, + ) -> None: + super().__init__( + dims=dims, + axes=axes, + nffts=nffts, + sampling=sampling, + norm=norm, + real=real, + ifftshift_before=ifftshift_before, + fftshift_after=fftshift_after, + dtype=dtype, + ) + + # checks + if self.ndim < 2: + raise ValueError("FFT2D requires at least two input dimensions") + if self.naxes != 2: + raise ValueError("FFT2D must be applied along exactly two dimensions") + + self.f1, self.f2 = self.fs + del self.fs + + self._kwargs_fft = kwargs_fft + self._norm_kwargs: Dict[str, Union[None, str]] = {"norm": None} + if self.norm is _FFTNorms.ORTHO: + self._norm_kwargs["norm"] = "ortho" + self._scale = np.sqrt(1 / np.prod(np.sqrt(self.nffts))) + elif self.norm is _FFTNorms.NONE: + self._scale = np.sqrt(np.prod(self.nffts)) + elif self.norm is _FFTNorms.ONE_OVER_N: + self._scale = np.sqrt(1.0 / np.prod(self.nffts)) + + @reshaped + def _matvec(self, x): + x = _float_utils._downcast_float128_array(x) + x = _float_utils._upcast_float16_array(x) + if self.ifftshift_before.any(): + x = scipy.fft.ifftshift(x, axes=self.axes[self.ifftshift_before]) + if not self.clinear: + x = np.real(x) + if self.real: + y = mkl_backend.rfft2( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) + y = np.swapaxes(y, -1, self.axes[-1]) + y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2) + y = np.swapaxes(y, self.axes[-1], -1) + else: + y = mkl_backend.fft2( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) + if self.norm is _FFTNorms.ONE_OVER_N: + y *= self._scale + y = y.astype(self.cdtype) + if self.fftshift_after.any(): + y = scipy.fft.fftshift(y, axes=self.axes[self.fftshift_after]) + return y + + @reshaped + def _rmatvec(self, x): + x = _float_utils._downcast_float128_array(x) + x = _float_utils._upcast_float16_array(x) + if self.fftshift_after.any(): + x = scipy.fft.ifftshift(x, axes=self.axes[self.fftshift_after]) + if self.real: + x = x.copy() + x = np.swapaxes(x, -1, self.axes[-1]) + x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2) + x = np.swapaxes(x, self.axes[-1], -1) + y = mkl_backend.irfft2( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) + else: + y = mkl_backend.ifft2( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) + if self.norm is _FFTNorms.NONE: + y *= self._scale + if self.nffts[0] > self.dims[self.axes[0]]: + y = np.take(y, np.arange(self.dims[self.axes[0]]), axis=self.axes[0]) + if self.nffts[1] > self.dims[self.axes[1]]: + y = np.take(y, np.arange(self.dims[self.axes[1]]), axis=self.axes[1]) + if self.doifftpad: + y = np.pad(y, self.ifftpad) + if not self.clinear: + y = np.real(y) + y = y.astype(self.rdtype) + if self.ifftshift_before.any(): + y = scipy.fft.fftshift(y, axes=self.axes[self.ifftshift_before]) + return y + + def __truediv__(self, y): + if self.norm is not _FFTNorms.ORTHO: + return self._rmatvec(y) / self._scale / self._scale + return self._rmatvec(y) + + def FFT2D( dims: InputDimsLike, axes: InputDimsLike = (-2, -1), @@ -331,7 +450,7 @@ def FFT2D( engine : :obj:`str`, optional .. versionadded:: 1.17.0 - Engine used for fft computation (``numpy`` or ``scipy``). Choose + Engine used for fft computation (``numpy`` or ``scipy`` or ``mkl_fft``). Choose ``numpy`` when working with cupy and jax arrays. dtype : :obj:`str`, optional Type of elements in input array. Note that the ``dtype`` of the operator @@ -387,7 +506,7 @@ def FFT2D( two elements. - If ``norm`` is not one of "ortho", "none", or "1/n". NotImplementedError - If ``engine`` is neither ``numpy``, nor ``scipy``. + If ``engine`` is neither ``numpy``, ``scipy`` nor ``mkl_fft``. See Also -------- @@ -420,7 +539,19 @@ def FFT2D( signals. """ - if engine == "numpy": + if engine == "mkl_fft" and mkl_fft_message is None: + f = _FFT2D_mklfft( + dims=dims, + axes=axes, + nffts=nffts, + sampling=sampling, + norm=norm, + real=real, + ifftshift_before=ifftshift_before, + fftshift_after=fftshift_after, + dtype=dtype, + ) + elif engine == "numpy" or (engine == "mkl_fft" and mkl_fft_message is not None): f = _FFT2D_numpy( dims=dims, axes=axes, @@ -447,6 +578,6 @@ def FFT2D( **kwargs_fft, ) else: - raise NotImplementedError("engine must be numpy or scipy") + raise NotImplementedError("engine must be numpy, scipy or mkl_fft") f.name = name return f diff --git a/pylops/signalprocessing/fftnd.py b/pylops/signalprocessing/fftnd.py index 5608b804..cb46a9f3 100644 --- a/pylops/signalprocessing/fftnd.py +++ b/pylops/signalprocessing/fftnd.py @@ -5,13 +5,21 @@ import numpy as np import numpy.typing as npt +import scipy.fft from pylops import LinearOperator from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms +from pylops.utils import deps from pylops.utils.backend import get_array_module, get_sp_fft from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray +mkl_fft_message = deps.mkl_fft_import("the mkl fft module") + +if mkl_fft_message is None: + import mkl_fft.interfaces.scipy_fft as mkl_backend + from mkl_fft.interfaces import _float_utils + class _FFTND_numpy(_BaseFFTND): """N-dimensional Fast-Fourier Transform using NumPy""" @@ -216,6 +224,106 @@ def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike: return self._rmatvec(y) +class _FFTND_mklfft(_BaseFFTND): + """N-dimensional Fast-Fourier Transform using MKL FFT""" + + def __init__( + self, + dims: Union[int, InputDimsLike], + axes: Union[int, InputDimsLike] = (-3, -2, -1), + nffts: Optional[Union[int, InputDimsLike]] = None, + sampling: Union[float, Sequence[float]] = 1.0, + norm: str = "ortho", + real: bool = False, + ifftshift_before: bool = False, + fftshift_after: bool = False, + dtype: DTypeLike = "complex128", + **kwargs_fft, + ) -> None: + super().__init__( + dims=dims, + axes=axes, + nffts=nffts, + sampling=sampling, + norm=norm, + real=real, + ifftshift_before=ifftshift_before, + fftshift_after=fftshift_after, + dtype=dtype, + ) + self._kwargs_fft = kwargs_fft + self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy + if self.norm is _FFTNorms.ORTHO: + self._norm_kwargs["norm"] = "ortho" + elif self.norm is _FFTNorms.NONE: + self._scale = np.prod(self.nffts) + elif self.norm is _FFTNorms.ONE_OVER_N: + self._scale = 1.0 / np.prod(self.nffts) + + @reshaped + def _matvec(self, x: NDArray) -> NDArray: + x = _float_utils._downcast_float128_array(x) + x = _float_utils._upcast_float16_array(x) + if self.ifftshift_before.any(): + x = scipy.fft.ifftshift(x, axes=self.axes[self.ifftshift_before]) + if not self.clinear: + x = np.real(x) + if self.real: + y = mkl_backend.rfftn( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) + # Apply scaling to obtain a correct adjoint for this operator + y = np.swapaxes(y, -1, self.axes[-1]) + y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2) + y = np.swapaxes(y, self.axes[-1], -1) + else: + y = mkl_backend.fftn( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) + if self.norm is _FFTNorms.ONE_OVER_N: + y *= self._scale + if self.fftshift_after.any(): + y = scipy.fft.fftshift(y, axes=self.axes[self.fftshift_after]) + return y + + @reshaped + def _rmatvec(self, x: NDArray) -> NDArray: + x = _float_utils._downcast_float128_array(x) + x = _float_utils._upcast_float16_array(x) + if self.fftshift_after.any(): + x = scipy.fft.ifftshift(x, axes=self.axes[self.fftshift_after]) + if self.real: + # Apply scaling to obtain a correct adjoint for this operator + x = x.copy() + x = np.swapaxes(x, -1, self.axes[-1]) + x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2) + x = np.swapaxes(x, self.axes[-1], -1) + y = mkl_backend.irfftn( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) + else: + y = mkl_backend.ifftn( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) + if self.norm is _FFTNorms.NONE: + y *= self._scale + for ax, nfft in zip(self.axes, self.nffts): + if nfft > self.dims[ax]: + y = np.take(y, range(self.dims[ax]), axis=ax) + if self.doifftpad: + y = np.pad(y, self.ifftpad) + if not self.clinear: + y = np.real(y) + if self.ifftshift_before.any(): + y = scipy.fft.fftshift(y, axes=self.axes[self.ifftshift_before]) + return y + + def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike: + if self.norm is not _FFTNorms.ORTHO: + return self._rmatvec(y) / self._scale + return self._rmatvec(y) + + def FFTND( dims: Union[int, InputDimsLike], axes: Union[int, InputDimsLike] = (-3, -2, -1), @@ -317,7 +425,7 @@ def FFTND( engine : :obj:`str`, optional .. versionadded:: 1.17.0 - Engine used for fft computation (``numpy`` or ``scipy``). Choose + Engine used for fft computation (``numpy`` or ``scipy`` or ``mkl_fft``). Choose ``numpy`` when working with cupy and jax arrays. dtype : :obj:`str`, optional Type of elements in input array. Note that the ``dtype`` of the operator @@ -375,7 +483,7 @@ def FFTND( the same dimension ``axes``. - If ``norm`` is not one of "ortho", "none", or "1/n". NotImplementedError - If ``engine`` is neither ``numpy``, nor ``scipy``. + If ``engine`` is neither ``numpy``, ``scipy`` nor ``mkl_fft``. Notes ----- @@ -410,7 +518,20 @@ def FFTND( for real input signals. """ - if engine == "numpy": + if engine == "mkl_fft" and mkl_fft_message is None: + f = _FFTND_mklfft( + dims=dims, + axes=axes, + nffts=nffts, + sampling=sampling, + norm=norm, + real=real, + ifftshift_before=ifftshift_before, + fftshift_after=fftshift_after, + dtype=dtype, + **kwargs_fft, + ) + elif engine == "numpy" or (engine == "mkl_fft" and mkl_fft_message is not None): f = _FFTND_numpy( dims=dims, axes=axes, @@ -437,6 +558,6 @@ def FFTND( **kwargs_fft, ) else: - raise NotImplementedError("engine must be numpy or scipy") + raise NotImplementedError("engine must be numpy, scipy or mkl_fft") f.name = name return f diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index df4a0d9e..67a55464 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -11,6 +11,7 @@ "sympy_enabled", "torch_enabled", "pytensor_enabled", + "mkl_fft_enabled", ] import os @@ -19,7 +20,7 @@ # error message at import of available package -def cupy_import(message: Optional[str] = None) -> str: +def cupy_import(message: Optional[str] = None) -> str | None: # detect if cupy is available and the user is expecting to be used cupy_test = ( util.find_spec("cupy") is not None and int(os.getenv("CUPY_PYLOPS", 1)) == 1 @@ -53,7 +54,7 @@ def cupy_import(message: Optional[str] = None) -> str: return cupy_message -def jax_import(message: Optional[str] = None) -> str: +def jax_import(message: Optional[str] = None) -> str | None: jax_test = ( util.find_spec("jax") is not None and int(os.getenv("JAX_PYLOPS", 1)) == 1 ) @@ -81,7 +82,7 @@ def jax_import(message: Optional[str] = None) -> str: return jax_message -def devito_import(message: Optional[str] = None) -> str: +def devito_import(message: Optional[str] = None) -> str | None: if devito_enabled: try: import_module("devito") # noqa: F401 @@ -98,7 +99,7 @@ def devito_import(message: Optional[str] = None) -> str: return devito_message -def dtcwt_import(message: Optional[str] = None) -> str: +def dtcwt_import(message: Optional[str] = None) -> str | None: if dtcwt_enabled: try: import dtcwt # noqa: F401 @@ -115,7 +116,7 @@ def dtcwt_import(message: Optional[str] = None) -> str: return dtcwt_message -def numba_import(message: Optional[str] = None) -> str: +def numba_import(message: Optional[str] = None) -> str | None: if numba_enabled: try: import_module("numba") # noqa: F401 @@ -134,7 +135,7 @@ def numba_import(message: Optional[str] = None) -> str: return numba_message -def pyfftw_import(message: Optional[str] = None) -> str: +def pyfftw_import(message: Optional[str] = None) -> str | None: if pyfftw_enabled: try: import_module("pyfftw") # noqa: F401 @@ -153,7 +154,7 @@ def pyfftw_import(message: Optional[str] = None) -> str: return pyfftw_message -def pywt_import(message: Optional[str] = None) -> str: +def pywt_import(message: Optional[str] = None) -> str | None: if pywt_enabled: try: import_module("pywt") # noqa: F401 @@ -172,7 +173,7 @@ def pywt_import(message: Optional[str] = None) -> str: return pywt_message -def skfmm_import(message: Optional[str] = None) -> str: +def skfmm_import(message: Optional[str] = None) -> str | None: if skfmm_enabled: try: import_module("skfmm") # noqa: F401 @@ -190,7 +191,7 @@ def skfmm_import(message: Optional[str] = None) -> str: return skfmm_message -def spgl1_import(message: Optional[str] = None) -> str: +def spgl1_import(message: Optional[str] = None) -> str | None: if spgl1_enabled: try: import_module("spgl1") # noqa: F401 @@ -207,7 +208,7 @@ def spgl1_import(message: Optional[str] = None) -> str: return spgl1_message -def sympy_import(message: Optional[str] = None) -> str: +def sympy_import(message: Optional[str] = None) -> str | None: if sympy_enabled: try: import_module("sympy") # noqa: F401 @@ -224,7 +225,7 @@ def sympy_import(message: Optional[str] = None) -> str: return sympy_message -def pytensor_import(message: Optional[str] = None) -> str: +def pytensor_import(message: Optional[str] = None) -> str | None: if pytensor_enabled: try: import_module("pytensor") # noqa: F401 @@ -241,6 +242,27 @@ def pytensor_import(message: Optional[str] = None) -> str: return pytensor_message +def mkl_fft_import(message: Optional[str]) -> str | None: + if mkl_fft_enabled: + try: + import_module("mkl_fft") # noqa: F401 + mkl_fft_message = None + except Exception as e: + mkl_fft_message = f"Failed to import mkl_fft (error:{e}), use numpy." + else: + mkl_fft_message = ( + "mkl_fft not available, reverting to numpy. " + "In order to be able to use " + f"{message} run " + '"pip install --index-url ' + "https://software.repos.intel.com/python/pypi " + '--extra-index-url https://pypi.org/simple mkl_fft" ' + 'or "conda install -c https://software.repos.intel.com/python/conda ' + '-c conda-forge mkl_fft".' + ) + return mkl_fft_message + + # Set package availability booleans # cupy and jax: the package is imported to check everything is working correctly, # if not the package is disabled. We do this here as these libraries are used as drop-in @@ -264,3 +286,4 @@ def pytensor_import(message: Optional[str] = None) -> str: sympy_enabled = util.find_spec("sympy") is not None torch_enabled = util.find_spec("torch") is not None pytensor_enabled = util.find_spec("pytensor") is not None +mkl_fft_enabled = util.find_spec("mkl_fft") is not None diff --git a/pytests/test_ffts.py b/pytests/test_ffts.py index 2c8d13c8..c98e85a4 100644 --- a/pytests/test_ffts.py +++ b/pytests/test_ffts.py @@ -16,7 +16,7 @@ from pylops.optimization.basic import lsqr from pylops.signalprocessing import FFT, FFT2D, FFTND -from pylops.utils import dottest +from pylops.utils import dottest, mkl_fft_enabled # Utility function @@ -227,11 +227,66 @@ def _choose_random_axes(ndim, n_choices=2): "dtype": np.complex128, "kwargs": {}, } # nfftnt, complex input, mkl-fft engine +par3t = { + "nt": 41, + "nx": 31, + "ny": 10, + "nfft": None, + "real": True, + "engine": "mkl_fft", + "ifftshift_before": False, + "dtype": np.float64, + "kwargs": {}, +} # nfft=nt, real input, mkl-fft engine +par4t = { + "nt": 41, + "nx": 31, + "ny": 10, + "nfft": 64, + "real": True, + "engine": "mkl_fft", + "ifftshift_before": False, + "dtype": np.float64, + "kwargs": {}, +} # nfft>nt, real input, mkl-fft engine +par5t = { + "nt": 41, + "nx": 31, + "ny": 10, + "nfft": 16, + "real": False, + "engine": "mkl_fft", + "ifftshift_before": False, + "dtype": np.complex128, + "kwargs": {}, +} # nfft=2.0.0 scipy>=1.13.0 jax numba -pyfftw PyWavelets spgl1 scikit-fmm diff --git a/requirements-doc.txt b/requirements-doc.txt index beccb183..1d564081 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -31,3 +31,4 @@ flake8 mypy pytensor>=2.28.0 pymc>=5.21.0 +mkl_fft; sys_platform != "darwin" diff --git a/requirements-intel-mkl.txt b/requirements-intel-mkl.txt new file mode 100644 index 00000000..d407245e --- /dev/null +++ b/requirements-intel-mkl.txt @@ -0,0 +1,4 @@ +--index-url https://software.repos.intel.com/python/pypi +numpy>=2.0.0 +scipy>=1.13.0 +mkl_fft; sys_platform != "darwin" diff --git a/requirements-pyfftw.txt b/requirements-pyfftw.txt new file mode 100644 index 00000000..0ee0a829 --- /dev/null +++ b/requirements-pyfftw.txt @@ -0,0 +1 @@ +pyfftw