Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
975d9e4
Init branch
Chengqian-Zhang Mar 10, 2025
8785ae3
Merge branch 'devel' into merge_denoise
Chengqian-Zhang Mar 10, 2025
2e4d94e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2025
e42a860
Add token loss
Chengqian-Zhang Mar 12, 2025
9cf0849
Solve conflict
Chengqian-Zhang Mar 12, 2025
45841a6
Merge branch 'devel' into merge_denoise
Chengqian-Zhang Mar 12, 2025
81cb7c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2025
8d622bf
Fix pre-commit and Code Scanning
Chengqian-Zhang Mar 13, 2025
36ac73e
Fix conflict
Chengqian-Zhang Mar 13, 2025
34e647d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2025
b2e1419
Add denoise rot,trans,permutation,smooth UT
Chengqian-Zhang Mar 13, 2025
add4005
Merge branch 'merge_denoise' of github.com:Chengqian-Zhang/deepmd-kit…
Chengqian-Zhang Mar 13, 2025
b9e8528
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2025
1246658
Fix pre-commit
Chengqian-Zhang Mar 13, 2025
9712271
Merge branch 'devel' into merge_denoise
Chengqian-Zhang Mar 14, 2025
5e5abaa
Support dpmodel denoise fitting
Chengqian-Zhang Mar 14, 2025
ebef6c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2025
b20dd58
Add universial denoise fitting UT
Chengqian-Zhang Mar 17, 2025
f84d28b
Fix conflict
Chengqian-Zhang Mar 17, 2025
cdcb74e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 17, 2025
f7060ba
Merge branch 'devel' into merge_denoise
Chengqian-Zhang Mar 17, 2025
96bd72a
Add dtype and device to strain_components
Chengqian-Zhang Mar 17, 2025
13eba8c
Add denoise universial atommic_model UT
Chengqian-Zhang Mar 17, 2025
ae98c15
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 17, 2025
491eabd
Fix pre-commit
Chengqian-Zhang Mar 17, 2025
2a2b707
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 17, 2025
9cc9f29
Merge branch 'devel' into merge_denoise
Chengqian-Zhang Mar 19, 2025
993ed2f
Add universial denoise model UT
Chengqian-Zhang Mar 19, 2025
acbbdea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2025
954bb4e
Merge branch 'devel' into merge_denoise
Chengqian-Zhang Mar 26, 2025
5c58054
delete special precision
Chengqian-Zhang Mar 26, 2025
0897114
Fix conflict
Chengqian-Zhang Mar 26, 2025
ebb2c34
Add universial denoise loss UT
Chengqian-Zhang Mar 26, 2025
9e1e1f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 27, 2025
ecf937d
Fix torch cuda UT
Chengqian-Zhang Mar 28, 2025
dc97854
Fix conflict
Chengqian-Zhang Mar 28, 2025
d468917
Add DeepDenoise part, but not complete
Chengqian-Zhang Mar 28, 2025
0a2f57d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
333 changes: 251 additions & 82 deletions deepmd/pt/loss/denoise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np
import torch
import torch.nn.functional as F

Expand All @@ -8,102 +9,270 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
GLOBAL_PT_FLOAT_PRECISION,
)
from deepmd.pt.utils.region import (
phys2inter,
)
from deepmd.utils.data import (
DataRequirementItem,
)


def get_cell_perturb_matrix(cell_pert_fraction: float):
# TODO: user fix some component
if cell_pert_fraction < 0:
raise RuntimeError("cell_pert_fraction can not be negative")
e0 = torch.rand(6)
e = e0 * 2 * cell_pert_fraction - cell_pert_fraction
cell_pert_matrix = torch.tensor(
[
[1 + e[0], 0, 0],
[e[5], 1 + e[1], 0],
[e[4], e[3], 1 + e[2]],
],
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)
return cell_pert_matrix, e


class DenoiseLoss(TaskLoss):
def __init__(
self,
ntypes,
masked_token_loss=1.0,
masked_coord_loss=1.0,
norm_loss=0.01,
use_l1=True,
beta=1.00,
mask_loss_coord=True,
mask_loss_token=True,
mask_token: bool = False,
mask_coord: bool = True,
mask_cell: bool = False,
token_loss: float = 1.0,
coord_loss: float = 1.0,
cell_loss: float = 1.0,
noise_type: str = "gaussian",
coord_noise: float = 0.2,
cell_pert_fraction: float = 0.0,
noise_mode: str = "prob",
mask_num: int = 1,
mask_prob: float = 0.2,
loss_func: str = "rmse",
**kwargs,
) -> None:
"""Construct a layer to compute loss on coord, and type reconstruction."""
r"""Construct a layer to compute loss on token, coord and cell.
Parameters
----------
mask_token : bool
Whether to mask token.
mask_coord : bool
Whether to mask coordinate.
mask_cell : bool
Whether to mask cell.
token_loss : float
The preference factor for token denoise.
coord_loss : float
The preference factor for coordinate denoise.
cell_loss : float
The preference factor for cell denoise.
noise_type : str
The type of noise to add to the coordinate. It can be 'uniform' or 'gaussian'.
coord_noise : float
The magnitude of noise to add to the coordinate.
cell_pert_fraction : float
A value determines how much will cell deform.
noise_mode : str
"'prob' means the noise is added with a probability.'fix_num' means the noise is added with a fixed number."
mask_num : int
The number of atoms to mask coordinates. It is only used when noise_mode is 'fix_num'.
mask_prob : float
The probability of masking coordinates. It is only used when noise_mode is 'prob'.
loss_func : str
The loss function to minimize, it can be 'mae' or 'rmse'.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.ntypes = ntypes
self.masked_token_loss = masked_token_loss
self.masked_coord_loss = masked_coord_loss
self.norm_loss = norm_loss
self.has_coord = self.masked_coord_loss > 0.0
self.has_token = self.masked_token_loss > 0.0
self.has_norm = self.norm_loss > 0.0
self.use_l1 = use_l1
self.beta = beta
self.frac_beta = 1.00 / self.beta
self.mask_loss_coord = mask_loss_coord
self.mask_loss_token = mask_loss_token

def forward(self, model_pred, label, natoms, learning_rate, mae=False):
"""Return loss on coord and type denoise.
self.mask_token = mask_token
self.mask_coord = mask_coord
self.mask_cell = mask_cell
self.token_loss = token_loss
self.coord_loss = coord_loss
self.cell_loss = cell_loss
self.noise_type = noise_type
self.coord_noise = coord_noise
self.cell_pert_fraction = cell_pert_fraction
self.noise_mode = noise_mode
self.mask_num = mask_num
self.mask_prob = mask_prob
self.loss_func = loss_func

def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
"""Return loss on token,coord and cell.
Parameters
----------
input_dict : dict[str, torch.Tensor]
Model inputs.
model : torch.nn.Module
Model to be used to output the predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.
Returns
-------
- loss: Loss to minimize.
model_pred: dict[str, torch.Tensor]
Model predictions.
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
updated_coord = model_pred["updated_coord"]
logits = model_pred["logits"]
clean_coord = label["clean_coord"]
clean_type = label["clean_type"]
coord_mask = label["coord_mask"]
type_mask = label["type_mask"]
nloc = input_dict["atype"].shape[1]
nbz = input_dict["atype"].shape[0]
input_dict["box"] = input_dict["box"].cuda()

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if self.has_coord:
if self.mask_loss_coord:
masked_updated_coord = updated_coord[coord_mask]
masked_clean_coord = clean_coord[coord_mask]
if masked_updated_coord.size(0) > 0:
coord_loss = F.smooth_l1_loss(
masked_updated_coord.view(-1, 3),
masked_clean_coord.view(-1, 3),
reduction="mean",
beta=self.beta,
)
else:
coord_loss = torch.zeros(
1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)[0]
else:
coord_loss = F.smooth_l1_loss(
updated_coord.view(-1, 3),
clean_coord.view(-1, 3),
reduction="mean",
beta=self.beta,
# TODO: Change lattice to lower triangular matrix

label["clean_coord"] = input_dict["coord"].clone().detach()
label["clean_box"] = input_dict["box"].clone().detach()
origin_frac_coord = phys2inter(
label["clean_coord"], label["clean_box"].reshape(nbz, 3, 3)
)
label["clean_frac_coord"] = origin_frac_coord.clone().detach()
if self.mask_cell:
strain_components_all = torch.zeros(
(nbz, 3), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
for ii in range(nbz):
cell_perturb_matrix, strain_components = get_cell_perturb_matrix_HEA(
self.cell_noise
)
loss += self.masked_coord_loss * coord_loss
more_loss["coord_l1_error"] = coord_loss.detach()
if self.has_token:
if self.mask_loss_token:
masked_logits = logits[type_mask]
masked_target = clean_type[type_mask]
if masked_logits.size(0) > 0:
token_loss = F.nll_loss(
F.log_softmax(masked_logits, dim=-1),
masked_target,
reduction="mean",
# left-multiplied by `cell_perturb_matrix`` to get the noise box
input_dict["box"][ii] = torch.matmul(
cell_perturb_matrix, input_dict["box"][ii].reshape(3, 3)
).reshape(-1)
input_dict["coord"][ii] = torch.matmul(
origin_frac_coord[ii].reshape(nloc, 3),
input_dict["box"][ii].reshape(3, 3),
)
strain_components_all[ii] = strain_components.reshape(-1)
label["strain_components"] = strain_components_all.clone().detach()

if self.mask_coord:
# add noise to coordinates and update label['updated_coord']
mask_num = 0
if self.noise_mode == "fix_num":
mask_num = self.mask_num
if nloc < mask_num:
mask_num = nloc
elif self.noise_mode == "prob":
mask_num = int(self.mask_prob * nloc)
if mask_num == 0:
mask_num = 1
else:
NotImplementedError(f"Unknown noise mode {self.noise_mode}!")

coord_mask_all = torch.zeros(
input_dict["atype"].shape, dtype=torch.bool, device=env.DEVICE
)
for ii in range(nbz):
noise_on_coord = 0.0
coord_mask_res = np.random.choice(
range(nloc), mask_num, replace=False
).tolist()
coord_mask = np.isin(range(nloc), coord_mask_res) # nloc
if self.noise_type == "uniform":
noise_on_coord = np.random.uniform(
low=-self.noise, high=self.noise, size=(mask_num, 3)
)
elif self.noise_type == "gaussian":
noise_on_coord = np.random.normal(
loc=0.0, scale=self.noise, size=(mask_num, 3)
)
else:
token_loss = torch.zeros(
1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)[0]
else:
token_loss = F.nll_loss(
F.log_softmax(logits.view(-1, self.ntypes - 1), dim=-1),
clean_type.view(-1),
reduction="mean",
raise NotImplementedError(f"Unknown noise type {self.noise_type}!")

noise_on_coord = torch.tensor(
noise_on_coord,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
) # mask_num 3
input_dict["coord"][ii][coord_mask, :] += (
noise_on_coord # nbz mask_num 3 //
)
loss += self.masked_token_loss * token_loss
more_loss["token_error"] = token_loss.detach()
if self.has_norm:
norm_x = model_pred["norm_x"]
norm_delta_pair_rep = model_pred["norm_delta_pair_rep"]
loss += self.norm_loss * (norm_x + norm_delta_pair_rep)
more_loss["norm_loss"] = norm_x.detach() + norm_delta_pair_rep.detach()

return loss, more_loss
coord_mask_all[ii] = torch.tensor(
coord_mask, dtype=torch.bool, device=env.DEVICE
)
label["coord_mask"] = coord_mask_all
frac_coord = phys2inter(
input_dict["coord"], input_dict["box"].reshape(nbz, 3, 3)
)
# label["updated_coord"] = (label["clean_frac_coord"] - frac_coord).clone().detach()
label["updated_coord"] = (
(
(label["clean_frac_coord"] - frac_coord)
@ label["clean_box"].reshape(nbz, 3, 3)
)
.clone()
.detach()
)

if self.mask_token:
# TODO: mask_token
pass

if (not self.mask_coord) and (not self.mask_cell) and (not self.mask_token):
raise RuntimeError(
"At least one of mask_coord, mask_cell and mask_token should be True!"
)

model_pred = model(**input_dict)

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}

diff_coord = (label["updated_coord"] - model_pred["updated_coord"]).reshape(-1)
diff_cell = (
label["strain_components"] - model_pred["strain_components"]
).reshape(-1)
if self.loss_func == "rmse":
l2_coord_loss = torch.mean(torch.square(diff_coord))
l2_cell_loss = torch.mean(torch.square(diff_cell))
rmse_f = l2_coord_loss.sqrt()
rmse_v = l2_cell_loss.sqrt()
more_loss["rmse_coord"] = rmse_f.detach()
more_loss["rmse_cell"] = rmse_v.detach()
loss += self.coord_loss * l2_coord_loss.to(
GLOBAL_PT_FLOAT_PRECISION
) + self.cell_loss * l2_cell_loss.to(GLOBAL_PT_FLOAT_PRECISION)
elif self.loss_func == "mae":
l1_coord_loss = F.l1_loss(
label["updated_coord"], model_pred["updated_coord"], reduction="none"
)
l1_cell_loss = F.l1_loss(
label["strain_components"],
model_pred["strain_components"],
reduction="none",
)
more_loss["mae_coord"] = l1_coord_loss.mean().detach()
more_loss["mae_cell"] = l1_cell_loss.mean().detach()
l1_coord_loss = l1_coord_loss.sum(-1).mean(-1).sum()
l1_cell_loss = l1_cell_loss.sum()
loss += self.coord_loss * l1_coord_loss.to(
GLOBAL_PT_FLOAT_PRECISION
) + self.cell_loss * l1_cell_loss.to(GLOBAL_PT_FLOAT_PRECISION)
else:
raise RuntimeError(f"Unknown loss function {self.loss_func}!")
return model_pred, loss, more_loss

@property
def label_requirement(self) -> list[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
return []

def serialize(self) -> dict:
pass

@classmethod
def deserialize(cls, data: dict) -> "TaskLoss":
pass
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from .base_atomic_model import (
BaseAtomicModel,
)
from .denoise_atomic_model import (
DPDenoiseAtomicModel,
)
from .dipole_atomic_model import (
DPDipoleAtomicModel,
)
Expand Down Expand Up @@ -47,6 +50,7 @@
"BaseAtomicModel",
"DPAtomicModel",
"DPDOSAtomicModel",
"DPDenoiseAtomicModel",
"DPDipoleAtomicModel",
"DPEnergyAtomicModel",
"DPPolarAtomicModel",
Expand Down
Loading
Loading