Skip to content

Commit cd672a0

Browse files
authored
Optim-wip: Add JIT support to all transforms & some image parameterizations (#821)
* Add JIT support to most transforms * Additional improvements * JIT support for `center_crop`. * Improve some transform tests. * Fix `RandomCrop` transform bug. * Fix Mypy bug * Interpolation based RandomScale & Other Improvements * Replace Affine `RandomScale` with Interpolation based variant. Renamed old variant to `RandomScaleAffine`. * `CenterCrop` & `center_crop` now use padding if the crop size is larger than the input dimensions. * Add distributions support to both versions of `RandomScale`. * Improve transform tests. * NumSeqOrTensorType -> NumSeqOrTensorOrProbDistType * Add `torch.distributions.distribution.Distribution` to `NumSeqOrTensorType` type hint. * Add TransformationRobustness transform& fix bug * Added `TransformationRobustness()` transform. * Fixed bug with `center_crop` padding code, and added related tests to `center_crop` & `CenterCrop`. * Fix center crop JIT tests * Add asserts & more tests for RandomScale transforms * Add JIT support for ToRGB, NaturalImage, & FFTImage * Add JIT support `NaturalImage`, `FFTImage`, & `PixelImage`. * Added proper JIT support for `ToRGB`. * Improved `NaturalImage` & `FFTImage` tests, and test coverage. * Add ImageParameterization Instance support for NaturalImage * Added `ImageParameterization` instance support for `NaturalImage`. This improvement should make it easier to use parameterization enhancements like SharedImage, and will be helpful for custom parameterizations that don't use the standard input variable set (size, channels, batch, & init). * Added asserts to verify `NaturalImage` parameterization inputs are instances or types of `ImageParameterization`. * Support ToRGB with no named dimensions This should make it easier to work with the ToRGB module as many PyTorch functions still don't work with named dimensions yet. * Allow more than 4 channels in ToRGB * The maximum of 4 channels isn't required as we ignore all channels after 3. * Add assert check to `RandomScale`'s mode variable The `linear` mode only supports 3D inputs, and `trilinear` only supports 5D inputs. RandomScale only uses 4D inputs, so only `nearest`, `bilinear`, `bicubic`, & `area` are supported. * Change assert to check for unsupported RandomScale mode options * Change `RandomRotation` type hint & add `RandomRotation` to `TransformationRobustness` * Change `RandomRotation` type hint from `NumSeqOrTensorType` to `NumSeqOrTensorOrProbDistType`. * Uncomment `RandomRotation` from `TransformationRobustness` & tests.
1 parent cbc6aac commit cd672a0

File tree

6 files changed

+2112
-136
lines changed

6 files changed

+2112
-136
lines changed

captum/optim/_param/image/images.py

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ class FFTImage(ImageParameterization):
138138
Parameterize an image using inverse real 2D FFT
139139
"""
140140

141+
__constants__ = ["size", "_supports_is_scripting"]
142+
141143
def __init__(
142144
self,
143145
size: Tuple[int, int] = None,
@@ -197,6 +199,9 @@ def __init__(
197199
self.register_buffer("spectrum_scale", spectrum_scale)
198200
self.fourier_coeffs = nn.Parameter(fourier_coeffs)
199201

202+
# Check & store whether or not we can use torch.jit.is_scripting()
203+
self._supports_is_scripting = torch.__version__ >= "1.6.0"
204+
200205
def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor:
201206
"""
202207
Computes 2D spectrum frequencies.
@@ -214,6 +219,12 @@ def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor:
214219
fx = self.torch_fftfreq(width)[: width // 2 + 1]
215220
return torch.sqrt((fx * fx) + (fy * fy))
216221

222+
@torch.jit.export
223+
def torch_irfftn(self, x: torch.Tensor) -> torch.Tensor:
224+
if x.dtype != torch.complex64:
225+
x = torch.view_as_complex(x)
226+
return torch.fft.irfftn(x, s=self.size) # type: ignore
227+
217228
def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]:
218229
"""
219230
Support older versions of PyTorch. This function ensures that the same FFT
@@ -226,26 +237,24 @@ def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]:
226237
"""
227238

228239
if TORCH_VERSION >= "1.7.0":
229-
import torch.fft
240+
if TORCH_VERSION < "1.8.0":
241+
global torch
242+
import torch.fft
230243

231244
def torch_rfft(x: torch.Tensor) -> torch.Tensor:
232245
return torch.view_as_real(torch.fft.rfftn(x, s=self.size))
233246

234-
def torch_irfft(x: torch.Tensor) -> torch.Tensor:
235-
if type(x) is not torch.complex64:
236-
x = torch.view_as_complex(x)
237-
return torch.fft.irfftn(x, s=self.size) # type: ignore
247+
torch_irfftn = self.torch_irfftn
238248

239249
def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor:
240250
return torch.fft.fftfreq(v, d)
241251

242252
else:
243-
import torch
244253

245254
def torch_rfft(x: torch.Tensor) -> torch.Tensor:
246255
return torch.rfft(x, signal_ndim=2)
247256

248-
def torch_irfft(x: torch.Tensor) -> torch.Tensor:
257+
def torch_irfftn(x: torch.Tensor) -> torch.Tensor:
249258
return torch.irfft(x, signal_ndim=2)[
250259
:, :, : self.size[0], : self.size[1]
251260
]
@@ -258,7 +267,7 @@ def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor:
258267
results[s:] = torch.arange(-(v // 2), 0)
259268
return results * (1.0 / (v * d))
260269

261-
return torch_rfft, torch_irfft, torch_fftfreq
270+
return torch_rfft, torch_irfftn, torch_fftfreq
262271

263272
def forward(self) -> torch.Tensor:
264273
"""
@@ -268,6 +277,9 @@ def forward(self) -> torch.Tensor:
268277

269278
scaled_spectrum = self.fourier_coeffs * self.spectrum_scale
270279
output = self.torch_irfft(scaled_spectrum)
280+
if self._supports_is_scripting:
281+
if torch.jit.is_scripting():
282+
return output
271283
return output.refine_names("B", "C", "H", "W")
272284

273285

@@ -276,6 +288,8 @@ class PixelImage(ImageParameterization):
276288
Parameterize a simple pixel image tensor that requires no additional transforms.
277289
"""
278290

291+
__constants__ = ["_supports_is_scripting"]
292+
279293
def __init__(
280294
self,
281295
size: Tuple[int, int] = None,
@@ -309,7 +323,13 @@ def __init__(
309323
f"input has {init.shape[1]} channels."
310324
self.image = nn.Parameter(init)
311325

326+
# Check & store whether or not we can use torch.jit.is_scripting()
327+
self._supports_is_scripting = torch.__version__ >= "1.6.0"
328+
312329
def forward(self) -> torch.Tensor:
330+
if self._supports_is_scripting:
331+
if torch.jit.is_scripting():
332+
return self.image
313333
return self.image.refine_names("B", "C", "H", "W")
314334

315335

@@ -600,7 +620,7 @@ def __init__(
600620
nn.Parameter tensor, or stacking init images.
601621
Default: 1
602622
parameterization (ImageParameterization, optional): An image
603-
parameterization class.
623+
parameterization class, or instance of an image parameterization class.
604624
Default: FFTImage
605625
squash_func (Callable[[torch.Tensor], torch.Tensor]], optional): The squash
606626
function to use after color recorrelation. A funtion or lambda function.
@@ -612,8 +632,14 @@ def __init__(
612632
Default: True
613633
"""
614634
super().__init__()
635+
if not isinstance(parameterization, ImageParameterization):
636+
# Verify uninitialized class is correct type
637+
assert issubclass(parameterization, ImageParameterization)
638+
else:
639+
assert isinstance(parameterization, ImageParameterization)
640+
615641
self.decorrelate = decorrelation_module
616-
if init is not None:
642+
if init is not None and not isinstance(parameterization, ImageParameterization):
617643
assert init.dim() == 3 or init.dim() == 4
618644
if decorrelate_init and self.decorrelate is not None:
619645
init = (
@@ -622,27 +648,42 @@ def __init__(
622648
else init.refine_names("C", "H", "W")
623649
)
624650
init = self.decorrelate(init, inverse=True).rename(None)
651+
625652
if squash_func is None:
653+
squash_func = self._clamp_image
626654

627-
def squash_func(x: torch.Tensor) -> torch.Tensor:
628-
return x.clamp(0, 1)
655+
self.squash_func = torch.sigmoid if squash_func is None else squash_func
656+
if not isinstance(parameterization, ImageParameterization):
657+
parameterization = parameterization(
658+
size=size, channels=channels, batch=batch, init=init
659+
)
660+
self.parameterization = parameterization
629661

630-
else:
631-
if squash_func is None:
662+
@torch.jit.export
663+
def _clamp_image(self, x: torch.Tensor) -> torch.Tensor:
664+
"""JIT supported squash function."""
665+
return x.clamp(0, 1)
632666

633-
squash_func = torch.sigmoid
667+
@torch.jit.ignore
668+
def _to_image_tensor(self, x: torch.Tensor) -> torch.Tensor:
669+
"""
670+
Wrap ImageTensor in torch.jit.ignore for JIT support.
634671
635-
self.squash_func = squash_func
636-
self.parameterization = parameterization(
637-
size=size, channels=channels, batch=batch, init=init
638-
)
672+
Args:
673+
674+
x (torch.tensor): An input tensor.
675+
676+
Returns:
677+
x (ImageTensor): An instance of ImageTensor with the input tensor.
678+
"""
679+
return ImageTensor(x)
639680

640681
def forward(self) -> torch.Tensor:
641682
image = self.parameterization()
642683
if self.decorrelate is not None:
643684
image = self.decorrelate(image)
644685
image = image.rename(None) # TODO: the world is not yet ready
645-
return ImageTensor(self.squash_func(image))
686+
return self._to_image_tensor(self.squash_func(image))
646687

647688

648689
__all__ = [

0 commit comments

Comments
 (0)