Skip to content

Conversation

@Chengqian-Zhang
Copy link
Collaborator

@Chengqian-Zhang Chengqian-Zhang commented Nov 8, 2025

In this PR:

  1. Support writing fitting stat to stat_file and loading fitting stat from stat_file
  2. Ensure the fitting stat calculate is correct when using default_fparam
  3. Support sharing fitting stat when using share_fitting in multitask mode.
  4. Print the process of calculating fitting stat to the board via log.info.

Summary by CodeRabbit

  • New Features

    • Default frame parameters auto-fill missing samples and are exposed via a new accessor.
    • Compute/load/save per-parameter statistics to/from optional stat files.
    • Multitask training adds probability-weighted parameter sharing and propagates default fparam into data requirements.
    • Statistic items support scalar scaling for aggregation.
  • Refactor

    • Parameter-sharing and statistic-propagation flows reorganized for consistent buffering and persistence.
  • Tests

    • New tests and test data for stats computation, persistence, and multitask sharing.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions github-actions bot added the Python label Nov 8, 2025
@Chengqian-Zhang Chengqian-Zhang marked this pull request as draft November 8, 2025 10:16
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 8, 2025

📝 Walkthrough

Walkthrough

Expose and propagate default frame-parameter (fparam) through model APIs; thread an optional stat_file_path into fitting-statistics computation and I/O; add per-link model probability and protection to parameter-sharing; add StatItem scalar scaling; and add tests and test data for fitting-stat workflows.

Changes

Cohort / File(s) Summary
DP atomic & model API
deepmd/pt/model/atomic_model/dp_atomic_model.py, deepmd/pd/model/atomic_model/dp_atomic_model.py, deepmd/pt/model/model/make_model.py
Add get_default_fparam(); DPAtomicModel now fills missing sample fparam from default and threads optional stat_file_path into fitting-stat calls (compute_fitting_input_stat signature updated and propagated).
Fitting / statistics core (PT & PD)
deepmd/pt/model/task/fitting.py, deepmd/pd/model/task/fitting.py
Added stat persistence (save/restore per-type stat files), NumPy-based aggregation, compute_input_stats(..., stat_file_path), get_stats(), get_default_fparam(), and changed share_params to accept model_prob/protection. PD variant accepts stat_file_path in signature/docstring.
Training orchestration
deepmd/pt/train/training.py
Compute/normalize per-model probabilities for multitask, derive and validate data_stat_protect, pass model_key_prob_map and data_stat_protect into sharing, and propagate default fparam into DataRequirementItem defaults.
Wrapper parameter sharing
deepmd/pt/train/wrapper.py
share_params() signature extended to accept model_key_prob_map and data_stat_protect; computes per-link frac_prob and forwards model_prob (frac) and protection to underlying share calls.
Stat utilities
deepmd/utils/env_mat_stat.py
StatItem now accepts float number and supports scalar multiplication via __mul__.
Tests & test data
source/tests/pt/model/water/data/..., source/tests/pt/test_fitting_stat.py
Add raw test data files and comprehensive tests covering stat computation, file I/O (save/restore), multi-task weighting, protection behavior, and default fparam handling.

Sequence Diagram(s)

%% Accessible colors used sparingly via notes
sequenceDiagram
    participant Trainer
    participant Wrapper
    participant Model as DPAtomicModel
    participant Fitting

    Trainer->>Trainer: build model_key_prob_map & data_stat_protect
    Trainer->>Wrapper: share_params(shared_links, model_key_prob_map, data_stat_protect)
    activate Wrapper
    Wrapper->>Wrapper: for each link compute frac_prob = prob_link/prob_base
    Wrapper->>Fitting: share_params(base, level, model_prob=frac_prob, protection=data_stat_protect, resume)
    deactivate Wrapper

    Trainer->>Model: request data requirements (get_default_fparam)
    Model->>Fitting: compute_input_stats(merged, protection, stat_file_path)
    activate Fitting
    alt stat files exist
        Fitting->>Fitting: restore_fparam_from_file / restore_aparam_from_file
    else
        Fitting->>Fitting: aggregate stats (NumPy), apply protection
        Fitting->>Fitting: save_to_file_fparam / save_to_file_aparam
    end
    Fitting-->>Model: return stats and default_fparam (if any)
    Model-->>Trainer: include default fparam in DataRequirementItem
    deactivate Fitting
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Areas needing extra attention:

  • deepmd/pt/model/task/fitting.py — stat aggregation (NumPy vs torch), file I/O format, StatItem semantics, and share_params linkage logic.
  • deepmd/pt/train/wrapper.py & deepmd/pt/train/training.py — correctness of model probability computation/normalization and propagation to share_params.
  • deepmd/pt/model/atomic_model/dp_atomic_model.py — sampler injection of default fparam and interaction with DataRequirementItem defaults.
  • Tests & test data — validate assumptions, formats, and teardown behavior.

Possibly related PRs

Suggested reviewers

  • njzjz
  • iProzd
  • wanghan-iapcm

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 55.56% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title addresses the main objectives of the PR: fixing fitting statistics calculation when using default fparam and share fitting. However, it contains a typo ('Caculate' instead of 'Calculate') which affects clarity.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dbe92ab and 9ba8496.

📒 Files selected for processing (1)
  • deepmd/pd/model/task/fitting.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pd/model/task/fitting.py (1)
deepmd/utils/path.py (1)
  • DPPath (28-158)
🪛 Ruff (0.14.5)
deepmd/pd/model/task/fitting.py

82-82: Unused method argument: stat_file_path

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (1, 3.12)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
deepmd/pt/model/model/make_model.py (1)

9-9: Remove unused numpy import.

The numpy import is not used anywhere in this file.

Apply this diff:

-import numpy as np
deepmd/pt/train/training.py (1)

636-642: Fix unnecessary f-string prefix.

The assertion message on line 637 uses an f-string without any placeholders.

Apply this diff:

-            assert np.allclose(_data_stat_protect, _data_stat_protect[0]), f"Model key 'data_stat_protect' must be the same in each branch when multitask!"
+            assert np.allclose(_data_stat_protect, _data_stat_protect[0]), "Model key 'data_stat_protect' must be the same in each branch when multitask!"

The logic correctly validates consistency and propagates the protection value to parameter sharing.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 25fa707 and 4c3072e.

📒 Files selected for processing (6)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (3 hunks)
  • deepmd/pt/model/model/make_model.py (2 hunks)
  • deepmd/pt/model/task/fitting.py (6 hunks)
  • deepmd/pt/train/training.py (2 hunks)
  • deepmd/pt/train/wrapper.py (2 hunks)
  • deepmd/utils/env_mat_stat.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

📄 CodeRabbit inference engine (AGENTS.md)

Always run ruff check . and ruff format . before committing changes to Python code

Files:

  • deepmd/pt/train/wrapper.py
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
  • deepmd/utils/env_mat_stat.py
  • deepmd/pt/train/training.py
  • deepmd/pt/model/task/fitting.py
  • deepmd/pt/model/model/make_model.py
🧬 Code graph analysis (5)
deepmd/pt/train/wrapper.py (1)
deepmd/pt/model/task/fitting.py (1)
  • share_params (66-128)
deepmd/pt/model/atomic_model/dp_atomic_model.py (4)
deepmd/pt/model/model/make_model.py (2)
  • has_default_fparam (530-532)
  • get_default_fparam (535-536)
deepmd/pt/model/task/fitting.py (3)
  • has_default_fparam (599-601)
  • get_default_fparam (603-604)
  • compute_input_stats (208-269)
deepmd/pd/model/atomic_model/dp_atomic_model.py (2)
  • has_default_fparam (414-416)
  • wrapped_sampler (387-397)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
  • has_default_fparam (138-140)
deepmd/pt/train/training.py (4)
deepmd/pt/model/task/fitting.py (4)
  • share_params (66-128)
  • get_default_fparam (603-604)
  • has_default_fparam (599-601)
  • get_dim_fparam (595-597)
deepmd/pt/train/wrapper.py (1)
  • share_params (63-139)
deepmd/pt/model/atomic_model/dp_atomic_model.py (3)
  • get_default_fparam (355-356)
  • has_default_fparam (351-353)
  • get_dim_fparam (347-349)
deepmd/utils/data.py (1)
  • DataRequirementItem (745-825)
deepmd/pt/model/task/fitting.py (5)
deepmd/utils/path.py (13)
  • DPPath (28-158)
  • mkdir (149-158)
  • mkdir (270-282)
  • mkdir (472-490)
  • save_numpy (70-77)
  • save_numpy (200-211)
  • save_numpy (358-370)
  • load_numpy (50-57)
  • load_numpy (180-188)
  • load_numpy (335-343)
  • is_dir (115-116)
  • is_dir (249-251)
  • is_dir (439-445)
deepmd/utils/env_mat_stat.py (3)
  • StatItem (26-98)
  • compute_avg (58-73)
  • compute_std (75-98)
deepmd/pt/utils/utils.py (6)
  • to_numpy_array (224-224)
  • to_numpy_array (228-228)
  • to_numpy_array (231-247)
  • to_torch_tensor (251-251)
  • to_torch_tensor (255-255)
  • to_torch_tensor (258-276)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • get_default_fparam (355-356)
deepmd/pt/model/model/make_model.py (1)
  • get_default_fparam (535-536)
deepmd/pt/model/model/make_model.py (3)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • get_default_fparam (355-356)
deepmd/pt/model/task/fitting.py (1)
  • get_default_fparam (603-604)
deepmd/pt/model/network/network.py (1)
  • Tensor (36-37)
🪛 Ruff (0.14.3)
deepmd/pt/train/training.py

637-637: f-string without any placeholders

Remove extraneous f prefix

(F541)

deepmd/pt/model/task/fitting.py

269-270: Expected an indented block after if statement

(invalid-syntax)


272-272: unindent does not match any outer indentation level

(invalid-syntax)


272-272: Expected a statement

(invalid-syntax)


272-272: Expected a statement

(invalid-syntax)


272-273: Expected a statement

(invalid-syntax)


273-273: Unexpected indentation

(invalid-syntax)


297-297: unindent does not match any outer indentation level

(invalid-syntax)


298-298: Unexpected indentation

(invalid-syntax)


304-304: unindent does not match any outer indentation level

(invalid-syntax)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test C++ (true)
🔇 Additional comments (13)
deepmd/utils/env_mat_stat.py (1)

51-56: LGTM!

The scalar multiplication operator correctly scales all statistical components for probability-weighted aggregation in multitask training. The implementation properly supports the weighted averaging workflow where statistics from multiple models are combined using probability weights.

deepmd/pt/model/model/make_model.py (1)

534-536: LGTM!

The method correctly delegates to the atomic model and follows the established pattern for other similar accessors in this class.

deepmd/pt/train/wrapper.py (1)

63-63: LGTM!

The extended signature correctly supports probability-weighted parameter sharing for multitask training. The parameters align with the updated share_params implementation in the fitting net.

deepmd/pt/model/atomic_model/dp_atomic_model.py (2)

329-337: LGTM!

The logic correctly populates missing fparam with default values when available. The check for both "find_fparam" and "fparam" ensures proper handling of data loading states.


342-342: LGTM!

The stat_file_path propagation enables proper persistence of fparam/aparam statistics, and the get_default_fparam method correctly delegates to the fitting net.

Also applies to: 355-356

deepmd/pt/train/training.py (2)

619-632: LGTM!

The model probability calculation correctly supports both explicit configuration and data-driven defaults, with proper normalization and validation to ensure a valid probability distribution.


1344-1351: LGTM!

The default fparam handling correctly retrieves and converts the default value from the model, passing it to the data requirement with proper type conversion.

deepmd/pt/model/task/fitting.py (6)

66-128: LGTM!

The extended share_params correctly implements probability-weighted parameter sharing for multitask training. The logic properly accumulates weighted statistics for fparam/aparam buffers and links them to the base class.


130-206: LGTM!

The persistence methods correctly save and restore fparam/aparam statistics using numpy arrays, with proper path handling and logging.


208-266: LGTM!

The fparam statistics computation correctly implements the load-or-compute pattern with proper persistence and type conversions.


304-310: LGTM!

The get_stats method properly validates that statistics have been computed before returning them.


603-604: LGTM!

The method correctly exposes the default fparam tensor and aligns with the existing has_default_fparam accessor.


11-11: LGTM!

The new imports are properly used throughout the file for type hints and statistics handling.

Also applies to: 45-50

@codecov
Copy link

codecov bot commented Nov 8, 2025

Codecov Report

❌ Patch coverage is 89.03226% with 17 lines in your changes missing coverage. Please review.
✅ Project coverage is 84.26%. Comparing base (e98dc5a) to head (9ba8496).
⚠️ Report is 1 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/pt/model/task/fitting.py 87.80% 15 Missing ⚠️
deepmd/pt/train/training.py 86.66% 2 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##            devel    #5038    +/-   ##
========================================
  Coverage   84.26%   84.26%            
========================================
  Files         709      709            
  Lines       70279    70391   +112     
  Branches     3620     3619     -1     
========================================
+ Hits        59220    59317    +97     
- Misses       9892     9905    +13     
- Partials     1167     1169     +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@Chengqian-Zhang Chengqian-Zhang marked this pull request as ready for review November 9, 2025 12:20
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
deepmd/utils/env_mat_stat.py (1)

53-58: Consider adding __rmul__ for symmetric scalar multiplication.

Right now StatItem * scalar works but scalar * StatItem will not. If you expect stats to be scaled inside generic numeric code (e.g., with map/sum or broadcasting), adding __rmul__ improves ergonomics without changing behavior.

You could implement it as:

 class StatItem:
@@
     def __mul__(self, scalar: float) -> "StatItem":
         return StatItem(
             number=self.number * scalar,
             sum=self.sum * scalar,
             squared_sum=self.squared_sum * scalar,
         )
+
+    def __rmul__(self, scalar: float) -> "StatItem":
+        return self * scalar
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6f18144 and d6120a0.

📒 Files selected for processing (1)
  • deepmd/utils/env_mat_stat.py (2 hunks)
🔇 Additional comments (1)
deepmd/utils/env_mat_stat.py (1)

31-42: Switching number to float aligns with scalar scaling and persistence.

Updating number to float in both the docstring and constructor is consistent with __mul__ and avoids the previous int/float type mismatch when scaling stats (e.g., by probabilities or weights). This also matches how values are saved/loaded via NumPy arrays where everything is stored as floats.

@Chengqian-Zhang
Copy link
Collaborator Author

Who please can help me rerun the UT?
I do not change code related to C library.

@iProzd iProzd enabled auto-merge November 20, 2025 11:13
auto-merge was automatically disabled November 23, 2025 09:01

Head branch was pushed to by a user without write access

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d6120a0 and 07483a7.

📒 Files selected for processing (2)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (2 hunks)
  • deepmd/pt/model/model/make_model.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
deepmd/pt/model/model/make_model.py (2)
  • has_default_fparam (531-533)
  • get_default_fparam (535-536)
deepmd/pt/model/task/fitting.py (2)
  • has_default_fparam (639-641)
  • get_default_fparam (643-644)
deepmd/pt/model/model/make_model.py (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • get_default_fparam (373-374)
deepmd/pt/model/task/fitting.py (1)
  • get_default_fparam (643-644)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (2)
deepmd/pt/model/model/make_model.py (1)

535-536: LGTM! Consistent delegation pattern.

The new get_default_fparam() method correctly delegates to the atomic model, following the same pattern as existing methods like has_default_fparam() (line 531-533). This provides clean access to default frame parameters through the model hierarchy.

deepmd/pt/model/atomic_model/dp_atomic_model.py (1)

373-374: LGTM! Consistent delegation pattern.

The new get_default_fparam() method correctly delegates to the fitting net, following the same pattern as has_default_fparam() (line 369-371). This provides the necessary accessor for default frame parameters.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)

318-338: Handle empty/heterogeneous sampled lists when injecting default_fparam

The idea of auto‑filling fparam from default_fparam is good and matches the PR goals, but the current implementation has two issues:

  • sampled[0] will raise IndexError if sampled_func() ever returns an empty list.
  • You only inspect sampled[0] for "find_fparam" / "fparam" keys, which breaks if different samples have different keys.

You can make this robust by checking per sample and avoiding sampled[0] entirely:

         @functools.lru_cache
         def wrapped_sampler() -> list[dict]:
             sampled = sampled_func()
@@
-            if (
-                "find_fparam" not in sampled[0]
-                and "fparam" not in sampled[0]
-                and self.has_default_fparam()
-            ):
-                default_fparam = self.get_default_fparam()
-                for sample in sampled:
-                    nframe = sample["atype"].shape[0]
-                    sample["fparam"] = default_fparam.repeat(nframe, 1)
+            if self.has_default_fparam():
+                default_fparam = self.get_default_fparam()
+                for sample in sampled:
+                    if "find_fparam" not in sample and "fparam" not in sample:
+                        nframe = sample["atype"].shape[0]
+                        sample["fparam"] = default_fparam.repeat(nframe, 1)
             return sampled

This handles empty sampled, heterogeneous keys, and keeps the intended behavior.

🧹 Nitpick comments (2)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)

399-427: Stat file path propagation into fitting stats looks correct; minor doc and backend-consistency follow‑ups

Threading stat_file_path through compute_fitting_input_stat and into self.fitting_net.compute_input_stats(...) aligns PD DPAtomicModel with the dpmodel/PT implementations and should allow consistent saving/loading of fitting stats alongside descriptor stats. No functional issues seen here.

Two small follow‑ups you might want to consider:

  • The stat_file_path docstring still calls this “The dictionary of paths to the statistics files.”, but the type is Optional[DPPath]. Consider updating the wording to reflect that it is a path object (or path root) rather than a dict of paths, for clarity.
  • In the PT backend (deepmd/pt/model/atomic_model/dp_atomic_model.py), wrapped_sampler now auto‑fills sample["fparam"] from default_fparam when neither "find_fparam" nor "fparam" is present and has_default_fparam() is true. The PD backend’s wrapped_sampler currently does not do this. If PD fitting expects the same default‑fparam semantics for stats as PT, you may want to mirror that logic here for cross‑backend consistency; if Paddle fitting handles default fparams internally, then the current PD behavior is fine.
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)

340-368: Stat file path propagation into fitting stats is consistent; only docstring is slightly misleading

Passing stat_file_path into compute_fitting_input_stat(wrapped_sampler, stat_file_path) and forwarding it to self.fitting_net.compute_input_stats(..., stat_file_path=stat_file_path) wires the PT backend cleanly into the new stat‑file I/O workflow; this matches the dpmodel / PD patterns and looks correct.

The only nit is the docstring for stat_file_path, which still calls it “The dictionary of paths to the statistics files.” while the type is Optional[DPPath]. Consider rephrasing to reflect that it’s a path object (or directory root) instead of a dict.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 07483a7 and dbe92ab.

📒 Files selected for processing (2)
  • deepmd/pd/model/atomic_model/dp_atomic_model.py (2 hunks)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pt/model/atomic_model/dp_atomic_model.py (4)
deepmd/dpmodel/fitting/general_fitting.py (2)
  • has_default_fparam (303-305)
  • compute_input_stats (225-288)
deepmd/pd/model/atomic_model/base_atomic_model.py (2)
  • has_default_fparam (157-159)
  • compute_fitting_input_stat (518-534)
deepmd/pt/model/model/make_model.py (2)
  • has_default_fparam (531-533)
  • get_default_fparam (535-536)
deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)
  • has_default_fparam (238-240)
🔇 Additional comments (1)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)

374-380: get_default_fparam delegator is appropriate and aligns with higher‑level APIs

Adding get_default_fparam() as a thin delegator to self.fitting_net.get_default_fparam() is consistent with has_default_fparam() and with the CM wrapper exposing get_default_fparam. This should make it straightforward for callers (and wrapped_sampler) to access default frame parameters without poking the fitting net directly.

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Nov 24, 2025
Merged via the queue into deepmodeling:devel with commit 0a278e7 Nov 24, 2025
60 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants