Skip to content

Commit 156736f

Browse files
wanghan-iapcmHan Wangnjzjz
authored
feat(pt_expt): implement se_t and se_t_tebd descriptors. (#5208)
This PR is considered after #5194 #5204 and #5205 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added experimental PyTorch support for SeT and SeT-TEBD descriptors, enabling model training and serialization/export. * Introduced TypeEmbedNet wrapper for type embedding integration in PyTorch workflows. * **Bug Fixes** * Improved backend compatibility and device-aware tensor allocation across descriptor implementations. * Fixed PyTorch tensor indexing compatibility issues. * **Tests** * Added comprehensive test coverage for new experimental descriptors and consistency validation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: Han Wang <wang_han@iapcm.ac.cn> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent 2d7fdc5 commit 156736f

File tree

17 files changed

+599
-24
lines changed

17 files changed

+599
-24
lines changed

deepmd/dpmodel/descriptor/descriptor.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
NoReturn,
1313
)
1414

15-
import numpy as np
15+
import array_api_compat
1616

1717
from deepmd.dpmodel.array_api import (
1818
Array,
@@ -173,7 +173,18 @@ def extend_descrpt_stat(
173173
extend_dstd = des_with_stat["dstd"]
174174
else:
175175
extend_shape = [len(type_map), *list(des["davg"].shape[1:])]
176-
extend_davg = np.zeros(extend_shape, dtype=des["davg"].dtype)
177-
extend_dstd = np.ones(extend_shape, dtype=des["dstd"].dtype)
178-
des["davg"] = np.concatenate([des["davg"], extend_davg], axis=0)
179-
des["dstd"] = np.concatenate([des["dstd"], extend_dstd], axis=0)
176+
# Use array_api_compat to infer device and dtype from context
177+
xp = array_api_compat.array_namespace(des["davg"])
178+
extend_davg = xp.zeros(
179+
extend_shape,
180+
dtype=des["davg"].dtype,
181+
device=array_api_compat.device(des["davg"]),
182+
)
183+
extend_dstd = xp.ones(
184+
extend_shape,
185+
dtype=des["dstd"].dtype,
186+
device=array_api_compat.device(des["dstd"]),
187+
)
188+
xp = array_api_compat.array_namespace(des["davg"])
189+
des["davg"] = xp.concat([des["davg"], extend_davg], axis=0)
190+
des["dstd"] = xp.concat([des["dstd"], extend_dstd], axis=0)

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,8 @@ def call(
10491049
idx_j = xp.reshape(nei_type, (-1,))
10501050
# (nf x nl x nnei) x ng
10511051
idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng))
1052+
# Cast to int64 for PyTorch backend (take_along_dim requires Long indices)
1053+
idx = xp.astype(idx, xp.int64)
10521054
# (ntypes) * ntypes * nt
10531055
type_embedding_nei = xp.tile(
10541056
xp.reshape(type_embedding, (1, ntypes_with_padding, nt)),

deepmd/dpmodel/descriptor/se_t.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,11 @@ def call(
369369
sec = self.sel_cumsum
370370

371371
ng = self.neuron[-1]
372-
result = xp.zeros([nf * nloc, ng], dtype=get_xp_precision(xp, self.precision))
372+
result = xp.zeros(
373+
[nf * nloc, ng],
374+
dtype=get_xp_precision(xp, self.precision),
375+
device=array_api_compat.device(coord_ext),
376+
)
373377
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
374378
# merge nf and nloc axis, so for type_one_side == False,
375379
# we don't require atype is the same in all frames

deepmd/dpmodel/descriptor/se_t_tebd.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,9 @@ def call(
769769
sw = xp.where(
770770
nlist_mask[:, :, None],
771771
xp.reshape(sw, (nf * nloc, nnei, 1)),
772-
xp.zeros((nf * nloc, nnei, 1), dtype=sw.dtype),
772+
xp.zeros(
773+
(nf * nloc, nnei, 1), dtype=sw.dtype, device=array_api_compat.device(sw)
774+
),
773775
)
774776

775777
# nfnl x nnei x 4
@@ -832,6 +834,8 @@ def call(
832834

833835
# (nf x nl x nt_i x nt_j) x ng
834836
idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng))
837+
# Cast to int64 for PyTorch backend (take_along_dim requires Long indices)
838+
idx = xp.astype(idx, xp.int64)
835839

836840
# ntypes * (ntypes) * nt
837841
type_embedding_i = xp.tile(

deepmd/dpmodel/utils/type_embed.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,21 @@ def call(self) -> Array:
100100
sample_array = self.embedding_net[0]["w"]
101101
xp = array_api_compat.array_namespace(sample_array)
102102
if not self.use_econf_tebd:
103-
embed = self.embedding_net(xp.eye(self.ntypes, dtype=sample_array.dtype))
103+
embed = self.embedding_net(
104+
xp.eye(
105+
self.ntypes,
106+
dtype=sample_array.dtype,
107+
device=array_api_compat.device(sample_array),
108+
)
109+
)
104110
else:
105111
embed = self.embedding_net(self.econf_tebd)
106112
if self.padding:
107-
embed_pad = xp.zeros((1, embed.shape[-1]), dtype=embed.dtype)
113+
embed_pad = xp.zeros(
114+
(1, embed.shape[-1]),
115+
dtype=embed.dtype,
116+
device=array_api_compat.device(embed),
117+
)
108118
embed = xp.concat([embed, embed_pad], axis=0)
109119
return embed
110120

@@ -180,32 +190,51 @@ def change_type_map(
180190
"'activation_function' must be 'Linear' when performing type changing on resnet structure!"
181191
)
182192
first_layer_matrix = self.embedding_net.layers[0].w
183-
eye_vector = np.eye(self.ntypes, dtype=PRECISION_DICT[self.precision])
193+
# Use array_api_compat to handle both numpy and torch
194+
xp = array_api_compat.array_namespace(first_layer_matrix)
195+
eye_vector = xp.eye(
196+
self.ntypes,
197+
dtype=first_layer_matrix.dtype,
198+
device=array_api_compat.device(first_layer_matrix),
199+
)
184200
# preprocess for resnet connection
185201
if self.neuron[0] == self.ntypes:
186-
first_layer_matrix += eye_vector
202+
first_layer_matrix = first_layer_matrix + eye_vector
187203
elif self.neuron[0] == self.ntypes * 2:
188-
first_layer_matrix += np.concatenate([eye_vector, eye_vector], axis=-1)
204+
first_layer_matrix = first_layer_matrix + xp.concat(
205+
[eye_vector, eye_vector], axis=-1
206+
)
189207

190208
# randomly initialize params for the unseen types
191-
rng = np.random.default_rng()
192209
if has_new_type:
193-
extend_type_params = rng.random(
210+
# Create random params with same dtype and device as first_layer_matrix
211+
extend_type_params = np.random.default_rng().random(
194212
[len(type_map), first_layer_matrix.shape[-1]],
213+
dtype=PRECISION_DICT[self.precision],
214+
)
215+
extend_type_params = xp.asarray(
216+
extend_type_params,
195217
dtype=first_layer_matrix.dtype,
218+
device=array_api_compat.device(first_layer_matrix),
196219
)
197-
first_layer_matrix = np.concatenate(
220+
first_layer_matrix = xp.concat(
198221
[first_layer_matrix, extend_type_params], axis=0
199222
)
200223

201224
first_layer_matrix = first_layer_matrix[remap_index]
202225
new_ntypes = len(type_map)
203-
eye_vector = np.eye(new_ntypes, dtype=PRECISION_DICT[self.precision])
226+
eye_vector = xp.eye(
227+
new_ntypes,
228+
dtype=first_layer_matrix.dtype,
229+
device=array_api_compat.device(first_layer_matrix),
230+
)
204231

205232
if self.neuron[0] == new_ntypes:
206-
first_layer_matrix -= eye_vector
233+
first_layer_matrix = first_layer_matrix - eye_vector
207234
elif self.neuron[0] == new_ntypes * 2:
208-
first_layer_matrix -= np.concatenate([eye_vector, eye_vector], axis=-1)
235+
first_layer_matrix = first_layer_matrix - xp.concat(
236+
[eye_vector, eye_vector], axis=-1
237+
)
209238

210239
self.embedding_net.layers[0].num_in = new_ntypes
211240
self.embedding_net.layers[0].w = first_layer_matrix

deepmd/pt_expt/descriptor/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
# Import to register converters
3+
from . import se_t_tebd_block # noqa: F401
24
from .base_descriptor import (
35
BaseDescriptor,
46
)
@@ -8,9 +10,17 @@
810
from .se_r import (
911
DescrptSeR,
1012
)
13+
from .se_t import (
14+
DescrptSeT,
15+
)
16+
from .se_t_tebd import (
17+
DescrptSeTTebd,
18+
)
1119

1220
__all__ = [
1321
"BaseDescriptor",
1422
"DescrptSeA",
1523
"DescrptSeR",
24+
"DescrptSeT",
25+
"DescrptSeTTebd",
1626
]

deepmd/pt_expt/descriptor/se_e2_a.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,14 @@ def forward(
3535
extended_coord: torch.Tensor,
3636
extended_atype: torch.Tensor,
3737
nlist: torch.Tensor,
38-
extended_atype_embd: torch.Tensor | None = None,
3938
mapping: torch.Tensor | None = None,
40-
type_embedding: torch.Tensor | None = None,
4139
) -> tuple[
4240
torch.Tensor,
4341
torch.Tensor | None,
4442
torch.Tensor | None,
4543
torch.Tensor | None,
4644
torch.Tensor | None,
4745
]:
48-
del extended_atype_embd, type_embedding
4946
descrpt, rot_mat, g2, h2, sw = self.call(
5047
extended_coord,
5148
extended_atype,

deepmd/pt_expt/descriptor/se_r.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,14 @@ def forward(
3535
extended_coord: torch.Tensor,
3636
extended_atype: torch.Tensor,
3737
nlist: torch.Tensor,
38-
extended_atype_embd: torch.Tensor | None = None,
3938
mapping: torch.Tensor | None = None,
40-
type_embedding: torch.Tensor | None = None,
4139
) -> tuple[
4240
torch.Tensor,
4341
torch.Tensor | None,
4442
torch.Tensor | None,
4543
torch.Tensor | None,
4644
torch.Tensor | None,
4745
]:
48-
del extended_atype_embd, type_embedding
4946
descrpt, rot_mat, g2, h2, sw = self.call(
5047
extended_coord,
5148
extended_atype,

deepmd/pt_expt/descriptor/se_t.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
import torch
7+
8+
from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP
9+
from deepmd.pt_expt.common import (
10+
dpmodel_setattr,
11+
)
12+
from deepmd.pt_expt.descriptor.base_descriptor import (
13+
BaseDescriptor,
14+
)
15+
16+
17+
@BaseDescriptor.register("se_e3_expt")
18+
@BaseDescriptor.register("se_at_expt")
19+
@BaseDescriptor.register("se_a_3be_expt")
20+
class DescrptSeT(DescrptSeTDP, torch.nn.Module):
21+
def __init__(self, *args: Any, **kwargs: Any) -> None:
22+
torch.nn.Module.__init__(self)
23+
DescrptSeTDP.__init__(self, *args, **kwargs)
24+
25+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
26+
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
27+
return torch.nn.Module.__call__(self, *args, **kwargs)
28+
29+
def __setattr__(self, name: str, value: Any) -> None:
30+
handled, value = dpmodel_setattr(self, name, value)
31+
if not handled:
32+
super().__setattr__(name, value)
33+
34+
def forward(
35+
self,
36+
extended_coord: torch.Tensor,
37+
extended_atype: torch.Tensor,
38+
nlist: torch.Tensor,
39+
mapping: torch.Tensor | None = None,
40+
) -> tuple[
41+
torch.Tensor,
42+
torch.Tensor | None,
43+
torch.Tensor | None,
44+
torch.Tensor | None,
45+
torch.Tensor | None,
46+
]:
47+
descrpt, rot_mat, g2, h2, sw = self.call(
48+
extended_coord,
49+
extended_atype,
50+
nlist,
51+
mapping=mapping,
52+
)
53+
return descrpt, rot_mat, g2, h2, sw
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
import torch
7+
8+
from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP
9+
from deepmd.pt_expt.common import (
10+
dpmodel_setattr,
11+
)
12+
from deepmd.pt_expt.descriptor.base_descriptor import (
13+
BaseDescriptor,
14+
)
15+
16+
17+
@BaseDescriptor.register("se_e3_tebd_expt")
18+
class DescrptSeTTebd(DescrptSeTTebdDP, torch.nn.Module):
19+
def __init__(self, *args: Any, **kwargs: Any) -> None:
20+
torch.nn.Module.__init__(self)
21+
DescrptSeTTebdDP.__init__(self, *args, **kwargs)
22+
23+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
24+
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
25+
return torch.nn.Module.__call__(self, *args, **kwargs)
26+
27+
def __setattr__(self, name: str, value: Any) -> None:
28+
handled, value = dpmodel_setattr(self, name, value)
29+
if not handled:
30+
super().__setattr__(name, value)
31+
32+
def forward(
33+
self,
34+
extended_coord: torch.Tensor,
35+
extended_atype: torch.Tensor,
36+
nlist: torch.Tensor,
37+
mapping: torch.Tensor | None = None,
38+
) -> tuple[
39+
torch.Tensor,
40+
torch.Tensor | None,
41+
torch.Tensor | None,
42+
torch.Tensor | None,
43+
torch.Tensor | None,
44+
]:
45+
descrpt, rot_mat, g2, h2, sw = self.call(
46+
extended_coord,
47+
extended_atype,
48+
nlist,
49+
mapping=mapping,
50+
)
51+
return descrpt, rot_mat, g2, h2, sw

0 commit comments

Comments
 (0)