Skip to content

Commit 0ec3e64

Browse files
committed
Add doubled variant
1 parent c1daf1d commit 0ec3e64

File tree

3 files changed

+87
-17
lines changed

3 files changed

+87
-17
lines changed

src/metatomic_lj_test/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def lennard_jones_model(
3232
constant that ensure that the energy given by the above formula goes to 0 at the
3333
cutoff.
3434
35+
The model also provides a ``doubled`` **variant** where the :math:`\epsilon`
36+
parameter is scaled by a factor of 2.
37+
3538
:param atomic_type: atomic type to which sigma/epsilon correspond
3639
:param cutoff: spherical cutoff of the model
3740
:param epsilon: epsilon parameter of Lennard-Jones
@@ -47,6 +50,11 @@ def lennard_jones_model(
4750
unit=energy_unit,
4851
per_atom=True,
4952
),
53+
"energy/doubled": ModelOutput(
54+
quantity="energy",
55+
unit=energy_unit,
56+
per_atom=True,
57+
),
5058
}
5159

5260
if with_extension:
@@ -70,16 +78,31 @@ def lennard_jones_model(
7078
unit=energy_unit,
7179
per_atom=True,
7280
),
81+
"energy_uncertainty/doubled": ModelOutput(
82+
quantity="energy",
83+
unit=energy_unit,
84+
per_atom=True,
85+
),
7386
"non_conservative_forces": ModelOutput(
7487
quantity="force",
7588
unit="eV/Angstrom",
7689
per_atom=True,
7790
),
91+
"non_conservative_forces/doubled": ModelOutput(
92+
quantity="force",
93+
unit="eV/Angstrom",
94+
per_atom=True,
95+
),
7896
"non_conservative_stress": ModelOutput(
7997
quantity="pressure",
8098
unit="eV/Angstrom^3",
8199
per_atom=False,
82100
),
101+
"non_conservative_stress/doubled": ModelOutput(
102+
quantity="pressure",
103+
unit="eV/Angstrom^3",
104+
per_atom=True,
105+
),
83106
}
84107
)
85108

src/metatomic_lj_test/extension.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Dict, List, Optional
44

55
import torch
6-
from metatensor.torch import Labels, TensorBlock, TensorMap
6+
from metatensor.torch import Labels, TensorBlock, TensorMap, multiply
77
from metatomic.torch import ModelOutput, NeighborListOptions, System
88

99
_HERE = os.path.dirname(__file__)
@@ -116,11 +116,16 @@ def forward(
116116
properties=Labels(["energy"], torch.tensor([[0]], device=device)),
117117
)
118118

119-
return {
119+
results = {
120120
"energy": TensorMap(
121121
Labels("_", torch.tensor([[0]], device=device)), [block]
122122
),
123123
}
124124

125+
if "energy/doubled" in outputs:
126+
results["energy/doubled"] = multiply(results["energy"], 2.0)
127+
128+
return results
129+
125130
def requested_neighbor_lists(self) -> List[NeighborListOptions]:
126131
return [self._nl_options]

src/metatomic_lj_test/pure.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Dict, List, Optional
22

33
import torch
4-
from metatensor.torch import Labels, TensorBlock, TensorMap
4+
from metatensor.torch import Labels, TensorBlock, TensorMap, multiply
55
from metatomic.torch import ModelOutput, NeighborListOptions, System
66

77

@@ -31,10 +31,14 @@ def forward(
3131
) -> Dict[str, TensorMap]:
3232
if (
3333
"energy" not in outputs
34+
and "energy/doubled" not in outputs
3435
and "energy_ensemble" not in outputs
3536
and "energy_uncertainty" not in outputs
37+
and "energy_uncertainty/doubled" not in outputs
3638
and "non_conservative_forces" not in outputs
39+
and "non_conservative_forces/doubled" not in outputs
3740
and "non_conservative_stress" not in outputs
41+
and "non_conservative_stress/doubled" not in outputs
3842
):
3943
return {}
4044

@@ -79,7 +83,10 @@ def forward(
7983
all_energies_per_atom.append(energy)
8084
all_energies.append(energy.sum(0, keepdim=True))
8185

82-
if "non_conservative_forces" in outputs:
86+
if (
87+
"non_conservative_forces" in outputs
88+
or "non_conservative_forces/doubled" in outputs
89+
):
8390
# we fill the non-conservative forces as the negative gradient of the potential
8491
# with respect to the positions, plus a random term
8592
forces = torch.zeros(len(system), 3, device=device, dtype=dtype)
@@ -104,15 +111,21 @@ def forward(
104111

105112
all_non_conservative_forces.append(forces)
106113

107-
if "non_conservative_stress" in outputs:
114+
if (
115+
"non_conservative_stress" in outputs
116+
or "non_conservative_stress/doubled" in outputs
117+
):
108118
# we fill the non-conservative stress with random numbers
109119
stress = torch.randn((3, 3), device=device, dtype=dtype)
110120
all_non_conservative_stress.append(stress)
111121

112122
energy_values = torch.vstack(all_energies).reshape(-1, 1)
113123
energies_per_atom_values = torch.vstack(all_energies_per_atom).reshape(-1, 1)
114124

115-
if "non_conservative_forces" in outputs:
125+
if (
126+
"non_conservative_forces" in outputs
127+
or "non_conservative_forces/doubled" in outputs
128+
):
116129
nc_forces_values = torch.cat(all_non_conservative_forces).reshape(-1, 3, 1)
117130
else:
118131
nc_forces_values = torch.empty((0, 0))
@@ -126,10 +139,15 @@ def forward(
126139
# randomly shuffle the samples to make sure the different engines handle
127140
# out of order samples
128141
indexes = torch.randperm(len(samples_list))
129-
if "energy" in outputs and outputs["energy"].per_atom:
142+
if ("energy" in outputs and outputs["energy"].per_atom) or (
143+
"energy/doubled" in outputs and outputs["energy/doubled"].per_atom
144+
):
130145
energies_per_atom_values = energies_per_atom_values[indexes]
131146

132-
if "non_conservative_forces" in outputs:
147+
if (
148+
"non_conservative_forces" in outputs
149+
or "non_conservative_forces/doubled" in outputs
150+
):
133151
nc_forces_values = nc_forces_values[indexes]
134152

135153
per_atom_samples = Labels(
@@ -143,7 +161,9 @@ def forward(
143161
)
144162
single_key = Labels("_", torch.tensor([[0]], device=device))
145163

146-
if "energy" in outputs and outputs["energy"].per_atom:
164+
if ("energy" in outputs and outputs["energy"].per_atom) or (
165+
"energy/doubled" in outputs and outputs["energy/doubled"].per_atom
166+
):
147167
energy_block = TensorBlock(
148168
values=energies_per_atom_values,
149169
samples=per_atom_samples,
@@ -159,8 +179,12 @@ def forward(
159179
)
160180

161181
results: Dict[str, TensorMap] = {}
162-
if "energy" in outputs:
163-
results["energy"] = TensorMap(single_key, [energy_block])
182+
if "energy" in outputs or "energy/doubled" in outputs:
183+
result = TensorMap(single_key, [energy_block])
184+
if "energy" in outputs:
185+
results["energy"] = result
186+
if "energy/doubled" in outputs:
187+
results["energy/doubled"] = multiply(result, 2.0)
164188

165189
if "energy_ensemble" in outputs:
166190
# returns the same energy for all ensemble members
@@ -187,7 +211,7 @@ def forward(
187211

188212
results["energy_ensemble"] = TensorMap(single_key, [ensemble_block])
189213

190-
if "energy_uncertainty" in outputs:
214+
if "energy_uncertainty" in outputs or "energy_uncertainty/doubled" in outputs:
191215
# returns an uncertainty of `0.001 * n_atoms^2` (note that the natural
192216
# scaling would be `sqrt(n_atoms)` or `n_atoms`); this is useful in tests so
193217
# we can artificially increase the uncertainty with the number of atoms
@@ -220,10 +244,17 @@ def forward(
220244
properties=energy_block.properties,
221245
)
222246

223-
results["energy_uncertainty"] = TensorMap(single_key, [uncertainty_block])
247+
result = TensorMap(single_key, [uncertainty_block])
248+
if "energy_uncertainty" in outputs:
249+
results["energy_uncertainty"] = result
250+
if "energy_uncertainty/doubled" in outputs:
251+
results["energy_uncertainty/doubled"] = multiply(result, 2.0)
224252

225-
if "non_conservative_forces" in outputs:
226-
results["non_conservative_forces"] = TensorMap(
253+
if (
254+
"non_conservative_forces" in outputs
255+
or "non_conservative_forces/doubled" in outputs
256+
):
257+
result = TensorMap(
227258
keys=Labels("_", torch.tensor([[0]], device=device)),
228259
blocks=[
229260
TensorBlock(
@@ -242,9 +273,16 @@ def forward(
242273
)
243274
],
244275
)
276+
if "non_conservative_forces" in outputs:
277+
results["non_conservative_forces"] = result
278+
if "non_conservative_forces/doubled" in outputs:
279+
results["non_conservative_forces/doubled"] = multiply(result, 2.0)
245280

246-
if "non_conservative_stress" in outputs:
247-
results["non_conservative_stress"] = TensorMap(
281+
if (
282+
"non_conservative_stress" in outputs
283+
or "non_conservative_stress/doubled" in outputs
284+
):
285+
result = TensorMap(
248286
keys=Labels("_", torch.tensor([[0]], device=device)),
249287
blocks=[
250288
TensorBlock(
@@ -272,6 +310,10 @@ def forward(
272310
)
273311
],
274312
)
313+
if "non_conservative_stress" in outputs:
314+
results["non_conservative_stress"] = result
315+
if "non_conservative_stress/doubled" in outputs:
316+
results["non_conservative_stress/doubled"] = multiply(result, 2.0)
275317

276318
return results
277319

0 commit comments

Comments
 (0)