Skip to content

Commit f6d5d95

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): add fitting for energy (#5218)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved tensor/device and dtype handling for consistent behavior across NumPy and PyTorch backends. * Fixed deserialization when layer collections are empty to avoid errors. * **New Features** * Added experimental PyTorch fitting wrappers: EnergyFittingNet and InvarFitting for tensor-based workflows and export/tracing. * Renamed descriptor registration keys to streamlined identifiers. * Package exports updated to surface new fitting entry points. * **Tests** * Added extensive tests for energy/invariant fitting, statistics computation, and PyTorch export compatibility. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 2901448 commit f6d5d95

File tree

16 files changed

+1245
-20
lines changed

16 files changed

+1245
-20
lines changed

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,18 @@ def compute_input_stats(
261261
fparam_std,
262262
)
263263
fparam_inv_std = 1.0 / fparam_std
264-
self.fparam_avg = fparam_avg.astype(self.fparam_avg.dtype)
265-
self.fparam_inv_std = fparam_inv_std.astype(self.fparam_inv_std.dtype)
264+
# Use array_api_compat to handle both numpy and torch
265+
xp = array_api_compat.array_namespace(self.fparam_avg)
266+
self.fparam_avg = xp.asarray(
267+
fparam_avg,
268+
dtype=self.fparam_avg.dtype,
269+
device=array_api_compat.device(self.fparam_avg),
270+
)
271+
self.fparam_inv_std = xp.asarray(
272+
fparam_inv_std,
273+
dtype=self.fparam_inv_std.dtype,
274+
device=array_api_compat.device(self.fparam_inv_std),
275+
)
266276
# stat aparam
267277
if self.numb_aparam > 0:
268278
sys_sumv = []
@@ -284,8 +294,18 @@ def compute_input_stats(
284294
aparam_std,
285295
)
286296
aparam_inv_std = 1.0 / aparam_std
287-
self.aparam_avg = aparam_avg.astype(self.aparam_avg.dtype)
288-
self.aparam_inv_std = aparam_inv_std.astype(self.aparam_inv_std.dtype)
297+
# Use array_api_compat to handle both numpy and torch
298+
xp = array_api_compat.array_namespace(self.aparam_avg)
299+
self.aparam_avg = xp.asarray(
300+
aparam_avg,
301+
dtype=self.aparam_avg.dtype,
302+
device=array_api_compat.device(self.aparam_avg),
303+
)
304+
self.aparam_inv_std = xp.asarray(
305+
aparam_inv_std,
306+
dtype=self.aparam_inv_std.dtype,
307+
device=array_api_compat.device(self.aparam_inv_std),
308+
)
289309

290310
@abstractmethod
291311
def _net_out_dim(self) -> int:
@@ -566,7 +586,9 @@ def _call_common(
566586
# calculate the prediction
567587
if not self.mixed_types:
568588
outs = xp.zeros(
569-
[nf, nloc, net_dim_out], dtype=get_xp_precision(xp, self.precision)
589+
[nf, nloc, net_dim_out],
590+
dtype=get_xp_precision(xp, self.precision),
591+
device=array_api_compat.device(descriptor),
570592
)
571593
for type_i in range(self.ntypes):
572594
mask = xp.tile(

deepmd/dpmodel/utils/network.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,10 +1110,13 @@ def deserialize(cls, data: dict) -> "FittingNet":
11101110
layers = data.pop("layers")
11111111
obj = cls(**data)
11121112
# Use type(obj.layers[0]) to respect subclass layer types
1113-
layer_type = type(obj.layers[0])
1114-
obj.layers = type(obj.layers)(
1115-
[layer_type.deserialize(layer) for layer in layers]
1116-
)
1113+
if obj.layers:
1114+
layer_type = type(obj.layers[0])
1115+
obj.layers = type(obj.layers)(
1116+
[layer_type.deserialize(layer) for layer in layers]
1117+
)
1118+
else:
1119+
obj.layers = type(obj.layers)([])
11171120
return obj
11181121

11191122

deepmd/pt_expt/descriptor/se_e2_a.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
)
1212

1313

14-
@BaseDescriptor.register("se_e2_a_expt")
15-
@BaseDescriptor.register("se_a_expt")
14+
@BaseDescriptor.register("se_e2_a")
15+
@BaseDescriptor.register("se_a")
1616
@torch_module
1717
class DescrptSeA(DescrptSeADP):
1818
def forward(

deepmd/pt_expt/descriptor/se_r.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
)
1212

1313

14-
@BaseDescriptor.register("se_e2_r_expt")
15-
@BaseDescriptor.register("se_r_expt")
14+
@BaseDescriptor.register("se_e2_r")
15+
@BaseDescriptor.register("se_r")
1616
@torch_module
1717
class DescrptSeR(DescrptSeRDP):
1818
def forward(

deepmd/pt_expt/descriptor/se_t.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
)
1212

1313

14-
@BaseDescriptor.register("se_e3_expt")
15-
@BaseDescriptor.register("se_at_expt")
16-
@BaseDescriptor.register("se_a_3be_expt")
14+
@BaseDescriptor.register("se_e3")
15+
@BaseDescriptor.register("se_at")
16+
@BaseDescriptor.register("se_a_3be")
1717
@torch_module
1818
class DescrptSeT(DescrptSeTDP):
1919
def forward(

deepmd/pt_expt/descriptor/se_t_tebd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212

1313

14-
@BaseDescriptor.register("se_e3_tebd_expt")
14+
@BaseDescriptor.register("se_e3_tebd")
1515
@torch_module
1616
class DescrptSeTTebd(DescrptSeTTebdDP):
1717
def forward(

deepmd/pt_expt/fitting/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from .base_fitting import (
3+
BaseFitting,
4+
)
5+
from .ener_fitting import (
6+
EnergyFittingNet,
7+
)
8+
from .invar_fitting import (
9+
InvarFitting,
10+
)
11+
12+
__all__ = [
13+
"BaseFitting",
14+
"EnergyFittingNet",
15+
"InvarFitting",
16+
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
import torch
4+
5+
from deepmd.dpmodel.fitting import (
6+
make_base_fitting,
7+
)
8+
9+
BaseFitting = make_base_fitting(torch.Tensor, "forward")
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
import torch
7+
8+
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
9+
from deepmd.pt_expt.common import (
10+
dpmodel_setattr,
11+
register_dpmodel_mapping,
12+
)
13+
from deepmd.pt_expt.utils.network import (
14+
NetworkCollection,
15+
)
16+
17+
from .base_fitting import (
18+
BaseFitting,
19+
)
20+
21+
22+
@BaseFitting.register("ener")
23+
class EnergyFittingNet(EnergyFittingNetDP, torch.nn.Module):
24+
"""Energy fitting net for pt_expt backend.
25+
26+
This inherits from dpmodel EnergyFittingNet to get the correct serialize() method.
27+
"""
28+
29+
def __init__(self, *args: Any, **kwargs: Any) -> None:
30+
torch.nn.Module.__init__(self)
31+
EnergyFittingNetDP.__init__(self, *args, **kwargs)
32+
# Convert dpmodel NetworkCollection to pt_expt NetworkCollection
33+
self.nets = NetworkCollection.deserialize(self.nets.serialize())
34+
35+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
36+
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
37+
return torch.nn.Module.__call__(self, *args, **kwargs)
38+
39+
def __setattr__(self, name: str, value: Any) -> None:
40+
handled, value = dpmodel_setattr(self, name, value)
41+
if not handled:
42+
super().__setattr__(name, value)
43+
44+
def forward(
45+
self,
46+
descriptor: torch.Tensor,
47+
atype: torch.Tensor,
48+
gr: torch.Tensor | None = None,
49+
g2: torch.Tensor | None = None,
50+
h2: torch.Tensor | None = None,
51+
fparam: torch.Tensor | None = None,
52+
aparam: torch.Tensor | None = None,
53+
) -> dict[str, torch.Tensor]:
54+
return self.call(
55+
descriptor,
56+
atype,
57+
gr=gr,
58+
g2=g2,
59+
h2=h2,
60+
fparam=fparam,
61+
aparam=aparam,
62+
)
63+
64+
65+
register_dpmodel_mapping(
66+
EnergyFittingNetDP,
67+
lambda v: EnergyFittingNet.deserialize(v.serialize()),
68+
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
import torch
7+
8+
from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP
9+
from deepmd.pt_expt.common import (
10+
dpmodel_setattr,
11+
register_dpmodel_mapping,
12+
)
13+
from deepmd.pt_expt.fitting.base_fitting import (
14+
BaseFitting,
15+
)
16+
from deepmd.pt_expt.utils.network import (
17+
NetworkCollection,
18+
)
19+
20+
21+
@BaseFitting.register("invar")
22+
class InvarFitting(InvarFittingDP, torch.nn.Module):
23+
def __init__(self, *args: Any, **kwargs: Any) -> None:
24+
torch.nn.Module.__init__(self)
25+
InvarFittingDP.__init__(self, *args, **kwargs)
26+
# Convert dpmodel NetworkCollection to pt_expt NetworkCollection
27+
self.nets = NetworkCollection.deserialize(self.nets.serialize())
28+
29+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
30+
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
31+
return torch.nn.Module.__call__(self, *args, **kwargs)
32+
33+
def __setattr__(self, name: str, value: Any) -> None:
34+
handled, value = dpmodel_setattr(self, name, value)
35+
if not handled:
36+
super().__setattr__(name, value)
37+
38+
def forward(
39+
self,
40+
descriptor: torch.Tensor,
41+
atype: torch.Tensor,
42+
gr: torch.Tensor | None = None,
43+
g2: torch.Tensor | None = None,
44+
h2: torch.Tensor | None = None,
45+
fparam: torch.Tensor | None = None,
46+
aparam: torch.Tensor | None = None,
47+
) -> dict[str, torch.Tensor]:
48+
return self.call(
49+
descriptor,
50+
atype,
51+
gr=gr,
52+
g2=g2,
53+
h2=h2,
54+
fparam=fparam,
55+
aparam=aparam,
56+
)
57+
58+
59+
register_dpmodel_mapping(
60+
InvarFittingDP,
61+
lambda v: InvarFitting.deserialize(v.serialize()),
62+
)

0 commit comments

Comments
 (0)