Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
4157a20
fix:orb squeeze incorrect energy shape
thomasloux Sep 18, 2025
646ddf5
Merge branch 'TorchSim:main' into main
thomasloux Oct 8, 2025
69ee796
Merge branch 'TorchSim:main' into main
thomasloux Oct 10, 2025
38c6138
First draft constraints
thomasloux Oct 17, 2025
6eb3d78
change base class name for constraint
thomasloux Oct 21, 2025
c630f39
remove useless methods
thomasloux Oct 21, 2025
bfdf6de
Merge branch 'main' into features/constraints
thomasloux Oct 21, 2025
f5459b9
change redundant definition
thomasloux Oct 21, 2025
6b2710e
constraint to optimizer, compatibility with state manipulation
thomasloux Oct 23, 2025
c955273
Merge branch 'features/constraints' of https://github.com/thomasloux/…
thomasloux Oct 23, 2025
7d63069
test temperature, adapt calc_kt for reduced degrees of freedom
thomasloux Oct 23, 2025
e1388fd
Merge branch 'main' into pr/thomasloux/294
janosh Nov 10, 2025
ad4fa0a
fix typo + unreleased changelog entry
janosh Nov 10, 2025
8beb9d9
renamed validate_constraints now called in SimState.add_constraints a…
janosh Nov 10, 2025
c577e1d
tests for constraint validation warnings and errors
janosh Nov 10, 2025
9cfe52b
refactor to use getter setter and _constraints
thomasloux Nov 10, 2025
be30d45
remove edge case slice(None)
thomasloux Nov 10, 2025
33d6025
new API (remove slice(None) and _constraint as private var
thomasloux Nov 10, 2025
399fbfd
correct get_centers_of_mass
thomasloux Nov 10, 2025
b31ba80
add warnings for npt dynamics
thomasloux Nov 10, 2025
1483977
simplify state updating in _filter_attrs_by_mask
orionarcher Nov 18, 2025
3c267eb
simplify _split_state with select_sub_constraint function
orionarcher Nov 19, 2025
06400e1
make constraint handling more modular with methods, merge states curr…
orionarcher Nov 19, 2025
4081973
No longer allow initializing FixCom() or FixAtoms() with empty arguments
orionarcher Nov 21, 2025
35749c3
vibe code and verify some tests
orionarcher Nov 21, 2025
8c067fb
Merge pull request #1 from TorchSim/contraints
thomasloux Nov 24, 2025
0688bfe
rename update_constraint to select_constraint, remove None Constraint…
thomasloux Nov 24, 2025
be161e3
change to _constraint name
thomasloux Nov 24, 2025
6afab52
revert to previous return as it actually also change the device/dtype…
thomasloux Nov 24, 2025
4aa1447
use post_init to enforce constraint on forces
thomasloux Nov 24, 2025
e61e452
constraint is not a global_attrs anymore
thomasloux Nov 24, 2025
8144ed6
increase slightly steps to test FixCom
thomasloux Nov 24, 2025
940827b
add _constraint to attributes so that it's kept when cloning simstate
thomasloux Nov 24, 2025
1cbd0b0
compute com for all and only subselect depending on system_idx, remov…
thomasloux Nov 24, 2025
6e09895
remove comments
thomasloux Nov 24, 2025
7d8890f
remove comment and raise if dof is negative
thomasloux Nov 24, 2025
be55c9b
remove unwrap_pos and add dummy state to test for validate_constraints
thomasloux Nov 24, 2025
33c6e92
ruff happy, simplify function
thomasloux Nov 24, 2025
d99a1a7
test for unwrap_positions
thomasloux Nov 24, 2025
50d566f
Merge branch 'main' into features/constraints
thomasloux Nov 24, 2025
eb26975
silence ruff
thomasloux Nov 24, 2025
c15a012
modify args names
thomasloux Nov 24, 2025
87644fa
reduce precision for test_unwrap
thomasloux Nov 24, 2025
95857b0
updates names
thomasloux Nov 24, 2025
7022df2
remove einsteinModel (not for this PR)
thomasloux Nov 24, 2025
07624f0
rename var and add mask
thomasloux Nov 26, 2025
0940919
remove comment now that a warning is set up for NPT MD with constraints
thomasloux Nov 26, 2025
b49e309
Add duplicate error in FixAtoms (subclass of AtomConstraint will hand…
thomasloux Nov 26, 2025
65fd0cf
rename args FixAtoms tests
thomasloux Nov 26, 2025
fee207f
system_idx for constraint must be dim 1
thomasloux Nov 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch

import torch_sim as ts
from tests.conftest import DTYPE
from torch_sim.constraints import FixAtoms, FixCom
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.transforms import unwrap_positions
from torch_sim.units import MetalUnits


def test_fix_com_nvt_langevin(
ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel
):
n_steps = 1000
dt = torch.tensor(0.001, dtype=DTYPE)
kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature

dofs_before = ar_double_sim_state.calc_dof()
ar_double_sim_state.constraints = [FixCom()]
assert torch.allclose(ar_double_sim_state.calc_dof(), dofs_before - 3)

state = ts.nvt_langevin_init(
state=ar_double_sim_state, model=lj_model, kT=kT, seed=42
)
positions = []
system_masses = torch.zeros((state.n_systems, 1), dtype=DTYPE).scatter_add_(
0,
state.system_idx.unsqueeze(-1).expand(-1, 1),
state.masses.unsqueeze(-1),
)
for _step in range(n_steps):
state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT)
positions.append(state.positions.clone())
traj_positions = torch.stack(positions)

unwrapped_positions = unwrap_positions(
traj_positions, ar_double_sim_state.cell, state.system_idx
)
coms = torch.zeros((n_steps, state.n_systems, 3), dtype=DTYPE).scatter_add_(
1,
state.system_idx[None, :, None].expand(n_steps, -1, 3),
state.masses.unsqueeze(-1) * unwrapped_positions,
)
coms /= system_masses
coms_drift = coms - coms[0]
assert torch.allclose(coms_drift, torch.zeros_like(coms_drift), atol=1e-4)


def test_fix_atoms_nvt_langevin(
ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel
):
n_steps = 1000
dt = torch.tensor(0.001, dtype=DTYPE)
kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature

dofs_before = ar_double_sim_state.calc_dof()
ar_double_sim_state.constraints = [
FixAtoms(indices=torch.tensor([0, 1], dtype=torch.long))
]
assert torch.allclose(
ar_double_sim_state.calc_dof(), dofs_before - torch.tensor([6, 0])
)
state = ts.nvt_langevin_init(
state=ar_double_sim_state, model=lj_model, kT=kT, seed=42
)
positions = []
for _step in range(n_steps):
state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT)
positions.append(state.positions.clone())
traj_positions = torch.stack(positions)

unwrapped_positions = unwrap_positions(
traj_positions, ar_double_sim_state.cell, state.system_idx
)
diff_positions = unwrapped_positions - unwrapped_positions[0]
assert torch.max(diff_positions[:, :2]) < 1e-8
assert torch.max(diff_positions[:, 2:]) > 1e-2
2 changes: 1 addition & 1 deletion tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_get_attrs_for_scope(si_sim_state: SimState) -> None:
per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system"))
assert set(per_system_attrs) == {"cell"}
global_attrs = dict(get_attrs_for_scope(si_sim_state, "global"))
assert set(global_attrs) == {"pbc"}
assert set(global_attrs) == {"pbc", "constraints"}


def test_all_attributes_must_be_specified_in_scopes() -> None:
Expand Down
1 change: 1 addition & 0 deletions torch_sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch_sim as ts
from torch_sim import (
autobatching,
constraints,
elastic,
io,
math,
Expand Down
Loading
Loading