Skip to content

Commit a9b1e7a

Browse files
authored
Merge branch 'master' into 0224-observered-type
Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com>
2 parents 349315c + a3db25a commit a9b1e7a

File tree

159 files changed

+54710
-8199
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

159 files changed

+54710
-8199
lines changed

deepmd/dpmodel/array_api.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array:
7878
xp = array_api_compat.array_namespace(input)
7979

8080
# Create flat index array matching input shape
81-
idx = xp.arange(input.size, dtype=xp.int64)
81+
idx = xp.arange(input.size, dtype=xp.int64, device=array_api_compat.device(input))
8282
idx = xp.reshape(idx, input.shape)
8383

8484
# Get flat indices where we want to add values
@@ -190,6 +190,10 @@ def xp_bincount(x: Array, weights: Array | None = None, minlength: int = 0) -> A
190190
else:
191191
if weights is None:
192192
weights = xp.ones_like(x)
193-
result = xp.zeros((max(minlength, int(xp.max(x)) + 1),), dtype=weights.dtype)
193+
result = xp.zeros(
194+
(max(minlength, int(xp.max(x)) + 1),),
195+
dtype=weights.dtype,
196+
device=array_api_compat.device(weights),
197+
)
194198
result = xp_add_at(result, x, weights)
195199
return result

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import functools
23
import math
34
from collections.abc import (
45
Callable,
@@ -52,13 +53,15 @@ def __init__(
5253
pair_exclude_types: list[tuple[int, int]] = [],
5354
rcond: float | None = None,
5455
preset_out_bias: dict[str, Array] | None = None,
56+
data_stat_protect: float = 1e-2,
5557
) -> None:
5658
super().__init__()
5759
self.type_map = type_map
5860
self.reinit_atom_exclude(atom_exclude_types)
5961
self.reinit_pair_exclude(pair_exclude_types)
6062
self.rcond = rcond
6163
self.preset_out_bias = preset_out_bias
64+
self.data_stat_protect = data_stat_protect
6265

6366
def init_out_stat(self) -> None:
6467
"""Initialize the output bias."""
@@ -77,6 +80,14 @@ def init_out_stat(self) -> None:
7780
self.out_bias = out_bias_data
7881
self.out_std = out_std_data
7982

83+
def get_out_bias(self) -> Array:
84+
"""Get the output bias."""
85+
return self.out_bias
86+
87+
def set_out_bias(self, out_bias: Array) -> None:
88+
"""Set the output bias."""
89+
self.out_bias = out_bias
90+
8091
def __setitem__(self, key: str, value: Array) -> None:
8192
if key in ["out_bias"]:
8293
self.out_bias = value
@@ -287,6 +298,57 @@ def compute_or_load_out_stat(
287298
bias_adjust_mode="set-by-statistic",
288299
)
289300

301+
def _make_wrapped_sampler(
302+
self,
303+
sampled_func: Callable[[], list[dict]],
304+
) -> Callable[[], list[dict]]:
305+
"""Wrap the sampled function with exclusion types and default fparam.
306+
307+
The returned callable is cached so that the sampling (which may be
308+
expensive) is performed at most once.
309+
310+
Parameters
311+
----------
312+
sampled_func
313+
The lazy sampled function to get data frames from different data
314+
systems.
315+
316+
Returns
317+
-------
318+
Callable[[], list[dict]]
319+
A cached wrapper around *sampled_func* that additionally sets
320+
``pair_exclude_types``, ``atom_exclude_types`` and default
321+
``fparam`` on every sample dict when applicable.
322+
"""
323+
324+
@functools.lru_cache
325+
def wrapped_sampler() -> list[dict]:
326+
sampled = sampled_func()
327+
if self.pair_excl is not None:
328+
pair_exclude_types = self.pair_excl.get_exclude_types()
329+
for sample in sampled:
330+
sample["pair_exclude_types"] = list(pair_exclude_types)
331+
if self.atom_excl is not None:
332+
atom_exclude_types = self.atom_excl.get_exclude_types()
333+
for sample in sampled:
334+
sample["atom_exclude_types"] = list(atom_exclude_types)
335+
if (
336+
"find_fparam" not in sampled[0]
337+
and "fparam" not in sampled[0]
338+
and self.has_default_fparam()
339+
):
340+
default_fparam = self.get_default_fparam()
341+
if default_fparam is not None:
342+
default_fparam_np = np.array(default_fparam)
343+
for sample in sampled:
344+
nframe = sample["atype"].shape[0]
345+
sample["fparam"] = np.tile(
346+
default_fparam_np.reshape(1, -1), (nframe, 1)
347+
)
348+
return sampled
349+
350+
return wrapped_sampler
351+
290352
def change_out_bias(
291353
self,
292354
sample_merged: Callable[[], list[dict]] | list[dict],

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from collections.abc import (
3+
Callable,
4+
)
25
from typing import (
36
Any,
47
)
@@ -15,6 +18,9 @@
1518
from deepmd.dpmodel.output_def import (
1619
FittingOutputDef,
1720
)
21+
from deepmd.utils.path import (
22+
DPPath,
23+
)
1824
from deepmd.utils.version import (
1925
check_version_compatibility,
2026
)
@@ -26,7 +32,21 @@
2632

2733
@BaseAtomicModel.register("standard")
2834
class DPAtomicModel(BaseAtomicModel):
29-
"""Model give atomic prediction of some physical property.
35+
r"""Model give atomic prediction of some physical property.
36+
37+
The atomic model computes atomic properties by first extracting a descriptor
38+
from the atomic environment, then passing it through a fitting network:
39+
40+
.. math::
41+
\mathcal{D}^i = \mathcal{D}(\mathbf{R}^i, \mathbf{R}_j, \alpha_j),
42+
43+
.. math::
44+
\mathbf{y}^i = \mathcal{F}(\mathcal{D}^i),
45+
46+
where :math:`\mathcal{D}^i` is the descriptor for atom :math:`i`,
47+
:math:`\alpha_j` is the atom type of neighbor :math:`j`,
48+
:math:`\mathcal{F}` is the fitting network, and
49+
:math:`\mathbf{y}^i` is the predicted atomic property (energy, dipole, etc.).
3050
3151
Parameters
3252
----------
@@ -48,17 +68,16 @@ def __init__(
4868
**kwargs: Any,
4969
) -> None:
5070
super().__init__(type_map, **kwargs)
51-
self.type_map = type_map
5271
self.descriptor = descriptor
53-
self.fitting = fitting
54-
if hasattr(self.fitting, "reinit_exclude"):
55-
self.fitting.reinit_exclude(self.atom_exclude_types)
72+
self.fitting_net = fitting
73+
if hasattr(self.fitting_net, "reinit_exclude"):
74+
self.fitting_net.reinit_exclude(self.atom_exclude_types)
5675
self.type_map = type_map
5776
super().init_out_stat()
5877

5978
def fitting_output_def(self) -> FittingOutputDef:
6079
"""Get the output def of the fitting net."""
61-
return self.fitting.output_def()
80+
return self.fitting_net.output_def()
6281

6382
def get_rcut(self) -> float:
6483
"""Get the cut-off radius."""
@@ -73,7 +92,7 @@ def set_case_embd(self, case_idx: int) -> None:
7392
Set the case embedding of this atomic model by the given case_idx,
7493
typically concatenated with the output of the descriptor and fed into the fitting net.
7594
"""
76-
self.fitting.set_case_embd(case_idx)
95+
self.fitting_net.set_case_embd(case_idx)
7796

7897
def mixed_types(self) -> bool:
7998
"""If true, the model
@@ -166,7 +185,7 @@ def forward_atomic(
166185
nlist,
167186
mapping=mapping,
168187
)
169-
ret = self.fitting(
188+
ret = self.fitting_net(
170189
descriptor,
171190
atype,
172191
gr=rot_mat,
@@ -177,6 +196,37 @@ def forward_atomic(
177196
)
178197
return ret
179198

199+
def compute_or_load_stat(
200+
self,
201+
sampled_func: Callable[[], list[dict]],
202+
stat_file_path: DPPath | None = None,
203+
compute_or_load_out_stat: bool = True,
204+
) -> None:
205+
"""Compute or load the statistics parameters of the model,
206+
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
207+
208+
Parameters
209+
----------
210+
sampled_func
211+
The lazy sampled function to get data frames from different data systems.
212+
stat_file_path
213+
The path to the stat file.
214+
compute_or_load_out_stat : bool
215+
Whether to compute the output statistics.
216+
If False, it will only compute the input statistics
217+
(e.g. mean and standard deviation of descriptors).
218+
"""
219+
if stat_file_path is not None and self.type_map is not None:
220+
stat_file_path /= " ".join(self.type_map)
221+
222+
wrapped_sampler = self._make_wrapped_sampler(sampled_func)
223+
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
224+
self.fitting_net.compute_input_stats(
225+
wrapped_sampler, stat_file_path=stat_file_path
226+
)
227+
if compute_or_load_out_stat:
228+
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
229+
180230
def change_type_map(
181231
self, type_map: list[str], model_with_new_type_stat: Any | None = None
182232
) -> None:
@@ -193,7 +243,31 @@ def change_type_map(
193243
if model_with_new_type_stat is not None
194244
else None,
195245
)
196-
self.fitting.change_type_map(type_map=type_map)
246+
self.fitting_net.change_type_map(type_map=type_map)
247+
248+
def compute_fitting_input_stat(
249+
self,
250+
sample_merged: Callable[[], list[dict]] | list[dict],
251+
stat_file_path: DPPath | None = None,
252+
) -> None:
253+
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
254+
255+
Parameters
256+
----------
257+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
258+
- list[dict]: A list of data samples from various data systems.
259+
Each element, ``merged[i]``, is a data dictionary containing
260+
``keys``: ``np.ndarray`` originating from the ``i``-th data system.
261+
- Callable[[], list[dict]]: A lazy function that returns data samples
262+
in the above format only when needed.
263+
stat_file_path : Optional[DPPath]
264+
The path to the stat file.
265+
"""
266+
self.fitting_net.compute_input_stats(
267+
sample_merged,
268+
protection=self.data_stat_protect,
269+
stat_file_path=stat_file_path,
270+
)
197271

198272
def serialize(self) -> dict:
199273
dd = super().serialize()
@@ -204,7 +278,7 @@ def serialize(self) -> dict:
204278
"@version": 2,
205279
"type_map": self.type_map,
206280
"descriptor": self.descriptor.serialize(),
207-
"fitting": self.fitting.serialize(),
281+
"fitting": self.fitting_net.serialize(),
208282
}
209283
)
210284
return dd
@@ -230,19 +304,19 @@ def deserialize(cls, data: dict[str, Any]) -> "DPAtomicModel":
230304

231305
def get_dim_fparam(self) -> int:
232306
"""Get the number (dimension) of frame parameters of this atomic model."""
233-
return self.fitting.get_dim_fparam()
307+
return self.fitting_net.get_dim_fparam()
234308

235309
def get_dim_aparam(self) -> int:
236310
"""Get the number (dimension) of atomic parameters of this atomic model."""
237-
return self.fitting.get_dim_aparam()
311+
return self.fitting_net.get_dim_aparam()
238312

239313
def has_default_fparam(self) -> bool:
240314
"""Check if the model has default frame parameters."""
241-
return self.fitting.has_default_fparam()
315+
return self.fitting_net.has_default_fparam()
242316

243317
def get_default_fparam(self) -> list[float] | None:
244318
"""Get the default frame parameters."""
245-
return self.fitting.get_default_fparam()
319+
return self.fitting_net.get_default_fparam()
246320

247321
def get_sel_type(self) -> list[int]:
248322
"""Get the selected atom types of this model.
@@ -251,7 +325,7 @@ def get_sel_type(self) -> list[int]:
251325
to the result of the model.
252326
If returning an empty list, all atom types are selected.
253327
"""
254-
return self.fitting.get_sel_type()
328+
return self.fitting_net.get_sel_type()
255329

256330
def is_aparam_nall(self) -> bool:
257331
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).

0 commit comments

Comments
 (0)