1
1
from typing import Dict , List , Optional
2
2
3
3
import torch
4
- from metatensor .torch import Labels , TensorBlock , TensorMap
4
+ from metatensor .torch import Labels , TensorBlock , TensorMap , multiply
5
5
from metatomic .torch import ModelOutput , NeighborListOptions , System
6
6
7
7
@@ -31,10 +31,14 @@ def forward(
31
31
) -> Dict [str , TensorMap ]:
32
32
if (
33
33
"energy" not in outputs
34
+ and "energy/doubled" not in outputs
34
35
and "energy_ensemble" not in outputs
35
36
and "energy_uncertainty" not in outputs
37
+ and "energy_uncertainty/doubled" not in outputs
36
38
and "non_conservative_forces" not in outputs
39
+ and "non_conservative_forces/doubled" not in outputs
37
40
and "non_conservative_stress" not in outputs
41
+ and "non_conservative_stress/doubled" not in outputs
38
42
):
39
43
return {}
40
44
@@ -79,7 +83,10 @@ def forward(
79
83
all_energies_per_atom .append (energy )
80
84
all_energies .append (energy .sum (0 , keepdim = True ))
81
85
82
- if "non_conservative_forces" in outputs :
86
+ if (
87
+ "non_conservative_forces" in outputs
88
+ or "non_conservative_forces/doubled" in outputs
89
+ ):
83
90
# we fill the non-conservative forces as the negative gradient of the potential
84
91
# with respect to the positions, plus a random term
85
92
forces = torch .zeros (len (system ), 3 , device = device , dtype = dtype )
@@ -104,15 +111,21 @@ def forward(
104
111
105
112
all_non_conservative_forces .append (forces )
106
113
107
- if "non_conservative_stress" in outputs :
114
+ if (
115
+ "non_conservative_stress" in outputs
116
+ or "non_conservative_stress/doubled" in outputs
117
+ ):
108
118
# we fill the non-conservative stress with random numbers
109
119
stress = torch .randn ((3 , 3 ), device = device , dtype = dtype )
110
120
all_non_conservative_stress .append (stress )
111
121
112
122
energy_values = torch .vstack (all_energies ).reshape (- 1 , 1 )
113
123
energies_per_atom_values = torch .vstack (all_energies_per_atom ).reshape (- 1 , 1 )
114
124
115
- if "non_conservative_forces" in outputs :
125
+ if (
126
+ "non_conservative_forces" in outputs
127
+ or "non_conservative_forces/doubled" in outputs
128
+ ):
116
129
nc_forces_values = torch .cat (all_non_conservative_forces ).reshape (- 1 , 3 , 1 )
117
130
else :
118
131
nc_forces_values = torch .empty ((0 , 0 ))
@@ -126,10 +139,15 @@ def forward(
126
139
# randomly shuffle the samples to make sure the different engines handle
127
140
# out of order samples
128
141
indexes = torch .randperm (len (samples_list ))
129
- if "energy" in outputs and outputs ["energy" ].per_atom :
142
+ if ("energy" in outputs and outputs ["energy" ].per_atom ) or (
143
+ "energy/doubled" in outputs and outputs ["energy/doubled" ].per_atom
144
+ ):
130
145
energies_per_atom_values = energies_per_atom_values [indexes ]
131
146
132
- if "non_conservative_forces" in outputs :
147
+ if (
148
+ "non_conservative_forces" in outputs
149
+ or "non_conservative_forces/doubled" in outputs
150
+ ):
133
151
nc_forces_values = nc_forces_values [indexes ]
134
152
135
153
per_atom_samples = Labels (
@@ -143,7 +161,9 @@ def forward(
143
161
)
144
162
single_key = Labels ("_" , torch .tensor ([[0 ]], device = device ))
145
163
146
- if "energy" in outputs and outputs ["energy" ].per_atom :
164
+ if ("energy" in outputs and outputs ["energy" ].per_atom ) or (
165
+ "energy/doubled" in outputs and outputs ["energy/doubled" ].per_atom
166
+ ):
147
167
energy_block = TensorBlock (
148
168
values = energies_per_atom_values ,
149
169
samples = per_atom_samples ,
@@ -159,8 +179,12 @@ def forward(
159
179
)
160
180
161
181
results : Dict [str , TensorMap ] = {}
162
- if "energy" in outputs :
163
- results ["energy" ] = TensorMap (single_key , [energy_block ])
182
+ if "energy" in outputs or "energy/doubled" in outputs :
183
+ result = TensorMap (single_key , [energy_block ])
184
+ if "energy" in outputs :
185
+ results ["energy" ] = result
186
+ if "energy/doubled" in outputs :
187
+ results ["energy/doubled" ] = multiply (result , 2.0 )
164
188
165
189
if "energy_ensemble" in outputs :
166
190
# returns the same energy for all ensemble members
@@ -187,7 +211,7 @@ def forward(
187
211
188
212
results ["energy_ensemble" ] = TensorMap (single_key , [ensemble_block ])
189
213
190
- if "energy_uncertainty" in outputs :
214
+ if "energy_uncertainty" in outputs or "energy_uncertainty/doubled" in outputs :
191
215
# returns an uncertainty of `0.001 * n_atoms^2` (note that the natural
192
216
# scaling would be `sqrt(n_atoms)` or `n_atoms`); this is useful in tests so
193
217
# we can artificially increase the uncertainty with the number of atoms
@@ -220,10 +244,17 @@ def forward(
220
244
properties = energy_block .properties ,
221
245
)
222
246
223
- results ["energy_uncertainty" ] = TensorMap (single_key , [uncertainty_block ])
247
+ result = TensorMap (single_key , [uncertainty_block ])
248
+ if "energy_uncertainty" in outputs :
249
+ results ["energy_uncertainty" ] = result
250
+ if "energy_uncertainty/doubled" in outputs :
251
+ results ["energy_uncertainty/doubled" ] = multiply (result , 2.0 )
224
252
225
- if "non_conservative_forces" in outputs :
226
- results ["non_conservative_forces" ] = TensorMap (
253
+ if (
254
+ "non_conservative_forces" in outputs
255
+ or "non_conservative_forces/doubled" in outputs
256
+ ):
257
+ result = TensorMap (
227
258
keys = Labels ("_" , torch .tensor ([[0 ]], device = device )),
228
259
blocks = [
229
260
TensorBlock (
@@ -242,9 +273,16 @@ def forward(
242
273
)
243
274
],
244
275
)
276
+ if "non_conservative_forces" in outputs :
277
+ results ["non_conservative_forces" ] = result
278
+ if "non_conservative_forces/doubled" in outputs :
279
+ results ["non_conservative_forces/doubled" ] = multiply (result , 2.0 )
245
280
246
- if "non_conservative_stress" in outputs :
247
- results ["non_conservative_stress" ] = TensorMap (
281
+ if (
282
+ "non_conservative_stress" in outputs
283
+ or "non_conservative_stress/doubled" in outputs
284
+ ):
285
+ result = TensorMap (
248
286
keys = Labels ("_" , torch .tensor ([[0 ]], device = device )),
249
287
blocks = [
250
288
TensorBlock (
@@ -272,6 +310,10 @@ def forward(
272
310
)
273
311
],
274
312
)
313
+ if "non_conservative_stress" in outputs :
314
+ results ["non_conservative_stress" ] = result
315
+ if "non_conservative_stress/doubled" in outputs :
316
+ results ["non_conservative_stress/doubled" ] = multiply (result , 2.0 )
275
317
276
318
return results
277
319
0 commit comments