Skip to content

Commit 84ca63c

Browse files
authored
[r2] fix seeds in se_a and se_atten (#3880) (#3947)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Resolved inconsistencies in seed values by incrementing `self.seed` conditionally in descriptor modules. - **Tests** - Updated test arrays `refe`, `reff`, and `refv` with new reference values. - Adjusted expected values in `test_model_ener` method for better accuracy. These changes ensure more reliable descriptor computations and improved test accuracy. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- (cherry picked from commit 0c472d1) Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent a85d58f commit 84ca63c

File tree

4 files changed

+50
-33
lines changed

4 files changed

+50
-33
lines changed

deepmd/descriptor/se_a.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,8 @@ def _filter_lower(
10311031
mixed_prec=self.mixed_prec,
10321032
)
10331033
net_output = tf.nn.embedding_lookup(net_output, idx)
1034+
if (not self.uniform_seed) and (self.seed is not None):
1035+
self.seed += self.seed_shift
10341036
net_output = tf.reshape(net_output, [-1, self.filter_neuron[-1]])
10351037
else:
10361038
xyz_scatter = self._concat_type_embedding(
@@ -1042,7 +1044,7 @@ def _filter_lower(
10421044
)
10431045
# natom x 4 x outputs_size
10441046
if nvnmd_cfg.enable:
1045-
return filter_lower_R42GR(
1047+
oo = filter_lower_R42GR(
10461048
type_i,
10471049
type_input,
10481050
inputs_i,
@@ -1060,6 +1062,9 @@ def _filter_lower(
10601062
self.filter_resnet_dt,
10611063
self.embedding_net_variables,
10621064
)
1065+
if (not self.uniform_seed) and (self.seed is not None):
1066+
self.seed += self.seed_shift
1067+
return oo
10631068
if self.compress and (not is_exclude):
10641069
if self.stripped_type_embedding:
10651070
net_output = tf.nn.embedding_lookup(

deepmd/descriptor/se_atten.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,8 @@ def _attention_layers(
959959
uniform_seed=self.uniform_seed,
960960
initial_variables=self.attention_layer_variables,
961961
)
962+
if not self.uniform_seed and self.seed is not None:
963+
self.seed += 1
962964
K_c = one_layer(
963965
input_xyz,
964966
self.att_n,
@@ -972,6 +974,8 @@ def _attention_layers(
972974
uniform_seed=self.uniform_seed,
973975
initial_variables=self.attention_layer_variables,
974976
)
977+
if not self.uniform_seed and self.seed is not None:
978+
self.seed += 1
975979
V_c = one_layer(
976980
input_xyz,
977981
self.att_n,
@@ -985,6 +989,8 @@ def _attention_layers(
985989
uniform_seed=self.uniform_seed,
986990
initial_variables=self.attention_layer_variables,
987991
)
992+
if not self.uniform_seed and self.seed is not None:
993+
self.seed += 1
988994
# # natom x nei_type_i x out_size
989995
# xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1] // 4, outputs_size[-1]))
990996
# natom x nei_type_i x att_n
@@ -1017,6 +1023,8 @@ def _attention_layers(
10171023
uniform_seed=self.uniform_seed,
10181024
initial_variables=self.attention_layer_variables,
10191025
)
1026+
if not self.uniform_seed and self.seed is not None:
1027+
self.seed += 1
10201028
input_xyz = tf.keras.layers.LayerNormalization(
10211029
beta_initializer=tf.constant_initializer(self.beta[i]),
10221030
gamma_initializer=tf.constant_initializer(self.gamma[i]),
@@ -1080,6 +1088,8 @@ def _filter_lower(
10801088
initial_variables=self.embedding_net_variables,
10811089
mixed_prec=self.mixed_prec,
10821090
)
1091+
if (not self.uniform_seed) and (self.seed is not None):
1092+
self.seed += self.seed_shift
10831093
else:
10841094
if self.attn_layer == 0:
10851095
log.info(
@@ -1119,6 +1129,8 @@ def _filter_lower(
11191129
initial_variables=self.embedding_net_variables,
11201130
mixed_prec=self.mixed_prec,
11211131
)
1132+
if (not self.uniform_seed) and (self.seed is not None):
1133+
self.seed += self.seed_shift
11221134
else:
11231135
net = "filter_net"
11241136
info = [
@@ -1176,6 +1188,8 @@ def _filter_lower(
11761188
initial_variables=self.two_side_embeeding_net_variables,
11771189
mixed_prec=self.mixed_prec,
11781190
)
1191+
if (not self.uniform_seed) and (self.seed is not None):
1192+
self.seed += self.seed_shift
11791193
two_embd = tf.nn.embedding_lookup(
11801194
embedding_of_two_side_type_embedding, index_of_two_side
11811195
)
@@ -1194,8 +1208,6 @@ def _filter_lower(
11941208
is_sorted=len(self.exclude_types) == 0,
11951209
)
11961210

1197-
if (not self.uniform_seed) and (self.seed is not None):
1198-
self.seed += self.seed_shift
11991211
input_r = tf.slice(
12001212
tf.reshape(inputs_i, (-1, shape_i[1] // 4, 4)), [0, 0, 1], [-1, -1, 3]
12011213
)

source/tests/test_model_se_a_ebd_v2.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -139,37 +139,37 @@ def test_model(self):
139139
f = f.reshape([-1])
140140
v = v.reshape([-1])
141141

142-
refe = [5.435394596262052014e-01]
142+
refe = [6.100037044296185e-01]
143143
reff = [
144-
6.583728125594628944e-02,
145-
7.228993116083935744e-02,
146-
1.971543579114074483e-03,
147-
6.567474563776359853e-02,
148-
7.809421727465599983e-02,
149-
-4.866958849094786890e-03,
150-
-8.670511901715304004e-02,
151-
3.525374157021862048e-02,
152-
1.415748959800727487e-03,
153-
6.375813001810648473e-02,
154-
-1.139053242798149790e-01,
155-
-4.178593754384440744e-03,
156-
-1.471737787218250215e-01,
157-
4.189712704724830872e-02,
158-
7.011731363309440038e-03,
159-
3.860874082716164030e-02,
160-
-1.136296927731473005e-01,
161-
-1.353471298745012206e-03,
144+
8.448651008616304e-02,
145+
8.613568658155157e-02,
146+
4.377711655236228e-03,
147+
9.264613309788312e-02,
148+
9.351200240060925e-02,
149+
-6.743918515275118e-03,
150+
-1.268078358219972e-01,
151+
4.855965861982662e-02,
152+
1.361334787979757e-04,
153+
4.193213089916692e-02,
154+
-1.324120032345251e-01,
155+
-4.507320444374342e-03,
156+
-1.314595297986654e-01,
157+
4.120567370248839e-02,
158+
7.896917575801866e-03,
159+
3.920259153744955e-02,
160+
-1.370010180699507e-01,
161+
-1.159523750186610e-03,
162162
]
163163
refv = [
164-
-4.243979601186427253e-01,
165-
1.097173849143971286e-01,
166-
1.227299373463585502e-02,
167-
1.097173849143970314e-01,
168-
-2.462891443164323124e-01,
169-
-5.711664180530139426e-03,
170-
1.227299373463585502e-02,
171-
-5.711664180530143763e-03,
172-
-6.217348853341628408e-04,
164+
-0.277134219204478,
165+
0.088897922530779,
166+
0.008633318264458,
167+
0.088897922530779,
168+
-0.292191560546969,
169+
-0.005709595520904,
170+
0.008633318264458,
171+
-0.005709595520904,
172+
-0.000682136341924,
173173
]
174174
refe = np.reshape(refe, [-1])
175175
reff = np.reshape(reff, [-1])

source/tests/test_pairwise_dprc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,8 @@ def test_model_ener(self):
519519
# the model is pairwise!
520520
self.assertAllClose(e[1] + e[2] + e[3] - 3 * e[0], e[4] - e[0])
521521
self.assertAllClose(f[1] + f[2] + f[3] - 3 * f[0], f[4] - f[0])
522-
self.assertAllClose(e[0], 0.189075, 1e-6)
523-
self.assertAllClose(f[0, 0], 0.060047, 1e-6)
522+
self.assertAllClose(e[0], 4.82969, 1e-6)
523+
self.assertAllClose(f[0, 0], -0.104339, 1e-6)
524524

525525
def test_nloc(self):
526526
jfile = tests_path / "pairwise_dprc.json"

0 commit comments

Comments
 (0)