Skip to content

Commit 9b19544

Browse files
committed
Add doubled variant
1 parent c1daf1d commit 9b19544

File tree

3 files changed

+96
-18
lines changed

3 files changed

+96
-18
lines changed

src/metatomic_lj_test/__init__.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def lennard_jones_model(
3232
constant that ensure that the energy given by the above formula goes to 0 at the
3333
cutoff.
3434
35+
The model also provided a **variant** denoted with ``doubled`` of the potential
36+
where the :math:`\epsilon` parameter is scaled by a factor of 2.
37+
3538
:param atomic_type: atomic type to which sigma/epsilon correspond
3639
:param cutoff: spherical cutoff of the model
3740
:param epsilon: epsilon parameter of Lennard-Jones
@@ -47,6 +50,11 @@ def lennard_jones_model(
4750
unit=energy_unit,
4851
per_atom=True,
4952
),
53+
"energy/doubled": ModelOutput(
54+
quantity="energy",
55+
unit=energy_unit,
56+
per_atom=True,
57+
),
5058
}
5159

5260
if with_extension:
@@ -56,7 +64,9 @@ def lennard_jones_model(
5664
else:
5765
from .pure import LennardJonesPurePyTorch
5866

59-
model = LennardJonesPurePyTorch(cutoff=cutoff, epsilon=epsilon, sigma=sigma)
67+
model = LennardJonesPurePyTorch(
68+
cutoff=cutoff, epsilon=epsilon, sigma=sigma, variant_scale=variant_scale
69+
)
6070

6171
outputs.update(
6272
{
@@ -70,16 +80,31 @@ def lennard_jones_model(
7080
unit=energy_unit,
7181
per_atom=True,
7282
),
83+
"energy_uncertainty/doubled": ModelOutput(
84+
quantity="energy",
85+
unit=energy_unit,
86+
per_atom=True,
87+
),
7388
"non_conservative_forces": ModelOutput(
7489
quantity="force",
7590
unit="eV/Angstrom",
7691
per_atom=True,
7792
),
93+
"non_conservative_forces/doubled": ModelOutput(
94+
quantity="force",
95+
unit="eV/Angstrom",
96+
per_atom=True,
97+
),
7898
"non_conservative_stress": ModelOutput(
7999
quantity="pressure",
80100
unit="eV/Angstrom^3",
81101
per_atom=False,
82102
),
103+
"non_conservative_stress/doubled": ModelOutput(
104+
quantity="pressure",
105+
unit="eV/Angstrom^3",
106+
per_atom=True,
107+
),
83108
}
84109
)
85110

src/metatomic_lj_test/extension.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Dict, List, Optional
44

55
import torch
6-
from metatensor.torch import Labels, TensorBlock, TensorMap
6+
from metatensor.torch import Labels, TensorBlock, TensorMap, multiply
77
from metatomic.torch import ModelOutput, NeighborListOptions, System
88

99
_HERE = os.path.dirname(__file__)
@@ -116,11 +116,16 @@ def forward(
116116
properties=Labels(["energy"], torch.tensor([[0]], device=device)),
117117
)
118118

119-
return {
119+
results = {
120120
"energy": TensorMap(
121121
Labels("_", torch.tensor([[0]], device=device)), [block]
122122
),
123123
}
124124

125+
if "energy/doubled" in outputs:
126+
results["energy/doubled"] = multiply(results["energy"], 2.0)
127+
128+
return results
129+
125130
def requested_neighbor_lists(self) -> List[NeighborListOptions]:
126131
return [self._nl_options]

src/metatomic_lj_test/pure.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Dict, List, Optional
22

33
import torch
4-
from metatensor.torch import Labels, TensorBlock, TensorMap
4+
from metatensor.torch import Labels, TensorBlock, TensorMap, multiply
55
from metatomic.torch import ModelOutput, NeighborListOptions, System
66

77

@@ -31,10 +31,14 @@ def forward(
3131
) -> Dict[str, TensorMap]:
3232
if (
3333
"energy" not in outputs
34+
and "energy/doubled" not in outputs
3435
and "energy_ensemble" not in outputs
3536
and "energy_uncertainty" not in outputs
37+
and "energy_uncertainty/doubled" not in outputs
3638
and "non_conservative_forces" not in outputs
39+
and "non_conservative_forces/doubled" not in outputs
3740
and "non_conservative_stress" not in outputs
41+
and "non_conservative_stress/doubled" not in outputs
3842
):
3943
return {}
4044

@@ -79,7 +83,10 @@ def forward(
7983
all_energies_per_atom.append(energy)
8084
all_energies.append(energy.sum(0, keepdim=True))
8185

82-
if "non_conservative_forces" in outputs:
86+
if (
87+
"non_conservative_forces" in outputs
88+
or "non_conservative_forces/doubled" in outputs
89+
):
8390
# we fill the non-conservative forces as the negative gradient of the potential
8491
# with respect to the positions, plus a random term
8592
forces = torch.zeros(len(system), 3, device=device, dtype=dtype)
@@ -104,15 +111,21 @@ def forward(
104111

105112
all_non_conservative_forces.append(forces)
106113

107-
if "non_conservative_stress" in outputs:
114+
if (
115+
"non_conservative_stress" in outputs
116+
or "non_conservative_stress/doubled" in outputs
117+
):
108118
# we fill the non-conservative stress with random numbers
109119
stress = torch.randn((3, 3), device=device, dtype=dtype)
110120
all_non_conservative_stress.append(stress)
111121

112122
energy_values = torch.vstack(all_energies).reshape(-1, 1)
113123
energies_per_atom_values = torch.vstack(all_energies_per_atom).reshape(-1, 1)
114124

115-
if "non_conservative_forces" in outputs:
125+
if (
126+
"non_conservative_forces" in outputs
127+
or "non_conservative_forces/doubled" in outputs
128+
):
116129
nc_forces_values = torch.cat(all_non_conservative_forces).reshape(-1, 3, 1)
117130
else:
118131
nc_forces_values = torch.empty((0, 0))
@@ -126,10 +139,15 @@ def forward(
126139
# randomly shuffle the samples to make sure the different engines handle
127140
# out of order samples
128141
indexes = torch.randperm(len(samples_list))
129-
if "energy" in outputs and outputs["energy"].per_atom:
142+
if ("energy" in outputs and outputs["energy"].per_atom) or (
143+
"energy/doubled" in outputs and outputs["energy/doubled"].per_atom
144+
):
130145
energies_per_atom_values = energies_per_atom_values[indexes]
131146

132-
if "non_conservative_forces" in outputs:
147+
if (
148+
"non_conservative_forces" in outputs
149+
or "non_conservative_forces/doubled" in outputs
150+
):
133151
nc_forces_values = nc_forces_values[indexes]
134152

135153
per_atom_samples = Labels(
@@ -143,7 +161,9 @@ def forward(
143161
)
144162
single_key = Labels("_", torch.tensor([[0]], device=device))
145163

146-
if "energy" in outputs and outputs["energy"].per_atom:
164+
if ("energy" in outputs and outputs["energy"].per_atom) or (
165+
"energy/doubled" in outputs and outputs["energy/doubled"].per_atom
166+
):
147167
energy_block = TensorBlock(
148168
values=energies_per_atom_values,
149169
samples=per_atom_samples,
@@ -159,8 +179,12 @@ def forward(
159179
)
160180

161181
results: Dict[str, TensorMap] = {}
162-
if "energy" in outputs:
163-
results["energy"] = TensorMap(single_key, [energy_block])
182+
if "energy" in outputs or "energy/doubled" in outputs:
183+
result = TensorMap(single_key, [energy_block])
184+
if "energy" in outputs:
185+
results["energy"] = result
186+
if "energy/doubled" in outputs:
187+
results["energy/doubled"] = multiply(result, self._variant_scale)
164188

165189
if "energy_ensemble" in outputs:
166190
# returns the same energy for all ensemble members
@@ -187,7 +211,7 @@ def forward(
187211

188212
results["energy_ensemble"] = TensorMap(single_key, [ensemble_block])
189213

190-
if "energy_uncertainty" in outputs:
214+
if "energy_uncertainty" in outputs or "energy_uncertainty/doubled" in outputs:
191215
# returns an uncertainty of `0.001 * n_atoms^2` (note that the natural
192216
# scaling would be `sqrt(n_atoms)` or `n_atoms`); this is useful in tests so
193217
# we can artificially increase the uncertainty with the number of atoms
@@ -220,10 +244,19 @@ def forward(
220244
properties=energy_block.properties,
221245
)
222246

223-
results["energy_uncertainty"] = TensorMap(single_key, [uncertainty_block])
247+
result = TensorMap(single_key, [uncertainty_block])
248+
if "energy_uncertainty" in outputs:
249+
results["energy_uncertainty"] = result
250+
if "energy_uncertainty/doubled" in outputs:
251+
results["energy_uncertainty/doubled"] = multiply(
252+
result, self._variant_scale
253+
)
224254

225-
if "non_conservative_forces" in outputs:
226-
results["non_conservative_forces"] = TensorMap(
255+
if (
256+
"non_conservative_forces" in outputs
257+
or "non_conservative_forces/doubled" in outputs
258+
):
259+
result = TensorMap(
227260
keys=Labels("_", torch.tensor([[0]], device=device)),
228261
blocks=[
229262
TensorBlock(
@@ -242,9 +275,18 @@ def forward(
242275
)
243276
],
244277
)
278+
if "non_conservative_forces" in outputs:
279+
results["non_conservative_forces"] = result
280+
if "non_conservative_forces/doubled" in outputs:
281+
results["non_conservative_forces/doubled"] = multiply(
282+
result, self._variant_scale
283+
)
245284

246-
if "non_conservative_stress" in outputs:
247-
results["non_conservative_stress"] = TensorMap(
285+
if (
286+
"non_conservative_stress" in outputs
287+
or "non_conservative_stress/doubled" in outputs
288+
):
289+
result = TensorMap(
248290
keys=Labels("_", torch.tensor([[0]], device=device)),
249291
blocks=[
250292
TensorBlock(
@@ -272,6 +314,12 @@ def forward(
272314
)
273315
],
274316
)
317+
if "non_conservative_stress" in outputs:
318+
results["non_conservative_stress"] = result
319+
if "non_conservative_stress/doubled" in outputs:
320+
results["non_conservative_stress/doubled"] = multiply(
321+
result, self._variant_scale
322+
)
275323

276324
return results
277325

0 commit comments

Comments
 (0)