Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/dev-docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ module.
getting-started
architecture-life-cycle
new-architecture
new-mlip
dataset-information
new-loss
cli/index
Expand Down
65 changes: 3 additions & 62 deletions docs/src/dev-docs/new-architecture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,68 +157,9 @@ the user request and machines' availability, the optimal ``dtype`` and
.. note::

For MLIP-only models (models that only predict energies and forces),
``metatrain`` provides base classes :py:class:`metatrain.utils.mlip.MLIPModel`
and :py:class:`metatrain.utils.mlip.MLIPTrainer` that implement most of the
boilerplate code. See :doc:`utils/mlip` for more details.

Example: Creating an MLIP-only architecture
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

To demonstrate how easy it is to add a new MLIP-only architecture using the base
classes, let's look at the ``mlip_example`` architecture in ``metatrain``. This minimal
architecture always predicts zero energy, serving as a simple template for MLIP
development.

The model (``model.py``) only needs to implement the ``compute_energy`` method:

.. code-block:: python

from metatrain.utils.mlip import MLIPModel

class ZeroModel(MLIPModel):
"""A minimal example MLIP model that always predicts zero energy."""

__checkpoint_version__ = 1

def __init__(self, hypers, dataset_info):
super().__init__(hypers, dataset_info)
# Request a neighbor list with the cutoff from hyperparameters
cutoff = hypers["cutoff"]
self.request_neighbor_list(cutoff)

def compute_energy(
self,
edge_vectors,
species,
centers,
neighbors,
system_indices,
):
# Get the number of systems and return zeros
n_systems = system_indices.max().item() + 1
return torch.zeros(n_systems, device=edge_vectors.device)

The trainer (``trainer.py``) only needs to specify whether to use rotational
augmentation:

.. code-block:: python

from metatrain.utils.mlip import MLIPTrainer

class ZeroTrainer(MLIPTrainer):
"""Trainer for the ZeroModel."""

__checkpoint_version__ = 1

def use_rotational_augmentation(self):
return False # No rotational augmentation for this example

That's it! The base classes handle all the training loop, data loading,
composition weights, scaling, checkpointing, and export functionality. This
allows you to focus on implementing the core physics of your model in the
``compute_energy`` method.

The complete example architecture can be found in ``src/metatrain/mlip_example/``.
``metatrain`` provides base classes that implement most of the boilerplate
code. See :doc:`new-mlip` for a complete guide on adding MLIP-only
architectures.

Trainer class (``trainer.py``)
------------------------------
Expand Down
68 changes: 68 additions & 0 deletions docs/src/dev-docs/new-mlip.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
.. _adding-new-mlip:

Adding a new MLIP-only architecture
====================================

For MLIP-only models (models that only predict energies and forces),
``metatrain`` provides base classes :py:class:`metatrain.utils.mlip.MLIPModel`
and :py:class:`metatrain.utils.mlip.MLIPTrainer` that implement most of the
boilerplate code. See :doc:`utils/mlip` for more details.

Example: Creating an MLIP-only architecture
--------------------------------------------

To demonstrate how easy it is to add a new MLIP-only architecture using the base
classes, let's look at the ``mlip_example`` architecture in ``metatrain``. This minimal
architecture always predicts zero energy, serving as a simple template for MLIP
development.

The model (``model.py``) only needs to implement the ``compute_energy`` method:

.. code-block:: python

from metatrain.utils.mlip import MLIPModel

class ZeroModel(MLIPModel):
"""A minimal example MLIP model that always predicts zero energy."""

__checkpoint_version__ = 1

def __init__(self, hypers, dataset_info):
super().__init__(hypers, dataset_info)
# Request a neighbor list with the cutoff from hyperparameters
cutoff = hypers["cutoff"]
self.request_neighbor_list(cutoff)

def compute_energy(
self,
edge_vectors,
species,
centers,
neighbors,
system_indices,
):
# Get the number of systems and return zeros
n_systems = system_indices.max().item() + 1
return torch.zeros(n_systems, device=edge_vectors.device)

The trainer (``trainer.py``) only needs to specify whether to use rotational
augmentation:

.. code-block:: python

from metatrain.utils.mlip import MLIPTrainer

class ZeroTrainer(MLIPTrainer):
"""Trainer for the ZeroModel."""

__checkpoint_version__ = 1

def use_rotational_augmentation(self):
return False # No rotational augmentation for this example

That's it! The base classes handle all the training loop, data loading,
composition weights, scaling, checkpointing, and export functionality. This
allows you to focus on implementing the core physics of your model in the
``compute_energy`` method.

The complete example architecture can be found in ``src/metatrain/mlip_example/``.
63 changes: 0 additions & 63 deletions docs/src/dev-docs/utils/mlip.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,66 +5,3 @@ MLIP Base Classes
:members:
:undoc-members:
:show-inheritance:

MLIPModel
---------

The :py:class:`metatrain.utils.mlip.MLIPModel` class is a base class for
MLIP-only models that predict only energies and forces. It provides:

- Common forward pass logic with neighbor list processing
- Automatic integration of :py:class:`~metatrain.utils.additive.CompositionModel`
for composition-based energy corrections
- Automatic integration of :py:class:`~metatrain.utils.scaler.Scaler` for
target scaling
- Checkpoint saving/loading (``get_checkpoint``, ``load_checkpoint``)
- Model export to metatomic format (``export``)
- Support for restarting training (``restart``)

Derived classes only need to implement the
:py:meth:`~metatrain.utils.mlip.MLIPModel.compute_energy` method.

The base class automatically handles additive models and scaling at evaluation
time, so the derived class only needs to compute the "raw" energy predictions.

MLIPTrainer
-----------

The :py:class:`metatrain.utils.mlip.MLIPTrainer` class is a base trainer for
MLIP-only models. It implements the complete training loop and handles:

- Distributed training
- Data loading with optional rotational augmentation
- Loss computation
- Checkpointing

Derived classes only need to implement the
:py:meth:`~metatrain.utils.mlip.MLIPTrainer.use_rotational_augmentation` method
to specify whether rotational data augmentation should be used during training.

Example
^^^^^^^

Here's how to use the base classes to create a new MLIP architecture:

.. code-block:: python

from metatrain.utils.mlip import MLIPModel, MLIPTrainer

class MyMLIPModel(MLIPModel):
def compute_energy(
self,
edge_vectors: torch.Tensor,
species: torch.Tensor,
centers: torch.Tensor,
neighbors: torch.Tensor,
system_indices: torch.Tensor,
) -> torch.Tensor:
# Implement your energy computation here
...
return energies # shape: (N_systems,)

class MyMLIPTrainer(MLIPTrainer):
def use_rotational_augmentation(self) -> bool:
# Return True to use rotational augmentation, False otherwise
return False
2 changes: 1 addition & 1 deletion examples/0-beginner/00-basic-usage.sh
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,4 @@ mtt eval --help
# ---------------------
#
# The trained model can also be used to run molecular simulations.
# You can find how in the examples section.
# You can find how in the :ref:`tutorials` section.
77 changes: 75 additions & 2 deletions src/metatrain/utils/mlip.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,52 @@


class MLIPModel(ModelInterface):
"""
Base model class for MLIP-only architectures.

This class is a base class for MLIP-only models that predict only energies and
forces. It provides:

- Common forward pass logic with neighbor list processing
- Automatic integration of :py:class:`~metatrain.utils.additive.CompositionModel`
for composition-based energy corrections
- Automatic integration of :py:class:`~metatrain.utils.scaler.Scaler` for
target scaling
- Checkpoint saving/loading (``get_checkpoint``, ``load_checkpoint``)
- Model export to metatomic format (``export``)
- Support for restarting training (``restart``)

Derived classes only need to implement the
:py:meth:`~metatrain.utils.mlip.MLIPModel.compute_energy` method.

The base class automatically handles additive models and scaling at evaluation
time, so the derived class only needs to compute the "raw" energy predictions.

Example:

.. code-block:: python

from metatrain.utils.mlip import MLIPModel


class MyMLIPModel(MLIPModel):
def compute_energy(
self,
edge_vectors: torch.Tensor,
species: torch.Tensor,
centers: torch.Tensor,
neighbors: torch.Tensor,
system_indices: torch.Tensor,
) -> torch.Tensor:
# Implement your energy computation here
...
return energies # shape: (N_systems,)

:param hypers: Model hyperparameters.
:param dataset_info: Information about the dataset, including atomic types and
targets.
"""

__checkpoint_version__ = 1
__supported_devices__ = ["cuda", "cpu"]
__supported_dtypes__ = [torch.float64, torch.float32]
Expand Down Expand Up @@ -517,8 +563,35 @@ class MLIPTrainer(TrainerInterface):
"""
Base trainer class for MLIP-only architectures.

This class provides common training logic for models that only predict energies
and forces. Derived classes can customize behavior by implementing abstract methods.
This class is a base trainer for MLIP-only models. It implements the complete
training loop and handles:

- Distributed training
- Data loading with optional rotational augmentation
- Loss computation
- Checkpointing

Derived classes only need to implement the
:py:meth:`~metatrain.utils.mlip.MLIPTrainer.use_rotational_augmentation` method
to specify whether rotational data augmentation should be used during training.

Note on rotational augmentation: You don't need rotational augmentation if
rotational invariance is enforced in the neural network architecture itself
(e.g., through equivariant message passing). However, if your architecture does
not enforce rotational invariance, you should use rotational augmentation to
ensure the model learns rotationally invariant representations.

Example:

.. code-block:: python

from metatrain.utils.mlip import MLIPTrainer


class MyMLIPTrainer(MLIPTrainer):
def use_rotational_augmentation(self) -> bool:
# Return True to use rotational augmentation, False otherwise
return False

:param hypers: Training hyperparameters.
"""
Expand Down