diff --git a/deepmd/dpmodel/atomic_model/__init__.py b/deepmd/dpmodel/atomic_model/__init__.py index 4d882d5e4b..ef9ea9d7e3 100644 --- a/deepmd/dpmodel/atomic_model/__init__.py +++ b/deepmd/dpmodel/atomic_model/__init__.py @@ -17,6 +17,9 @@ from .base_atomic_model import ( BaseAtomicModel, ) +from .denoise_atomic_model import ( + DPDenoiseAtomicModel, +) from .dipole_atomic_model import ( DPDipoleAtomicModel, ) @@ -50,6 +53,7 @@ "BaseAtomicModel", "DPAtomicModel", "DPDOSAtomicModel", + "DPDenoiseAtomicModel", "DPDipoleAtomicModel", "DPEnergyAtomicModel", "DPPolarAtomicModel", diff --git a/deepmd/dpmodel/atomic_model/denoise_atomic_model.py b/deepmd/dpmodel/atomic_model/denoise_atomic_model.py new file mode 100644 index 0000000000..71b5ad90a7 --- /dev/null +++ b/deepmd/dpmodel/atomic_model/denoise_atomic_model.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from deepmd.dpmodel.fitting.denoise_fitting import ( + DenoiseFitting, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPDenoiseAtomicModel(DPAtomicModel): + def __init__(self, descriptor, fitting, type_map, **kwargs): + if not isinstance(fitting, DenoiseFitting): + raise TypeError( + "fitting must be an instance of DenoiseFitting for DPDenoiseAtomicModel" + ) + super().__init__(descriptor, fitting, type_map, **kwargs) + + def apply_out_stat( + self, + ret: dict[str, np.ndarray], + atype: np.ndarray, + ): + """Apply the stat to each atomic output. + + In denoise fitting, each output will be multiplied by label std. + + Parameters + ---------- + ret + The returned dict by the forward_atomic method + atype + The atom types. nf x nloc. It is useless in denoise fitting. + + """ + # Scale values to appropriate magnitudes + noise_type = self.fitting_net.get_noise_type() + cell_std = self.fitting_net.get_cell_pert_fraction() / 1.732 + if noise_type == "gaussian": + coord_std = self.fitting_net.get_coord_noise() + elif noise_type == "uniform": + coord_std = self.fitting_net.get_coord_noise() / 1.732 + else: + raise RuntimeError(f"Unknown noise type {noise_type}") + ret["strain_components"] = ( + ret["strain_components"] * cell_std + if cell_std > 0 + else ret["strain_components"] + ) + ret["updated_coord"] = ( + ret["updated_coord"] * coord_std if coord_std > 0 else ret["updated_coord"] + ) + return ret diff --git a/deepmd/dpmodel/fitting/__init__.py b/deepmd/dpmodel/fitting/__init__.py index 5bdfff2571..21b5928275 100644 --- a/deepmd/dpmodel/fitting/__init__.py +++ b/deepmd/dpmodel/fitting/__init__.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .denoise_fitting import ( + DenoiseFitting, +) from .dipole_fitting import ( DipoleFitting, ) @@ -23,6 +26,8 @@ __all__ = [ "DOSFittingNet", + "DenoiseFitting", + "DenoiseFitting", "DipoleFitting", "EnergyFittingNet", "InvarFitting", diff --git a/deepmd/dpmodel/fitting/denoise_fitting.py b/deepmd/dpmodel/fitting/denoise_fitting.py new file mode 100644 index 0000000000..e88a5f8d3e --- /dev/null +++ b/deepmd/dpmodel/fitting/denoise_fitting.py @@ -0,0 +1,579 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, + Union, +) + +import array_api_compat +import numpy as np + +from deepmd.dpmodel import ( + DEFAULT_PRECISION, + PRECISION_DICT, + NativeOP, +) +from deepmd.dpmodel.common import ( + cast_precision, + get_xp_precision, + to_numpy_array, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.dpmodel.utils import ( + AtomExcludeMask, + FittingNet, + NetworkCollection, +) +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_atom_exclude_types, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_fitting import ( + BaseFitting, +) + + +@BaseFitting.register("denoise") +class DenoiseFitting(NativeOP, BaseFitting): + r"""Deoise fitting class. + + Parameters + ---------- + var_name + The name of the output variable. + ntypes + The number of atom types. + dim_descrpt + The dimension of the input descriptor. + neuron + Number of neurons :math:`N` in each hidden layer of the fitting net + bias_atom_e + Average energy per atom for each element. + resnet_dt + Time-step `dt` in the resnet construction: + :math:`y = x + dt * \phi (Wx + b)` + numb_fparam + Number of frame parameter + numb_aparam + Number of atomic parameter + trainable + If the weights of fitting net are trainable. + Suppose that we have :math:`N_l` hidden layers in the fitting net, + this list is of length :math:`N_l + 1`, specifying if the hidden layers and the output layer are trainable. + activation_function + The activation function :math:`\boldsymbol{\phi}` in the embedding net. Supported options are |ACTIVATION_FN| + precision + The precision of the embedding net parameters. Supported options are |PRECISION| + use_aparam_as_mask: bool, optional + If True, the atomic parameters will be used as a mask that determines the atom is real/virtual. + And the aparam will not be used as the atomic parameters for embedding. + mixed_types + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. + exclude_types: list[int] + Atomic contributions of the excluded atom types are set zero. + type_map: list[str], Optional + A list of strings. Give the name to each type of atoms. + seed: Optional[Union[int, list[int]]] + Random seed for initializing the network parameters. + """ + + def __init__( + self, + ntypes: int, + dim_descrpt: int, + embedding_width: int, + neuron: list[int] = [120, 120, 120], + bias_atom_e: Optional[np.ndarray] = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + dim_case_embd: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + mixed_types: bool = True, + seed: Optional[Union[int, list[int]]] = None, + exclude_types: list[int] = [], + trainable: Optional[list[bool]] = None, + type_map: Optional[list[str]] = None, + use_aparam_as_mask: bool = False, + coord_noise: Optional[float] = None, + cell_pert_fraction: Optional[float] = None, + noise_type: Optional[str] = None, + ) -> None: + self.ntypes = ntypes + self.dim_descrpt = dim_descrpt + self.neuron = neuron + self.embedding_width = embedding_width + self.resnet_dt = resnet_dt + self.numb_fparam = numb_fparam + self.numb_aparam = numb_aparam + self.dim_case_embd = dim_case_embd + self.trainable = trainable + self.type_map = type_map + self.seed = seed + self.var_name = ["strain_components", "updated_coord", "logits"] + self.coord_noise = coord_noise + self.cell_pert_fraction = cell_pert_fraction + self.noise_type = noise_type + if self.trainable is None: + self.trainable = [True for ii in range(len(self.neuron) + 1)] + if isinstance(self.trainable, bool): + self.trainable = [self.trainable] * (len(self.neuron) + 1) + self.activation_function = activation_function + self.precision = precision + if self.precision.lower() not in PRECISION_DICT: + raise ValueError( + f"Unsupported precision '{self.precision}'. Supported options are: {list(PRECISION_DICT.keys())}" + ) + self.prec = PRECISION_DICT[self.precision.lower()] + self.use_aparam_as_mask = use_aparam_as_mask + self.mixed_types = mixed_types + # order matters, should be place after the assignment of ntypes + self.reinit_exclude(exclude_types) + + # init constants + if bias_atom_e is None: + self.bias_atom_e = np.zeros( + [self.ntypes, self.embedding_width], dtype=GLOBAL_NP_FLOAT_PRECISION + ) + else: + assert bias_atom_e.shape == (self.ntypes, self.embedding_width) + self.bias_atom_e = bias_atom_e.astype(GLOBAL_NP_FLOAT_PRECISION) + if self.numb_fparam > 0: + self.fparam_avg = np.zeros(self.numb_fparam, dtype=self.prec) + self.fparam_inv_std = np.ones(self.numb_fparam, dtype=self.prec) + else: + self.fparam_avg, self.fparam_inv_std = None, None + if self.numb_aparam > 0: + self.aparam_avg = np.zeros(self.numb_aparam, dtype=self.prec) + self.aparam_inv_std = np.ones(self.numb_aparam, dtype=self.prec) + else: + self.aparam_avg, self.aparam_inv_std = None, None + if self.dim_case_embd > 0: + self.case_embd = np.zeros(self.dim_case_embd, dtype=self.prec) + else: + self.case_embd = None + # init networks + in_dim = ( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + + self.dim_case_embd + ) + self.coord_nets = NetworkCollection( + 1 if not self.mixed_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + self.embedding_width, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + seed=child_seed(seed, ii), + ) + for ii in range(self.ntypes if not self.mixed_types else 1) + ], + ) + self.cell_nets = NetworkCollection( + 1 if not self.mixed_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + 6, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + seed=child_seed(self.seed, ii), + ) + for ii in range(self.ntypes if not self.mixed_types else 1) + ], + ) + self.token_nets = NetworkCollection( + 1 if not self.mixed_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + self.ntypes - 1, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + seed=child_seed(self.seed, ii), + ) + for ii in range(self.ntypes if not self.mixed_types else 1) + ], + ) + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.numb_fparam + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.numb_aparam + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return [ii for ii in range(self.ntypes) if ii not in self.exclude_types] + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def get_coord_noise(self): + """Get the noise level of the coordinates.""" + return self.coord_noise + + def get_cell_pert_fraction(self): + """Get the fraction of the cell perturbation.""" + return self.cell_pert_fraction + + def get_noise_type(self): + """Get the noise type.""" + return self.noise_type + + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this fitting net by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + self.case_embd = np.eye(self.dim_case_embd, dtype=self.prec)[case_idx] + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert self.type_map is not None, ( + "'type_map' must be defined when performing type changing!" + ) + assert self.mixed_types, "Only models in mixed types can perform type changing!" + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + self.type_map = type_map + self.ntypes = len(type_map) + self.reinit_exclude(map_atom_exclude_types(self.exclude_types, remap_index)) + if has_new_type: + extend_shape = [len(type_map), *list(self.bias_atom_e.shape[1:])] + extend_bias_atom_e = np.zeros(extend_shape, dtype=self.bias_atom_e.dtype) + self.bias_atom_e = np.concatenate( + [self.bias_atom_e, extend_bias_atom_e], axis=0 + ) + self.bias_atom_e = self.bias_atom_e[remap_index] + + def __setitem__(self, key, value) -> None: + if key in ["bias_atom_e"]: + self.bias_atom_e = value + elif key in ["fparam_avg"]: + self.fparam_avg = value + elif key in ["fparam_inv_std"]: + self.fparam_inv_std = value + elif key in ["aparam_avg"]: + self.aparam_avg = value + elif key in ["aparam_inv_std"]: + self.aparam_inv_std = value + elif key in ["case_embd"]: + self.case_embd = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ["bias_atom_e"]: + return self.bias_atom_e + elif key in ["fparam_avg"]: + return self.fparam_avg + elif key in ["fparam_inv_std"]: + return self.fparam_inv_std + elif key in ["aparam_avg"]: + return self.aparam_avg + elif key in ["aparam_inv_std"]: + return self.aparam_inv_std + elif key in ["case_embd"]: + return self.case_embd + else: + raise KeyError(key) + + def reinit_exclude( + self, + exclude_types: list[int] = [], + ) -> None: + self.exclude_types = exclude_types + self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + return { + "@class": "Fitting", + "@version": 3, + "type": "denoise", + "ntypes": self.ntypes, + "embedding_width": self.embedding_width, + "dim_descrpt": self.dim_descrpt, + "neuron": self.neuron, + "resnet_dt": self.resnet_dt, + "numb_fparam": self.numb_fparam, + "numb_aparam": self.numb_aparam, + "dim_case_embd": self.dim_case_embd, + "activation_function": self.activation_function, + "precision": self.precision, + "mixed_types": self.mixed_types, + "cell_nets": self.cell_nets.serialize(), + "coord_nets": self.coord_nets.serialize(), + "token_nets": self.token_nets.serialize(), + "exclude_types": self.exclude_types, + "coord_noise": self.coord_noise, + "cell_pert_fraction": self.cell_pert_fraction, + "noise_type": self.noise_type, + "@variables": { + "bias_atom_e": to_numpy_array(self.bias_atom_e), + "case_embd": to_numpy_array(self.case_embd), + "fparam_avg": to_numpy_array(self.fparam_avg), + "fparam_inv_std": to_numpy_array(self.fparam_inv_std), + "aparam_avg": to_numpy_array(self.aparam_avg), + "aparam_inv_std": to_numpy_array(self.aparam_inv_std), + }, + "type_map": self.type_map, + } + + @classmethod + def deserialize(cls, data: dict) -> "DenoiseFitting": + data = data.copy() + data.pop("@class") + data.pop("type") + check_version_compatibility(data.pop("@version"), 3, 1) + variables = data.pop("@variables") + cell_nets = data.pop("cell_nets") + coord_nets = data.pop("coord_nets") + token_nets = data.pop("token_nets") + obj = cls(**data) + for kk in variables.keys(): + obj[kk] = variables[kk] + obj.cell_nets = NetworkCollection.deserialize(cell_nets) + obj.coord_nets = NetworkCollection.deserialize(coord_nets) + obj.token_nets = NetworkCollection.deserialize(token_nets) + return obj + + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "strain_components", + [6], + reducible=True, + r_differentiable=False, + c_differentiable=False, + intensive=True, + ), + OutputVariableDef( + "updated_coord", + [3], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), + OutputVariableDef( + "logits", + [self.ntypes - 1], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), + ] + ) + + @cast_precision + def call( + self, + descriptor: np.ndarray, + atype: np.ndarray, + gr: Optional[np.ndarray] = None, + g2: Optional[np.ndarray] = None, + h2: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + ) -> dict[str, np.ndarray]: + """Calculate the fitting. + + Parameters + ---------- + descriptor + input descriptor. shape: nf x nloc x nd + atype + the atom type. shape: nf x nloc + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + fparam + The frame parameter. shape: nf x nfp. nfp being `numb_fparam` + aparam + The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam` + + """ + xp = array_api_compat.array_namespace(descriptor, atype) + nf, nloc, nd = descriptor.shape + # check input dim + if nd != self.dim_descrpt: + raise ValueError( + f"get an input descriptor of dim {nd}," + f"which is not consistent with {self.dim_descrpt}." + ) + xx = descriptor + + # check fparam dim, concate to input descriptor + if self.numb_fparam > 0: + assert fparam is not None, "fparam should not be None" + if fparam.shape[-1] != self.numb_fparam: + raise ValueError( + f"get an input fparam of dim {fparam.shape[-1]}, " + f"which is not consistent with {self.numb_fparam}." + ) + fparam = (fparam - self.fparam_avg) * self.fparam_inv_std + fparam = xp.tile( + xp.reshape(fparam, [nf, 1, self.numb_fparam]), (1, nloc, 1) + ) + xx = xp.concat( + [xx, fparam], + axis=-1, + ) + # check aparam dim, concate to input descriptor + if self.numb_aparam > 0 and not self.use_aparam_as_mask: + assert aparam is not None, "aparam should not be None" + if aparam.shape[-1] != self.numb_aparam: + raise ValueError( + f"get an input aparam of dim {aparam.shape[-1]}, " + f"which is not consistent with {self.numb_aparam}." + ) + aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam]) + aparam = (aparam - self.aparam_avg) * self.aparam_inv_std + xx = xp.concat( + [xx, aparam], + axis=-1, + ) + + if self.dim_case_embd > 0: + assert self.case_embd is not None + case_embd = xp.tile(xp.reshape(self.case_embd, [1, 1, -1]), [nf, nloc, 1]) + xx = xp.concat( + [xx, case_embd], + axis=-1, + ) + + # calculate the prediction + if not self.mixed_types: + strain_components = xp.zeros( + [nf, nloc, 6], dtype=get_xp_precision(xp, self.precision) + ) + updated_coord = xp.zeros( + [nf, nloc, 3], dtype=get_xp_precision(xp, self.precision) + ) + logits = xp.zeros( + [nf, nloc, self.ntypes - 1], dtype=get_xp_precision(xp, self.precision) + ) + # coord fitting + for type_i in range(self.ntypes): + mask = xp.tile(xp.reshape((atype == type_i), [nf, nloc, 1]), (1, 1, 3)) + updated_coord_type = self.coord_nets[(type_i,)](xx) + assert list(updated_coord_type.shape) == [ + nf, + nloc, + self.embedding_width, + ] + updated_coord_type = xp.reshape( + updated_coord_type, (-1, 1, self.embedding_width) + ) # (nf * nloc, 1, embedding_width) + gr = xp.reshape( + gr, (nf * nloc, -1, 3) + ) # (nf * nloc, embedding_width, 3) + updated_coord_type = updated_coord_type @ gr # (nf, nloc, 3) + updated_coord_type = xp.reshape(updated_coord_type, (nf, nloc, 3)) + updated_coord_type = xp.where( + mask, updated_coord_type, xp.zeros_like(updated_coord_type) + ) + updated_coord = ( + updated_coord + updated_coord_type + ) # Shape is [nf, nloc, 3] + # cell fitting + for type_i in range(self.ntypes): + mask = xp.tile(xp.reshape((atype == type_i), [nf, nloc, 1]), (1, 1, 6)) + strain_components_type = self.cell_nets[(type_i,)](xx) + strain_components_type = xp.where( + mask, strain_components_type, xp.zeros_like(strain_components_type) + ) + strain_components = strain_components + strain_components_type + # token fitting + for type_i in range(self.ntypes): + mask = xp.tile( + xp.reshape((atype == type_i), [nf, nloc, 1]), + (1, 1, self.ntypes - 1), + ) + logits_type = self.token_nets[(type_i,)](xx) + logits_type = xp.where(mask, logits_type, xp.zeros_like(logits_type)) + logits = logits + logits_type + else: + # coord fitting + updated_coord = self.coord_nets[()](xx) + assert list(updated_coord.shape) == [nf, nloc, self.embedding_width] + updated_coord = xp.reshape( + updated_coord, (-1, 1, self.embedding_width) + ) # (nf * nloc, 1, embedding_width) + gr = xp.reshape(gr, (nf * nloc, -1, 3)) # (nf * nloc, embedding_width, 3) + updated_coord = updated_coord @ gr # (nf, nloc, 3) + updated_coord = xp.reshape(updated_coord, (nf, nloc, 3)) + # cell fitting + strain_components = self.cell_nets[()](xx) # [nf, nloc, 6] + # token fitting + logits = self.token_nets[()](xx) # [nf, natoms[0], ntypes-1] + # nf x nloc + exclude_mask = self.emask.build_type_exclude_mask(atype) + exclude_mask = xp.astype(exclude_mask, xp.bool) + # nf x nloc x od + strain_components = xp.where( + exclude_mask[:, :, None], + strain_components, + xp.zeros_like(strain_components), + ) + updated_coord = xp.where( + exclude_mask[:, :, None], updated_coord, xp.zeros_like(updated_coord) + ) + logits = xp.where(exclude_mask[:, :, None], logits, xp.zeros_like(logits)) + + return { + "strain_components": strain_components, + "updated_coord": updated_coord, + "logits": logits, + } diff --git a/deepmd/dpmodel/model/__init__.py b/deepmd/dpmodel/model/__init__.py index 37ef57b38b..31277fdaed 100644 --- a/deepmd/dpmodel/model/__init__.py +++ b/deepmd/dpmodel/model/__init__.py @@ -12,6 +12,9 @@ Models generated by `make_model` have already done it. """ +from .denoise_model import ( + DenoiseModel, +) from .dp_model import ( DPModelCommon, ) @@ -30,6 +33,7 @@ __all__ = [ "DPModelCommon", + "DenoiseModel", "EnergyModel", "PropertyModel", "SpinModel", diff --git a/deepmd/dpmodel/model/denoise_model.py b/deepmd/dpmodel/model/denoise_model.py new file mode 100644 index 0000000000..ddde836469 --- /dev/null +++ b/deepmd/dpmodel/model/denoise_model.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model import ( + DPDenoiseAtomicModel, +) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_model import ( + make_model, +) + +DPDenoiseModel_ = make_model(DPDenoiseAtomicModel) + + +@BaseModel.register("denoise") +class DenoiseModel(DPModelCommon, DPDenoiseModel_): + def __init__( + self, + *args, + **kwargs, + ) -> None: + DPModelCommon.__init__(self) + DPDenoiseModel_.__init__(self, *args, **kwargs) diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 1d18b70e8e..e49773dbca 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -19,6 +19,9 @@ from deepmd.dpmodel.model.base_model import ( BaseModel, ) +from deepmd.dpmodel.model.denoise_model import ( + DenoiseModel, +) from deepmd.dpmodel.model.dipole_model import ( DipoleModel, ) @@ -60,6 +63,14 @@ def _get_standard_model_components(data, ntypes): fitting_net["embedding_width"] = descriptor.get_dim_emb() fitting_net["dim_descrpt"] = descriptor.get_dim_out() grad_force = "direct" not in fitting_net["type"] + if fitting_net["type"] in ["denoise"]: + assert data["type_map"][-1] == "MASKED_TOKEN", ( + f"When using denoise fitting, the last element in `type_map` must be 'MASKED_TOKEN', but got '{data['type_map'][-1]}'" + ) + fitting_net["embedding_width"] = descriptor.get_dim_emb() + fitting_net["coord_noise"] = data.get("coord_noise", 0.2) + fitting_net["cell_pert_fraction"] = data.get("cell_pert_fraction", 0.0) + fitting_net["noise_type"] = data.get("noise_type", "gaussian") if not grad_force: fitting_net["out_dim"] = descriptor.get_dim_emb() if "ener" in fitting_net["type"]: @@ -96,6 +107,8 @@ def get_standard_model(data: dict) -> EnergyModel: modelcls = EnergyModel elif fitting_net_type == "property": modelcls = PropertyModel + elif fitting_net_type == "denoise": + modelcls = DenoiseModel else: raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 919d23f757..71e048fd34 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -15,6 +15,9 @@ from deepmd.common import ( expand_sys_str, ) +from deepmd.infer.deep_denoise import ( + DeepDenoise, +) from deepmd.infer.deep_dipole import ( DeepDipole, ) @@ -174,6 +177,8 @@ def test( err = test_polar( dp, data, numb_test, detail_file, atomic=False ) # YWolfeee: downward compatibility + elif isinstance(dp, DeepDenoise): + raise NotImplementedError("DeepDenoise is not supported in test mode.") log.info("# ----------------------------------------------- ") err_coll.append(err) diff --git a/deepmd/infer/deep_denoise.py b/deepmd/infer/deep_denoise.py new file mode 100644 index 0000000000..cc6498e070 --- /dev/null +++ b/deepmd/infer/deep_denoise.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, + Union, +) + +import numpy as np + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, +) + +from .deep_eval import ( + DeepEval, +) + + +class DeepDenoise(DeepEval): + """Given structures with noise, denoising them to get relaxed structures. + + Parameters + ---------- + model_file : Path + The name of the frozen model file. + *args : list + Positional arguments. + auto_batch_size : bool or int or AutoBatchSize, default: True + If True, automatic batch size will be used. If int, it will be used + as the initial batch size. + neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional + The ASE neighbor list class to produce the neighbor list. If None, the + neighbor list will be built natively in the model. + **kwargs : dict + Keyword arguments. + """ + + @property + def output_def(self) -> ModelOutputDef: + """ + Get the output definition of this model. + """ + return ModelOutputDef( + FittingOutputDef( + [ + OutputVariableDef( + "strain_components", + [6], + reducible=True, + r_differentiable=False, + c_differentiable=False, + intensive=True, + ), + OutputVariableDef( + "updated_coord", + [3], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), + OutputVariableDef( + "logits", + [-1], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), + ] + ) + ) + + def eval( + self, + coords: np.ndarray, + cells: Optional[np.ndarray], + atom_types: Union[list[int], np.ndarray], + atomic: bool = False, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + mixed_type: bool = False, + **kwargs: dict[str, Any], + ) -> tuple[np.ndarray, ...]: + """Evaluate properties. If atomic is True, also return atomic property. + + Parameters + ---------- + coords : np.ndarray + The coordinates of the atoms, in shape (nframes, natoms, 3). + cells : np.ndarray + The cell vectors of the system, in shape (nframes, 9). If the system + is not periodic, set it to None. + atom_types : list[int] or np.ndarray + The types of the atoms. If mixed_type is False, the shape is (natoms,); + otherwise, the shape is (nframes, natoms). + atomic : bool, optional + Whether to return atomic property, by default False. + fparam : np.ndarray, optional + The frame parameters, by default None. + aparam : np.ndarray, optional + The atomic parameters, by default None. + mixed_type : bool, optional + Whether the atom_types is mixed type, by default False. + **kwargs : dict[str, Any] + Keyword arguments. + + Returns + ------- + property + The properties of the system, in shape (nframes, num_tasks). + """ + ( + coords, + cells, + atom_types, + fparam, + aparam, + nframes, + natoms, + ) = self._standard_input(coords, cells, atom_types, fparam, aparam, mixed_type) + results = self.deep_eval.eval( + coords, + cells, + atom_types, + atomic, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + + # TODO: + return None + + +__all__ = ["DeepDenoise"] diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index d1cc4fb82f..5b28c8e3b3 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -17,6 +17,9 @@ OutputVariableCategory, OutputVariableDef, ) +from deepmd.infer.deep_denoise import ( + DeepDenoise, +) from deepmd.infer.deep_dipole import ( DeepDipole, ) @@ -211,6 +214,8 @@ def model_type(self) -> type["DeepEvalWrapper"]: return DeepGlobalPolar elif "wfc" in model_output_type: return DeepWFC + elif "updated_coord" in model_output_type: + return DeepDenoise elif self.get_var_name() in model_output_type: return DeepProperty else: diff --git a/deepmd/pt/loss/__init__.py b/deepmd/pt/loss/__init__.py index 1d25c1e52f..01942a414f 100644 --- a/deepmd/pt/loss/__init__.py +++ b/deepmd/pt/loss/__init__.py @@ -25,6 +25,7 @@ __all__ = [ "DOSLoss", "DenoiseLoss", + "DenoiseLoss", "EnergyHessianStdLoss", "EnergySpinLoss", "EnergyStdLoss", diff --git a/deepmd/pt/loss/denoise.py b/deepmd/pt/loss/denoise.py index 574210adb6..06d1202bdc 100644 --- a/deepmd/pt/loss/denoise.py +++ b/deepmd/pt/loss/denoise.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np import torch import torch.nn.functional as F @@ -8,102 +9,321 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + GLOBAL_PT_FLOAT_PRECISION, +) +from deepmd.pt.utils.region import ( + phys2inter, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + + +def get_cell_perturb_matrix(cell_pert_fraction: float): + # TODO: user fix some component + if cell_pert_fraction < 0: + raise RuntimeError("cell_pert_fraction can not be negative") + e0 = torch.rand(6, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE) + e = e0 * 2 * cell_pert_fraction - cell_pert_fraction + cell_pert_matrix = torch.tensor( + [ + [1 + e[0], 0, 0], + [e[5], 1 + e[1], 0], + [e[4], e[3], 1 + e[2]], + ], + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) + return cell_pert_matrix, e class DenoiseLoss(TaskLoss): def __init__( self, - ntypes, - masked_token_loss=1.0, - masked_coord_loss=1.0, - norm_loss=0.01, - use_l1=True, - beta=1.00, - mask_loss_coord=True, - mask_loss_token=True, + ntypes: int, + mask_token: bool = False, + mask_coord: bool = True, + mask_cell: bool = False, + token_loss: float = 1.0, + coord_loss: float = 1.0, + cell_loss: float = 1.0, + noise_type: str = "gaussian", + coord_noise: float = 0.2, + cell_pert_fraction: float = 0.0, + noise_mode: str = "prob", + mask_num: int = 1, + mask_prob: float = 0.2, + same_mask: bool = False, + loss_func: str = "rmse", **kwargs, ) -> None: - """Construct a layer to compute loss on coord, and type reconstruction.""" + r"""Construct a layer to compute loss on token, coord and cell. + + Parameters + ---------- + mask_token : bool + Whether to mask token. + mask_coord : bool + Whether to mask coordinate. + mask_cell : bool + Whether to mask cell. + token_loss : float + The preference factor for token denoise. + coord_loss : float + The preference factor for coordinate denoise. + cell_loss : float + The preference factor for cell denoise. + noise_type : str + The type of noise to add to the coordinate. It can be 'uniform' or 'gaussian'. + coord_noise : float + The magnitude of noise to add to the coordinate. + cell_pert_fraction : float + A value determines how much will cell deform. + noise_mode : str + "'prob' means the noise is added with a probability.'fix_num' means the noise is added with a fixed number." + mask_num : int + The number of atoms to mask coordinates. It is only used when noise_mode is 'fix_num'. + mask_prob : float + The probability of masking coordinates. It is only used when noise_mode is 'prob'. + loss_func : str + The loss function to minimize, it can be 'mae' or 'rmse'. + **kwargs + Other keyword arguments. + """ super().__init__() self.ntypes = ntypes - self.masked_token_loss = masked_token_loss - self.masked_coord_loss = masked_coord_loss - self.norm_loss = norm_loss - self.has_coord = self.masked_coord_loss > 0.0 - self.has_token = self.masked_token_loss > 0.0 - self.has_norm = self.norm_loss > 0.0 - self.use_l1 = use_l1 - self.beta = beta - self.frac_beta = 1.00 / self.beta - self.mask_loss_coord = mask_loss_coord - self.mask_loss_token = mask_loss_token - - def forward(self, model_pred, label, natoms, learning_rate, mae=False): - """Return loss on coord and type denoise. + self.mask_type_idx = self.ntypes - 1 + self.mask_token = mask_token + self.mask_coord = mask_coord + self.mask_cell = mask_cell + self.token_loss = token_loss + self.coord_loss = coord_loss + self.cell_loss = cell_loss + self.noise_type = noise_type + self.coord_noise = coord_noise + self.cell_pert_fraction = cell_pert_fraction + self.noise_mode = noise_mode + self.mask_num = mask_num + self.mask_prob = mask_prob + self.same_mask = same_mask + self.loss_func = loss_func + + def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): + """Return loss on token,coord and cell. + + Parameters + ---------- + input_dict : dict[str, torch.Tensor] + Model inputs. + model : torch.nn.Module + Model to be used to output the predictions. + label : dict[str, torch.Tensor] + Labels. + natoms : int + The local atom number. Returns ------- - - loss: Loss to minimize. + model_pred: dict[str, torch.Tensor] + Model predictions. + loss: torch.Tensor + Loss for model to minimize. + more_loss: dict[str, torch.Tensor] + Other losses for display. """ - updated_coord = model_pred["updated_coord"] - logits = model_pred["logits"] - clean_coord = label["clean_coord"] - clean_type = label["clean_type"] - coord_mask = label["coord_mask"] - type_mask = label["type_mask"] + rng = np.random.default_rng() + nloc = input_dict["atype"].shape[1] + nbz = input_dict["atype"].shape[0] + if torch.cuda.is_available(): + input_dict["box"] = input_dict["box"].cuda() - loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] - more_loss = {} - if self.has_coord: - if self.mask_loss_coord: - masked_updated_coord = updated_coord[coord_mask] - masked_clean_coord = clean_coord[coord_mask] - if masked_updated_coord.size(0) > 0: - coord_loss = F.smooth_l1_loss( - masked_updated_coord.view(-1, 3), - masked_clean_coord.view(-1, 3), - reduction="mean", - beta=self.beta, + # TODO: Change lattice to lower triangular matrix + + label["clean_coord"] = input_dict["coord"].clone().detach() + label["clean_box"] = input_dict["box"].clone().detach() + origin_frac_coord = phys2inter( + label["clean_coord"], label["clean_box"].reshape(nbz, 3, 3) + ) + label["clean_frac_coord"] = origin_frac_coord.clone().detach() + label["clean_type"] = input_dict["atype"].clone().detach().to(torch.int64) + if self.mask_cell: + strain_components_all = torch.zeros( + (nbz, 6), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + for ii in range(nbz): + cell_perturb_matrix, strain_components = get_cell_perturb_matrix( + self.cell_pert_fraction + ) + # left-multiplied by `cell_perturb_matrix`` to get the noise box + input_dict["box"][ii] = torch.matmul( + cell_perturb_matrix, input_dict["box"][ii].reshape(3, 3) + ).reshape(-1) + input_dict["coord"][ii] = torch.matmul( + origin_frac_coord[ii].reshape(nloc, 3), + input_dict["box"][ii].reshape(3, 3), + ) + strain_components_all[ii] = strain_components.reshape(-1) + label["strain_components"] = strain_components_all.clone().detach() + + if self.mask_coord: + # add noise to coordinates and update label['updated_coord'] + mask_num = 0 + if self.noise_mode == "fix_num": + mask_num = self.mask_num + if nloc < mask_num: + mask_num = nloc + elif self.noise_mode == "prob": + mask_num = int(self.mask_prob * nloc) + if mask_num == 0: + mask_num = 1 + else: + raise NotImplementedError(f"Unknown noise mode {self.noise_mode}!") + + coord_mask_all = torch.zeros( + input_dict["atype"].shape, dtype=torch.bool, device=env.DEVICE + ) + for ii in range(nbz): + coord_mask_res = rng.choice( + range(nloc), mask_num, replace=False + ).tolist() + coord_mask = np.isin(range(nloc), coord_mask_res) + if self.noise_type == "uniform": + noise_on_coord = rng.uniform( + low=-self.noise, high=self.coord_noise, size=(mask_num, 3) + ) + elif self.noise_type == "gaussian": + noise_on_coord = rng.normal( + loc=0.0, scale=self.coord_noise, size=(mask_num, 3) ) else: - coord_loss = torch.zeros( - 1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - )[0] - else: - coord_loss = F.smooth_l1_loss( - updated_coord.view(-1, 3), - clean_coord.view(-1, 3), - reduction="mean", - beta=self.beta, + raise NotImplementedError(f"Unknown noise type {self.noise_type}!") + + noise_on_coord = torch.tensor( + noise_on_coord, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) # mask_num 3 + input_dict["coord"][ii][coord_mask, :] += ( + noise_on_coord # nbz mask_num 3 // ) - loss += self.masked_coord_loss * coord_loss - more_loss["coord_l1_error"] = coord_loss.detach() - if self.has_token: - if self.mask_loss_token: - masked_logits = logits[type_mask] - masked_target = clean_type[type_mask] - if masked_logits.size(0) > 0: - token_loss = F.nll_loss( - F.log_softmax(masked_logits, dim=-1), - masked_target, - reduction="mean", - ) + coord_mask_all[ii] = torch.tensor( + coord_mask, dtype=torch.bool, device=env.DEVICE + ) + label["coord_mask"] = coord_mask_all + frac_coord = phys2inter( + input_dict["coord"], input_dict["box"].reshape(nbz, 3, 3) + ) + # label["updated_coord"] = (label["clean_frac_coord"] - frac_coord).clone().detach() + label["updated_coord"] = ( + ( + (label["clean_frac_coord"] - frac_coord) + @ label["clean_box"].reshape(nbz, 3, 3) + ) + .clone() + .detach() + ) + + if self.mask_token: + type_mask_all = torch.zeros( + input_dict["atype"].shape, dtype=torch.bool, device=env.DEVICE + ) + for ii in range(nbz): + if self.same_mask: + type_mask = coord_mask_all[ii].clone() else: - token_loss = torch.zeros( - 1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE - )[0] - else: - token_loss = F.nll_loss( - F.log_softmax(logits.view(-1, self.ntypes - 1), dim=-1), - clean_type.view(-1), - reduction="mean", + mask_count = min(self.mask_num, nloc) + type_mask_res = rng.choice( + range(nloc), mask_count, replace=False + ).tolist() + type_mask = np.isin(range(nloc), type_mask_res) + input_dict["atype"][ii][type_mask] = self.mask_type_idx + type_mask_all[ii] = torch.tensor( + type_mask, dtype=torch.bool, device=env.DEVICE ) - loss += self.masked_token_loss * token_loss - more_loss["token_error"] = token_loss.detach() - if self.has_norm: - norm_x = model_pred["norm_x"] - norm_delta_pair_rep = model_pred["norm_delta_pair_rep"] - loss += self.norm_loss * (norm_x + norm_delta_pair_rep) - more_loss["norm_loss"] = norm_x.detach() + norm_delta_pair_rep.detach() - - return loss, more_loss + label["type_mask"] = type_mask_all + + if (not self.mask_coord) and (not self.mask_cell) and (not self.mask_token): + raise RuntimeError( + "At least one of mask_coord, mask_cell and mask_token should be True!" + ) + + model_pred = model(**input_dict) + + loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] + more_loss = {} + + # cell and coord loss + diff_coord = (label["updated_coord"] - model_pred["updated_coord"]).reshape(-1) + diff_cell = ( + label["strain_components"] - model_pred["strain_components"] + ).reshape(-1) + if self.loss_func == "rmse": + l2_coord_loss = torch.mean(torch.square(diff_coord)) + l2_cell_loss = torch.mean(torch.square(diff_cell)) + rmse_f = l2_coord_loss.sqrt() + rmse_v = l2_cell_loss.sqrt() + more_loss["rmse_coord"] = rmse_f.detach() + more_loss["rmse_cell"] = rmse_v.detach() + loss += self.coord_loss * l2_coord_loss.to( + GLOBAL_PT_FLOAT_PRECISION + ) + self.cell_loss * l2_cell_loss.to(GLOBAL_PT_FLOAT_PRECISION) + elif self.loss_func == "mae": + l1_coord_loss = F.l1_loss( + label["updated_coord"], model_pred["updated_coord"], reduction="none" + ) + l1_cell_loss = F.l1_loss( + label["strain_components"], + model_pred["strain_components"], + reduction="none", + ) + more_loss["mae_coord"] = l1_coord_loss.mean().detach() + more_loss["mae_cell"] = l1_cell_loss.mean().detach() + l1_coord_loss = l1_coord_loss.sum(-1).mean(-1).sum() + l1_cell_loss = l1_cell_loss.sum() + loss += self.coord_loss * l1_coord_loss.to( + GLOBAL_PT_FLOAT_PRECISION + ) + self.cell_loss * l1_cell_loss.to(GLOBAL_PT_FLOAT_PRECISION) + else: + raise RuntimeError(f"Unknown loss function {self.loss_func}!") + # token loss + type_mask = label["type_mask"] + masked_logits = model_pred["logits"][type_mask] + masked_target = label["clean_type"][type_mask] + token_loss = F.nll_loss( + F.log_softmax(masked_logits, dim=-1), + masked_target, + reduction="mean", + ) + more_loss["token_loss"] = token_loss.detach() + loss += self.token_loss * token_loss.to(GLOBAL_PT_FLOAT_PRECISION) + + return model_pred, loss, more_loss + + @property + def label_requirement(self) -> list[DataRequirementItem]: + """Return data label requirements needed for this loss calculation.""" + label_requirement = [ + DataRequirementItem( + "strain_components", + ndof=6, + atomic=False, + must=False, + high_prec=False, + ), + DataRequirementItem( + "updated_coord", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "logits", + ndof=self.ntypes - 1, + atomic=True, + must=False, + high_prec=False, + ), + ] + return label_requirement diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index 4da9bf781b..8079678592 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -17,6 +17,9 @@ from .base_atomic_model import ( BaseAtomicModel, ) +from .denoise_atomic_model import ( + DPDenoiseAtomicModel, +) from .dipole_atomic_model import ( DPDipoleAtomicModel, ) @@ -47,6 +50,7 @@ "BaseAtomicModel", "DPAtomicModel", "DPDOSAtomicModel", + "DPDenoiseAtomicModel", "DPDipoleAtomicModel", "DPEnergyAtomicModel", "DPPolarAtomicModel", diff --git a/deepmd/pt/model/atomic_model/denoise_atomic_model.py b/deepmd/pt/model/atomic_model/denoise_atomic_model.py new file mode 100644 index 0000000000..3dd32809e9 --- /dev/null +++ b/deepmd/pt/model/atomic_model/denoise_atomic_model.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import logging + +import torch + +from deepmd.pt.model.task.denoise import ( + DenoiseFittingNet, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + +log = logging.getLogger(__name__) + + +class DPDenoiseAtomicModel(DPAtomicModel): + def __init__(self, descriptor, fitting, type_map, **kwargs): + if not isinstance(fitting, DenoiseFittingNet): + raise TypeError( + "fitting must be an instance of DenoiseFittingNet for DPDenoiseAtomicModel" + ) + super().__init__(descriptor, fitting, type_map, **kwargs) + + def apply_out_stat( + self, + ret: dict[str, torch.Tensor], + atype: torch.Tensor, + ): + """Apply the stat to each atomic output. + + In denoise fitting, each output will be multiplied by label std. + + Parameters + ---------- + ret + The returned dict by the forward_atomic method + atype + The atom types. nf x nloc. It is useless in denoise fitting. + + """ + # Scale values to appropriate magnitudes + noise_type = self.fitting_net.get_noise_type() + cell_std = self.fitting_net.get_cell_pert_fraction() / 1.732 + if noise_type == "gaussian": + coord_std = self.fitting_net.get_coord_noise() + elif noise_type == "uniform": + coord_std = self.fitting_net.get_coord_noise() / 1.732 + else: + raise RuntimeError(f"Unknown noise type {noise_type}") + ret["strain_components"] = ( + ret["strain_components"] * cell_std + if cell_std > 0 + else ret["strain_components"] + ) + ret["updated_coord"] = ( + ret["updated_coord"] * coord_std if coord_std > 0 else ret["updated_coord"] + ) + return ret diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 8d451f087f..53f76c3ece 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -33,6 +33,9 @@ Spin, ) +from .denoise_model import ( + DenoiseModel, +) from .dipole_model import ( DipoleModel, ) @@ -94,6 +97,14 @@ def _get_standard_model_components(model_params, ntypes): fitting_net["embedding_width"] = descriptor.get_dim_emb() fitting_net["dim_descrpt"] = descriptor.get_dim_out() grad_force = "direct" not in fitting_net["type"] + if fitting_net["type"] in ["denoise"]: + assert model_params["type_map"][-1] == "MASKED_TOKEN", ( + f"When using denoise fitting, the last element in `type_map` must be 'MASKED_TOKEN', but got '{model_params['type_map'][-1]}'" + ) + fitting_net["embedding_width"] = descriptor.get_dim_emb() + fitting_net["coord_noise"] = model_params.get("coord_noise", 0.2) + fitting_net["cell_pert_fraction"] = model_params.get("cell_pert_fraction", 0.0) + fitting_net["noise_type"] = model_params.get("noise_type", "gaussian") if not grad_force: fitting_net["out_dim"] = descriptor.get_dim_emb() if "ener" in fitting_net["type"]: @@ -266,6 +277,8 @@ def get_standard_model(model_params): modelcls = EnergyModel elif fitting_net_type == "property": modelcls = PropertyModel + elif fitting_net_type == "denoise": + modelcls = DenoiseModel else: raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") diff --git a/deepmd/pt/model/model/denoise_model.py b/deepmd/pt/model/model/denoise_model.py new file mode 100644 index 0000000000..e2bbf38241 --- /dev/null +++ b/deepmd/pt/model/model/denoise_model.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + +import torch + +from deepmd.pt.model.atomic_model import ( + DPDenoiseAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_model import ( + make_model, +) + +DPDenoiseModel_ = make_model(DPDenoiseAtomicModel) + + +@BaseModel.register("denoise") +class DenoiseModel(DPModelCommon, DPDenoiseModel_): + model_type = "property" + + def __init__( + self, + *args, + **kwargs, + ) -> None: + DPModelCommon.__init__(self) + DPDenoiseModel_.__init__(self, *args, **kwargs) + + def translated_output_def(self): + out_def_data = self.model_output_def().get_data() + output_def = { + "strain_components": out_def_data["strain_components_redu"], + "atom_strain_components": out_def_data["strain_components"], + "updated_coord": out_def_data["updated_coord"], + "logits": out_def_data["logits"], + } + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def + + def forward( + self, + coord, + atype, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["updated_coord"] = model_ret["updated_coord"] + model_predict["atom_strain_components"] = model_ret["strain_components"] + model_predict["strain_components"] = model_ret["strain_components_redu"] + model_predict["logits"] = model_ret["logits"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + @torch.jit.export + def forward_lower( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ): + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) + model_predict = {} + model_predict["updated_coord"] = model_ret["updated_coord"] + model_predict["atom_strain_components"] = model_ret["strain_components"] + model_predict["strain_components"] = model_ret["strain_components_redu"] + model_predict["logits"] = model_ret["logits"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index 37ffec2725..a142b69c65 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -3,7 +3,7 @@ BaseFitting, ) from .denoise import ( - DenoiseNet, + DenoiseFittingNet, ) from .dipole import ( DipoleFittingNet, @@ -31,7 +31,7 @@ __all__ = [ "BaseFitting", "DOSFittingNet", - "DenoiseNet", + "DenoiseFittingNet", "DipoleFittingNet", "EnergyFittingNet", "EnergyFittingNetDirect", diff --git a/deepmd/pt/model/task/denoise.py b/deepmd/pt/model/task/denoise.py index fc9e8943e9..72906494c4 100644 --- a/deepmd/pt/model/task/denoise.py +++ b/deepmd/pt/model/task/denoise.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Optional, + Union, ) +import numpy as np import torch from deepmd.dpmodel import ( @@ -10,9 +12,12 @@ OutputVariableDef, fitting_check_output, ) -from deepmd.pt.model.network.network import ( - MaskLMHead, - NonLinearHead, +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.network.mlp import ( + FittingNet, + NetworkCollection, ) from deepmd.pt.model.task.fitting import ( Fitting, @@ -20,60 +25,276 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) +from deepmd.pt.utils.exclude_mask import ( + AtomExcludeMask, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_atom_exclude_types, +) +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE + +@Fitting.register("denoise") @fitting_check_output -class DenoiseNet(Fitting): +class DenoiseFittingNet(Fitting): def __init__( self, - feature_dim, - ntypes, - attn_head=8, - prefactor=[0.5, 0.5], - activation_function="gelu", + ntypes: int, + dim_descrpt: int, + embedding_width: int, + neuron: list[int] = [128, 128, 128], + bias_atom_e: Optional[torch.Tensor] = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + dim_case_embd: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + mixed_types: bool = True, + seed: Optional[Union[int, list[int]]] = None, + exclude_types: list[int] = [], + trainable: Union[bool, list[bool]] = True, + type_map: Optional[list[str]] = None, + use_aparam_as_mask: bool = False, + coord_noise: Optional[float] = None, + cell_pert_fraction: Optional[float] = None, + noise_type: Optional[str] = None, **kwargs, ) -> None: - """Construct a denoise net. - - Args: - - ntypes: Element count. - - embedding_width: Embedding width per atom. - - neuron: Number of neurons in each hidden layers of the fitting net. - - bias_atom_e: Average energy per atom for each element. - - resnet_dt: Using time-step in the ResNet construction. + """Construct a direct token, coordinate and cell fitting net. + + Parameters + ---------- + ntypes : int + Element count. + dim_descrpt : int + Embedding width per atom. + neuron : list[int] + Number of neurons in each hidden layers of the fitting net. + bias_atom_e : torch.Tensor, optional + Average energy per atom for each element. + resnet_dt : bool + Using time-step in the ResNet construction. + embedding_width : int + The output dimension of the fitting net. + numb_fparam : int + Number of frame parameters. + numb_aparam : int + Number of atomic parameters. + dim_case_embd : int + Dimension of case specific embedding. + activation_function : str + Activation function. + precision : str + Numerical precision. + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. + seed : int, optional + Random seed. + exclude_types : list[int] + Atomic contributions of the excluded atom types are set zero. + trainable : Union[list[bool], bool] + If the parameters in the fitting net are trainable. + Now this only supports setting all the parameters in the fitting net at one state. + When in list[bool], the trainable will be True only if all the boolean parameters are True. + type_map : list[str], Optional + A list of strings. Give the name to each type of atoms. + use_aparam_as_mask : bool + If True, the aparam will not be used in fitting net for embedding. """ super().__init__() - self.feature_dim = feature_dim self.ntypes = ntypes - self.attn_head = attn_head - self.prefactor = torch.tensor( - prefactor, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + self.dim_descrpt = dim_descrpt + self.neuron = neuron + self.mixed_types = mixed_types + self.resnet_dt = resnet_dt + self.embedding_width = embedding_width + self.numb_fparam = numb_fparam + self.numb_aparam = numb_aparam + self.dim_case_embd = dim_case_embd + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.seed = seed + self.var_name = ["strain_components", "updated_coord", "logits"] + self.type_map = type_map + self.use_aparam_as_mask = use_aparam_as_mask + self.coord_noise = coord_noise + self.cell_pert_fraction = cell_pert_fraction + self.noise_type = noise_type + # order matters, should be place after the assignment of ntypes + self.reinit_exclude(exclude_types) + self.trainable = trainable + # need support for each layer settings + self.trainable = ( + all(self.trainable) if isinstance(self.trainable, list) else self.trainable ) - self.lm_head = MaskLMHead( - embed_dim=self.feature_dim, - output_dim=ntypes, - activation_fn=activation_function, - weight=None, + # init constants + if bias_atom_e is None: + bias_atom_e = np.zeros([self.ntypes, 1], dtype=np.float64) + bias_atom_e = torch.tensor( + bias_atom_e, device=env.DEVICE, dtype=env.GLOBAL_PT_FLOAT_PRECISION ) + bias_atom_e = bias_atom_e.view([self.ntypes, 1]) + if not self.mixed_types: + assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!" + self.register_buffer("bias_atom_e", bias_atom_e) - if not isinstance(self.attn_head, list): - self.pair2coord_proj = NonLinearHead( - self.attn_head, 1, activation_fn=activation_function + if self.numb_fparam > 0: + self.register_buffer( + "fparam_avg", + torch.zeros(self.numb_fparam, dtype=self.prec, device=device), + ) + self.register_buffer( + "fparam_inv_std", + torch.ones(self.numb_fparam, dtype=self.prec, device=device), ) else: - self.pair2coord_proj = [] - self.ndescriptor = len(self.attn_head) - for ii in range(self.ndescriptor): - _pair2coord_proj = NonLinearHead( - self.attn_head[ii], 1, activation_fn=activation_function + self.fparam_avg, self.fparam_inv_std = None, None + if self.numb_aparam > 0: + self.register_buffer( + "aparam_avg", + torch.zeros(self.numb_aparam, dtype=self.prec, device=device), + ) + self.register_buffer( + "aparam_inv_std", + torch.ones(self.numb_aparam, dtype=self.prec, device=device), + ) + else: + self.aparam_avg, self.aparam_inv_std = None, None + + if self.dim_case_embd > 0: + self.register_buffer( + "case_embd", + torch.zeros(self.dim_case_embd, dtype=self.prec, device=device), + # torch.eye(self.dim_case_embd, dtype=self.prec, device=device)[0], + ) + else: + self.case_embd = None + + in_dim = ( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + + self.dim_case_embd + ) + + self.filter_layers_coord = NetworkCollection( + 1 if not self.mixed_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + self.embedding_width, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + seed=child_seed(self.seed, ii), + ) + for ii in range(self.ntypes if not self.mixed_types else 1) + ], + ) + + self.filter_layers_cell = NetworkCollection( + 1 if not self.mixed_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + 6, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + seed=child_seed(self.seed, ii), ) - self.pair2coord_proj.append(_pair2coord_proj) - self.pair2coord_proj = torch.nn.ModuleList(self.pair2coord_proj) + for ii in range(self.ntypes if not self.mixed_types else 1) + ], + ) + + self.filter_layers_token = NetworkCollection( + 1 if not self.mixed_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + self.ntypes - 1, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + seed=child_seed(self.seed, ii), + ) + for ii in range(self.ntypes if not self.mixed_types else 1) + ], + ) + + # set trainable + for param in self.parameters(): + param.requires_grad = self.trainable + + def reinit_exclude( + self, + exclude_types: list[int] = [], + ) -> None: + self.exclude_types = exclude_types + self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert self.type_map is not None, ( + "'type_map' must be defined when performing type changing!" + ) + assert self.mixed_types, "Only models in mixed types can perform type changing!" + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + self.type_map = type_map + self.ntypes = len(type_map) + self.reinit_exclude(map_atom_exclude_types(self.exclude_types, remap_index)) + if has_new_type: + extend_shape = [len(type_map), *list(self.bias_atom_e.shape[1:])] + extend_bias_atom_e = torch.zeros( + extend_shape, + dtype=self.bias_atom_e.dtype, + device=self.bias_atom_e.device, + ) + self.bias_atom_e = torch.cat([self.bias_atom_e, extend_bias_atom_e], dim=0) + self.bias_atom_e = self.bias_atom_e[remap_index] def output_def(self): return FittingOutputDef( [ + OutputVariableDef( + "strain_components", + [6], + reducible=True, + r_differentiable=False, + c_differentiable=False, + intensive=True, + ), OutputVariableDef( "updated_coord", [3], @@ -83,7 +304,7 @@ def output_def(self): ), OutputVariableDef( "logits", - [-1], + [self.ntypes - 1], reducible=False, r_differentiable=False, c_differentiable=False, @@ -91,47 +312,326 @@ def output_def(self): ] ) + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + return { + "@class": "Fitting", + "@version": 3, + "type": "denoise", + "ntypes": self.ntypes, + "embedding_width": self.embedding_width, + "dim_descrpt": self.dim_descrpt, + "neuron": self.neuron, + "resnet_dt": self.resnet_dt, + "numb_fparam": self.numb_fparam, + "numb_aparam": self.numb_aparam, + "dim_case_embd": self.dim_case_embd, + "activation_function": self.activation_function, + "precision": self.precision, + "mixed_types": self.mixed_types, + "cell_nets": self.filter_layers_cell.serialize(), + "coord_nets": self.filter_layers_coord.serialize(), + "token_nets": self.filter_layers_token.serialize(), + "exclude_types": self.exclude_types, + "coord_noise": self.coord_noise, + "cell_pert_fraction": self.cell_pert_fraction, + "noise_type": self.noise_type, + "@variables": { + "bias_atom_e": to_numpy_array(self.bias_atom_e), + "case_embd": to_numpy_array(self.case_embd), + "fparam_avg": to_numpy_array(self.fparam_avg), + "fparam_inv_std": to_numpy_array(self.fparam_inv_std), + "aparam_avg": to_numpy_array(self.aparam_avg), + "aparam_inv_std": to_numpy_array(self.aparam_inv_std), + }, + "type_map": self.type_map, + } + + @classmethod + def deserialize(cls, data: dict) -> "DenoiseFittingNet": + data = data.copy() + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + cell_nets = data.pop("cell_nets") + coord_nets = data.pop("coord_nets") + token_nets = data.pop("token_nets") + obj = cls(**data) + for kk in variables.keys(): + obj[kk] = to_torch_tensor(variables[kk]) + obj.filter_layers_cell = NetworkCollection.deserialize(cell_nets) + obj.filter_layers_coord = NetworkCollection.deserialize(coord_nets) + obj.filter_layers_token = NetworkCollection.deserialize(token_nets) + return obj + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.numb_fparam + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.numb_aparam + + # make jit happy + exclude_types: list[int] + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + # make jit happy + sel_type: list[int] = [] + for ii in range(self.ntypes): + if ii not in self.exclude_types: + sel_type.append(ii) + return sel_type + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def get_coord_noise(self): + """Get the noise level of the coordinates.""" + return self.coord_noise + + def get_cell_pert_fraction(self): + """Get the fraction of the cell perturbation.""" + return self.cell_pert_fraction + + def get_noise_type(self): + """Get the noise type.""" + return self.noise_type + + def set_case_embd(self, case_idx: int): + """ + Set the case embedding of this fitting net by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + self.case_embd = torch.eye(self.dim_case_embd, dtype=self.prec, device=device)[ + case_idx + ] + + def __setitem__(self, key, value) -> None: + if key in ["bias_atom_e"]: + value = value.view([self.ntypes, 1]) + self.bias_atom_e = value + elif key in ["fparam_avg"]: + self.fparam_avg = value + elif key in ["fparam_inv_std"]: + self.fparam_inv_std = value + elif key in ["aparam_avg"]: + self.aparam_avg = value + elif key in ["aparam_inv_std"]: + self.aparam_inv_std = value + elif key in ["case_embd"]: + self.case_embd = value + elif key in ["scale"]: + self.scale = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ["bias_atom_e"]: + return self.bias_atom_e + elif key in ["fparam_avg"]: + return self.fparam_avg + elif key in ["fparam_inv_std"]: + return self.fparam_inv_std + elif key in ["aparam_avg"]: + return self.aparam_avg + elif key in ["aparam_inv_std"]: + return self.aparam_inv_std + elif key in ["case_embd"]: + return self.case_embd + elif key in ["scale"]: + return self.scale + else: + raise KeyError(key) + + def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor: + return torch.tile(xx.view([1, self.numb_fparam]), [nb, 1]) + + def _extend_a_avg_std(self, xx: torch.Tensor, nb: int, nloc: int) -> torch.Tensor: + return torch.tile(xx.view([1, 1, self.numb_aparam]), [nb, nloc, 1]) + def forward( self, - pair_weights, - diff, - nlist_mask, - features, - sw, - masked_tokens: Optional[torch.Tensor] = None, - ): - """Calculate the updated coord. - Args: - - coord: Input noisy coord with shape [nframes, nloc, 3]. - - pair_weights: Input pair weights with shape [nframes, nloc, nnei, head]. - - diff: Input pair relative coord list with shape [nframes, nloc, nnei, 3]. - - nlist_mask: Input nlist mask with shape [nframes, nloc, nnei]. - - Returns - ------- - - denoised_coord: Denoised updated coord with shape [nframes, nloc, 3]. + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ) -> dict[str, torch.Tensor]: + """Calculate the fitting. + + Parameters + ---------- + descriptor + input descriptor. shape: nf x nloc x nd + atype + the atom type. shape: nf x nloc + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + fparam + The frame parameter. shape: nf x nfp. nfp being `numb_fparam` + aparam + The atomic parameter. shape: nf x nloc x nap. nap being `numb_aparam` + """ - # [nframes, nloc, nnei, 1] - logits = self.lm_head(features, masked_tokens=masked_tokens) - if not isinstance(self.attn_head, list): - attn_probs = self.pair2coord_proj(pair_weights) - out_coord = (attn_probs * diff).sum(dim=-2) / ( - sw.sum(dim=-1).unsqueeze(-1) + 1e-6 + # cast the input to internal precsion + xx = descriptor.to(self.prec) + fparam = fparam.to(self.prec) if fparam is not None else None + aparam = aparam.to(self.prec) if aparam is not None else None + + nf, nloc, nd = xx.shape + + if nd != self.dim_descrpt: + raise ValueError( + f"get an input descriptor of dim {nd}," + f"which is not consistent with {self.dim_descrpt}." ) - else: - assert len(self.prefactor) == self.ndescriptor - all_coord_update = [] - assert len(pair_weights) == len(diff) == len(nlist_mask) == self.ndescriptor - for ii in range(self.ndescriptor): - _attn_probs = self.pair2coord_proj[ii](pair_weights[ii]) - _coord_update = (_attn_probs * diff[ii]).sum(dim=-2) / ( - nlist_mask[ii].sum(dim=-1).unsqueeze(-1) + 1e-6 + # check fparam dim, concate to input descriptor + if self.numb_fparam > 0: + assert fparam is not None, "fparam should not be None" + assert self.fparam_avg is not None + assert self.fparam_inv_std is not None + if fparam.shape[-1] != self.numb_fparam: + raise ValueError( + "get an input fparam of dim {fparam.shape[-1]}, ", + "which is not consistent with {self.numb_fparam}.", + ) + fparam = fparam.view([nf, self.numb_fparam]) + nb, _ = fparam.shape + t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) + t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb) + fparam = (fparam - t_fparam_avg) * t_fparam_inv_std + fparam = torch.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1]) + xx = torch.cat( + [xx, fparam], + dim=-1, + ) + # check aparam dim, concate to input descriptor + if self.numb_aparam > 0 and not self.use_aparam_as_mask: + assert aparam is not None, "aparam should not be None" + assert self.aparam_avg is not None + assert self.aparam_inv_std is not None + if aparam.shape[-1] != self.numb_aparam: + raise ValueError( + f"get an input aparam of dim {aparam.shape[-1]}, ", + f"which is not consistent with {self.numb_aparam}.", ) - all_coord_update.append(_coord_update) - out_coord = self.prefactor[0] * all_coord_update[0] - for ii in range(self.ndescriptor - 1): - out_coord += self.prefactor[ii + 1] * all_coord_update[ii + 1] + aparam = aparam.view([nf, -1, self.numb_aparam]) + nb, nloc, _ = aparam.shape + t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) + t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc) + aparam = (aparam - t_aparam_avg) * t_aparam_inv_std + xx = torch.cat( + [xx, aparam], + dim=-1, + ) + + if self.dim_case_embd > 0: + assert self.case_embd is not None + case_embd = torch.tile(self.case_embd.reshape([1, 1, -1]), [nf, nloc, 1]) + xx = torch.cat( + [xx, case_embd], + dim=-1, + ) + + if self.mixed_types: + # coord fitting + updated_coord = self.filter_layers_coord.networks[0](xx) + assert list(updated_coord.size()) == [nf, nloc, self.embedding_width] + updated_coord = updated_coord.view( + -1, 1, self.embedding_width + ) # (nf x nloc) x 1 x od + assert gr is not None + gr = gr.view(-1, self.embedding_width, 3) # (nf x nloc) x od x 3 + updated_coord = ( + torch.bmm(updated_coord, gr).squeeze(-2).view(nf, nloc, 3) + ) # [nf, nloc, 3] + # cell fitting + strain_components = self.filter_layers_cell.networks[0]( + xx + ) # [nframes, natoms[0], 6] + # token fitting + logits = self.filter_layers_token.networks[0]( + xx + ) # [nframes, natoms[0], ntypes-1] + else: + strain_components = torch.zeros( + (nf, nloc, 6), + dtype=self.prec, + device=descriptor.device, + ) + updated_coord = torch.zeros( + (nf, nloc, 3), + dtype=self.prec, + device=descriptor.device, + ) + logits = torch.zeros( + (nf, nloc, self.ntypes - 1), + dtype=self.prec, + device=descriptor.device, + ) + # coord fitting + for type_i, ll in enumerate(self.filter_layers_coord.networks): + mask = (atype == type_i).unsqueeze(-1) + mask = torch.tile(mask, (1, 1, 3)) + updated_coord_type = ll(xx) + assert list(updated_coord_type.size()) == [ + nf, + nloc, + self.embedding_width, + ] + updated_coord_type = updated_coord_type.view( + -1, 1, self.embedding_width + ) # (nf x nloc) x 1 x od + assert gr is not None + gr = gr.view(-1, self.embedding_width, 3) # (nf x nloc) x od x 3 + updated_coord_type = ( + torch.bmm(updated_coord_type, gr).squeeze(-2).view(nf, nloc, 3) + ) # [nf, nloc, 3] + updated_coord_type = torch.where(mask, updated_coord_type, 0.0) + updated_coord = ( + updated_coord + updated_coord_type + ) # [nframes, natoms[0], 3] + # cell fitting + for type_i, ll in enumerate(self.filter_layers_cell.networks): + mask = (atype == type_i).unsqueeze(-1) + mask = torch.tile(mask, (1, 1, 6)) + strain_components_type = ll(xx) + strain_components_type = torch.where(mask, strain_components_type, 0.0) + strain_components = ( + strain_components + strain_components_type + ) # [nframes, natoms[0], 6] + # token fitting + for type_i, ll in enumerate(self.filter_layers_token.networks): + mask = (atype == type_i).unsqueeze(-1) + mask = torch.tile(mask, (1, 1, self.ntypes - 1)) + logits_type = ll(xx) + logits_type = torch.where(mask, logits_type, 0.0) + logits = logits + logits_type + # nf x nloc + mask = self.emask(atype).to(torch.bool) + # nf x nloc x nod + strain_components = torch.where(mask[:, :, None], strain_components, 0.0) + updated_coord = torch.where(mask[:, :, None], updated_coord, 0.0) + logits = torch.where(mask[:, :, None], logits, 0.0) return { - "updated_coord": out_coord, + "strain_components": strain_components, + "updated_coord": updated_coord, "logits": logits, } diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 63e0180ace..23552ade14 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1292,6 +1292,10 @@ def get_model_for_wrapper( if "model_dict" not in _model_params: if _loss_params is not None and whether_hessian(_loss_params): _model_params["hessian_mode"] = True + if _loss_params is not None and _loss_params.get("type", "ener") == "denoise": + _model_params["coord_noise"] = _loss_params.get("coord_noise") + _model_params["cell_pert_fraction"] = _loss_params.get("cell_pert_fraction") + _model_params["noise_type"] = _loss_params.get("noise_type") _model = get_single_model( _model_params, ) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index cf6892b49d..c23f34929e 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -281,6 +281,10 @@ def compute_output_stats( intensive : bool, optional Whether the fitting target is intensive. """ + # in denoise mode, label is created in loss, so we don't need to compute the bias + if ("strain_components" in keys) or ("updated_coord" in keys) or ("logits" in keys): + keys = [] + # try to restore the bias from stat file bias_atom_e, std_atom_e = _restore_from_file(stat_file_path, keys) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index a1afcaf1d0..d469d6c78e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1864,6 +1864,47 @@ def fitting_property(): ] +@fitting_args_plugin.register("denoise", doc=doc_only_pt_supported) +def fitting_denoise(): + doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." + doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." + doc_dim_case_embd = "The dimension of the case embedding embedding. When training or fine-tuning a multitask model with case embedding embeddings, this number should be set to the number of model branches." + doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built" + doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' + doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' + doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_seed = "Random seed for parameter initialization of the fitting net" + return [ + Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam), + Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), + Argument( + "dim_case_embd", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_dim_case_embd, + ), + Argument( + "neuron", + list[int], + optional=True, + default=[120, 120, 120], + alias=["n_neuron"], + doc=doc_neuron, + ), + Argument( + "activation_function", + str, + optional=True, + default="tanh", + doc=doc_activation_function, + ), + Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt), + Argument("precision", str, optional=True, default="default", doc=doc_precision), + Argument("seed", [int, None], optional=True, doc=doc_seed), + ] + + @fitting_args_plugin.register("polar", doc=doc_polar) def fitting_polar(): doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." @@ -2768,6 +2809,126 @@ def loss_property(): ] +@loss_args_plugin.register("denoise") +def loss_denoise(): + doc_mask_token = "Whether to mask the token" + doc_mask_coord = "Whether to mask the coordinate." + doc_mask_cell = "Whether to mask the cell." + doc_token_loss = "The preference factor for token denoise." + doc_coord_loss = "The preference factor for coordinate denoise." + doc_cell_loss = "The preference factor for cell denoise." + doc_noise_type = ( + "The type of noise to add to the coordinate. It can be 'uniform' or 'gaussian'." + ) + doc_coord_noise = "The magnitude of noise to add to the coordinate." + doc_cell_pert_fraction = "A value determines how much will cell deform." + doc_noise_mode = "'prob' means the noise is added with a probability.'fix_num' means the noise is added with a fixed number." + doc_mask_num = "The number of atoms to mask coordinates. It is only used when noise_mode is 'fix_num'." + doc_mask_prob = "The probability of masking coordinates. It is only used when noise_mode is 'prob'." + doc_same_mask = "Whether mask same atoms when masking coordinates and token." + doc_loss_func = "The loss function to minimize, it can be 'mae' or 'rmse'." + return [ + Argument( + "mask_token", + bool, + optional=True, + default=False, + doc=doc_mask_token, + ), + Argument( + "mask_coord", + bool, + optional=True, + default=True, + doc=doc_mask_coord, + ), + Argument( + "mask_cell", + bool, + optional=True, + default=False, + doc=doc_mask_cell, + ), + Argument( + "token_loss", + float, + optional=True, + default=1.0, + doc=doc_token_loss, + ), + Argument( + "coord_loss", + float, + optional=True, + default=1.0, + doc=doc_coord_loss, + ), + Argument( + "cell_loss", + float, + optional=True, + default=1.0, + doc=doc_cell_loss, + ), + Argument( + "noise_type", + str, + optional=True, + default="gaussian", + doc=doc_noise_type, + ), + Argument( + "coord_noise", + float, + optional=True, + default=0.2, + doc=doc_coord_noise, + ), + Argument( + "cell_pert_fraction", + float, + optional=True, + default=0.0, + doc=doc_cell_pert_fraction, + ), + Argument( + "noise_mode", + str, + optional=True, + default="prob", + doc=doc_noise_mode, + ), + Argument( + "mask_num", + int, + optional=True, + default=1, + doc=doc_mask_num, + ), + Argument( + "mask_prob", + float, + optional=True, + default=0.2, + doc=doc_mask_prob, + ), + Argument( + "same_mask", + bool, + optional=True, + default=False, + doc=doc_same_mask, + ), + Argument( + "loss_func", + str, + optional=True, + default="rmse", + doc=doc_loss_func, + ), + ] + + # YWolfeee: Modified to support tensor type of loss args. @loss_args_plugin.register("tensor") def loss_tensor(): diff --git a/examples/denoise/data/data_0/set.000/box.npy b/examples/denoise/data/data_0/set.000/box.npy new file mode 100644 index 0000000000..e17441f1a8 Binary files /dev/null and b/examples/denoise/data/data_0/set.000/box.npy differ diff --git a/examples/denoise/data/data_0/set.000/coord.npy b/examples/denoise/data/data_0/set.000/coord.npy new file mode 100644 index 0000000000..67bfa35f97 Binary files /dev/null and b/examples/denoise/data/data_0/set.000/coord.npy differ diff --git a/examples/denoise/data/data_0/set.000/real_atom_types.npy b/examples/denoise/data/data_0/set.000/real_atom_types.npy new file mode 100644 index 0000000000..f2d13ddfa9 Binary files /dev/null and b/examples/denoise/data/data_0/set.000/real_atom_types.npy differ diff --git a/examples/denoise/data/data_0/type.raw b/examples/denoise/data/data_0/type.raw new file mode 100644 index 0000000000..e2338f70ca --- /dev/null +++ b/examples/denoise/data/data_0/type.raw @@ -0,0 +1,46 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 diff --git a/examples/denoise/data/data_0/type_map.raw b/examples/denoise/data/data_0/type_map.raw new file mode 100644 index 0000000000..c365ac55fd --- /dev/null +++ b/examples/denoise/data/data_0/type_map.raw @@ -0,0 +1,7 @@ +Ru +Pt +Ir +Pd +O +Ag +H diff --git a/examples/denoise/data/data_1/set.000/box.npy b/examples/denoise/data/data_1/set.000/box.npy new file mode 100644 index 0000000000..9c0da084cf Binary files /dev/null and b/examples/denoise/data/data_1/set.000/box.npy differ diff --git a/examples/denoise/data/data_1/set.000/coord.npy b/examples/denoise/data/data_1/set.000/coord.npy new file mode 100644 index 0000000000..aecbc8cb85 Binary files /dev/null and b/examples/denoise/data/data_1/set.000/coord.npy differ diff --git a/examples/denoise/data/data_1/set.000/real_atom_types.npy b/examples/denoise/data/data_1/set.000/real_atom_types.npy new file mode 100644 index 0000000000..19f40a37f8 Binary files /dev/null and b/examples/denoise/data/data_1/set.000/real_atom_types.npy differ diff --git a/examples/denoise/data/data_1/type.raw b/examples/denoise/data/data_1/type.raw new file mode 100644 index 0000000000..e2338f70ca --- /dev/null +++ b/examples/denoise/data/data_1/type.raw @@ -0,0 +1,46 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 diff --git a/examples/denoise/data/data_1/type_map.raw b/examples/denoise/data/data_1/type_map.raw new file mode 100644 index 0000000000..c365ac55fd --- /dev/null +++ b/examples/denoise/data/data_1/type_map.raw @@ -0,0 +1,7 @@ +Ru +Pt +Ir +Pd +O +Ag +H diff --git a/examples/denoise/data/data_2/set.000/box.npy b/examples/denoise/data/data_2/set.000/box.npy new file mode 100644 index 0000000000..f6ed7286b8 Binary files /dev/null and b/examples/denoise/data/data_2/set.000/box.npy differ diff --git a/examples/denoise/data/data_2/set.000/coord.npy b/examples/denoise/data/data_2/set.000/coord.npy new file mode 100644 index 0000000000..3683dd0cca Binary files /dev/null and b/examples/denoise/data/data_2/set.000/coord.npy differ diff --git a/examples/denoise/data/data_2/set.000/real_atom_types.npy b/examples/denoise/data/data_2/set.000/real_atom_types.npy new file mode 100644 index 0000000000..3003dc1cd3 Binary files /dev/null and b/examples/denoise/data/data_2/set.000/real_atom_types.npy differ diff --git a/examples/denoise/data/data_2/type.raw b/examples/denoise/data/data_2/type.raw new file mode 100644 index 0000000000..e2338f70ca --- /dev/null +++ b/examples/denoise/data/data_2/type.raw @@ -0,0 +1,46 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 diff --git a/examples/denoise/data/data_2/type_map.raw b/examples/denoise/data/data_2/type_map.raw new file mode 100644 index 0000000000..c365ac55fd --- /dev/null +++ b/examples/denoise/data/data_2/type_map.raw @@ -0,0 +1,7 @@ +Ru +Pt +Ir +Pd +O +Ag +H diff --git a/examples/denoise/train/input.json b/examples/denoise/train/input.json new file mode 100644 index 0000000000..1a87a19987 --- /dev/null +++ b/examples/denoise/train/input.json @@ -0,0 +1,99 @@ +{ + "_comment": "that's all", + "model": { + "type_map": [ + "H", + "O", + "Ru", + "Pt", + "Ir", + "Pd", + "Ag", + "MASKED_TOKEN" + ], + "descriptor": { + "type": "dpa1", + "sel": 120, + "rcut_smth": 0.5, + "rcut": 6.0, + "neuron": [ + 25, + 50, + 100 + ], + "tebd_dim": 8, + "axis_neuron": 16, + "type_one_side": true, + "attn": 128, + "attn_layer": 0, + "attn_dotr": true, + "attn_mask": false, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": true, + "temperature": 1.0 + }, + "fitting_net": { + "type": "denoise", + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "precision": "float64", + "seed": 1, + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss": { + "type": "denoise", + "mask_token": true, + "mask_coord": true, + "mask_cell": true, + "coord_loss": 1.0, + "cell_loss": 1.0, + "token_loss": 1.0, + "noise_type": "gaussian", + "coord_noise": 0.06, + "noise_mode": "prob", + "mask_prob": 0.2, + "cell_pert_fraction": 0.008, + "same_mask": true, + "loss_func": "rmse", + "_comment": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1" + ], + "batch_size": 2, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 200, + "save_freq": 2000, + "_comment": "that's all" + } + } diff --git a/source/tests/pt/common.py b/source/tests/pt/common.py index 8709c8b4f9..415954662f 100644 --- a/source/tests/pt/common.py +++ b/source/tests/pt/common.py @@ -66,6 +66,8 @@ def eval_model( force_mag_out = [] virial_out = [] atomic_virial_out = [] + strain_components_out = [] + atom_strain_components_out = [] updated_coord_out = [] logits_out = [] err_msg = ( @@ -162,6 +164,14 @@ def eval_model( atomic_virial_out.append( batch_output["atom_virial"].detach().cpu().numpy() ) + if "strain_components" in batch_output: + strain_components_out.append( + batch_output["strain_components"].detach().cpu().numpy() + ) + if "atom_strain_components" in batch_output: + atom_strain_components_out.append( + batch_output["atom_strain_components"].detach().cpu().numpy() + ) if "updated_coord" in batch_output: updated_coord_out.append( batch_output["updated_coord"].detach().cpu().numpy() @@ -181,6 +191,12 @@ def eval_model( virial_out.append(batch_output["virial"]) if "atom_virial" in batch_output: atomic_virial_out.append(batch_output["atom_virial"]) + if "strain_components" in batch_output: + strain_components_out.append(batch_output["strain_components"]) + if "atom_strain_components" in batch_output: + atom_strain_components_out.append( + batch_output["atom_strain_components"] + ) if "updated_coord" in batch_output: updated_coord_out.append(batch_output["updated_coord"]) if "logits" in batch_output: @@ -210,8 +226,20 @@ def eval_model( if atomic_virial_out else np.zeros([nframes, natoms, 3, 3]) # pylint: disable=no-explicit-dtype ) + strain_components_out = ( + np.concatenate(strain_components_out) + if strain_components_out + else np.zeros([nframes, 6]) # pylint: disable=no-explicit-dtype + ) + atom_strain_components_out = ( + np.concatenate(atom_strain_components_out) + if atom_strain_components_out + else np.zeros([nframes, natoms, 6]) # pylint: disable=no-explicit-dtype + ) updated_coord_out = ( - np.concatenate(updated_coord_out) if updated_coord_out else None + np.concatenate(updated_coord_out) + if updated_coord_out + else np.zeros([nframes, natoms, 3]) ) logits_out = np.concatenate(logits_out) if logits_out else None else: @@ -257,10 +285,37 @@ def eval_model( [nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE ) ) - updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None + strain_components_out = ( + torch.cat(strain_components_out) + if strain_components_out + else torch.zeros( + [nframes, 6], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + atom_strain_components_out = ( + torch.cat(atom_strain_components_out) + if atom_strain_components_out + else torch.zeros( + [nframes, natoms, 6], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) + updated_coord_out = ( + torch.cat(updated_coord_out) + if updated_coord_out + else torch.zeros( + [nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ) logits_out = torch.cat(logits_out) if logits_out else None if denoise: - return updated_coord_out, logits_out + results_dict = { + "strain_components": strain_components_out, + "updated_coord": updated_coord_out, + "logits": logits_out, + } + if atomic: + results_dict["atom_strain_components"] = atom_strain_components_out + return results_dict else: results_dict = { "energy": energy_out, diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index 0354336e37..f2443cd15e 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -377,6 +377,36 @@ }, } +model_denoise = { + "type_map": ["H", "C", "N", "O", "MASKED_TOKEN"], + "descriptor": { + "type": "se_atten", + "sel": 40, + "rcut_smth": 0.5, + "rcut": 4.0, + "neuron": [25, 50, 100], + "axis_neuron": 16, + "attn": 64, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": False, + "temperature": 1.0, + "set_davg_zero": True, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "type": "denoise", + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + "_comment": " that's all", + }, +} + class PermutationTest: def test( @@ -396,7 +426,10 @@ def test( atype = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32, device=env.DEVICE) idx_perm = [1, 0, 4, 3, 2] test_spin = getattr(self, "test_spin", False) - if not test_spin: + test_denoise = getattr(self, "test_denoise", False) + if test_denoise: + test_keys = ["strain_components", "updated_coord", "logits"] + elif not test_spin: test_keys = ["energy", "force", "virial"] else: test_keys = ["energy", "force", "force_mag", "virial"] @@ -406,6 +439,7 @@ def test( cell.unsqueeze(0), atype, spins=spin.unsqueeze(0), + denoise=test_denoise, ) ret0 = {key: result_0[key].squeeze(0) for key in test_keys} result_1 = eval_model( @@ -414,13 +448,14 @@ def test( cell.unsqueeze(0), atype[idx_perm], spins=spin[idx_perm].unsqueeze(0), + denoise=test_denoise, ) ret1 = {key: result_1[key].squeeze(0) for key in test_keys} prec = 1e-10 for key in test_keys: - if key in ["energy"]: + if key in ["energy", "strain_components"]: torch.testing.assert_close(ret0[key], ret1[key], rtol=prec, atol=prec) - elif key in ["force", "force_mag"]: + elif key in ["force", "force_mag", "updated_coord", "logits"]: torch.testing.assert_close( ret0[key][idx_perm], ret1[key], rtol=prec, atol=prec ) @@ -501,6 +536,14 @@ def setUp(self) -> None: self.model = get_model(model_params).to(env.DEVICE) +class TestDenoiseModelDPA1(unittest.TestCase, PermutationTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_denoise) + self.type_split = False + self.test_denoise = True + self.model = get_model(model_params).to(env.DEVICE) + + # class TestEnergyFoo(unittest.TestCase): # def test(self): # model_params = model_dpau diff --git a/source/tests/pt/model/test_permutation_denoise.py b/source/tests/pt/model/test_permutation_denoise.py deleted file mode 100644 index 389520daa3..0000000000 --- a/source/tests/pt/model/test_permutation_denoise.py +++ /dev/null @@ -1,99 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import copy -import unittest - -import torch - -from deepmd.pt.model.model import ( - get_model, -) -from deepmd.pt.utils import ( - env, -) - -from ...seed import ( - GLOBAL_SEED, -) -from ..common import ( - eval_model, -) -from .test_permutation import ( # model_dpau, - model_dpa1, - model_dpa2, - model_hybrid, -) - -dtype = torch.float64 - -model_dpa1 = copy.deepcopy(model_dpa1) -model_dpa2 = copy.deepcopy(model_dpa2) -model_hybrid = copy.deepcopy(model_hybrid) -model_dpa1["type_map"] = ["O", "H", "B", "MASKED_TOKEN"] -model_dpa1.pop("fitting_net") -model_dpa2["type_map"] = ["O", "H", "B", "MASKED_TOKEN"] -model_dpa2.pop("fitting_net") -model_hybrid["type_map"] = ["O", "H", "B", "MASKED_TOKEN"] -model_hybrid.pop("fitting_net") - - -class PermutationDenoiseTest: - def test( - self, - ) -> None: - generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED) - natoms = 5 - cell = torch.rand([3, 3], dtype=dtype, generator=generator).to(env.DEVICE) - cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) - coord = torch.rand([natoms, 3], dtype=dtype, generator=generator).to(env.DEVICE) - coord = torch.matmul(coord, cell) - atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) - idx_perm = [1, 0, 4, 3, 2] - updated_c0, logits0 = eval_model( - self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True - ) - ret0 = {"updated_coord": updated_c0.squeeze(0), "logits": logits0.squeeze(0)} - updated_c1, logits1 = eval_model( - self.model, - coord[idx_perm].unsqueeze(0), - cell.unsqueeze(0), - atype[idx_perm], - denoise=True, - ) - ret1 = {"updated_coord": updated_c1.squeeze(0), "logits": logits1.squeeze(0)} - prec = 1e-10 - torch.testing.assert_close( - ret0["updated_coord"][idx_perm], ret1["updated_coord"], rtol=prec, atol=prec - ) - torch.testing.assert_close( - ret0["logits"][idx_perm], ret1["logits"], rtol=prec, atol=prec - ) - - -@unittest.skip("support of the denoise is temporally disabled") -class TestDenoiseModelDPA1(unittest.TestCase, PermutationDenoiseTest): - def setUp(self) -> None: - model_params = copy.deepcopy(model_dpa1) - self.type_split = True - self.model = get_model(model_params).to(env.DEVICE) - - -@unittest.skip("support of the denoise is temporally disabled") -class TestDenoiseModelDPA2(unittest.TestCase, PermutationDenoiseTest): - def setUp(self) -> None: - model_params = copy.deepcopy(model_dpa2) - self.type_split = True - self.model = get_model( - model_params, - ).to(env.DEVICE) - - -# @unittest.skip("hybrid not supported at the moment") -# class TestDenoiseModelHybrid(unittest.TestCase, TestPermutationDenoise): -# def setUp(self): -# model_params = copy.deepcopy(model_hybrid_denoise) -# self.type_split = True -# self.model = get_model(model_params).to(env.DEVICE) - - -if __name__ == "__main__": - unittest.main() diff --git a/source/tests/pt/model/test_rot.py b/source/tests/pt/model/test_rot.py index 283dbb31d7..c1e1632dfe 100644 --- a/source/tests/pt/model/test_rot.py +++ b/source/tests/pt/model/test_rot.py @@ -18,6 +18,7 @@ eval_model, ) from .test_permutation import ( # model_dpau, + model_denoise, model_dos, model_dpa1, model_dpa2, @@ -51,7 +52,10 @@ def test( ) test_spin = getattr(self, "test_spin", False) - if not test_spin: + test_denoise = getattr(self, "test_denoise", False) + if test_denoise: + test_keys = ["strain_components", "updated_coord", "logits"] + elif not test_spin: test_keys = ["energy", "force", "virial"] else: test_keys = ["energy", "force", "force_mag"] @@ -66,6 +70,7 @@ def test( cell.unsqueeze(0), atype, spins=spin.unsqueeze(0), + denoise=test_denoise, ) ret0 = {key: result_0[key].squeeze(0) for key in test_keys} result_1 = eval_model( @@ -74,12 +79,13 @@ def test( cell.unsqueeze(0), atype, spins=spin_rot.unsqueeze(0), + denoise=test_denoise, ) ret1 = {key: result_1[key].squeeze(0) for key in test_keys} for key in test_keys: - if key in ["energy"]: + if key in ["energy", "strain_components", "logits"]: torch.testing.assert_close(ret0[key], ret1[key], rtol=prec, atol=prec) - elif key in ["force", "force_mag"]: + elif key in ["force", "force_mag", "updated_coord"]: torch.testing.assert_close( torch.matmul(ret0[key], rmat), ret1[key], rtol=prec, atol=prec ) @@ -116,6 +122,7 @@ def test( cell.unsqueeze(0), atype, spins=spin.unsqueeze(0), + denoise=test_denoise, ) ret0 = {key: result_0[key].squeeze(0) for key in test_keys} result_1 = eval_model( @@ -124,12 +131,13 @@ def test( cell_rot.unsqueeze(0), atype, spins=spin_rot.unsqueeze(0), + denoise=test_denoise, ) ret1 = {key: result_1[key].squeeze(0) for key in test_keys} for key in test_keys: - if key in ["energy"]: + if key in ["energy", "strain_components", "logits"]: torch.testing.assert_close(ret0[key], ret1[key], rtol=prec, atol=prec) - elif key in ["force", "force_mag"]: + elif key in ["force", "force_mag", "updated_coord"]: torch.testing.assert_close( torch.matmul(ret0[key], rmat), ret1[key], rtol=prec, atol=prec ) @@ -213,5 +221,13 @@ def setUp(self) -> None: self.model = get_model(model_params).to(env.DEVICE) +class TestDenoiseModelDPA1(unittest.TestCase, RotTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_denoise) + self.type_split = False + self.test_denoise = True + self.model = get_model(model_params).to(env.DEVICE) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/model/test_rot_denoise.py b/source/tests/pt/model/test_rot_denoise.py deleted file mode 100644 index fcae4b23d7..0000000000 --- a/source/tests/pt/model/test_rot_denoise.py +++ /dev/null @@ -1,130 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import copy -import unittest - -import torch - -from deepmd.pt.model.model import ( - get_model, -) -from deepmd.pt.utils import ( - env, -) - -from ...seed import ( - GLOBAL_SEED, -) -from ..common import ( - eval_model, -) -from .test_permutation_denoise import ( - model_dpa1, - model_dpa2, -) - -dtype = torch.float64 - - -class RotDenoiseTest: - def test( - self, - ) -> None: - generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED) - prec = 1e-10 - natoms = 5 - cell = 10.0 * torch.eye(3, dtype=dtype).to(env.DEVICE) - coord = 2 * torch.rand( - [natoms, 3], dtype=dtype, generator=generator, device=env.DEVICE - ) - shift = torch.tensor([4, 4, 4], dtype=dtype).to(env.DEVICE) - atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) - from scipy.stats import ( - special_ortho_group, - ) - - rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype).to(env.DEVICE) - - # rotate only coord and shift to the center of cell - coord_rot = torch.matmul(coord, rmat) - update_c0, logits0 = eval_model( - self.model, - (coord + shift).unsqueeze(0), - cell.unsqueeze(0), - atype, - denoise=True, - ) - update_c0 = update_c0 - (coord + shift).unsqueeze(0) - ret0 = {"updated_coord": update_c0.squeeze(0), "logits": logits0.squeeze(0)} - update_c1, logits1 = eval_model( - self.model, - (coord_rot + shift).unsqueeze(0), - cell.unsqueeze(0), - atype, - denoise=True, - ) - update_c1 = update_c1 - (coord_rot + shift).unsqueeze(0) - ret1 = {"updated_coord": update_c1.squeeze(0), "logits": logits1.squeeze(0)} - torch.testing.assert_close( - torch.matmul(ret0["updated_coord"], rmat), - ret1["updated_coord"], - rtol=prec, - atol=prec, - ) - torch.testing.assert_close(ret0["logits"], ret1["logits"], rtol=prec, atol=prec) - - # rotate coord and cell - torch.manual_seed(0) - cell = torch.rand([3, 3], dtype=dtype, generator=generator).to(env.DEVICE) - cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) - coord = torch.rand([natoms, 3], dtype=dtype, generator=generator).to(env.DEVICE) - coord = torch.matmul(coord, cell) - atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) - coord_rot = torch.matmul(coord, rmat) - cell_rot = torch.matmul(cell, rmat) - update_c0, logits0 = eval_model( - self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True - ) - ret0 = {"updated_coord": update_c0.squeeze(0), "logits": logits0.squeeze(0)} - update_c1, logits1 = eval_model( - self.model, - coord_rot.unsqueeze(0), - cell_rot.unsqueeze(0), - atype, - denoise=True, - ) - ret1 = {"updated_coord": update_c1.squeeze(0), "logits": logits1.squeeze(0)} - torch.testing.assert_close(ret0["logits"], ret1["logits"], rtol=prec, atol=prec) - torch.testing.assert_close( - torch.matmul(ret0["updated_coord"], rmat), - ret1["updated_coord"], - rtol=prec, - atol=prec, - ) - - -@unittest.skip("support of the denoise is temporally disabled") -class TestDenoiseModelDPA1(unittest.TestCase, RotDenoiseTest): - def setUp(self) -> None: - model_params = copy.deepcopy(model_dpa1) - self.type_split = True - self.model = get_model(model_params).to(env.DEVICE) - - -@unittest.skip("support of the denoise is temporally disabled") -class TestDenoiseModelDPA2(unittest.TestCase, RotDenoiseTest): - def setUp(self) -> None: - model_params = copy.deepcopy(model_dpa2) - self.type_split = True - self.model = get_model(model_params).to(env.DEVICE) - - -# @unittest.skip("hybrid not supported at the moment") -# class TestEnergyModelHybrid(unittest.TestCase, TestRotDenoise): -# def setUp(self): -# model_params = copy.deepcopy(model_hybrid_denoise) -# self.type_split = True -# self.model = get_model(model_params).to(env.DEVICE) - - -if __name__ == "__main__": - unittest.main() diff --git a/source/tests/pt/model/test_smooth.py b/source/tests/pt/model/test_smooth.py index 1c6303d14c..8eb430cf3d 100644 --- a/source/tests/pt/model/test_smooth.py +++ b/source/tests/pt/model/test_smooth.py @@ -18,6 +18,7 @@ eval_model, ) from .test_permutation import ( # model_dpau, + model_denoise, model_dos, model_dpa1, model_dpa2, @@ -83,7 +84,10 @@ def test( coord3[1][0] += epsilon coord3[2][1] += epsilon test_spin = getattr(self, "test_spin", False) - if not test_spin: + test_denoise = getattr(self, "test_denoise", False) + if test_denoise: + test_keys = ["strain_components", "updated_coord", "logits"] + elif not test_spin: test_keys = ["energy", "force", "virial"] else: test_keys = ["energy", "force", "force_mag", "virial"] @@ -94,6 +98,7 @@ def test( cell.unsqueeze(0), atype, spins=spin.unsqueeze(0), + denoise=test_denoise, ) ret0 = {key: result_0[key].squeeze(0) for key in test_keys} result_1 = eval_model( @@ -102,6 +107,7 @@ def test( cell.unsqueeze(0), atype, spins=spin.unsqueeze(0), + denoise=test_denoise, ) ret1 = {key: result_1[key].squeeze(0) for key in test_keys} result_2 = eval_model( @@ -110,6 +116,7 @@ def test( cell.unsqueeze(0), atype, spins=spin.unsqueeze(0), + denoise=test_denoise, ) ret2 = {key: result_2[key].squeeze(0) for key in test_keys} result_3 = eval_model( @@ -118,12 +125,13 @@ def test( cell.unsqueeze(0), atype, spins=spin.unsqueeze(0), + denoise=test_denoise, ) ret3 = {key: result_3[key].squeeze(0) for key in test_keys} def compare(ret0, ret1) -> None: for key in test_keys: - if key in ["energy"]: + if key in ["energy", "strain_components", "updated_coord", "logits"]: torch.testing.assert_close( ret0[key], ret1[key], rtol=rprec, atol=aprec ) @@ -251,6 +259,15 @@ def setUp(self) -> None: self.epsilon, self.aprec = None, None +class TestDenoiseModelDPA1(unittest.TestCase, SmoothTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_denoise) + self.type_split = False + self.test_denoise = True + self.model = get_model(model_params).to(env.DEVICE) + self.epsilon, self.aprec = None, None + + # class TestEnergyFoo(unittest.TestCase): # def test(self): # model_params = model_dpau diff --git a/source/tests/pt/model/test_smooth_denoise.py b/source/tests/pt/model/test_smooth_denoise.py deleted file mode 100644 index 199d6664a1..0000000000 --- a/source/tests/pt/model/test_smooth_denoise.py +++ /dev/null @@ -1,141 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import copy -import unittest - -import torch - -from deepmd.pt.model.model import ( - get_model, -) -from deepmd.pt.utils import ( - env, -) - -from ...seed import ( - GLOBAL_SEED, -) -from ..common import ( - eval_model, -) -from .test_permutation_denoise import ( - model_dpa2, -) - -dtype = torch.float64 - - -class SmoothDenoiseTest: - def test( - self, - ) -> None: - # displacement of atoms - epsilon = 1e-5 if self.epsilon is None else self.epsilon - # required prec. relative prec is not checked. - rprec = 0 - aprec = 1e-5 if self.aprec is None else self.aprec - - natoms = 10 - cell = 8.6 * torch.eye(3, dtype=dtype).to(env.DEVICE) - generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED) - atype = torch.randint(0, 3, [natoms], generator=generator, device=env.DEVICE) - coord0 = ( - torch.tensor( - [ - 0.0, - 0.0, - 0.0, - 4.0 - 0.5 * epsilon, - 0.0, - 0.0, - 0.0, - 4.0 - 0.5 * epsilon, - 0.0, - ], - dtype=dtype, - ) - .view([-1, 3]) - .to(env.DEVICE) - ) - coord1 = torch.rand( - [natoms - coord0.shape[0], 3], dtype=dtype, generator=generator - ).to(env.DEVICE) - coord1 = torch.matmul(coord1, cell) - coord = torch.concat([coord0, coord1], dim=0) - - coord0 = torch.clone(coord) - coord1 = torch.clone(coord) - coord1[1][0] += epsilon - coord2 = torch.clone(coord) - coord2[2][1] += epsilon - coord3 = torch.clone(coord) - coord3[1][0] += epsilon - coord3[2][1] += epsilon - - update_c0, logits0 = eval_model( - self.model, coord0.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True - ) - ret0 = {"updated_coord": update_c0.squeeze(0), "logits": logits0.squeeze(0)} - update_c1, logits1 = eval_model( - self.model, coord1.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True - ) - ret1 = {"updated_coord": update_c1.squeeze(0), "logits": logits1.squeeze(0)} - update_c2, logits2 = eval_model( - self.model, coord2.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True - ) - ret2 = {"updated_coord": update_c2.squeeze(0), "logits": logits2.squeeze(0)} - update_c3, logits3 = eval_model( - self.model, coord3.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True - ) - ret3 = {"updated_coord": update_c3.squeeze(0), "logits": logits3.squeeze(0)} - - def compare(ret0, ret1) -> None: - torch.testing.assert_close( - ret0["updated_coord"], ret1["updated_coord"], rtol=rprec, atol=aprec - ) - torch.testing.assert_close( - ret0["logits"], ret1["logits"], rtol=rprec, atol=aprec - ) - - compare(ret0, ret1) - compare(ret1, ret2) - compare(ret0, ret3) - - -@unittest.skip("support of the denoise is temporally disabled") -class TestDenoiseModelDPA2(unittest.TestCase, SmoothDenoiseTest): - def setUp(self) -> None: - model_params = copy.deepcopy(model_dpa2) - model_params["descriptor"]["sel"] = 8 - model_params["descriptor"]["rcut_smth"] = 3.5 - self.type_split = True - self.model = get_model(model_params).to(env.DEVICE) - self.epsilon, self.aprec = None, None - self.epsilon = 1e-7 - self.aprec = 1e-5 - - -@unittest.skip("support of the denoise is temporally disabled") -class TestDenoiseModelDPA2_1(unittest.TestCase, SmoothDenoiseTest): - def setUp(self) -> None: - model_params = copy.deepcopy(model_dpa2) - # model_params["descriptor"]["combine_grrg"] = True - self.type_split = True - self.model = get_model(model_params).to(env.DEVICE) - self.epsilon, self.aprec = None, None - self.epsilon = 1e-7 - self.aprec = 1e-5 - - -# @unittest.skip("hybrid not supported at the moment") -# class TestDenoiseModelHybrid(unittest.TestCase, TestSmoothDenoise): -# def setUp(self): -# model_params = copy.deepcopy(model_hybrid_denoise) -# self.type_split = True -# self.model = get_model(model_params).to(env.DEVICE) -# self.epsilon, self.aprec = None, None -# self.epsilon = 1e-7 -# self.aprec = 1e-5 - - -if __name__ == "__main__": - unittest.main() diff --git a/source/tests/pt/model/test_trans.py b/source/tests/pt/model/test_trans.py index 2e39cc4bd5..916dce8bff 100644 --- a/source/tests/pt/model/test_trans.py +++ b/source/tests/pt/model/test_trans.py @@ -18,6 +18,7 @@ eval_model, ) from .test_permutation import ( # model_dpau, + model_denoise, model_dos, model_dpa1, model_dpa2, @@ -54,7 +55,10 @@ def test( cell, ) test_spin = getattr(self, "test_spin", False) - if not test_spin: + test_denoise = getattr(self, "test_denoise", False) + if test_denoise: + test_keys = ["strain_components", "updated_coord", "logits"] + elif not test_spin: test_keys = ["energy", "force", "virial"] else: test_keys = ["energy", "force", "force_mag", "virial"] @@ -64,6 +68,7 @@ def test( cell.unsqueeze(0), atype, spins=spin.unsqueeze(0), + denoise=test_denoise, ) ret0 = {key: result_0[key].squeeze(0) for key in test_keys} result_1 = eval_model( @@ -72,11 +77,19 @@ def test( cell.unsqueeze(0), atype, spins=spin.unsqueeze(0), + denoise=test_denoise, ) ret1 = {key: result_1[key].squeeze(0) for key in test_keys} prec = 1e-7 for key in test_keys: - if key in ["energy", "force", "force_mag"]: + if key in [ + "energy", + "force", + "force_mag", + "strain_components", + "updated_coord", + "logits", + ]: torch.testing.assert_close(ret0[key], ret1[key], rtol=prec, atol=prec) elif key == "virial": if not hasattr(self, "test_virial") or self.test_virial: @@ -155,5 +168,13 @@ def setUp(self) -> None: self.model = get_model(model_params).to(env.DEVICE) +class TestDenoiseModelDPA1(unittest.TestCase, TransTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_denoise) + self.type_split = False + self.test_denoise = True + self.model = get_model(model_params).to(env.DEVICE) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/model/test_trans_denoise.py b/source/tests/pt/model/test_trans_denoise.py deleted file mode 100644 index 77bff7980a..0000000000 --- a/source/tests/pt/model/test_trans_denoise.py +++ /dev/null @@ -1,89 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import copy -import unittest - -import torch - -from deepmd.pt.model.model import ( - get_model, -) -from deepmd.pt.utils import ( - env, -) - -from ...seed import ( - GLOBAL_SEED, -) -from ..common import ( - eval_model, -) -from .test_permutation_denoise import ( - model_dpa1, - model_dpa2, - model_hybrid, -) - -dtype = torch.float64 - - -class TransDenoiseTest: - def test( - self, - ) -> None: - natoms = 5 - generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED) - cell = torch.rand([3, 3], dtype=dtype, generator=generator).to(env.DEVICE) - cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE) - coord = torch.rand([natoms, 3], dtype=dtype).to(env.DEVICE) - coord = torch.matmul(coord, cell) - atype = torch.IntTensor([0, 0, 0, 1, 1]).to(env.DEVICE) - shift = (torch.rand([3], dtype=dtype, generator=generator) - 0.5).to( - env.DEVICE - ) * 2.0 - coord_s = torch.matmul( - torch.remainder(torch.matmul(coord + shift, torch.linalg.inv(cell)), 1.0), - cell, - ) - updated_c0, logits0 = eval_model( - self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True - ) - updated_c0 = updated_c0 - coord.unsqueeze(0) - ret0 = {"updated_coord": updated_c0.squeeze(0), "logits": logits0.squeeze(0)} - updated_c1, logits1 = eval_model( - self.model, coord_s.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True - ) - updated_c1 = updated_c1 - coord_s.unsqueeze(0) - ret1 = {"updated_coord": updated_c1.squeeze(0), "logits": logits1.squeeze(0)} - prec = 1e-10 - torch.testing.assert_close( - ret0["updated_coord"], ret1["updated_coord"], rtol=prec, atol=prec - ) - torch.testing.assert_close(ret0["logits"], ret1["logits"], rtol=prec, atol=prec) - - -@unittest.skip("support of the denoise is temporally disabled") -class TestDenoiseModelDPA1(unittest.TestCase, TransDenoiseTest): - def setUp(self) -> None: - model_params = copy.deepcopy(model_dpa1) - self.type_split = True - self.model = get_model(model_params).to(env.DEVICE) - - -@unittest.skip("support of the denoise is temporally disabled") -class TestDenoiseModelDPA2(unittest.TestCase, TransDenoiseTest): - def setUp(self) -> None: - model_params = copy.deepcopy(model_dpa2) - self.type_split = True - self.model = get_model(model_params).to(env.DEVICE) - - -@unittest.skip("hybrid not supported at the moment") -class TestDenoiseModelHybrid(unittest.TestCase, TransDenoiseTest): - def setUp(self) -> None: - model_params = copy.deepcopy(model_hybrid) - self.type_split = True - self.model = get_model(model_params).to(env.DEVICE) - - -if __name__ == "__main__": - unittest.main() diff --git a/source/tests/universal/common/cases/atomic_model/atomic_model.py b/source/tests/universal/common/cases/atomic_model/atomic_model.py index 499a313a32..1969404c4f 100644 --- a/source/tests/universal/common/cases/atomic_model/atomic_model.py +++ b/source/tests/universal/common/cases/atomic_model/atomic_model.py @@ -127,3 +127,27 @@ def setUpClass(cls) -> None: cls.aprec_dict = {} cls.rprec_dict = {} cls.epsilon_dict = {} + + +class DenoiseAtomicModelTest(AtomicModelTestCase): + @classmethod + def setUpClass(cls) -> None: + cls.expected_rcut = 5.0 + cls.expected_type_map = ["O", "H"] + cls.expected_dim_fparam = 0 + cls.expected_dim_aparam = 0 + cls.expected_sel_type = [0, 1] + cls.expected_aparam_nall = False + cls.expected_model_output_type = [ + "strain_components", + "updated_coord", + "logits", + "mask", + ] + cls.model_output_equivariant = ["updated_coord"] + cls.expected_sel = [46, 92] + cls.expected_sel_mix = sum(cls.expected_sel) + cls.expected_has_message_passing = False + cls.aprec_dict = {} + cls.rprec_dict = {} + cls.epsilon_dict = {} diff --git a/source/tests/universal/common/cases/fitting/utils.py b/source/tests/universal/common/cases/fitting/utils.py index de6b12c3a2..55d1e14247 100644 --- a/source/tests/universal/common/cases/fitting/utils.py +++ b/source/tests/universal/common/cases/fitting/utils.py @@ -76,33 +76,38 @@ def test_exclude_types( ) self.module = self.module.deserialize(serialize_dict) ff = self.forward_wrapper(self.module) - var_name = self.module.var_name - if var_name == "polar": - var_name = "polarizability" + var_names = self.module.var_name + if isinstance(var_names, str): + var_names = [var_names] + var_names = ["polarizability" if v == "polar" else v for v in var_names] for em in [[0], [1]]: ex_pair = AtomExcludeMask(self.nt, em) atom_mask = ex_pair.build_type_exclude_mask(atype_device) # exclude neighbors in the output - rd = ff( + result = ff( self.mock_descriptor, self.atype_ext[:, : self.nloc], gr=self.mock_gr, - )[var_name] - for _ in range(len(rd.shape) - len(atom_mask.shape)): - atom_mask = atom_mask[..., None] - rd = rd * atom_mask + ) + for var in var_names: + rd = result[var] + _atom_mask = atom_mask.copy() + for _ in range(len(rd.shape) - len(_atom_mask.shape)): + _atom_mask = _atom_mask[..., None] + rd_masked = rd * _atom_mask + # normal nlist but use exclude_types params + serialize_dict_em = deepcopy(serialize_dict) + serialize_dict_em.update({"exclude_types": em}) + ff_ex = self.forward_wrapper(self.module.deserialize(serialize_dict_em)) + result_ex = ff_ex( + self.mock_descriptor, + self.atype_ext[:, : self.nloc], + gr=self.mock_gr, + ) + rd_ex = result_ex[var] - # normal nlist but use exclude_types params - serialize_dict_em = deepcopy(serialize_dict) - serialize_dict_em.update({"exclude_types": em}) - ff_ex = self.forward_wrapper(self.module.deserialize(serialize_dict_em)) - rd_ex = ff_ex( - self.mock_descriptor, - self.atype_ext[:, : self.nloc], - gr=self.mock_gr, - )[var_name] - np.testing.assert_allclose(rd, rd_ex) + np.testing.assert_allclose(rd_masked, rd_ex) def test_change_type_map(self) -> None: if not self.module.mixed_types: @@ -168,23 +173,29 @@ def test_change_type_map(self) -> None: size=serialize_dict["@variables"]["bias_atom_e"].shape ) old_tm_module = old_tm_module.deserialize(serialize_dict) - var_name = old_tm_module.var_name - if var_name == "polar": - var_name = "polarizability" + var_names = old_tm_module.var_name + if isinstance(var_names, str): + var_names = [var_names] + var_names = ["polarizability" if v == "polar" else v for v in var_names] old_tm_ff = self.forward_wrapper(old_tm_module) - rd_old_tm = old_tm_ff( + + result_old = old_tm_ff( self.mock_descriptor, old_tm_index[atype_device], gr=self.mock_gr, - )[var_name] + ) old_tm_module.change_type_map(new_tm) new_tm_ff = self.forward_wrapper(old_tm_module) - rd_new_tm = new_tm_ff( + result_new = new_tm_ff( self.mock_descriptor, new_tm_index[atype_device], gr=self.mock_gr, - )[var_name] - np.testing.assert_allclose(rd_old_tm, rd_new_tm) + ) + for var in var_names: + np.testing.assert_allclose( + result_old[var], + result_new[var], + ) def remap_exclude_types(exclude_types, ori_tm, new_tm): diff --git a/source/tests/universal/common/cases/loss/utils.py b/source/tests/universal/common/cases/loss/utils.py index 63e6e3ed27..240a9e51ca 100644 --- a/source/tests/universal/common/cases/loss/utils.py +++ b/source/tests/universal/common/cases/loss/utils.py @@ -35,7 +35,7 @@ def test_forward(self): natoms = 5 nframes = 2 - def fake_model(): + def fake_model(**kwargs): model_predict = { data_key: fake_input( label_dict[data_key], natoms=natoms, nframes=nframes @@ -55,13 +55,53 @@ def fake_model(): } labels.update({"find_" + data_key: 1.0 for data_key in label_keys}) - _, loss, more_loss = module( - {}, - fake_model, - labels, - natoms, - 1.0, - ) + if "updated_coord" in self.key_to_pref_map: + import torch + + from deepmd.pt.utils import ( + env, + ) + + labels.update( + { + "type_mask": torch.tensor( + [[False] * natoms, [False] * natoms], + dtype=torch.bool, + device=env.DEVICE, + ) + } + ) + input_dict = {} + input_dict["box"] = torch.tensor( + [[1, 0, 0, 0, 1, 0, 0, 0, 1]] * nframes, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) + input_dict["atype"] = torch.tensor( + [[0] * natoms, [0] * natoms], + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) + input_dict["coord"] = torch.tensor( + [[[0] * 3] * natoms] * nframes, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) + _, loss, more_loss = module( + input_dict, + fake_model, + labels, + natoms, + 1.0, + ) + else: + _, loss, more_loss = module( + {}, + fake_model, + labels, + natoms, + 1.0, + ) def fake_input(data_item: DataRequirementItem, natoms=5, nframes=2) -> np.ndarray: diff --git a/source/tests/universal/common/cases/model/model.py b/source/tests/universal/common/cases/model/model.py index 06ddd90970..791e00fdb5 100644 --- a/source/tests/universal/common/cases/model/model.py +++ b/source/tests/universal/common/cases/model/model.py @@ -174,3 +174,28 @@ def setUpClass(cls) -> None: cls.rprec_dict = {} cls.epsilon_dict = {} cls.skip_test_autodiff = True + + +class DenoiseModelTest(ModelTestCase): + @classmethod + def setUpClass(cls) -> None: + cls.expected_rcut = 5.0 + cls.expected_type_map = ["O", "H"] + cls.expected_dim_fparam = 0 + cls.expected_dim_aparam = 0 + cls.expected_sel_type = [0, 1] + cls.expected_aparam_nall = False + cls.expected_model_output_type = [ + "strain_components", + "updated_coord", + "logits", + "mask", + ] + cls.model_output_equivariant = ["updated_coord"] + cls.expected_sel = [46, 92] + cls.expected_sel_mix = sum(cls.expected_sel) + cls.expected_has_message_passing = False + cls.aprec_dict = {} + cls.rprec_dict = {} + cls.epsilon_dict = {} + cls.skip_test_autodiff = True diff --git a/source/tests/universal/dpmodel/atomc_model/test_atomic_model.py b/source/tests/universal/dpmodel/atomc_model/test_atomic_model.py index 7b579ae82c..03d0871be6 100644 --- a/source/tests/universal/dpmodel/atomc_model/test_atomic_model.py +++ b/source/tests/universal/dpmodel/atomc_model/test_atomic_model.py @@ -15,6 +15,7 @@ DescrptSeT, ) from deepmd.dpmodel.fitting import ( + DenoiseFitting, DipoleFitting, DOSFittingNet, EnergyFittingNet, @@ -30,6 +31,7 @@ TEST_DEVICE, ) from ...common.cases.atomic_model.atomic_model import ( + DenoiseAtomicModelTest, DipoleAtomicModelTest, DosAtomicModelTest, EnerAtomicModelTest, @@ -59,6 +61,8 @@ DPTestCase, ) from ..fitting.test_fitting import ( + FittingParamDenoise, + FittingParamDenoiseList, FittingParamDipole, FittingParamDipoleList, FittingParamDos, @@ -469,3 +473,63 @@ def setUpClass(cls) -> None: cls.expected_sel_type = ft.get_sel_type() cls.expected_dim_fparam = ft.get_dim_fparam() cls.expected_dim_aparam = ft.get_dim_aparam() + + +@parameterized( + des_parameterized=( + ( + *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], + *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + (DescriptorParamHybridMixed, DescrptHybrid), + ), # descrpt_class_param & class + ((FittingParamDenoise, DenoiseFitting),), # fitting_class_param & class + ), + fit_parameterized=( + ( + (DescriptorParamDPA1, DescrptDPA1), + (DescriptorParamDPA2, DescrptDPA2), + ), # descrpt_class_param & class + ( + *[(param_func, DenoiseFitting) for param_func in FittingParamDenoiseList], + ), # fitting_class_param & class + ), +) +@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") +class TestDenoiseAtomicModelDP(unittest.TestCase, DenoiseAtomicModelTest, DPTestCase): + @classmethod + def setUpClass(cls) -> None: + DenoiseAtomicModelTest.setUpClass() + (DescriptorParam, Descrpt) = cls.param[0] + (FittingParam, Fitting) = cls.param[1] + cls.input_dict_ds = DescriptorParam( + len(cls.expected_type_map), + cls.expected_rcut, + cls.expected_rcut / 2, + cls.expected_sel, + cls.expected_type_map, + ) + # set skip tests + skiptest, skip_reason = skip_model_tests(cls) + if skiptest: + raise cls.skipTest(cls, skip_reason) + ds = Descrpt(**cls.input_dict_ds) + cls.input_dict_ft = FittingParam( + ntypes=len(cls.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + type_map=cls.expected_type_map, + embedding_width=ds.get_dim_emb(), + ) + ft = Fitting( + **cls.input_dict_ft, + ) + cls.module = DPAtomicModel( + ds, + ft, + type_map=cls.expected_type_map, + ) + cls.output_def = cls.module.atomic_output_def().get_data() + cls.expected_has_message_passing = ds.has_message_passing() + cls.expected_sel_type = ft.get_sel_type() + cls.expected_dim_fparam = ft.get_dim_fparam() + cls.expected_dim_aparam = ft.get_dim_aparam() diff --git a/source/tests/universal/dpmodel/fitting/test_fitting.py b/source/tests/universal/dpmodel/fitting/test_fitting.py index 90b0668d20..451cd7e4de 100644 --- a/source/tests/universal/dpmodel/fitting/test_fitting.py +++ b/source/tests/universal/dpmodel/fitting/test_fitting.py @@ -5,6 +5,7 @@ ) from deepmd.dpmodel.fitting import ( + DenoiseFitting, DipoleFitting, DOSFittingNet, EnergyFittingNet, @@ -234,6 +235,52 @@ def FittingParamProperty( FittingParamProperty = FittingParamPropertyList[0] +def FittingParamDenoise( + ntypes, + dim_descrpt, + mixed_types, + type_map, + exclude_types=[], + precision="float64", + embedding_width=None, + numb_param=0, # test numb_fparam, numb_aparam and dim_case_embd together +): + assert embedding_width is not None, ( + "embedding_width for denoise fitting is required." + ) + input_dict = { + "ntypes": ntypes, + "dim_descrpt": dim_descrpt, + "mixed_types": mixed_types, + "type_map": type_map, + "embedding_width": embedding_width, + "exclude_types": exclude_types, + "seed": GLOBAL_SEED, + "precision": precision, + "numb_fparam": numb_param, + "numb_aparam": numb_param, + "dim_case_embd": numb_param, + "coord_noise": 0.2, + "cell_pert_fraction": 0.008, + "noise_type": "gaussian", + } + return input_dict + + +FittingParamDenoiseList = parameterize_func( + FittingParamDenoise, + OrderedDict( + { + "exclude_types": ([], [0]), + "precision": ("float64",), + "numb_param": (0, 2), + } + ), +) +# to get name for the default function +FittingParamDenoise = FittingParamDenoiseList[0] + + @parameterized( ( (FittingParamEnergy, EnergyFittingNet), @@ -241,6 +288,7 @@ def FittingParamProperty( (FittingParamDipole, DipoleFitting), (FittingParamPolar, PolarFitting), (FittingParamProperty, PropertyFittingNet), + (FittingParamDenoise, DenoiseFitting), ), # class_param & class (True, False), # mixed_types ) diff --git a/source/tests/universal/dpmodel/loss/test_loss.py b/source/tests/universal/dpmodel/loss/test_loss.py index 79c67cdba4..ffd427180c 100644 --- a/source/tests/universal/dpmodel/loss/test_loss.py +++ b/source/tests/universal/dpmodel/loss/test_loss.py @@ -204,3 +204,35 @@ def LossParamProperty(): LossParamPropertyList = [LossParamProperty] # to get name for the default function LossParamProperty = LossParamPropertyList[0] + + +def LossParamDenoise(): + key_to_pref_map = { + "strain_components": 1.0, + "updated_coord": 1.0, + "logits": 1.0, + } + input_dict = { + "key_to_pref_map": key_to_pref_map, + "ntypes": 1, + "mask_token": False, + "mask_coord": True, + "mask_cell": False, + "token_loss": 1.0, + "coord_loss": 1.0, + "cell_loss": 1.0, + "noise_type": "gaussian", + "coord_noise": 0.2, + "cell_pert_fraction": 0.0, + "noise_mode": "prob", + "mask_num": 1, + "mask_prob": 0.2, + "same_mask": False, + "loss_func": "rmse", + } + return input_dict + + +LossParamDenoiseList = [LossParamDenoise] +# to get name for the default function +LossParamDenoise = LossParamDenoiseList[0] diff --git a/source/tests/universal/pt/atomc_model/test_atomic_model.py b/source/tests/universal/pt/atomc_model/test_atomic_model.py index f41d384b6b..8269942fd8 100644 --- a/source/tests/universal/pt/atomc_model/test_atomic_model.py +++ b/source/tests/universal/pt/atomc_model/test_atomic_model.py @@ -15,6 +15,7 @@ DescrptSeT, ) from deepmd.pt.model.task import ( + DenoiseFittingNet, DipoleFittingNet, DOSFittingNet, EnergyFittingNet, @@ -26,6 +27,7 @@ parameterized, ) from ...common.cases.atomic_model.atomic_model import ( + DenoiseAtomicModelTest, DipoleAtomicModelTest, DosAtomicModelTest, EnerAtomicModelTest, @@ -49,6 +51,8 @@ DescriptorParamSeTList, ) from ...dpmodel.fitting.test_fitting import ( + FittingParamDenoise, + FittingParamDenoiseList, FittingParamDipole, FittingParamDipoleList, FittingParamDos, @@ -459,3 +463,66 @@ def setUpClass(cls) -> None: cls.expected_sel_type = ft.get_sel_type() cls.expected_dim_fparam = ft.get_dim_fparam() cls.expected_dim_aparam = ft.get_dim_aparam() + + +@parameterized( + des_parameterized=( + ( + *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], + *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + (DescriptorParamHybrid, DescrptHybrid), + (DescriptorParamHybridMixed, DescrptHybrid), + ), # descrpt_class_param & class + ((FittingParamDenoise, DenoiseFittingNet),), # fitting_class_param & class + ), + fit_parameterized=( + ( + (DescriptorParamDPA1, DescrptDPA1), + (DescriptorParamDPA2, DescrptDPA2), + ), # descrpt_class_param & class + ( + *[ + (param_func, DenoiseFittingNet) + for param_func in FittingParamDenoiseList + ], + ), # fitting_class_param & class + ), +) +class TestDenoiseAtomicModelPT(unittest.TestCase, DenoiseAtomicModelTest, PTTestCase): + @classmethod + def setUpClass(cls) -> None: + DenoiseAtomicModelTest.setUpClass() + (DescriptorParam, Descrpt) = cls.param[0] + (FittingParam, Fitting) = cls.param[1] + cls.input_dict_ds = DescriptorParam( + len(cls.expected_type_map), + cls.expected_rcut, + cls.expected_rcut / 2, + cls.expected_sel, + cls.expected_type_map, + ) + # set skip tests + skiptest, skip_reason = skip_model_tests(cls) + if skiptest: + raise cls.skipTest(cls, skip_reason) + ds = Descrpt(**cls.input_dict_ds) + cls.input_dict_ft = FittingParam( + ntypes=len(cls.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + type_map=cls.expected_type_map, + embedding_width=ds.get_dim_emb(), + ) + ft = Fitting( + **cls.input_dict_ft, + ) + cls.module = DPAtomicModel( + ds, + ft, + type_map=cls.expected_type_map, + ) + cls.output_def = cls.module.atomic_output_def().get_data() + cls.expected_has_message_passing = ds.has_message_passing() + cls.expected_sel_type = ft.get_sel_type() + cls.expected_dim_fparam = ft.get_dim_fparam() + cls.expected_dim_aparam = ft.get_dim_aparam() diff --git a/source/tests/universal/pt/fitting/test_fitting.py b/source/tests/universal/pt/fitting/test_fitting.py index efda3b9619..ee44afb4ca 100644 --- a/source/tests/universal/pt/fitting/test_fitting.py +++ b/source/tests/universal/pt/fitting/test_fitting.py @@ -2,6 +2,7 @@ import unittest from deepmd.pt.model.task import ( + DenoiseFittingNet, DipoleFittingNet, DOSFittingNet, EnergyFittingNet, @@ -16,6 +17,7 @@ FittingTest, ) from ...dpmodel.fitting.test_fitting import ( + FittingParamDenoise, FittingParamDipole, FittingParamDos, FittingParamEnergy, @@ -34,6 +36,7 @@ (FittingParamDipole, DipoleFittingNet), (FittingParamPolar, PolarFittingNet), (FittingParamProperty, PropertyFittingNet), + (FittingParamDenoise, DenoiseFittingNet), ), # class_param & class (True, False), # mixed_types ) diff --git a/source/tests/universal/pt/loss/test_loss.py b/source/tests/universal/pt/loss/test_loss.py index 47c2d06fbc..c3f4cdce26 100644 --- a/source/tests/universal/pt/loss/test_loss.py +++ b/source/tests/universal/pt/loss/test_loss.py @@ -2,6 +2,7 @@ import unittest from deepmd.pt.loss import ( + DenoiseLoss, DOSLoss, EnergySpinLoss, EnergyStdLoss, @@ -16,6 +17,7 @@ LossTest, ) from ...dpmodel.loss.test_loss import ( + LossParamDenoiseList, LossParamDosList, LossParamEnergyList, LossParamEnergySpinList, @@ -34,6 +36,7 @@ *[(param_func, DOSLoss) for param_func in LossParamDosList], *[(param_func, TensorLoss) for param_func in LossParamTensorList], *[(param_func, PropertyLoss) for param_func in LossParamPropertyList], + *[(param_func, DenoiseLoss) for param_func in LossParamDenoiseList], ) # class_param & class ) class TestLossPT(unittest.TestCase, LossTest, PTTestCase): diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index 867fa48b87..e32ddfab6d 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -18,6 +18,7 @@ DescrptSeTTebd, ) from deepmd.pt.model.model import ( + DenoiseModel, DipoleModel, DOSModel, DPZBLModel, @@ -28,6 +29,7 @@ SpinEnergyModel, ) from deepmd.pt.model.task import ( + DenoiseFittingNet, DipoleFittingNet, DOSFittingNet, EnergyFittingNet, @@ -42,6 +44,7 @@ parameterized, ) from ...common.cases.model.model import ( + DenoiseModelTest, DipoleModelTest, DosModelTest, EnerModelTest, @@ -71,6 +74,8 @@ DescriptorParamSeTTebdList, ) from ...dpmodel.fitting.test_fitting import ( + FittingParamDenoise, + FittingParamDenoiseList, FittingParamDipole, FittingParamDipoleList, FittingParamDos, @@ -106,6 +111,7 @@ FittingParamDipole, FittingParamPolar, FittingParamProperty, + FittingParamDenoise, ] @@ -919,3 +925,92 @@ def setUpClass(cls) -> None: cls.expected_dim_fparam = ft1.get_dim_fparam() cls.expected_dim_aparam = ft1.get_dim_aparam() cls.expected_sel_type = ft1.get_sel_type() + + +@parameterized( + des_parameterized=( + ( + *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], + *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], + (DescriptorParamHybrid, DescrptHybrid), + (DescriptorParamHybridMixed, DescrptHybrid), + ), # descrpt_class_param & class + ((FittingParamDenoise, DenoiseFittingNet),), # fitting_class_param & class + ), + fit_parameterized=( + ( + (DescriptorParamDPA1, DescrptDPA1), + (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), + ), # descrpt_class_param & class + ( + *[ + (param_func, DenoiseFittingNet) + for param_func in FittingParamDenoiseList + ], + ), # fitting_class_param & class + ), +) +class TestDenoiseModelPT(unittest.TestCase, DenoiseModelTest, PTTestCase): + @property + def modules_to_test(self): + skip_test_jit = getattr(self, "skip_test_jit", False) + modules = PTTestCase.modules_to_test.fget(self) + if not skip_test_jit: + # for Model, we can test script module API + modules += [ + self._script_module + if hasattr(self, "_script_module") + else self.script_module + ] + return modules + + @classmethod + def setUpClass(cls) -> None: + DenoiseModelTest.setUpClass() + (DescriptorParam, Descrpt) = cls.param[0] + (FittingParam, Fitting) = cls.param[1] + cls.input_dict_ds = DescriptorParam( + len(cls.expected_type_map), + cls.expected_rcut, + cls.expected_rcut / 2, + cls.expected_sel, + cls.expected_type_map, + ) + + # set skip tests + skiptest, skip_reason = skip_model_tests(cls) + if skiptest: + raise cls.skipTest(cls, skip_reason) + + ds = Descrpt(**cls.input_dict_ds) + cls.input_dict_ft = FittingParam( + ntypes=len(cls.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + type_map=cls.expected_type_map, + embedding_width=ds.get_dim_emb(), + ) + ft = Fitting( + **cls.input_dict_ft, + ) + cls.module = DenoiseModel( + ds, + ft, + type_map=cls.expected_type_map, + ) + # only test jit API once for different models + if ( + DescriptorParam not in defalut_des_param + or FittingParam not in defalut_fit_param + ): + cls.skip_test_jit = True + else: + with torch.jit.optimized_execution(False): + cls._script_module = torch.jit.script(cls.module) + cls.output_def = cls.module.translated_output_def() + cls.expected_has_message_passing = ds.has_message_passing() + cls.expected_sel_type = ft.get_sel_type() + cls.expected_dim_fparam = ft.get_dim_fparam() + cls.expected_dim_aparam = ft.get_dim_aparam()