-
Notifications
You must be signed in to change notification settings - Fork 577
fix(stat): Caculate correct fitting stat when using default fparam and using share fitting. #5038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(stat): Caculate correct fitting stat when using default fparam and using share fitting. #5038
Conversation
for more information, see https://pre-commit.ci
…an-Zhang/deepmd-kit into 1108_default_fparam_stat
📝 WalkthroughWalkthroughExpose and propagate default frame-parameter (fparam) through model APIs; thread an optional stat_file_path into fitting-statistics computation and I/O; add per-link model probability and protection to parameter-sharing; add StatItem scalar scaling; and add tests and test data for fitting-stat workflows. Changes
Sequence Diagram(s)%% Accessible colors used sparingly via notes
sequenceDiagram
participant Trainer
participant Wrapper
participant Model as DPAtomicModel
participant Fitting
Trainer->>Trainer: build model_key_prob_map & data_stat_protect
Trainer->>Wrapper: share_params(shared_links, model_key_prob_map, data_stat_protect)
activate Wrapper
Wrapper->>Wrapper: for each link compute frac_prob = prob_link/prob_base
Wrapper->>Fitting: share_params(base, level, model_prob=frac_prob, protection=data_stat_protect, resume)
deactivate Wrapper
Trainer->>Model: request data requirements (get_default_fparam)
Model->>Fitting: compute_input_stats(merged, protection, stat_file_path)
activate Fitting
alt stat files exist
Fitting->>Fitting: restore_fparam_from_file / restore_aparam_from_file
else
Fitting->>Fitting: aggregate stats (NumPy), apply protection
Fitting->>Fitting: save_to_file_fparam / save_to_file_aparam
end
Fitting-->>Model: return stats and default_fparam (if any)
Model-->>Trainer: include default fparam in DataRequirementItem
deactivate Fitting
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Areas needing extra attention:
Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧬 Code graph analysis (1)deepmd/pd/model/task/fitting.py (1)
🪛 Ruff (0.14.5)deepmd/pd/model/task/fitting.py82-82: Unused method argument: (ARG002) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
deepmd/pt/model/model/make_model.py (1)
9-9: Remove unused numpy import.The numpy import is not used anywhere in this file.
Apply this diff:
-import numpy as npdeepmd/pt/train/training.py (1)
636-642: Fix unnecessary f-string prefix.The assertion message on line 637 uses an f-string without any placeholders.
Apply this diff:
- assert np.allclose(_data_stat_protect, _data_stat_protect[0]), f"Model key 'data_stat_protect' must be the same in each branch when multitask!" + assert np.allclose(_data_stat_protect, _data_stat_protect[0]), "Model key 'data_stat_protect' must be the same in each branch when multitask!"The logic correctly validates consistency and propagates the protection value to parameter sharing.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
deepmd/pt/model/atomic_model/dp_atomic_model.py(3 hunks)deepmd/pt/model/model/make_model.py(2 hunks)deepmd/pt/model/task/fitting.py(6 hunks)deepmd/pt/train/training.py(2 hunks)deepmd/pt/train/wrapper.py(2 hunks)deepmd/utils/env_mat_stat.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
deepmd/pt/train/wrapper.pydeepmd/pt/model/atomic_model/dp_atomic_model.pydeepmd/utils/env_mat_stat.pydeepmd/pt/train/training.pydeepmd/pt/model/task/fitting.pydeepmd/pt/model/model/make_model.py
🧬 Code graph analysis (5)
deepmd/pt/train/wrapper.py (1)
deepmd/pt/model/task/fitting.py (1)
share_params(66-128)
deepmd/pt/model/atomic_model/dp_atomic_model.py (4)
deepmd/pt/model/model/make_model.py (2)
has_default_fparam(530-532)get_default_fparam(535-536)deepmd/pt/model/task/fitting.py (3)
has_default_fparam(599-601)get_default_fparam(603-604)compute_input_stats(208-269)deepmd/pd/model/atomic_model/dp_atomic_model.py (2)
has_default_fparam(414-416)wrapped_sampler(387-397)deepmd/pt/model/atomic_model/base_atomic_model.py (1)
has_default_fparam(138-140)
deepmd/pt/train/training.py (4)
deepmd/pt/model/task/fitting.py (4)
share_params(66-128)get_default_fparam(603-604)has_default_fparam(599-601)get_dim_fparam(595-597)deepmd/pt/train/wrapper.py (1)
share_params(63-139)deepmd/pt/model/atomic_model/dp_atomic_model.py (3)
get_default_fparam(355-356)has_default_fparam(351-353)get_dim_fparam(347-349)deepmd/utils/data.py (1)
DataRequirementItem(745-825)
deepmd/pt/model/task/fitting.py (5)
deepmd/utils/path.py (13)
DPPath(28-158)mkdir(149-158)mkdir(270-282)mkdir(472-490)save_numpy(70-77)save_numpy(200-211)save_numpy(358-370)load_numpy(50-57)load_numpy(180-188)load_numpy(335-343)is_dir(115-116)is_dir(249-251)is_dir(439-445)deepmd/utils/env_mat_stat.py (3)
StatItem(26-98)compute_avg(58-73)compute_std(75-98)deepmd/pt/utils/utils.py (6)
to_numpy_array(224-224)to_numpy_array(228-228)to_numpy_array(231-247)to_torch_tensor(251-251)to_torch_tensor(255-255)to_torch_tensor(258-276)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
get_default_fparam(355-356)deepmd/pt/model/model/make_model.py (1)
get_default_fparam(535-536)
deepmd/pt/model/model/make_model.py (3)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
get_default_fparam(355-356)deepmd/pt/model/task/fitting.py (1)
get_default_fparam(603-604)deepmd/pt/model/network/network.py (1)
Tensor(36-37)
🪛 Ruff (0.14.3)
deepmd/pt/train/training.py
637-637: f-string without any placeholders
Remove extraneous f prefix
(F541)
deepmd/pt/model/task/fitting.py
269-270: Expected an indented block after if statement
(invalid-syntax)
272-272: unindent does not match any outer indentation level
(invalid-syntax)
272-272: Expected a statement
(invalid-syntax)
272-272: Expected a statement
(invalid-syntax)
272-273: Expected a statement
(invalid-syntax)
273-273: Unexpected indentation
(invalid-syntax)
297-297: unindent does not match any outer indentation level
(invalid-syntax)
298-298: Unexpected indentation
(invalid-syntax)
304-304: unindent does not match any outer indentation level
(invalid-syntax)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (python)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
🔇 Additional comments (13)
deepmd/utils/env_mat_stat.py (1)
51-56: LGTM!The scalar multiplication operator correctly scales all statistical components for probability-weighted aggregation in multitask training. The implementation properly supports the weighted averaging workflow where statistics from multiple models are combined using probability weights.
deepmd/pt/model/model/make_model.py (1)
534-536: LGTM!The method correctly delegates to the atomic model and follows the established pattern for other similar accessors in this class.
deepmd/pt/train/wrapper.py (1)
63-63: LGTM!The extended signature correctly supports probability-weighted parameter sharing for multitask training. The parameters align with the updated
share_paramsimplementation in the fitting net.deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
329-337: LGTM!The logic correctly populates missing fparam with default values when available. The check for both
"find_fparam"and"fparam"ensures proper handling of data loading states.
342-342: LGTM!The stat_file_path propagation enables proper persistence of fparam/aparam statistics, and the
get_default_fparammethod correctly delegates to the fitting net.Also applies to: 355-356
deepmd/pt/train/training.py (2)
619-632: LGTM!The model probability calculation correctly supports both explicit configuration and data-driven defaults, with proper normalization and validation to ensure a valid probability distribution.
1344-1351: LGTM!The default fparam handling correctly retrieves and converts the default value from the model, passing it to the data requirement with proper type conversion.
deepmd/pt/model/task/fitting.py (6)
66-128: LGTM!The extended
share_paramscorrectly implements probability-weighted parameter sharing for multitask training. The logic properly accumulates weighted statistics for fparam/aparam buffers and links them to the base class.
130-206: LGTM!The persistence methods correctly save and restore fparam/aparam statistics using numpy arrays, with proper path handling and logging.
208-266: LGTM!The fparam statistics computation correctly implements the load-or-compute pattern with proper persistence and type conversions.
304-310: LGTM!The
get_statsmethod properly validates that statistics have been computed before returning them.
603-604: LGTM!The method correctly exposes the default fparam tensor and aligns with the existing
has_default_fparamaccessor.
11-11: LGTM!The new imports are properly used throughout the file for type hints and statistics handling.
Also applies to: 45-50
for more information, see https://pre-commit.ci
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #5038 +/- ##
========================================
Coverage 84.26% 84.26%
========================================
Files 709 709
Lines 70279 70391 +112
Branches 3620 3619 -1
========================================
+ Hits 59220 59317 +97
- Misses 9892 9905 +13
- Partials 1167 1169 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/utils/env_mat_stat.py (1)
53-58: Consider adding__rmul__for symmetric scalar multiplication.Right now
StatItem * scalarworks butscalar * StatItemwill not. If you expect stats to be scaled inside generic numeric code (e.g., withmap/sumor broadcasting), adding__rmul__improves ergonomics without changing behavior.You could implement it as:
class StatItem: @@ def __mul__(self, scalar: float) -> "StatItem": return StatItem( number=self.number * scalar, sum=self.sum * scalar, squared_sum=self.squared_sum * scalar, ) + + def __rmul__(self, scalar: float) -> "StatItem": + return self * scalar
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/utils/env_mat_stat.py(2 hunks)
🔇 Additional comments (1)
deepmd/utils/env_mat_stat.py (1)
31-42: Switchingnumberto float aligns with scalar scaling and persistence.Updating
numbertofloatin both the docstring and constructor is consistent with__mul__and avoids the previous int/float type mismatch when scaling stats (e.g., by probabilities or weights). This also matches how values are saved/loaded via NumPy arrays where everything is stored as floats.
|
Who please can help me rerun the UT? |
Signed-off-by: Chenqqian Zhang <[email protected]>
Head branch was pushed to by a user without write access
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py(2 hunks)deepmd/pt/model/model/make_model.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
deepmd/pt/model/model/make_model.py (2)
has_default_fparam(531-533)get_default_fparam(535-536)deepmd/pt/model/task/fitting.py (2)
has_default_fparam(639-641)get_default_fparam(643-644)
deepmd/pt/model/model/make_model.py (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
get_default_fparam(373-374)deepmd/pt/model/task/fitting.py (1)
get_default_fparam(643-644)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test C++ (false)
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (2)
deepmd/pt/model/model/make_model.py (1)
535-536: LGTM! Consistent delegation pattern.The new
get_default_fparam()method correctly delegates to the atomic model, following the same pattern as existing methods likehas_default_fparam()(line 531-533). This provides clean access to default frame parameters through the model hierarchy.deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
373-374: LGTM! Consistent delegation pattern.The new
get_default_fparam()method correctly delegates to the fitting net, following the same pattern ashas_default_fparam()(line 369-371). This provides the necessary accessor for default frame parameters.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
318-338: Handle empty/heterogeneous sampled lists when injecting default_fparamThe idea of auto‑filling
fparamfromdefault_fparamis good and matches the PR goals, but the current implementation has two issues:
sampled[0]will raiseIndexErrorifsampled_func()ever returns an empty list.- You only inspect
sampled[0]for"find_fparam"/"fparam"keys, which breaks if different samples have different keys.You can make this robust by checking per sample and avoiding
sampled[0]entirely:@functools.lru_cache def wrapped_sampler() -> list[dict]: sampled = sampled_func() @@ - if ( - "find_fparam" not in sampled[0] - and "fparam" not in sampled[0] - and self.has_default_fparam() - ): - default_fparam = self.get_default_fparam() - for sample in sampled: - nframe = sample["atype"].shape[0] - sample["fparam"] = default_fparam.repeat(nframe, 1) + if self.has_default_fparam(): + default_fparam = self.get_default_fparam() + for sample in sampled: + if "find_fparam" not in sample and "fparam" not in sample: + nframe = sample["atype"].shape[0] + sample["fparam"] = default_fparam.repeat(nframe, 1) return sampledThis handles empty
sampled, heterogeneous keys, and keeps the intended behavior.
🧹 Nitpick comments (2)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
399-427: Stat file path propagation into fitting stats looks correct; minor doc and backend-consistency follow‑upsThreading
stat_file_paththroughcompute_fitting_input_statand intoself.fitting_net.compute_input_stats(...)aligns PDDPAtomicModelwith the dpmodel/PT implementations and should allow consistent saving/loading of fitting stats alongside descriptor stats. No functional issues seen here.Two small follow‑ups you might want to consider:
- The
stat_file_pathdocstring still calls this “The dictionary of paths to the statistics files.”, but the type isOptional[DPPath]. Consider updating the wording to reflect that it is a path object (or path root) rather than a dict of paths, for clarity.- In the PT backend (
deepmd/pt/model/atomic_model/dp_atomic_model.py),wrapped_samplernow auto‑fillssample["fparam"]fromdefault_fparamwhen neither"find_fparam"nor"fparam"is present andhas_default_fparam()is true. The PD backend’swrapped_samplercurrently does not do this. If PD fitting expects the same default‑fparam semantics for stats as PT, you may want to mirror that logic here for cross‑backend consistency; if Paddle fitting handles default fparams internally, then the current PD behavior is fine.deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
340-368: Stat file path propagation into fitting stats is consistent; only docstring is slightly misleadingPassing
stat_file_pathintocompute_fitting_input_stat(wrapped_sampler, stat_file_path)and forwarding it toself.fitting_net.compute_input_stats(..., stat_file_path=stat_file_path)wires the PT backend cleanly into the new stat‑file I/O workflow; this matches the dpmodel / PD patterns and looks correct.The only nit is the docstring for
stat_file_path, which still calls it “The dictionary of paths to the statistics files.” while the type isOptional[DPPath]. Consider rephrasing to reflect that it’s a path object (or directory root) instead of a dict.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pd/model/atomic_model/dp_atomic_model.py(2 hunks)deepmd/pt/model/atomic_model/dp_atomic_model.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pt/model/atomic_model/dp_atomic_model.py (4)
deepmd/dpmodel/fitting/general_fitting.py (2)
has_default_fparam(303-305)compute_input_stats(225-288)deepmd/pd/model/atomic_model/base_atomic_model.py (2)
has_default_fparam(157-159)compute_fitting_input_stat(518-534)deepmd/pt/model/model/make_model.py (2)
has_default_fparam(531-533)get_default_fparam(535-536)deepmd/dpmodel/atomic_model/dp_atomic_model.py (1)
has_default_fparam(238-240)
🔇 Additional comments (1)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
374-380: get_default_fparam delegator is appropriate and aligns with higher‑level APIsAdding
get_default_fparam()as a thin delegator toself.fitting_net.get_default_fparam()is consistent withhas_default_fparam()and with the CM wrapper exposingget_default_fparam. This should make it straightforward for callers (andwrapped_sampler) to access default frame parameters without poking the fitting net directly.
In this PR:
stat_fileand loading fitting stat fromstat_filedefault_fparamshare_fittingin multitask mode.log.info.Summary by CodeRabbit
New Features
Refactor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.