Skip to content

Commit 9ddd5e6

Browse files
aymuos15ericspod
andauthored
Fft cleanup/update (#8762)
Remove legacy PyTorch 1.8.0 compatibility code from FFT utilities. MONAI now requires PyTorch ≥ 2.4.1, so the version checks and NumPy fallbacks are no longer needed. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 583d5ca commit 9ddd5e6

File tree

2 files changed

+2
-22
lines changed

2 files changed

+2
-22
lines changed

monai/networks/blocks/fft_utils_t.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ def roll_1d(x: Tensor, shift: int, shift_dim: int) -> Tensor:
2828
2929
Returns:
3030
1d-shifted version of x
31-
32-
Note:
33-
This function is called when fftshift and ifftshift are not available in the running pytorch version
3431
"""
3532
shift = shift % x.size(shift_dim)
3633
if shift == 0:
@@ -55,9 +52,6 @@ def roll(x: Tensor, shift: list[int], shift_dims: list[int]) -> Tensor:
5552
5653
Returns:
5754
shifted version of x
58-
59-
Note:
60-
This function is called when fftshift and ifftshift are not available in the running pytorch version
6155
"""
6256
if len(shift) != len(shift_dims):
6357
raise ValueError(f"len(shift) != len(shift_dims), got f{len(shift)} and f{len(shift_dims)}.")
@@ -78,9 +72,6 @@ def fftshift(x: Tensor, shift_dims: list[int]) -> Tensor:
7872
7973
Returns:
8074
fft-shifted version of x
81-
82-
Note:
83-
This function is called when fftshift is not available in the running pytorch version
8475
"""
8576
shift = [0] * len(shift_dims)
8677
for i, dim_num in enumerate(shift_dims):
@@ -100,9 +91,6 @@ def ifftshift(x: Tensor, shift_dims: list[int]) -> Tensor:
10091
10192
Returns:
10293
ifft-shifted version of x
103-
104-
Note:
105-
This function is called when ifftshift is not available in the running pytorch version
10694
"""
10795
shift = [0] * len(shift_dims)
10896
for i, dim_num in enumerate(shift_dims):

monai/transforms/utils.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,11 +1878,7 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, as_contiguous: bool = F
18781878
dims = tuple(range(-spatial_dims, 0))
18791879
k: NdarrayOrTensor
18801880
if isinstance(x, torch.Tensor):
1881-
if hasattr(torch.fft, "fftshift"): # `fftshift` is new in torch 1.8.0
1882-
k = torch.fft.fftshift(torch.fft.fftn(x, dim=dims), dim=dims)
1883-
else:
1884-
# if using old PyTorch, will convert to numpy array and return
1885-
k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims)
1881+
k = torch.fft.fftshift(torch.fft.fftn(x, dim=dims), dim=dims)
18861882
else:
18871883
k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims)
18881884
return ascontiguousarray(k) if as_contiguous else k
@@ -1906,11 +1902,7 @@ def inv_shift_fourier(
19061902
dims = tuple(range(-spatial_dims, 0))
19071903
out: NdarrayOrTensor
19081904
if isinstance(k, torch.Tensor):
1909-
if hasattr(torch.fft, "ifftshift"): # `ifftshift` is new in torch 1.8.0
1910-
out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward").real
1911-
else:
1912-
# if using old PyTorch, will convert to numpy array and return
1913-
out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real
1905+
out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm="backward").real
19141906
else:
19151907
out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real
19161908
return ascontiguousarray(out) if as_contiguous else out

0 commit comments

Comments
 (0)