Skip to content

learning the generative model of periodical process #149

@qtomcatq

Description

@qtomcatq

I've applied the WGAN algorithm implemented in torchsde/example/sde_gap.py to sine function (deterministic with fixed initial conditions). After 30000 learning epochs we can see that algorithm struggles to capture the periodic structure of the signal:

sine wave

The sine function was implemented as:

class PeriodicSDE(torch.nn.Module):
sde_type='ito'
noise_type='diagonal'

    def __init__(self):
        super().__init__()
    def f(self,t,y):
        x1, x2 = torch.split(y, split_size_or_sections=(1, 1), dim=1)
        f1 = -x2/3
        f2 = x1/3
        return torch.cat([f1, f2,], dim=1)
    def g(self,t,y):
        return 0*torch.ones_like(y)
        
   
ou_sde = PeriodicSDE().to(device)
y0= torch.ones([dataset_size,2],device=device)*2 - 1
norm= (torch.sqrt(torch.sum(y0**2,dim=1))).unsqueeze(1)
y0=y0/norm

In my opinion, the reason of low efficiency is caused by vanishing/exploding gradients in discriminator network due to weight clipping. The histograms of weights for input and output layers of "f" function of NCDE discriminator:

input_layer

out_layer

Most weights are stucked on the limits imposed by clipping, and effectively the learning process for discriminator network stops once this happens. Is it possible to fix through gradient penalty?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions