-
Notifications
You must be signed in to change notification settings - Fork 223
Description
I've applied the WGAN algorithm implemented in torchsde/example/sde_gap.py to sine function (deterministic with fixed initial conditions). After 30000 learning epochs we can see that algorithm struggles to capture the periodic structure of the signal:
The sine function was implemented as:
class PeriodicSDE(torch.nn.Module):
sde_type='ito'
noise_type='diagonal'
def __init__(self):
super().__init__()
def f(self,t,y):
x1, x2 = torch.split(y, split_size_or_sections=(1, 1), dim=1)
f1 = -x2/3
f2 = x1/3
return torch.cat([f1, f2,], dim=1)
def g(self,t,y):
return 0*torch.ones_like(y)
ou_sde = PeriodicSDE().to(device)
y0= torch.ones([dataset_size,2],device=device)*2 - 1
norm= (torch.sqrt(torch.sum(y0**2,dim=1))).unsqueeze(1)
y0=y0/norm
In my opinion, the reason of low efficiency is caused by vanishing/exploding gradients in discriminator network due to weight clipping. The histograms of weights for input and output layers of "f" function of NCDE discriminator:
Most weights are stucked on the limits imposed by clipping, and effectively the learning process for discriminator network stops once this happens. Is it possible to fix through gradient penalty?


