Skip to content

Commit 0a278e7

Browse files
fix(stat): Caculate correct fitting stat when using default fparam and using share fitting. (#5038)
In this PR: 1. Support writing fitting stat to `stat_file` and loading fitting stat from `stat_file` 2. Ensure the fitting stat calculate is correct when using `default_fparam` 3. Support sharing fitting stat when using `share_fitting` in multitask mode. 4. Print the process of calculating fitting stat to the board via `log.info`. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Default frame parameters auto-fill missing samples and are exposed via a new accessor. * Compute/load/save per-parameter statistics to/from optional stat files. * Multitask training adds probability-weighted parameter sharing and propagates default fparam into data requirements. * Statistic items support scalar scaling for aggregation. * **Refactor** * Parameter-sharing and statistic-propagation flows reorganized for consistent buffering and persistence. * **Tests** * New tests and test data for stats computation, persistence, and multitask sharing. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chenqqian Zhang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e98dc5a commit 0a278e7

File tree

15 files changed

+935
-82
lines changed

15 files changed

+935
-82
lines changed

deepmd/pd/model/atomic_model/dp_atomic_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,13 +397,14 @@ def wrapped_sampler():
397397
return sampled
398398

399399
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
400-
self.compute_fitting_input_stat(wrapped_sampler)
400+
self.compute_fitting_input_stat(wrapped_sampler, stat_file_path)
401401
if compute_or_load_out_stat:
402402
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
403403

404404
def compute_fitting_input_stat(
405405
self,
406406
sample_merged: Union[Callable[[], list[dict]], list[dict]],
407+
stat_file_path: Optional[DPPath] = None,
407408
) -> None:
408409
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
409410
@@ -416,9 +417,13 @@ def compute_fitting_input_stat(
416417
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
417418
only when needed. Since the sampling process can be slow and memory-intensive,
418419
the lazy function helps by only sampling once.
420+
stat_file_path : Optional[DPPath]
421+
The dictionary of paths to the statistics files.
419422
"""
420423
self.fitting_net.compute_input_stats(
421-
sample_merged, protection=self.data_stat_protect
424+
sample_merged,
425+
protection=self.data_stat_protect,
426+
stat_file_path=stat_file_path,
422427
)
423428

424429
def get_dim_fparam(self) -> int:

deepmd/pd/model/task/fitting.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
get_index_between_two_maps,
4141
map_atom_exclude_types,
4242
)
43+
from deepmd.utils.path import (
44+
DPPath,
45+
)
4346

4447
dtype = env.GLOBAL_PD_FLOAT_PRECISION
4548
device = env.DEVICE
@@ -76,6 +79,7 @@ def compute_input_stats(
7679
self,
7780
merged: Union[Callable[[], list[dict]], list[dict]],
7881
protection: float = 1e-2,
82+
stat_file_path: Optional[DPPath] = None,
7983
) -> None:
8084
"""
8185
Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
@@ -91,6 +95,8 @@ def compute_input_stats(
9195
the lazy function helps by only sampling once.
9296
protection : float
9397
Divided-by-zero protection
98+
stat_file_path : Optional[DPPath]
99+
The path to the stat file.
94100
"""
95101
if self.numb_fparam == 0 and self.numb_aparam == 0:
96102
# skip data statistics

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,16 +326,26 @@ def wrapped_sampler() -> list[dict]:
326326
atom_exclude_types = self.atom_excl.get_exclude_types()
327327
for sample in sampled:
328328
sample["atom_exclude_types"] = list(atom_exclude_types)
329+
if (
330+
"find_fparam" not in sampled[0]
331+
and "fparam" not in sampled[0]
332+
and self.has_default_fparam()
333+
):
334+
default_fparam = self.get_default_fparam()
335+
for sample in sampled:
336+
nframe = sample["atype"].shape[0]
337+
sample["fparam"] = default_fparam.repeat(nframe, 1)
329338
return sampled
330339

331340
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
332-
self.compute_fitting_input_stat(wrapped_sampler)
341+
self.compute_fitting_input_stat(wrapped_sampler, stat_file_path)
333342
if compute_or_load_out_stat:
334343
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
335344

336345
def compute_fitting_input_stat(
337346
self,
338347
sample_merged: Union[Callable[[], list[dict]], list[dict]],
348+
stat_file_path: Optional[DPPath] = None,
339349
) -> None:
340350
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
341351
@@ -348,9 +358,13 @@ def compute_fitting_input_stat(
348358
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
349359
only when needed. Since the sampling process can be slow and memory-intensive,
350360
the lazy function helps by only sampling once.
361+
stat_file_path : Optional[DPPath]
362+
The dictionary of paths to the statistics files.
351363
"""
352364
self.fitting_net.compute_input_stats(
353-
sample_merged, protection=self.data_stat_protect
365+
sample_merged,
366+
protection=self.data_stat_protect,
367+
stat_file_path=stat_file_path,
354368
)
355369

356370
def get_dim_fparam(self) -> int:
@@ -361,6 +375,9 @@ def has_default_fparam(self) -> bool:
361375
"""Check if the model has default frame parameters."""
362376
return self.fitting_net.has_default_fparam()
363377

378+
def get_default_fparam(self) -> Optional[torch.Tensor]:
379+
return self.fitting_net.get_default_fparam()
380+
364381
def get_dim_aparam(self) -> int:
365382
"""Get the number (dimension) of atomic parameters of this atomic model."""
366383
return self.fitting_net.get_dim_aparam()

deepmd/pt/model/model/make_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,9 @@ def has_default_fparam(self) -> bool:
532532
"""Check if the model has default frame parameters."""
533533
return self.atomic_model.has_default_fparam()
534534

535+
def get_default_fparam(self) -> Optional[torch.Tensor]:
536+
return self.atomic_model.get_default_fparam()
537+
535538
@torch.jit.export
536539
def get_dim_aparam(self) -> int:
537540
"""Get the number (dimension) of atomic parameters of this atomic model."""

0 commit comments

Comments
 (0)