Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions pinn/pinn_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ def __init__(self, loss_type, loss_func=nn.MSELoss(), bc_weight=1.0):
self.name = "PINN Loss"
elif self.type == 1:
self.name = "DRM Loss"
elif self.type== 2:
self.name = "DCGD(PINN+DRM) Loss"
else:
raise ValueError(f"Unknown loss type: {self.type}")
self.bc_weight = bc_weight
Expand All @@ -294,7 +296,7 @@ def super_loss(self, model, mesh, loss_func):
x = mesh.x_train
u = model.get_solution(x)
loss = loss_func(u, mesh.u_ex)
return loss
return loss, ()

# "PINN" loss
def pinn_loss(self, model, mesh, loss_func):
Expand All @@ -306,15 +308,16 @@ def pinn_loss(self, model, mesh, loss_func):

# Internal loss
pde = mesh.pde
loss = loss_func(d2u_dx2[1:-1] + mesh.f[1:-1], pde.r * u[1:-1])
loss = loss_pinn = loss_func(d2u_dx2[1:-1] + mesh.f[1:-1], pde.r * u[1:-1])
# Boundary loss
if not model.enforce_bc:
u_bc = u[[0, -1]]
u_ex_bc = mesh.u_ex[[0, -1]]
loss_b = loss_func(u_bc, u_ex_bc)
loss += self.bc_weight * loss_b

return loss
return loss, (loss_pinn, loss_b)
return loss, (loss_pinn,)

def drm_loss(self, model, mesh: Mesh):
"""Deep Ritz Method loss"""
Expand All @@ -332,7 +335,7 @@ def drm_loss(self, model, mesh: Mesh):
fu_prod = f_val * u

integrand_values = 0.5 * grad_u_pred_sq[1:-1] + 0.5 * mesh.pde.r * u_pred_sq[1:-1] - fu_prod[1:-1]
loss = torch.mean(integrand_values)
loss = loss_drm = torch.mean(integrand_values)

# Boundary loss
u_bc = u[[0,-1]]
Expand All @@ -342,7 +345,7 @@ def drm_loss(self, model, mesh: Mesh):


xs.requires_grad_(False) # Disable gradient tracking for x
return loss
return loss, (loss_drm, loss_b)

def loss(self, model, mesh):
if self.type == -1:
Expand All @@ -351,6 +354,10 @@ def loss(self, model, mesh):
loss_value = self.pinn_loss(model=model, mesh=mesh, loss_func=self.loss_func)
elif self.type == 1:
loss_value = self.drm_loss(model=model, mesh=mesh)
elif self.type == 2:
loss_p, pinn_losses = self.pinn_loss(model=model, mesh=mesh, loss_func=self.loss_func)
_, drm_losses = self.drm_loss(model=model, mesh=mesh)
loss_value = pinn_losses[0] + drm_losses[0], [pinn_losses[0], drm_losses[0]]
else:
raise ValueError(f"Unknown loss type: {self.type}")
return loss_value
Expand All @@ -359,8 +366,10 @@ def loss(self, model, mesh):
# %%
# Define the training loop
def train(model, mesh, criterion, iterations, adam_iterations, learning_rate,
num_check, num_plots, sweep_idx, level_idx, frame_dir):
num_check, num_plots, sweep_idx, level_idx, frame_dir, loss_type=0):
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
if loss_type ==2:
optimizer = DCGD(optimizer, 1, type="center")
# optimizer = SOAP(model.parameters(), lr = 3e-3, betas=(.95, .95), weight_decay=.01,
# precondition_frequency=10)
scheduler = StepLR(optimizer, step_size=1000, gamma=0.9)
Expand Down Expand Up @@ -390,12 +399,16 @@ def closure():
# we need to set to zero the gradients of all model parameters (PyTorch accumulates grad by default)
optimizer.zero_grad()
# compute the loss value for the current batch of data
loss = criterion.loss(model=model, mesh=mesh)
loss, losses = criterion.loss(model=model, mesh=mesh)
# backpropagation to compute gradients of model param respect to the loss. computes dloss/dx
# for every parameter x which has requires_grad=True.
loss.backward()
# update the model param doing an optim step using the computed gradients and learning rate
optimizer.step()
if loss_type != 2:
loss.backward()
optimizer.step()
else:
# Dual Cone GD optimizer
optimizer.step(losses)
#
scheduler.step()

Expand Down Expand Up @@ -444,12 +457,13 @@ def main(args=None):
# Input and output dimension: x -> u(x)
dim_inputs = 1
dim_outputs = 1
enforce_bc = args.enforce_bc if args.loss_type !=2 else True # Dual Cone GD enforces hard constarint on BC
model = MultiLevelNN(mesh=mesh,
num_levels=args.levels,
dim_inputs=dim_inputs, dim_outputs=dim_outputs,
dim_hidden=args.hidden_dims,
act=get_activation(args.activation),
enforce_bc=args.enforce_bc)
enforce_bc=enforce_bc)
print(model)
model.to(device)
# Plotting
Expand Down
4 changes: 2 additions & 2 deletions pinn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def parse_args(args=None):
help="Learning rate for the optimizer.")
parser.add_argument('--levels', type=int, default=4,
help="Number of levels in multilevel training.")
parser.add_argument('--loss_type', type=int, default=0, choices=[-1, 0],
help="Loss type: -1 for supervised (true solution), 0 for PINN loss.")
parser.add_argument('--loss_type', type=int, default=0, choices=[-1, 0,1,2],
help="Loss type: -1 for supervised (true solution), 0 for PINN loss. 1 for DRM loss, 2 for DCGD loss")
parser.add_argument('--activation', type=str, default='tanh',
choices=['tanh', 'silu', 'relu', 'gelu', 'softmax'],
help="Activation function to use.")
Expand Down