1
1
from typing import Dict , List , Optional
2
2
3
3
import torch
4
- from metatensor .torch import Labels , TensorBlock , TensorMap
4
+ from metatensor .torch import Labels , TensorBlock , TensorMap , multiply
5
5
from metatomic .torch import ModelOutput , NeighborListOptions , System
6
6
7
7
@@ -31,10 +31,14 @@ def forward(
31
31
) -> Dict [str , TensorMap ]:
32
32
if (
33
33
"energy" not in outputs
34
+ and "energy/doubled" not in outputs
34
35
and "energy_ensemble" not in outputs
35
36
and "energy_uncertainty" not in outputs
37
+ and "energy_uncertainty/doubled" not in outputs
36
38
and "non_conservative_forces" not in outputs
39
+ and "non_conservative_forces/doubled" not in outputs
37
40
and "non_conservative_stress" not in outputs
41
+ and "non_conservative_stress/doubled" not in outputs
38
42
):
39
43
return {}
40
44
@@ -79,7 +83,10 @@ def forward(
79
83
all_energies_per_atom .append (energy )
80
84
all_energies .append (energy .sum (0 , keepdim = True ))
81
85
82
- if "non_conservative_forces" in outputs :
86
+ if (
87
+ "non_conservative_forces" in outputs
88
+ or "non_conservative_forces/doubled" in outputs
89
+ ):
83
90
# we fill the non-conservative forces as the negative gradient of the potential
84
91
# with respect to the positions, plus a random term
85
92
forces = torch .zeros (len (system ), 3 , device = device , dtype = dtype )
@@ -104,15 +111,21 @@ def forward(
104
111
105
112
all_non_conservative_forces .append (forces )
106
113
107
- if "non_conservative_stress" in outputs :
114
+ if (
115
+ "non_conservative_stress" in outputs
116
+ or "non_conservative_stress/doubled" in outputs
117
+ ):
108
118
# we fill the non-conservative stress with random numbers
109
119
stress = torch .randn ((3 , 3 ), device = device , dtype = dtype )
110
120
all_non_conservative_stress .append (stress )
111
121
112
122
energy_values = torch .vstack (all_energies ).reshape (- 1 , 1 )
113
123
energies_per_atom_values = torch .vstack (all_energies_per_atom ).reshape (- 1 , 1 )
114
124
115
- if "non_conservative_forces" in outputs :
125
+ if (
126
+ "non_conservative_forces" in outputs
127
+ or "non_conservative_forces/doubled" in outputs
128
+ ):
116
129
nc_forces_values = torch .cat (all_non_conservative_forces ).reshape (- 1 , 3 , 1 )
117
130
else :
118
131
nc_forces_values = torch .empty ((0 , 0 ))
@@ -126,10 +139,15 @@ def forward(
126
139
# randomly shuffle the samples to make sure the different engines handle
127
140
# out of order samples
128
141
indexes = torch .randperm (len (samples_list ))
129
- if "energy" in outputs and outputs ["energy" ].per_atom :
142
+ if ("energy" in outputs and outputs ["energy" ].per_atom ) or (
143
+ "energy/doubled" in outputs and outputs ["energy/doubled" ].per_atom
144
+ ):
130
145
energies_per_atom_values = energies_per_atom_values [indexes ]
131
146
132
- if "non_conservative_forces" in outputs :
147
+ if (
148
+ "non_conservative_forces" in outputs
149
+ or "non_conservative_forces/doubled" in outputs
150
+ ):
133
151
nc_forces_values = nc_forces_values [indexes ]
134
152
135
153
per_atom_samples = Labels (
@@ -143,7 +161,9 @@ def forward(
143
161
)
144
162
single_key = Labels ("_" , torch .tensor ([[0 ]], device = device ))
145
163
146
- if "energy" in outputs and outputs ["energy" ].per_atom :
164
+ if ("energy" in outputs and outputs ["energy" ].per_atom ) or (
165
+ "energy/doubled" in outputs and outputs ["energy/doubled" ].per_atom
166
+ ):
147
167
energy_block = TensorBlock (
148
168
values = energies_per_atom_values ,
149
169
samples = per_atom_samples ,
@@ -159,8 +179,12 @@ def forward(
159
179
)
160
180
161
181
results : Dict [str , TensorMap ] = {}
162
- if "energy" in outputs :
163
- results ["energy" ] = TensorMap (single_key , [energy_block ])
182
+ if "energy" in outputs or "energy/doubled" in outputs :
183
+ result = TensorMap (single_key , [energy_block ])
184
+ if "energy" in outputs :
185
+ results ["energy" ] = result
186
+ if "energy/doubled" in outputs :
187
+ results ["energy/doubled" ] = multiply (result , self ._variant_scale )
164
188
165
189
if "energy_ensemble" in outputs :
166
190
# returns the same energy for all ensemble members
@@ -187,7 +211,7 @@ def forward(
187
211
188
212
results ["energy_ensemble" ] = TensorMap (single_key , [ensemble_block ])
189
213
190
- if "energy_uncertainty" in outputs :
214
+ if "energy_uncertainty" in outputs or "energy_uncertainty/doubled" in outputs :
191
215
# returns an uncertainty of `0.001 * n_atoms^2` (note that the natural
192
216
# scaling would be `sqrt(n_atoms)` or `n_atoms`); this is useful in tests so
193
217
# we can artificially increase the uncertainty with the number of atoms
@@ -220,10 +244,19 @@ def forward(
220
244
properties = energy_block .properties ,
221
245
)
222
246
223
- results ["energy_uncertainty" ] = TensorMap (single_key , [uncertainty_block ])
247
+ result = TensorMap (single_key , [uncertainty_block ])
248
+ if "energy_uncertainty" in outputs :
249
+ results ["energy_uncertainty" ] = result
250
+ if "energy_uncertainty/doubled" in outputs :
251
+ results ["energy_uncertainty/doubled" ] = multiply (
252
+ result , self ._variant_scale
253
+ )
224
254
225
- if "non_conservative_forces" in outputs :
226
- results ["non_conservative_forces" ] = TensorMap (
255
+ if (
256
+ "non_conservative_forces" in outputs
257
+ or "non_conservative_forces/doubled" in outputs
258
+ ):
259
+ result = TensorMap (
227
260
keys = Labels ("_" , torch .tensor ([[0 ]], device = device )),
228
261
blocks = [
229
262
TensorBlock (
@@ -242,9 +275,18 @@ def forward(
242
275
)
243
276
],
244
277
)
278
+ if "non_conservative_forces" in outputs :
279
+ results ["non_conservative_forces" ] = result
280
+ if "non_conservative_forces/doubled" in outputs :
281
+ results ["non_conservative_forces/doubled" ] = multiply (
282
+ result , self ._variant_scale
283
+ )
245
284
246
- if "non_conservative_stress" in outputs :
247
- results ["non_conservative_stress" ] = TensorMap (
285
+ if (
286
+ "non_conservative_stress" in outputs
287
+ or "non_conservative_stress/doubled" in outputs
288
+ ):
289
+ result = TensorMap (
248
290
keys = Labels ("_" , torch .tensor ([[0 ]], device = device )),
249
291
blocks = [
250
292
TensorBlock (
@@ -272,6 +314,12 @@ def forward(
272
314
)
273
315
],
274
316
)
317
+ if "non_conservative_stress" in outputs :
318
+ results ["non_conservative_stress" ] = result
319
+ if "non_conservative_stress/doubled" in outputs :
320
+ results ["non_conservative_stress/doubled" ] = multiply (
321
+ result , self ._variant_scale
322
+ )
275
323
276
324
return results
277
325
0 commit comments