diff --git a/pytorch_ssim/__init__.py b/pytorch_ssim/__init__.py index 738e803..d9a4233 100644 --- a/pytorch_ssim/__init__.py +++ b/pytorch_ssim/__init__.py @@ -1,11 +1,10 @@ import torch import torch.nn.functional as F from torch.autograd import Variable -import numpy as np -from math import exp def gaussian(window_size, sigma): - gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + gauss = torch.exp(torch.tensor( + [-(x - window_size//2)**2/float(2*sigma**2) for x in range(window_size)])) return gauss/gauss.sum() def create_window(window_size, channel):