Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/metatomic_lj_test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def lennard_jones_model(
constant that ensure that the energy given by the above formula goes to 0 at the
cutoff.

The model also provides a ``doubled`` **variant** where the :math:`\epsilon`
parameter is scaled by a factor of 2.

:param atomic_type: atomic type to which sigma/epsilon correspond
:param cutoff: spherical cutoff of the model
:param epsilon: epsilon parameter of Lennard-Jones
Expand All @@ -47,6 +50,11 @@ def lennard_jones_model(
unit=energy_unit,
per_atom=True,
),
"energy/doubled": ModelOutput(
quantity="energy",
unit=energy_unit,
per_atom=True,
),
}

if with_extension:
Expand All @@ -70,16 +78,31 @@ def lennard_jones_model(
unit=energy_unit,
per_atom=True,
),
"energy_uncertainty/doubled": ModelOutput(
quantity="energy",
unit=energy_unit,
per_atom=True,
),
"non_conservative_forces": ModelOutput(
quantity="force",
unit="eV/Angstrom",
per_atom=True,
),
"non_conservative_forces/doubled": ModelOutput(
quantity="force",
unit="eV/Angstrom",
per_atom=True,
),
"non_conservative_stress": ModelOutput(
quantity="pressure",
unit="eV/Angstrom^3",
per_atom=False,
),
"non_conservative_stress/doubled": ModelOutput(
quantity="pressure",
unit="eV/Angstrom^3",
per_atom=True,
),
}
)

Expand Down
9 changes: 7 additions & 2 deletions src/metatomic_lj_test/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, List, Optional

import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch import Labels, TensorBlock, TensorMap, multiply
from metatomic.torch import ModelOutput, NeighborListOptions, System

_HERE = os.path.dirname(__file__)
Expand Down Expand Up @@ -116,11 +116,16 @@ def forward(
properties=Labels(["energy"], torch.tensor([[0]], device=device)),
)

return {
results = {
"energy": TensorMap(
Labels("_", torch.tensor([[0]], device=device)), [block]
),
}

if "energy/doubled" in outputs:
results["energy/doubled"] = multiply(results["energy"], 2.0)

return results

def requested_neighbor_lists(self) -> List[NeighborListOptions]:
return [self._nl_options]
72 changes: 57 additions & 15 deletions src/metatomic_lj_test/pure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, List, Optional

import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch import Labels, TensorBlock, TensorMap, multiply
from metatomic.torch import ModelOutput, NeighborListOptions, System


Expand Down Expand Up @@ -31,10 +31,14 @@ def forward(
) -> Dict[str, TensorMap]:
if (
"energy" not in outputs
and "energy/doubled" not in outputs
and "energy_ensemble" not in outputs
and "energy_uncertainty" not in outputs
and "energy_uncertainty/doubled" not in outputs
and "non_conservative_forces" not in outputs
and "non_conservative_forces/doubled" not in outputs
and "non_conservative_stress" not in outputs
and "non_conservative_stress/doubled" not in outputs
):
return {}

Expand Down Expand Up @@ -79,7 +83,10 @@ def forward(
all_energies_per_atom.append(energy)
all_energies.append(energy.sum(0, keepdim=True))

if "non_conservative_forces" in outputs:
if (
"non_conservative_forces" in outputs
or "non_conservative_forces/doubled" in outputs
):
# we fill the non-conservative forces as the negative gradient of the potential
# with respect to the positions, plus a random term
forces = torch.zeros(len(system), 3, device=device, dtype=dtype)
Expand All @@ -104,15 +111,21 @@ def forward(

all_non_conservative_forces.append(forces)

if "non_conservative_stress" in outputs:
if (
"non_conservative_stress" in outputs
or "non_conservative_stress/doubled" in outputs
):
# we fill the non-conservative stress with random numbers
stress = torch.randn((3, 3), device=device, dtype=dtype)
all_non_conservative_stress.append(stress)

energy_values = torch.vstack(all_energies).reshape(-1, 1)
energies_per_atom_values = torch.vstack(all_energies_per_atom).reshape(-1, 1)

if "non_conservative_forces" in outputs:
if (
"non_conservative_forces" in outputs
or "non_conservative_forces/doubled" in outputs
):
nc_forces_values = torch.cat(all_non_conservative_forces).reshape(-1, 3, 1)
else:
nc_forces_values = torch.empty((0, 0))
Expand All @@ -126,10 +139,15 @@ def forward(
# randomly shuffle the samples to make sure the different engines handle
# out of order samples
indexes = torch.randperm(len(samples_list))
if "energy" in outputs and outputs["energy"].per_atom:
if ("energy" in outputs and outputs["energy"].per_atom) or (
"energy/doubled" in outputs and outputs["energy/doubled"].per_atom
):
energies_per_atom_values = energies_per_atom_values[indexes]

if "non_conservative_forces" in outputs:
if (
"non_conservative_forces" in outputs
or "non_conservative_forces/doubled" in outputs
):
nc_forces_values = nc_forces_values[indexes]

per_atom_samples = Labels(
Expand All @@ -143,7 +161,9 @@ def forward(
)
single_key = Labels("_", torch.tensor([[0]], device=device))

if "energy" in outputs and outputs["energy"].per_atom:
if ("energy" in outputs and outputs["energy"].per_atom) or (
"energy/doubled" in outputs and outputs["energy/doubled"].per_atom
):
energy_block = TensorBlock(
values=energies_per_atom_values,
samples=per_atom_samples,
Expand All @@ -159,8 +179,12 @@ def forward(
)

results: Dict[str, TensorMap] = {}
if "energy" in outputs:
results["energy"] = TensorMap(single_key, [energy_block])
if "energy" in outputs or "energy/doubled" in outputs:
result = TensorMap(single_key, [energy_block])
if "energy" in outputs:
results["energy"] = result
if "energy/doubled" in outputs:
results["energy/doubled"] = multiply(result, 2.0)

if "energy_ensemble" in outputs:
# returns the same energy for all ensemble members
Expand All @@ -187,7 +211,7 @@ def forward(

results["energy_ensemble"] = TensorMap(single_key, [ensemble_block])

if "energy_uncertainty" in outputs:
if "energy_uncertainty" in outputs or "energy_uncertainty/doubled" in outputs:
# returns an uncertainty of `0.001 * n_atoms^2` (note that the natural
# scaling would be `sqrt(n_atoms)` or `n_atoms`); this is useful in tests so
# we can artificially increase the uncertainty with the number of atoms
Expand Down Expand Up @@ -220,10 +244,17 @@ def forward(
properties=energy_block.properties,
)

results["energy_uncertainty"] = TensorMap(single_key, [uncertainty_block])
result = TensorMap(single_key, [uncertainty_block])
if "energy_uncertainty" in outputs:
results["energy_uncertainty"] = result
if "energy_uncertainty/doubled" in outputs:
results["energy_uncertainty/doubled"] = multiply(result, 2.0)

if "non_conservative_forces" in outputs:
results["non_conservative_forces"] = TensorMap(
if (
"non_conservative_forces" in outputs
or "non_conservative_forces/doubled" in outputs
):
result = TensorMap(
keys=Labels("_", torch.tensor([[0]], device=device)),
blocks=[
TensorBlock(
Expand All @@ -242,9 +273,16 @@ def forward(
)
],
)
if "non_conservative_forces" in outputs:
results["non_conservative_forces"] = result
if "non_conservative_forces/doubled" in outputs:
results["non_conservative_forces/doubled"] = multiply(result, 2.0)

if "non_conservative_stress" in outputs:
results["non_conservative_stress"] = TensorMap(
if (
"non_conservative_stress" in outputs
or "non_conservative_stress/doubled" in outputs
):
result = TensorMap(
keys=Labels("_", torch.tensor([[0]], device=device)),
blocks=[
TensorBlock(
Expand Down Expand Up @@ -272,6 +310,10 @@ def forward(
)
],
)
if "non_conservative_stress" in outputs:
results["non_conservative_stress"] = result
if "non_conservative_stress/doubled" in outputs:
results["non_conservative_stress/doubled"] = multiply(result, 2.0)

return results

Expand Down