-
Notifications
You must be signed in to change notification settings - Fork 493
Coding standards for models implementation
This document defines the coding standards and best practices for implementing model classes in the PhysicsNeMo repository. These rules are designed to ensure consistency, maintainability, and high code quality across all model implementations.
Important: These rules are enforced as strictly as possible. Deviations from these standards should only be made when absolutely necessary and must be documented with clear justification in code comments and approved during code review.
This document is structured in two main sections:
-
Rule Index: A quick-reference table listing all rules with their IDs, one-line summaries, and the context in which they apply. Use this section to quickly identify relevant rules when implementing or reviewing code.
-
Detailed Rules: Comprehensive descriptions of each rule, including:
- Clear descriptions of what the rule requires
- Rationale explaining why the rule exists
- Examples demonstrating correct implementation
- Anti-patterns showing common mistakes to avoid
- When creating new models: Review all rules before starting implementation, paying special attention to rules MOD-000 through MOD-003.
- When reviewing code: Use the Rule Index to quickly verify compliance with all applicable rules.
- When refactoring: Ensure refactored code maintains or improves compliance with these standards.
| Rule ID | Summary | Apply When |
|---|---|---|
MOD-000a |
Reusable layers/blocks belong in physicsnemo.nn | Creating or refactoring reusable layer classes |
MOD-000b |
Complete models belong in physicsnemo.models | Creating or refactoring complete model classes |
MOD-001 |
Use physicsnemo.Module as model base classes | Creating or refactoring new model classes |
MOD-002a |
New models and layers belong in physicsnemo.experimental | Creating new model or layer classes |
MOD-002b |
Add deprecation warnings to deprecating model class | Deprecating existing model classes |
MOD-002c |
Remove deprecated model from codebase | Removing deprecated models after warning period |
MOD-003a |
Missing or incomplete docstring for model/layer code | Creating or editing any model or layer code |
MOD-003b |
Docstring must use raw string prefix r""" | Writing any model or method docstring |
MOD-003c |
Missing required class docstring sections | Writing class docstrings |
MOD-003d |
Missing required method docstring sections | Writing method docstrings |
MOD-003e |
Tensor shapes must use LaTeX math notation | Documenting tensors in docstrings |
MOD-003f |
Callback functions must have code-block specification | Documenting callback function parameters |
MOD-003g |
Inline code must use double backticks | Writing inline code in docstrings |
MOD-003h |
Parameters must be documented on single line | Documenting function/method parameters |
MOD-003i |
Docstrings should include cross-references | Writing comprehensive docstrings |
MOD-003j |
Docstrings should include Examples section | Writing model class docstrings |
MOD-003k |
Add high-level comments for complex tensor operations | Writing model code with complex tensor operations |
MOD-004 |
Model code is not self-contained | Organizing or refactoring model code |
MOD-005 |
Invalid or missing tensor shape validation logic | Implementing model forward or public methods |
MOD-006 |
Invalid or missing jaxtyping tensor annotations in public function signature | Adding type hints to model methods |
MOD-007a |
Cannot add required parameters without defaults | Modifying production model signatures |
MOD-007b |
Cannot remove or rename parameters without compat mapper | Modifying production model signatures |
MOD-007c |
Cannot change return types of public methods | Modifying production model method signatures |
MOD-008a |
Model missing constructor/attributes tests | Adding CI tests for models |
MOD-008b |
Model missing non-regression test with reference data | Adding CI tests for models |
MOD-008c |
Model missing checkpoint loading test | Adding CI tests for models |
MOD-009 |
Avoid string-based class selection in model constructors | Designing model constructor APIs |
MOD-010 |
Avoid splatted kwargs in model constructors | Designing model constructor APIs |
MOD-011 |
Use proper optional dependency handling | Implementing models with optional dependencies |
Description:
Reusable layers that are the building blocks of more complex architectures
should go into physicsnemo/nn. Those include for instance FullyConnected,
various variants of attention layers, UNetBlock (a block of a U-Net), etc.
All layers that are directly exposed to the user should be imported in
physicsnemo/nn/__init__.py, such that they can be used as follows:
from physicsnemo.nn import MyLayerThe only exception to this rule is for layers that are highly specific to a
single example. In this case, it may be acceptable to place them in a module
specific to the example code, such as examples/<example_name>/utils/nn.py.
Rationale:
Ensures consistency in the organization of reusable layers in the repository. Keeping all reusable components in a single location makes them easy to find and promotes code reuse across different models.
Example:
# Good: Reusable layer in physicsnemo/nn/attention.py
class MultiHeadAttention(Module):
"""A reusable attention layer that can be used in various architectures."""
pass
# Good: Import in physicsnemo/nn/__init__.py
from physicsnemo.nn.attention import MultiHeadAttention
# Good: Example-specific layer in examples/weather/utils/nn.py
class WeatherSpecificLayer(Module):
"""Layer highly specific to the weather forecasting example."""
passAnti-pattern:
# WRONG: Reusable layer placed in physicsnemo/models/
# File: physicsnemo/models/attention.py
class MultiHeadAttention(Module):
"""Should be in physicsnemo/nn/ not physicsnemo/models/"""
passDescription:
More complete models, composed of multiple layers and/or other sub-models,
should go into physicsnemo/models. All models that are directly exposed to
the user should be imported in physicsnemo/models/__init__.py, such that they
can be used as follows:
from physicsnemo.models import MyModelThe only exception to this rule is for models that are highly specific to a
single example. In this case, it may be acceptable to place them in a module
specific to the example code, such as examples/<example_name>/utils/nn.py.
Rationale:
Ensures consistency and clarity in the organization of models in the repository, in particular a clear separation between reusable layers and more complete models that are applicable to a specific domain or specific data modality.
Example:
# Good: Complete model in physicsnemo/models/transformer.py
class TransformerModel(Module):
"""A complete transformer model composed of attention and feedforward layers."""
def __init__(self):
super().__init__()
self.attention = MultiHeadAttention(...)
self.ffn = FeedForward(...)
# Good: Import in physicsnemo/models/__init__.py
from physicsnemo.models.transformer import TransformerModelAnti-pattern:
# WRONG: Complete model placed in physicsnemo/nn/
# File: physicsnemo/nn/transformer.py
class TransformerModel(Module):
"""Should be in physicsnemo/models/ not physicsnemo/nn/"""
passDescription:
All model classes must inherit from physicsnemo.Module. Direct subclasses of
torch.nn.Module are not allowed. Direct subclasses of physicsnemo.Module
are allowed (note that physicsnemo.Module is a subclass of torch.nn.Module).
Ensure proper initialization of parent classes using super().__init__(). Pass
the meta argument to the super().__init__() call if appropriate, otherwise
set it manually with self.meta = meta.
Rationale:
Ensures invariants and functionality of the physicsnemo.Module class for all
models. In particular, instances of physicsnemo.Module benefit from features
that are not available in torch.nn.Module instances. Those include serialization
for checkpointing and loading modules and submodules, versioning system to
handle backward compatibility, as well as ability to be registered in the
physicsnemo.registry for easy instantiation and use in any codebase.
Example:
from physicsnemo import Module
class MyModel(Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__(meta=MyModelMetaData())
self.linear = nn.Linear(input_dim, output_dim)Anti-pattern:
from torch import nn
class MyModel(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
self.linear = nn.Linear(input_dim, output_dim)Description:
For the vast majority of models, new classes are created either in
physicsnemo/experimental/nn for reusable layers, or in
physicsnemo/experimental/models for more complete models. The experimental
folder is used to store models that are still under development (beta or alpha
releases), where backward compatibility is not guaranteed.
One exception is when the developer is highly confident that the model is
sufficiently mature and applicable to many domains or use cases. In this case
the model class can be created in the physicsnemo/nn or physicsnemo/models
folders directly, and backward compatibility is guaranteed.
Another exception is when the model class is highly specific to a single
example. In this case, it may be acceptable to place it in a module specific to
the example code, such as examples/<example_name>/utils/nn.py.
After staying in experimental for a sufficient amount of time (typically at
least 1 release cycle), the model class can be promoted to production. It is
then moved to the physicsnemo/nn or physicsnemo/models folders, based on
whether it's a reusable layer (MOD-000a) or complete model (MOD-000b). During
the production stage, backward compatibility is guaranteed.
Note: Per MOD-008a, MOD-008b, and MOD-008c, it is forbidden to move a model out of the experimental stage/directory without the required CI tests.
Rationale:
The experimental stage allows rapid iteration without backward compatibility constraints, enabling developers to refine APIs based on user feedback. This protects users from unstable APIs while allowing innovation.
Example:
# Good: New experimental model
# File: physicsnemo/experimental/models/new_diffusion.py
class DiffusionModel(Module):
"""New diffusion model under active development. API may change."""
pass
# Good: After 1+ release cycles, promoted to production
# File: physicsnemo/models/diffusion.py (moved from experimental/)
class DiffusionModel(Module):
"""Stable diffusion model with backward compatibility guarantees."""
passAnti-pattern:
# WRONG: New model directly in production folder
# File: physicsnemo/models/brand_new_model.py (should be in experimental/ first)
class BrandNewModel(Module):
"""Skipped experimental stage - risky for stability"""
passDescription:
For a model class being deprecated in physicsnemo/nn or physicsnemo/models,
the developer must add warning messages indicating that the model class is
deprecated and will be removed in a future release.
The warning message should be clear and concise, explaining why the model class is being deprecated and what the user should do instead. The deprecation message must be added to both:
- The docstring using
.. deprecated::directive - Runtime using
warnings.warn(..., DeprecationWarning)
The developer is free to choose the mechanism to raise the deprecation warning. A model class cannot be deprecated without staying in the pre-deprecation stage for at least 1 release cycle before it can be deleted (see MOD-002c).
Rationale:
Ensures users have sufficient time to migrate to newer alternatives, preventing breaking changes that could disrupt their workflows. This graduated approach balances innovation with stability.
Example:
# Good: Pre-deprecation with proper warnings
# File: physicsnemo/models/old_diffusion.py
class DiffusionModel(Module):
"""
Legacy diffusion model.
.. deprecated:: 0.5.0
``OldDiffusionModel`` is deprecated and will be removed in version 0.7.0.
Use :class:`~physicsnemo.models.NewDiffusionModel` instead.
"""
def __init__(self):
import warnings
warnings.warn(
"OldDiffusionModel is deprecated. Use NewDiffusionModel instead.",
DeprecationWarning,
stacklevel=2
)
super().__init__()Anti-pattern:
# WRONG: No runtime warning
# File: physicsnemo/models/old_model.py
class OldModel(Module):
"""Will be removed next release.""" # Docstring mentions it but no runtime warning
def __init__(self):
# Missing: warnings.warn(..., DeprecationWarning)
super().__init__()
# WRONG: Deprecation without sufficient warning period
# (Model deprecated and removed in same release)Description:
After staying in the pre-deprecation stage for at least 1 release cycle, the model class is considered deprecated and can be deleted from the codebase.
A model class cannot be deleted without first spending at least 1 release cycle in the pre-deprecation stage with proper deprecation warnings (see MOD-002b).
Rationale:
This ensures users have sufficient warning and time to migrate their code to newer alternatives. Premature deletion of models would break user code without adequate notice, violating the framework's commitment to stability.
Example:
# Good: Proper deprecation timeline
# v0.5.0: Added deprecation warnings (Stage 3 - pre-deprecation)
# v0.6.0: Model can be safely removed (Stage 4 - deprecation)
# File: physicsnemo/models/old_diffusion.py - DELETEDAnti-pattern:
# WRONG: Deleting model without deprecation period
# v0.5.0: Model exists without warnings
# v0.6.0: Model deleted - BREAKS USER CODE!
# WRONG: Breaking changes without deprecation
# File: physicsnemo/models/diffusion.py
class DiffusionModel(Module):
def __init__(self, new_required_param): # Breaking change!
# Changed API without deprecation warning - breaks user code
passDescription:
Every new model or modification of any model code should be documented with a comprehensive docstring following all the sub-rules MOD-003b through MOD-003k. All docstrings should be written in the NumPy style and adopt formatting to be compatible with our Sphinx restructured text (RST) documentation.
Rationale:
Comprehensive and well-formatted documentation is essential for scientific software. It enables users to understand model capabilities, expected inputs, and outputs without inspecting source code.
Example:
class MyEncoder(Module):
r"""
A simple encoder network.
Parameters
----------
input_dim : int
Dimension of input features.
output_dim : int
Dimension of output features.
Forward
-------
x : torch.Tensor
Input tensor of shape :math:`(B, D_{in})`.
Outputs
-------
torch.Tensor
Output tensor of shape :math:`(B, D_{out})`.
Examples
--------
>>> model = MyEncoder(input_dim=784, output_dim=128)
>>> x = torch.randn(32, 784)
>>> output = model(x)
>>> output.shape
torch.Size([32, 128])
"""
passAnti-pattern:
# WRONG: Missing all required sections
class BadEncoder(Module):
'''A simple encoder.''' # Wrong quotes, no sections
passDescription:
Each docstring should be prefixed with r""" (not """ or '''). The r
prefix creates a raw string that prevents Python from interpreting backslashes,
which is essential for LaTeX math notation to render correctly in Sphinx
documentation.
Rationale:
LaTeX commands in docstrings use backslashes (e.g., \math, \text). Without
the raw string prefix, Python interprets these as escape sequences, breaking the
documentation rendering.
Example:
class MyModel(Module):
r"""
A model with LaTeX notation.
Parameters
----------
dim : int
Dimension :math:`D` of input features.
"""
passAnti-pattern:
# WRONG: Using ''' instead of r"""
class MyModel(Module):
'''
A model with LaTeX notation.
'''
passDescription:
The class docstring should at least contain three sections: Parameters,
Forward, and Outputs. The forward method should be documented in the
docstring of the model class, instead of being in the docstring of the forward
method itself. A docstring for the forward method is still possible but it
should be concise and to the point.
Other sections such as Notes, Examples, or ..important:: or
..code-block::python
are possible. Other sections are not recognized by our Sphinx restructured text
(RST) documentation and are prohibited.
Rationale:
Standardized sections ensure documentation is consistent and complete across all models. The Forward and Outputs sections in the class docstring provide a centralized place to document the model's primary behavior, making it easier for users to understand the model's API.
Example:
class MyModel(Module):
r"""
A simple encoder model.
Parameters
----------
input_dim : int
Dimension of input features.
Forward
-------
x : torch.Tensor
Input tensor of shape :math:`(B, D_{in})`.
Outputs
-------
torch.Tensor
Output tensor of shape :math:`(B, D_{out})`.
"""
passAnti-pattern:
# WRONG: Missing required sections
class BadModel(Module):
r"""
A simple encoder model.
No proper sections defined.
"""
passDescription:
All methods should be documented with a docstring, with at least a Parameters
section and a Returns section. Other sections such as Notes, Examples, or
..important:: or ..code-block:: python are possible. Other sections are not
recognized by our Sphinx documentation and are prohibited.
Note: The forward method is a special case - its full documentation should be in the class docstring (see MOD-003c), though a concise forward method docstring is permitted.
Rationale:
Complete method documentation ensures users understand how to call methods and what to expect in return. Standardized sections make documentation consistent and easier to parse for both humans and AI agents.
Example:
def compute_loss(
self,
pred: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
r"""
Compute mean squared error loss.
Parameters
----------
pred : torch.Tensor
Predicted values of shape :math:`(B, D)`.
target : torch.Tensor
Target values of shape :math:`(B, D)`.
Returns
-------
torch.Tensor
Scalar loss value.
"""
return torch.nn.functional.mse_loss(pred, target)Anti-pattern:
# WRONG: No docstring
def helper_method(self, x):
return x * 2
# WRONG: Using wrong section names
def compute_loss(self, pred, target):
"""
Args:
pred: predictions
Returns:
loss
"""
passDescription:
All tensors should be documented with their shape, using LaTeX math notation
such as :math:(N, C, H_{in}, W_{in})``. There is flexibility for naming the
dimensions, but the math format should be enforced.
Our documentation is rendered using LaTeX, and supports a rich set of LaTeX commands, so it is recommended to use LaTeX commands whenever possible for mathematical variables in the docstrings. The mathematical notations should be to some degree consistent with the actual variable names in the code.
Rationale:
LaTeX math notation ensures tensor shapes render correctly and consistently in Sphinx documentation. This is critical for scientific software where precise mathematical notation is expected. Plain text shapes don't render properly and can be ambiguous.
Example:
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
Process input tensor.
Parameters
----------
x : torch.Tensor
Input of shape :math:`(B, C, H_{in}, W_{in})` where :math:`B` is batch
size, :math:`C` is channels, and :math:`H_{in}, W_{in}` are spatial dims.
Returns
-------
torch.Tensor
Output of shape :math:`(B, C_{out}, H_{out}, W_{out})`.
"""
passAnti-pattern:
# WRONG: Not using :math: notation
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x : torch.Tensor
Input of shape (B, C, H, W) # Missing :math:`...`
"""
passDescription:
For arguments or variables that are callback functions (e.g. Callable), the
docstring should include a clear separated ..code-block:: that specifies the
required signature and return type of the callback function. This is not only
true for callback functions, but for any type of parameters or arguments that
has some complex type specification or API requirements.
The explanation code block should be placed in the top or bottom section of the
docstrings, but not in the Parameters or Forward or Outputs sections, for
readability and clarity.
Rationale:
Callback functions have complex type signatures that are difficult to express clearly in the Parameters section alone. A dedicated code-block provides a clear visual reference for the expected signature, making it much easier for users to implement compatible callbacks.
Example:
class MyModel(Module):
r"""
Model with callback function.
.. code-block:: python
def preprocess_fn(x: torch.Tensor) -> torch.Tensor:
'''Preprocessing function signature.'''
...
return y
where ``x`` is input of shape :math:`(B, D_{in})` and ``y`` is output
of shape :math:`(B, D_{out})`.
Parameters
----------
preprocess_fn : Callable[[torch.Tensor], torch.Tensor], optional
Optional preprocessing function. See code block above for signature.
"""
passAnti-pattern:
# WRONG: No code-block specification
class MyModel(Module):
r"""
Parameters
----------
preprocess_fn : Callable, optional
Preprocessing function. # No specification!
"""
passDescription:
Inline code should be formatted with double backticks, such as my_variable.
Single backticks are not allowed as they don't render properly in our Sphinx
documentation.
Rationale:
Sphinx uses reStructuredText, which requires double backticks for inline code literals. Single backticks are interpreted differently and don't produce the expected code formatting in the rendered documentation.
Example:
class MyModel(Module):
r"""
Model with inline code references.
If ``True``, enables dropout. Set ``model.training`` to control behavior.
The parameter ``hidden_dim`` controls layer size.
Parameters
----------
hidden_dim : int
Size of hidden layer. Access via ``self.hidden_dim``.
"""
passAnti-pattern:
# WRONG: Using single backticks
class MyModel(Module):
r"""
If `True`, enables dropout. # WRONG
"""
passDescription:
All parameters should be documented with their type and default values on a single line, following the NumPy docstring style format:
parameter_name : type, optional, default=value
The description then follows on the next line(s), indented.
Rationale:
This standardized format makes parameter documentation consistent and easy to parse. It provides all key information (name, type, optionality, default) at a glance, improving readability.
Example:
class MyModel(Module):
r"""
Parameters
----------
input_dim : int
Dimension of input features.
hidden_dim : int, optional, default=128
Dimension of hidden layer.
dropout : float, optional, default=0.1
Dropout probability.
"""
passAnti-pattern:
# WRONG: Type and default not on same line
class MyModel(Module):
r"""
Parameters
----------
hidden_dim : int
optional, default=128 # Should be on line above
Dimension of hidden layer.
"""
passDescription:
When possible, docstrings should use links to other docstrings using Sphinx cross-reference syntax:
- Classes:
:class:~physicsnemo.models.some_model.SomeModel`` - Functions:
:func:~physicsnemo.utils.common_function`` - Methods:
:meth:~physicsnemo.models.some_model.SomeModel.some_method``
When referencing external resources, such as papers, websites, or other
documentation, docstrings should use links to the external resource in the
format some link text <some_url>_.
Rationale:
Cross-references create a navigable documentation structure where users can easily jump between related classes, methods, and functions. External links provide context and attribution for algorithms and techniques.
Example:
class MyEncoder(Module):
r"""
Encoder using attention.
Based on `Transformer Architecture <https://arxiv.org/abs/1706.03762>`_.
See :class:`~physicsnemo.nn.MultiHeadAttention` for attention details.
Parameters
----------
activation : str
Activation function. See :func:`~torch.nn.functional.relu`.
"""
passAnti-pattern:
# Not wrong, but missing opportunities for useful links
class MyEncoder(Module):
r"""
Uses MultiHeadAttention. # Could link to class
Based on Transformer paper. # Could link to paper
"""
passDescription:
Docstrings are strongly encouraged to have an Examples section that
demonstrates basic construction and usage of the model. These example sections
serve as both documentation and tests, as our CI system automatically tests
these code sections for correctness when present.
Examples should be executable Python code showing typical use cases, including model instantiation, input preparation, and forward pass execution. The examples should use realistic tensor shapes and demonstrate key features of the model.
Rationale:
Example sections provide immediate value to users by showing concrete usage patterns. By automatically testing these examples in CI, we ensure that documentation stays synchronized with code and that examples remain correct as the codebase evolves.
Example:
class MyEncoder(Module):
r"""
A simple encoder network.
Parameters
----------
input_dim : int
Dimension of input features.
Examples
--------
>>> import torch
>>> from physicsnemo.models import MyEncoder
>>> model = MyEncoder(input_dim=784, output_dim=128)
>>> x = torch.randn(32, 784)
>>> output = model(x)
>>> output.shape
torch.Size([32, 128])
"""
passAnti-pattern:
# Not wrong, but discouraged - no Examples section
class MyEncoder(Module):
r"""
Parameters
----------
input_dim : int
Dimension of input features.
"""
passDescription:
Model code that involves complex tensor operations should include high-level comments that explain what blocks of code accomplish semantically. One-line comments every few lines of tensor operations is sufficient.
Comments should focus on high-level semantic explanations rather than low-level syntactic details. For example, use "Compute the encodings" instead of "Doing a concatenation followed by a linear projection, followed by a nonlinear activation". The goal is to give a high-level overview of what a block of tensor operations accomplishes.
When multiple tensor operations are chained, it is welcomed to add short inline comments with the tensor shapes of computed tensors, e.g.:
x = torch.cat([y, z], dim=1) # (B, 2*C_in, H, W)The symbols chosen in the comments should be consistent with the docstring (possibly shortened versions of dimension names for explicitness).
Rationale:
High-level comments make complex tensor manipulation code more understandable without cluttering it with excessive detail. Shape annotations help developers track tensor dimensions through complex operations, catching shape mismatches early. Consistency with docstring notation creates a unified mental model.
Example:
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
"""Process input with context conditioning."""
# Encode input features
h = self.encoder(x) # (B, C_enc, H, W)
# Combine with context information
c = self.context_proj(context) # (B, C_enc)
c = c[:, :, None, None].expand(-1, -1, h.shape[2], h.shape[3]) # (B, C_enc, H, W)
h = torch.cat([h, c], dim=1) # (B, 2*C_enc, H, W)
# Apply attention mechanism
h = self.attention(h) # (B, 2*C_enc, H, W)
# Decode to output
out = self.decoder(h) # (B, C_out, H, W)
return outAnti-pattern:
# WRONG: No comments
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
h = self.encoder(x)
c = self.context_proj(context)
c = c[:, :, None, None].expand(-1, -1, h.shape[2], h.shape[3])
h = torch.cat([h, c], dim=1)
return self.decoder(self.attention(h))
# WRONG: Too low-level, syntactic comments
def forward(self, x, context):
# Pass x through encoder layer
h = self.encoder(x)
# Project context using linear layer
c = self.context_proj(context)
# Add two None dimensions and expand
c = c[:, :, None, None].expand(-1, -1, h.shape[2], h.shape[3])Description:
All utility functions for a model class should be organized together with the model class in a clear and logical structure. Acceptable patterns include:
- A single self-contained file:
physicsnemo/<models or nn>/model_name.py - A subdirectory:
physicsnemo/<models or nn>/model_name/containing:-
model_name.pywith the main model class - Additional modules for utility functions specific to this model
-
What should be avoided is a flat organization where model files and their
utility files are all mixed together in physicsnemo/<models or nn>/, making it
unclear which utilities belong to which models.
The only exception is when a utility function is used across multiple models. In that case, the shared utility should be placed in an appropriate shared module.
Rationale:
Self-contained modules are easier to understand, maintain, and navigate. Having all model-specific code in one place reduces cognitive load and makes it clear which utilities are model-specific versus shared. This also simplifies code reviews and reduces the likelihood of orphaned utility files when models are refactored or removed.
Example:
# Good Pattern 1: Single self-contained file
# File: physicsnemo/models/my_simple_model.py
def _compute_attention_mask(seq_length: int) -> torch.Tensor:
"""Helper function specific to MySimpleModel."""
mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1)
return mask.masked_fill(mask == 1, float('-inf'))
class MySimpleModel(Module):
"""A simple model with utilities in same file."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
mask = _compute_attention_mask(x.shape[1])
return self._apply_attention(x, mask)
# Good Pattern 2: Subdirectory organization
# File: physicsnemo/models/my_complex_model/my_complex_model.py
from physicsnemo.models.my_complex_model.utils import helper_function
class MyComplexModel(Module):
"""A complex model with utilities in subdirectory."""
pass
# File: physicsnemo/models/my_complex_model/utils.py
def helper_function(x):
"""Utility specific to MyComplexModel."""
passAnti-pattern:
# WRONG: Flat organization with utilities mixed in main directory
# File: physicsnemo/models/my_transformer.py
from physicsnemo.models.my_transformer_utils import _compute_mask # WRONG
class MyTransformer(Module):
pass
# File: physicsnemo/models/my_transformer_utils.py (WRONG: mixed with other models)
# File: physicsnemo/models/other_model.py
# File: physicsnemo/models/other_model_utils.py (WRONG: utilities scattered)
# All mixed together in flat structure - unclear organization!Description:
All forward methods and other public methods that accept tensor arguments must validate tensor shapes at the beginning of the method. This rule applies to:
- Individual tensor arguments
- Containers of tensors (lists, tuples, dictionaries)
For containers, validate their length, required keys, and the shapes of
contained tensors. Validation statements should be concise (ideally one check
per argument). Error messages must follow the standardized format:
"Expected tensor of shape (B, D) but got tensor of shape {actual_shape}".
To avoid interactions with torch.compile, all validation must be wrapped in a
conditional check using torch.compiler.is_compiling(). Follow the "fail-fast"
approach by validating inputs before any computation.
Rationale:
Early shape validation catches errors at the API boundary with clear, actionable
error messages, making debugging significantly easier. Without validation, shape
mismatches result in cryptic errors deep in the computation graph. The
torch.compile guard ensures that validation overhead is eliminated in
production compiled code while preserving debug-time safety.
Example:
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward pass with shape validation."""
### Input validation
# Skip validation when running under torch.compile for performance
if not torch.compiler.is_compiling():
# Extract expected dimensions
B, C, H, W = x.shape if x.ndim == 4 else (None, None, None, None)
# Validate x shape
if x.ndim != 4:
raise ValueError(
f"Expected 4D input tensor (B, C, H, W), got {x.ndim}D tensor with shape {tuple(x.shape)}"
)
if C != self.in_channels:
raise ValueError(
f"Expected {self.in_channels} input channels, got {C} channels"
)
# Validate optional mask
if mask is not None:
if mask.shape != (B, H, W):
raise ValueError(
f"Expected mask shape ({B}, {H}, {W}), got {tuple(mask.shape)}"
)
# Actual computation happens after validation
return self._process(x, mask)
def process_list(self, tensors: List[torch.Tensor]) -> torch.Tensor:
"""Process a list of tensors with validation."""
### Input validation
if not torch.compiler.is_compiling():
if len(tensors) == 0:
raise ValueError("Expected non-empty list of tensors")
# Validate all tensors have consistent shapes
ref_shape = tensors[0].shape
for i, t in enumerate(tensors[1:], start=1):
if t.shape != ref_shape:
raise ValueError(
f"All tensors must have the same shape. "
f"Tensor 0 has shape {tuple(ref_shape)}, "
f"but tensor {i} has shape {tuple(t.shape)}"
)
return torch.stack(tensors)Anti-pattern:
# WRONG: No validation at all
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layer(x) # Will fail with cryptic error if shape is wrong
# WRONG: Validation not guarded by torch.compiler.is_compiling()
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.ndim != 4: # Breaks torch.compile
raise ValueError(f"Expected 4D tensor, got {x.ndim}D")
return self.layer(x)
# WRONG: Validation after computation has started
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.layer1(x) # Computation started
if x.shape[1] != self.in_channels: # Too late!
raise ValueError(f"Wrong number of channels")
return self.layer2(h)
# WRONG: Non-standard error message format
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not torch.compiler.is_compiling():
if x.ndim != 4:
raise ValueError("Input must be 4D") # Missing actual shape info
return self.layer(x)Description:
All tensor arguments and variables in model __init__, forward, and other
public methods must have type annotations using jaxtyping. This provides
runtime-checkable shape information in type hints.
Use the format Float[torch.Tensor, "shape_spec"] where shape_spec describes
tensor dimensions using space-separated dimension names (e.g., "batch channels height width"
or "b c h w").
Rationale:
Jaxtyping annotations provide explicit, machine-readable documentation of expected tensor shapes. This enables better IDE support, catches shape errors earlier, and makes code more self-documenting. The annotations serve as both documentation and optional runtime checks when jaxtyping's validation is enabled.
Example:
from jaxtyping import Float
import torch
class MyConvNet(Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3)
def forward(
self,
x: Float[torch.Tensor, "batch in_channels height width"]
) -> Float[torch.Tensor, "batch out_channels height width"]:
"""Process input with convolution."""
return self.conv(x)
def process_attention(
query: Float[torch.Tensor, "batch seq_len d_model"],
key: Float[torch.Tensor, "batch seq_len d_model"],
value: Float[torch.Tensor, "batch seq_len d_model"]
) -> Float[torch.Tensor, "batch seq_len d_model"]:
"""Compute attention with clear shape annotations."""
passAnti-pattern:
# WRONG: No jaxtyping annotations
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layer(x)
# WRONG: Using plain comments instead of jaxtyping
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, channels, height, width) # Use jaxtyping instead
return self.layer(x)
# WRONG: Incomplete annotations (missing jaxtyping for tensor arguments)
def forward(
self,
x: Float[torch.Tensor, "b c h w"],
mask: torch.Tensor # Missing jaxtyping annotation
) -> Float[torch.Tensor, "b c h w"]:
return self.layer(x, mask)Description:
For any model in physicsnemo/nn or physicsnemo/models, adding new required
parameters (parameters without default values) to __init__ or any public
method is strictly forbidden. This breaks backward compatibility.
New parameters must have default values to ensure existing code and checkpoints
continue to work. If a new parameter is truly required, increment the model
version number using __model_checkpoint_version__ and add appropriate
versioning support.
Rationale:
Adding required parameters breaks all existing code that instantiates the model, and breaks loading of old checkpoints. This violates PhysicsNeMo's commitment to backward compatibility and would disrupt user workflows.
Example:
# Good: Adding parameter with default value (backward compatible)
class MyModel(Module):
__model_checkpoint_version__ = "2.0"
__supported_model_checkpoint_version__ = {
"1.0": "Loading checkpoint from version 1.0 (current is 2.0). Still supported."
}
def __init__(
self,
input_dim: int,
output_dim: int,
dropout: float = 0.0, # New parameter with default - backward compatible
new_feature: bool = False # New parameter with default - backward compatible
):
super().__init__(meta=MyModelMetaData())Anti-pattern:
# WRONG: Adding required parameter without default
class MyModel(Module):
__model_checkpoint_version__ = "2.0"
def __init__(
self,
input_dim: int,
output_dim: int,
new_param: int # WRONG: No default! Breaks old checkpoints
):
super().__init__(meta=MyModelMetaData())Description:
For any model in physicsnemo/nn or physicsnemo/models, removing or renaming
parameters is strictly forbidden without proper backward compatibility support.
If a parameter must be renamed or removed, the developer must:
- Increment
__model_checkpoint_version__ - Add the old version to
__supported_model_checkpoint_version__dict - Implement
_backward_compat_arg_mapperclassmethod to handle the mapping - Maintain support for the old API for at least 2 release cycles
Rationale:
Removing or renaming parameters breaks existing checkpoints and user code. Proper version management and argument mapping ensures old checkpoints can still be loaded and users have time to migrate to the new API.
Example:
from typing import Any, Dict
# Good: Proper backward compatibility for parameter rename
class MyModel(Module):
__model_checkpoint_version__ = "2.0"
__supported_model_checkpoint_version__ = {
"1.0": (
"Loading checkpoint from version 1.0 (current is 2.0). "
"Parameter 'hidden_dim' renamed to 'hidden_size'."
)
}
@classmethod
def _backward_compat_arg_mapper(
cls, version: str, args: Dict[str, Any]
) -> Dict[str, Any]:
"""Map arguments from older versions."""
args = super()._backward_compat_arg_mapper(version, args)
if version == "1.0":
# Map old parameter name to new name
if "hidden_dim" in args:
args["hidden_size"] = args.pop("hidden_dim")
# Remove deprecated parameters
if "legacy_param" in args:
_ = args.pop("legacy_param")
return args
def __init__(
self,
input_dim: int,
hidden_size: int = 128, # Renamed from 'hidden_dim'
):
super().__init__(meta=MyModelMetaData())Anti-pattern:
# WRONG: Renaming without backward compat
class MyModel(Module):
__model_checkpoint_version__ = "2.0"
# Missing: __supported_model_checkpoint_version__ and _backward_compat_arg_mapper
def __init__(self, input_dim: int, hidden_size: int): # Renamed!
super().__init__(meta=MyModelMetaData())
# WRONG: Old checkpoints with 'hidden_dim' will fail!
# WRONG: Not calling super()
class MyModel(Module):
@classmethod
def _backward_compat_arg_mapper(cls, version: str, args: Dict[str, Any]) -> Dict[str, Any]:
# WRONG: Missing super()._backward_compat_arg_mapper(version, args)
if "hidden_dim" in args:
args["hidden_size"] = args.pop("hidden_dim")
return argsDescription:
For any model in physicsnemo/nn or physicsnemo/models, changing the return
type of any public method (including forward) is strictly forbidden. This
includes:
- Changing from returning a single value to returning a tuple
- Changing from a tuple to a single value
- Changing the number of elements in a returned tuple
- Changing the type of returned values
If a return type change is absolutely necessary, create a new method with a different name and deprecate the old method following MOD-002b.
Rationale:
Changing return types is a breaking change that silently breaks user code. Users who unpack return values or depend on specific return structures will experience runtime errors. Unlike parameter changes (which can be managed with versioning), return type changes affect runtime behavior and are harder to detect.
Example:
# Good: Keeping consistent return type
class MyModel(Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Always returns single tensor."""
return self.process(x)
# Good: If new return is needed, add new method
class MyModel(Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Returns output tensor."""
output, loss = self._forward_with_loss(x)
return output
def forward_with_loss(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""New method for returning both output and loss."""
return self._forward_with_loss(x)Anti-pattern:
# WRONG: Changing return type
class MyModel(Module):
__model_checkpoint_version__ = "2.0"
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# WRONG: v1.0 returned single tensor, v2.0 returns tuple - breaks user code!
return output, lossDescription:
Every model in physicsnemo/nn or physicsnemo/models must have tests that
verify model instantiation and all public attributes (excluding buffers and
parameters).
These tests should:
- Use
pytestparameterization to test at least 2 configurations - Test one configuration with all default arguments
- Test another configuration with non-default arguments
- Verify all public attributes have expected values
Rationale:
Constructor tests ensure the model can be instantiated correctly with various configurations and that all attributes are properly initialized. This catches issues early in the development cycle.
Example:
@pytest.mark.parametrize(
"config",
["default", "custom"],
ids=["with_defaults", "with_custom_args"]
)
def test_my_model_constructor(config):
"""Test model constructor and attributes."""
if config == "default":
model = MyModel(input_dim=64, output_dim=32)
assert model.hidden_dim == 128 # Default value
assert model.dropout == 0.0 # Default value
else:
model = MyModel(
input_dim=64,
output_dim=32,
hidden_dim=256,
dropout=0.1
)
assert model.hidden_dim == 256
assert model.dropout == 0.1
# Test common attributes
assert model.input_dim == 64
assert model.output_dim == 32Anti-pattern:
# WRONG: Only testing default configuration
def test_my_model_bad():
model = MyModel(input_dim=64, output_dim=32)
# Only tests defaultsDescription:
Every model must have non-regression tests that:
- Instantiate the model with reproducible random parameters
- Run forward pass with test data
- Compare outputs against reference data saved in a
.pthfile
Requirements:
- Use
pytestparameterization to test multiple configurations - Test tensors must have realistic shapes (no singleton dimensions except batch)
- Test data should be meaningful and representative of actual use cases
- Compare actual tensor values, not just shapes
- All public methods (not just forward) need similar non-regression tests
Critical: Per MOD-002a, models cannot move out of experimental without these tests.
Rationale:
Non-regression tests with reference data catch subtle numerical changes that could break reproducibility. Simply checking output shapes is insufficient to detect algorithmic changes or numerical instabilities.
Example:
import pytest
import torch
from physicsnemo.models import MyModel
def _instantiate_model(cls, seed: int = 0, **kwargs):
"""Helper to create model with reproducible parameters."""
model = cls(**kwargs)
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
with torch.no_grad():
for param in model.parameters():
param.copy_(torch.randn(param.shape, generator=gen, dtype=param.dtype))
return model
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
@pytest.mark.parametrize("config", ["default", "custom"])
def test_my_model_non_regression(device, config):
"""Test model forward pass against reference output."""
if config == "default":
model = _instantiate_model(MyModel, input_dim=64, output_dim=32)
else:
model = _instantiate_model(
MyModel,
input_dim=64,
output_dim=32,
hidden_dim=256
)
model = model.to(device)
# Load reference data (meaningful shapes, no singleton dimensions)
data = torch.load(f"test/models/data/my_model_{config}_v1.0.pth")
x = data["x"].to(device) # Shape: (4, 64), not (1, 64)
out_ref = data["out"].to(device)
# Run forward and compare values
out = model(x)
assert torch.allclose(out, out_ref, atol=1e-5, rtol=1e-5)Anti-pattern:
# WRONG: Only testing output shapes
def test_my_model_bad(device):
model = MyModel(input_dim=64, output_dim=32).to(device)
x = torch.randn(4, 64).to(device)
out = model(x)
assert out.shape == (4, 32) # NOT SUFFICIENT!
# WRONG: Using singleton dimensions
def test_my_model_bad(device):
x = torch.randn(1, 1, 64) # WRONG: Trivial shapesDescription:
Every model must have tests that load the model from a checkpoint file
(.mdlus) using physicsnemo.Module.from_checkpoint() and verify that:
- The model loads successfully
- All public attributes have expected values
- Forward pass outputs match reference data
This ensures the model's serialization and deserialization work correctly.
Critical: Per MOD-002a, models cannot move out of experimental without these tests.
Rationale:
Checkpoint tests verify that the model's custom serialization logic works correctly and that saved models can be loaded in different environments. This is critical for reproducibility and for users who need to save and load trained models.
Example:
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_my_model_from_checkpoint(device):
"""Test loading model from checkpoint and verify outputs."""
model = physicsnemo.Module.from_checkpoint(
"test/models/data/my_model_default_v1.0.mdlus"
).to(device)
# Verify attributes after loading
assert model.input_dim == 64
assert model.output_dim == 32
# Load reference data and verify outputs
data = torch.load("test/models/data/my_model_default_v1.0.pth")
x = data["x"].to(device)
out_ref = data["out"].to(device)
out = model(x)
assert torch.allclose(out, out_ref, atol=1e-5, rtol=1e-5)Anti-pattern:
# WRONG: No checkpoint loading test
# (Missing test_my_model_from_checkpoint entirely)Description:
Passing a string that represents a class name, which is then used to instantiate an internal submodule, should be avoided unless there are only a few choices (2 or 3 maximum) for the class name.
When there are more than 2-3 choices, the recommended practice is to pass an already instantiated instance of a submodule instead of a string primitive for dependency injection. This promotes better type safety, clearer APIs, and easier testing.
Rationale:
String-based class selection makes code harder to type-check, debug, and test. It obscures dependencies and makes it difficult for static analysis tools to understand the code. Direct instance injection provides better IDE support, type safety, and makes testing easier by allowing mock object injection.
Example:
# Good: Limited choices (2-3 max) - string selection acceptable
class MyModel(Module):
def __init__(
self,
activation: Literal["relu", "gelu"] = "relu"
):
if activation == "relu":
self.act = nn.ReLU()
elif activation == "gelu":
self.act = nn.GELU()
# Good: Many choices - use instance injection
class MyModel(Module):
def __init__(
self,
encoder: Module, # Pass instance, not string
decoder: Module # Pass instance, not string
):
self.encoder = encoder
self.decoder = decoder
# Usage:
model = MyModel(
encoder=MyCustomEncoder(dim=128),
decoder=MyCustomDecoder(dim=128)
)Anti-pattern:
# WRONG: String selection with many choices
class MyModel(Module):
def __init__(
self,
encoder_type: str = "transformer" # Many possible values
):
# String-based factory pattern with 10+ choices
if encoder_type == "transformer":
self.encoder = TransformerEncoder()
elif encoder_type == "cnn":
self.encoder = CNNEncoder()
elif encoder_type == "rnn":
self.encoder = RNNEncoder()
# ... many more options
# WRONG: Should accept encoder instance insteadDescription:
Passing splatted arguments like **kwargs_for_submodules should be avoided in
model constructors as it might create conflicts in the names of these kwargs and
makes the API unclear.
Instead, it is recommended to pass non-splatted arguments in the form of a
Dict when configuration for submodules needs to be passed through. This makes
parameter passing explicit and avoids naming conflicts.
Rationale:
Splatted kwargs obscure the actual parameters being passed, make type checking impossible, and can lead to subtle bugs from name conflicts. Explicit dictionary parameters make the API clearer and enable better IDE support and error detection.
Example:
# Good: Explicit dict parameter
class MyModel(Module):
def __init__(
self,
input_dim: int,
output_dim: int,
encoder_config: Optional[Dict[str, Any]] = None
):
encoder_config = encoder_config or {}
self.encoder = Encoder(input_dim=input_dim, **encoder_config)
# Usage:
model = MyModel(
input_dim=64,
output_dim=32,
encoder_config={"hidden_dim": 128, "num_layers": 3}
)Anti-pattern:
# WRONG: Splatted kwargs
class MyModel(Module):
def __init__(
self,
input_dim: int,
output_dim: int,
**encoder_kwargs # WRONG: Unclear what's accepted
):
self.encoder = Encoder(input_dim=input_dim, **encoder_kwargs)
# Risk of name conflicts, unclear API
# Usage - unclear what parameters are valid:
model = MyModel(input_dim=64, output_dim=32, hidden_dim=128, num_layers=3)
# Are hidden_dim and num_layers for MyModel or Encoder? Unclear!Description:
When a model requires optional dependencies (packages not installed by default), use the PhysicsNeMo APIs for dependency handling:
-
check_min_version(package, version, hard_fail=False): Use this function to check if a package is installed and available without actually importing it. Sethard_fail=Truefor hard requirements,hard_fail=Falsefor soft requirements. This is the primary method for handling optional dependencies. -
@require_version(package, version): Use this decorator when core code must always be available but certain features need to be protected against older versions. This is rare and should only be used when you need to protect specific methods or classes against version incompatibilities. -
pyproject.toml: This file is the one, only, and universal source of truth for all dependencies in PhysicsNeMo. All optional dependencies must be declared there.
Rationale:
Centralized dependency handling ensures consistent error messages and version
checking across the codebase. Checking availability without importing prevents
import errors and allows graceful degradation when optional packages are not
available. Using pyproject.toml as the single source of truth prevents
dependency specification from becoming scattered and inconsistent.
Example:
import torch
from physicsnemo.core import Module
from physicsnemo.core.version_check import check_min_version, require_version
# Check optional dependency availability without importing
APEX_AVAILABLE = check_min_version("apex", "0.1.0", hard_fail=False)
class MyModel(Module):
def __init__(
self,
input_dim: int,
use_apex: bool = False
):
super().__init__()
self.use_apex = use_apex
if use_apex and not APEX_AVAILABLE:
raise RuntimeError(
"apex is required for use_apex=True but is not installed. "
"Install with: pip install apex>=0.1.0"
)
if use_apex:
import apex # Only import when actually needed
self.fused_layer = apex.FusedLayer()
else:
self.fused_layer = None
# Using @require_version for protecting version-specific features
class AdvancedModel(Module):
@require_version("torch", "2.4.0")
def use_device_mesh(self):
"""This feature requires torch>=2.4.0."""
from torch.distributed.device_mesh import DeviceMesh
# Protected code that needs torch>=2.4.0Anti-pattern:
# WRONG: Direct import without checking availability
import apex # Will fail if apex not installed!
class MyModel(Module):
def __init__(self, use_apex: bool = False):
if use_apex:
self.layer = apex.FusedLayer() # Already failed at import!
# WRONG: Try/except for dependency checking
try:
import apex
APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
# Use check_min_version instead!
# WRONG: Hardcoded version strings in multiple places
if version.parse(apex.__version__) < version.parse("0.1.0"):
raise ImportError("apex>=0.1.0 required")
# Should use check_min_version or require_version!
# WRONG: Not declaring dependency in pyproject.toml
# All optional dependencies must be in pyproject.toml!When implementing models, ensure all rules are followed. Code reviews should verify each rule is followed and enforce the rules as strictly as possible. For exceptions to these rules, document the reasoning in code comments and obtain approval during code review.