11# SPDX-License-Identifier: LGPL-3.0-or-later
2+ from collections .abc import (
3+ Callable ,
4+ )
25from typing import (
36 Any ,
47)
1518from deepmd .dpmodel .output_def import (
1619 FittingOutputDef ,
1720)
21+ from deepmd .utils .path import (
22+ DPPath ,
23+ )
1824from deepmd .utils .version import (
1925 check_version_compatibility ,
2026)
2632
2733@BaseAtomicModel .register ("standard" )
2834class DPAtomicModel (BaseAtomicModel ):
29- """Model give atomic prediction of some physical property.
35+ r"""Model give atomic prediction of some physical property.
36+
37+ The atomic model computes atomic properties by first extracting a descriptor
38+ from the atomic environment, then passing it through a fitting network:
39+
40+ .. math::
41+ \mathcal{D}^i = \mathcal{D}(\mathbf{R}^i, \mathbf{R}_j, \alpha_j),
42+
43+ .. math::
44+ \mathbf{y}^i = \mathcal{F}(\mathcal{D}^i),
45+
46+ where :math:`\mathcal{D}^i` is the descriptor for atom :math:`i`,
47+ :math:`\alpha_j` is the atom type of neighbor :math:`j`,
48+ :math:`\mathcal{F}` is the fitting network, and
49+ :math:`\mathbf{y}^i` is the predicted atomic property (energy, dipole, etc.).
3050
3151 Parameters
3252 ----------
@@ -48,17 +68,16 @@ def __init__(
4868 ** kwargs : Any ,
4969 ) -> None :
5070 super ().__init__ (type_map , ** kwargs )
51- self .type_map = type_map
5271 self .descriptor = descriptor
53- self .fitting = fitting
54- if hasattr (self .fitting , "reinit_exclude" ):
55- self .fitting .reinit_exclude (self .atom_exclude_types )
72+ self .fitting_net = fitting
73+ if hasattr (self .fitting_net , "reinit_exclude" ):
74+ self .fitting_net .reinit_exclude (self .atom_exclude_types )
5675 self .type_map = type_map
5776 super ().init_out_stat ()
5877
5978 def fitting_output_def (self ) -> FittingOutputDef :
6079 """Get the output def of the fitting net."""
61- return self .fitting .output_def ()
80+ return self .fitting_net .output_def ()
6281
6382 def get_rcut (self ) -> float :
6483 """Get the cut-off radius."""
@@ -73,7 +92,7 @@ def set_case_embd(self, case_idx: int) -> None:
7392 Set the case embedding of this atomic model by the given case_idx,
7493 typically concatenated with the output of the descriptor and fed into the fitting net.
7594 """
76- self .fitting .set_case_embd (case_idx )
95+ self .fitting_net .set_case_embd (case_idx )
7796
7897 def mixed_types (self ) -> bool :
7998 """If true, the model
@@ -166,7 +185,7 @@ def forward_atomic(
166185 nlist ,
167186 mapping = mapping ,
168187 )
169- ret = self .fitting (
188+ ret = self .fitting_net (
170189 descriptor ,
171190 atype ,
172191 gr = rot_mat ,
@@ -177,6 +196,37 @@ def forward_atomic(
177196 )
178197 return ret
179198
199+ def compute_or_load_stat (
200+ self ,
201+ sampled_func : Callable [[], list [dict ]],
202+ stat_file_path : DPPath | None = None ,
203+ compute_or_load_out_stat : bool = True ,
204+ ) -> None :
205+ """Compute or load the statistics parameters of the model,
206+ such as mean and standard deviation of descriptors or the energy bias of the fitting net.
207+
208+ Parameters
209+ ----------
210+ sampled_func
211+ The lazy sampled function to get data frames from different data systems.
212+ stat_file_path
213+ The path to the stat file.
214+ compute_or_load_out_stat : bool
215+ Whether to compute the output statistics.
216+ If False, it will only compute the input statistics
217+ (e.g. mean and standard deviation of descriptors).
218+ """
219+ if stat_file_path is not None and self .type_map is not None :
220+ stat_file_path /= " " .join (self .type_map )
221+
222+ wrapped_sampler = self ._make_wrapped_sampler (sampled_func )
223+ self .descriptor .compute_input_stats (wrapped_sampler , stat_file_path )
224+ self .fitting_net .compute_input_stats (
225+ wrapped_sampler , stat_file_path = stat_file_path
226+ )
227+ if compute_or_load_out_stat :
228+ self .compute_or_load_out_stat (wrapped_sampler , stat_file_path )
229+
180230 def change_type_map (
181231 self , type_map : list [str ], model_with_new_type_stat : Any | None = None
182232 ) -> None :
@@ -193,7 +243,31 @@ def change_type_map(
193243 if model_with_new_type_stat is not None
194244 else None ,
195245 )
196- self .fitting .change_type_map (type_map = type_map )
246+ self .fitting_net .change_type_map (type_map = type_map )
247+
248+ def compute_fitting_input_stat (
249+ self ,
250+ sample_merged : Callable [[], list [dict ]] | list [dict ],
251+ stat_file_path : DPPath | None = None ,
252+ ) -> None :
253+ """Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
254+
255+ Parameters
256+ ----------
257+ sample_merged : Union[Callable[[], list[dict]], list[dict]]
258+ - list[dict]: A list of data samples from various data systems.
259+ Each element, ``merged[i]``, is a data dictionary containing
260+ ``keys``: ``np.ndarray`` originating from the ``i``-th data system.
261+ - Callable[[], list[dict]]: A lazy function that returns data samples
262+ in the above format only when needed.
263+ stat_file_path : Optional[DPPath]
264+ The path to the stat file.
265+ """
266+ self .fitting_net .compute_input_stats (
267+ sample_merged ,
268+ protection = self .data_stat_protect ,
269+ stat_file_path = stat_file_path ,
270+ )
197271
198272 def serialize (self ) -> dict :
199273 dd = super ().serialize ()
@@ -204,7 +278,7 @@ def serialize(self) -> dict:
204278 "@version" : 2 ,
205279 "type_map" : self .type_map ,
206280 "descriptor" : self .descriptor .serialize (),
207- "fitting" : self .fitting .serialize (),
281+ "fitting" : self .fitting_net .serialize (),
208282 }
209283 )
210284 return dd
@@ -230,19 +304,19 @@ def deserialize(cls, data: dict[str, Any]) -> "DPAtomicModel":
230304
231305 def get_dim_fparam (self ) -> int :
232306 """Get the number (dimension) of frame parameters of this atomic model."""
233- return self .fitting .get_dim_fparam ()
307+ return self .fitting_net .get_dim_fparam ()
234308
235309 def get_dim_aparam (self ) -> int :
236310 """Get the number (dimension) of atomic parameters of this atomic model."""
237- return self .fitting .get_dim_aparam ()
311+ return self .fitting_net .get_dim_aparam ()
238312
239313 def has_default_fparam (self ) -> bool :
240314 """Check if the model has default frame parameters."""
241- return self .fitting .has_default_fparam ()
315+ return self .fitting_net .has_default_fparam ()
242316
243317 def get_default_fparam (self ) -> list [float ] | None :
244318 """Get the default frame parameters."""
245- return self .fitting .get_default_fparam ()
319+ return self .fitting_net .get_default_fparam ()
246320
247321 def get_sel_type (self ) -> list [int ]:
248322 """Get the selected atom types of this model.
@@ -251,7 +325,7 @@ def get_sel_type(self) -> list[int]:
251325 to the result of the model.
252326 If returning an empty list, all atom types are selected.
253327 """
254- return self .fitting .get_sel_type ()
328+ return self .fitting_net .get_sel_type ()
255329
256330 def is_aparam_nall (self ) -> bool :
257331 """Check whether the shape of atomic parameters is (nframes, nall, ndim).
0 commit comments