|
blurred = F.conv2d(x, self.blur_filter, stride=1, padding=(1, 1), |
Not to be nitpicky, but this could actually be replaced with two "1d" convolutions, one for width and one for height, which would use ~2K operations instead of ~K^2:
def separable_conv2d(inputs: Tensor, k_h: Tensor, k_w: Tensor) -> Tensor:
kernel_size = max(k_h.shape[-2:])
pad_amount = kernel_size // 2 #'same' padding.
# Gaussian filter is separable:
out_1 = F.conv2d(inputs, k_h, padding=(0, pad_amount))
out_2 = F.conv2d(out_1, k_w, padding=(pad_amount, 0))
return out_2