diff --git a/CHANGELOG.md b/CHANGELOG.md index b2e87489..9f662a8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,15 @@ # Changelog +## Unreleased + +### 🎉 New Features +* Constraints support for molecular dynamics and optimization by @thomasloux in [#294](https://github.com/TorchSim/torch-sim/pull/294) + - Added `FixAtoms` constraint to fix specific atoms in place + - Added `FixCom` constraint to prevent center of mass drift + - Constraints automatically adjust degrees of freedom for accurate temperature calculations + - Full support across all integrators (NVE, NVT, NPT) and optimizers (FIRE, Gradient Descent) + - Constraints preserved during state manipulation (slicing, splitting, concatenation) ## v0.4.1 Thank you to everyone who contributed to this release! This release includes important bug fixes, new features, and API improvements. @thomasloux, @curtischong, @CompRhys, @orionarcher, @WillEngler, @samanvya10, @hn-yu, @wendymak8, @chuin-wei, @pragnya17, and many others made valuable contributions. 🚀 diff --git a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py index 8a25f0c9..ee638199 100644 --- a/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py +++ b/examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py @@ -67,9 +67,8 @@ class HybridSwapMCState(ts.SwapMCState, MDState): last_swap: Last swap attempted """ - last_permutation: torch.Tensor _atom_attributes = ( - ts.SwapMCState._atom_attributes | MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ts.SwapMCState._atom_attributes | MDState._atom_attributes # noqa: SLF001 ) _system_attributes = ( ts.SwapMCState._system_attributes | MDState._system_attributes # noqa: SLF001 diff --git a/examples/tutorials/hybrid_swap_tutorial.py b/examples/tutorials/hybrid_swap_tutorial.py index 4ceca9f9..2f959fb6 100644 --- a/examples/tutorials/hybrid_swap_tutorial.py +++ b/examples/tutorials/hybrid_swap_tutorial.py @@ -100,9 +100,11 @@ class HybridSwapMCState(SwapMCState, MDState): from MDState. """ - last_permutation: torch.Tensor _atom_attributes = ( - MDState._atom_attributes | {"last_permutation"} # noqa: SLF001 + ts.SwapMCState._atom_attributes | MDState._atom_attributes # noqa: SLF001 + ) + _system_attributes = ( + ts.SwapMCState._system_attributes | MDState._system_attributes # noqa: SLF001 ) diff --git a/tests/test_constraints.py b/tests/test_constraints.py new file mode 100644 index 00000000..7704bb8e --- /dev/null +++ b/tests/test_constraints.py @@ -0,0 +1,839 @@ +from typing import get_args + +import pytest +import torch + +import torch_sim as ts +from tests.conftest import DTYPE +from torch_sim.constraints import ( + Constraint, + FixAtoms, + FixCom, + merge_constraints, + validate_constraints, +) +from torch_sim.models.interface import ModelInterface +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.optimizers import FireFlavor +from torch_sim.transforms import get_centers_of_mass +from torch_sim.units import MetalUnits + + +def test_fix_com(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel): + """Test adjustment of positions and momenta with FixCom constraint.""" + ar_supercell_sim_state.constraints = [FixCom([0])] + initial_positions = ar_supercell_sim_state.positions.clone() + ar_supercell_sim_state.set_positions(initial_positions + 0.5) + assert torch.allclose(ar_supercell_sim_state.positions, initial_positions, atol=1e-8) + + ar_supercell_md_state = ts.nve_init( + state=ar_supercell_sim_state, + model=lj_model, + kT=torch.tensor(10.0, dtype=DTYPE), + seed=42, + ) + ar_supercell_md_state.set_momenta( + torch.randn_like(ar_supercell_md_state.momenta) * 0.1 + ) + assert torch.allclose( + ar_supercell_md_state.momenta.mean(dim=0), + torch.zeros(3, dtype=DTYPE), + atol=1e-8, + ) + + +def test_fix_atoms(ar_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel): + """Test adjustment of positions and momenta with FixAtoms constraint.""" + indices_to_fix = torch.tensor([0, 5, 10], dtype=torch.long) + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=indices_to_fix)] + initial_positions = ar_supercell_sim_state.positions.clone() + # displacement = torch.randn_like(ar_supercell_sim_state.positions) * 0.5 + displacement = 0.5 + ar_supercell_sim_state.set_positions(initial_positions + displacement) + assert torch.allclose( + ar_supercell_sim_state.positions[indices_to_fix], + initial_positions[indices_to_fix], + atol=1e-8, + ) + # Check that other positions have changed + unfixed_indices = torch.tensor( + [i for i in range(ar_supercell_sim_state.n_atoms) if i not in indices_to_fix], + dtype=torch.long, + ) + assert not torch.allclose( + ar_supercell_sim_state.positions[unfixed_indices], + initial_positions[unfixed_indices], + atol=1e-8, + ) + + ar_supercell_md_state = ts.nve_init( + state=ar_supercell_sim_state, + model=lj_model, + kT=torch.tensor(10.0, dtype=DTYPE), + seed=42, + ) + ar_supercell_md_state.set_momenta( + torch.randn_like(ar_supercell_md_state.momenta) * 0.1 + ) + assert torch.allclose( + ar_supercell_md_state.momenta[indices_to_fix], + torch.zeros_like(ar_supercell_md_state.momenta[indices_to_fix]), + atol=1e-8, + ) + + +def test_fix_com_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJonesModel): + """Test FixCom constraint in NVT Langevin dynamics.""" + n_steps = 1000 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature + + dofs_before = cu_sim_state.get_number_of_degrees_of_freedom() + cu_sim_state.constraints = [FixCom([0])] + assert torch.allclose( + cu_sim_state.get_number_of_degrees_of_freedom(), dofs_before - 3 + ) + + state = ts.nvt_langevin_init(state=cu_sim_state, model=lj_model, kT=kT, seed=42) + positions = [] + system_masses = torch.zeros((state.n_systems, 1), dtype=DTYPE).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 1), + state.masses.unsqueeze(-1), + ) + temperatures = [] + for _step in range(n_steps): + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) + positions.append(state.positions.clone()) + temp = ts.calc_kT( + masses=state.masses, + momenta=state.momenta, + system_idx=state.system_idx, + dof_per_system=state.get_number_of_degrees_of_freedom(), + ) + temperatures.append(temp / MetalUnits.temperature) + temperatures = torch.stack(temperatures) + + traj_positions = torch.stack(positions) + + coms = torch.zeros((n_steps, state.n_systems, 3), dtype=DTYPE).scatter_add_( + 1, + state.system_idx[None, :, None].expand(n_steps, -1, 3), + state.masses.unsqueeze(-1) * traj_positions, + ) + coms /= system_masses + coms_drift = coms - coms[0] + assert torch.allclose(coms_drift, torch.zeros_like(coms_drift), atol=1e-6) + assert (torch.mean(temperatures[len(temperatures) // 2 :]) - 300) / 300 < 0.30 + + +def test_fix_atoms_nvt_langevin(cu_sim_state: ts.SimState, lj_model: LennardJonesModel): + """Test FixAtoms constraint in NVT Langevin dynamics.""" + n_steps = 1000 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature + + dofs_before = cu_sim_state.get_number_of_degrees_of_freedom() + cu_sim_state.constraints = [FixAtoms(atom_idx=torch.tensor([0, 1], dtype=torch.long))] + assert torch.allclose( + cu_sim_state.get_number_of_degrees_of_freedom(), dofs_before - torch.tensor([6]) + ) + state = ts.nvt_langevin_init(state=cu_sim_state, model=lj_model, kT=kT, seed=42) + positions = [] + temperatures = [] + for _step in range(n_steps): + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) + positions.append(state.positions.clone()) + temp = ts.calc_kT( + masses=state.masses, + momenta=state.momenta, + system_idx=state.system_idx, + dof_per_system=state.get_number_of_degrees_of_freedom(), + ) + temperatures.append(temp / MetalUnits.temperature) + temperatures = torch.stack(temperatures) + traj_positions = torch.stack(positions) + + diff_positions = traj_positions - traj_positions[0] + assert torch.max(diff_positions[:, :2]) < 1e-8 + assert torch.max(diff_positions[:, 2:]) > 1e-3 + assert (torch.mean(temperatures[len(temperatures) // 2 :]) - 300) / 300 < 0.30 + + +def test_state_manipulation_with_constraints(ar_double_sim_state: ts.SimState): + """Test that constraints are properly propagated during state manipulation.""" + # Set up constraints on the original state + ar_double_sim_state.constraints = [ + FixAtoms(atom_idx=torch.tensor([0, 1])), # Only applied to first system + FixCom([0, 1]), + ] + + # Extract individual systems from the double system state + first_system = ar_double_sim_state[0] # FixAtoms + FixCom + second_system = ar_double_sim_state[1] # FixCom only + concatenated_state = ts.concatenate_states( + [first_system, first_system, second_system] + ) + + # Verify constraint propagation to subsystems + assert len(first_system.constraints) == 2 + assert len(second_system.constraints) == 1 + assert len(concatenated_state.constraints) == 2 + + # Verify FixAtoms constraint indices are correctly mapped + assert torch.all(first_system.constraints[0].atom_idx == torch.tensor([0, 1])) + assert torch.all( + concatenated_state.constraints[0].atom_idx == torch.tensor([0, 1, 32, 33]) + ) + + # Verify FixCom constraint system masks + assert torch.all( + concatenated_state.constraints[1].system_idx == torch.tensor([0, 1, 2]) + ) + + # Test constraint propagation after splitting concatenated state + split_systems = concatenated_state.split() + assert len(split_systems[0].constraints) == 2 + assert torch.all(split_systems[0].constraints[0].atom_idx == torch.tensor([0, 1])) + assert torch.all(split_systems[1].constraints[0].atom_idx == torch.tensor([0, 1])) + assert len(split_systems[2].constraints) == 1 + + # Test constraint manipulation with different configurations + ar_double_sim_state.constraints = [] + ar_double_sim_state.constraints = [FixCom([0, 1])] + isolated_system = ar_double_sim_state[0] + assert torch.all( + isolated_system.constraints[0].system_idx == torch.tensor([0], dtype=torch.long) + ) + + # Test concatenation with mixed constraint states + isolated_system.constraints = [] + mixed_concatenated_state = ts.concatenate_states( + [isolated_system, ar_double_sim_state, isolated_system] + ) + assert torch.all( + mixed_concatenated_state.constraints[0].system_idx == torch.tensor([1, 2]) + ) + + +def test_fix_com_gradient_descent_optimization( + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface +) -> None: + """Test FixCom constraint in Gradient Descent optimization.""" + # Add some random displacement to positions + perturbed_positions = ( + ar_supercell_sim_state.positions + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + ar_supercell_sim_state.positions = perturbed_positions + initial_state = ar_supercell_sim_state + ar_supercell_sim_state.constraints = [FixCom([0])] + + initial_coms = get_centers_of_mass( + positions=initial_state.positions, + masses=initial_state.masses, + system_idx=initial_state.system_idx, + n_systems=initial_state.n_systems, + ) + + # Initialize Gradient Descent optimizer + state = ts.gradient_descent_init( + state=ar_supercell_sim_state, model=lj_model, lr=0.01 + ) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + while abs(energies[-2] - energies[-1]) > 1e-6: + state = ts.gradient_descent_step(state=state, model=lj_model, pos_lr=0.01) + energies.append(state.energy.item()) + + final_coms = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=initial_state.n_systems, + ) + + assert torch.allclose(final_coms, initial_coms, atol=1e-4) + assert not torch.allclose(state.positions, initial_state.positions) + + +def test_fix_atoms_gradient_descent_optimization( + ar_supercell_sim_state: ts.SimState, lj_model: ModelInterface +) -> None: + """Test FixAtoms constraint in Gradient Descent optimization.""" + # Add some random displacement to positions + perturbed_positions = ( + ar_supercell_sim_state.positions + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + ar_supercell_sim_state.positions = perturbed_positions + initial_state = ar_supercell_sim_state + initial_state.constraints = [FixAtoms(atom_idx=[0])] + initial_position = initial_state.positions[0].clone() + + # Initialize Gradient Descent optimizer + state = ts.gradient_descent_init( + state=ar_supercell_sim_state, model=lj_model, lr=0.01 + ) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + while abs(energies[-2] - energies[-1]) > 1e-6: + state = ts.gradient_descent_step(state=state, model=lj_model, pos_lr=0.01) + energies.append(state.energy.item()) + + final_position = state.positions[0] + + assert torch.allclose(final_position, initial_position, atol=1e-5) + assert not torch.allclose(state.positions, initial_state.positions) + + +@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) +def test_test_atoms_fire_optimization( + ar_supercell_sim_state: ts.SimState, + lj_model: ModelInterface, + fire_flavor: FireFlavor, +) -> None: + """Test FixAtoms constraint in FIRE optimization.""" + # Add some random displacement to positions + # Create a fresh copy for each test run to avoid interference + + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = ts.SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + indices = torch.tensor([0, 2], dtype=torch.long) + current_sim_state.constraints = [FixAtoms(atom_idx=indices)] + + # Initialize FIRE optimizer + state = ts.fire_init( + current_sim_state, lj_model, fire_flavor=fire_flavor, dt_start=0.1 + ) + initial_position = state.positions[indices].clone() + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + max_steps = 1000 # Add max step to prevent infinite loop + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.fire_step(state=state, model=lj_model, dt_max=0.3) + energies.append(state.energy.item()) + steps_taken += 1 + + final_position = state.positions[indices] + + assert torch.allclose(final_position, initial_position, atol=1e-5) + + +@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) +def test_fix_com_fire_optimization( + ar_supercell_sim_state: ts.SimState, + lj_model: ModelInterface, + fire_flavor: FireFlavor, +) -> None: + """Test FixCom constraint in FIRE optimization.""" + # Add some random displacement to positions + # Create a fresh copy for each test run to avoid interference + + current_positions = ( + ar_supercell_sim_state.positions.clone() + + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + + current_sim_state = ts.SimState( + positions=current_positions, + masses=ar_supercell_sim_state.masses.clone(), + cell=ar_supercell_sim_state.cell.clone(), + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(), + system_idx=ar_supercell_sim_state.system_idx.clone(), + ) + current_sim_state.constraints = [FixCom([0])] + + # Initialize FIRE optimizer + state = ts.fire_init( + current_sim_state, lj_model, fire_flavor=fire_flavor, dt_start=0.1 + ) + initial_com = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=state.n_systems, + ) + + # Run optimization for a few steps + energies = [1000, state.energy.item()] + max_steps = 1000 # Add max step to prevent infinite loop + steps_taken = 0 + while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps: + state = ts.fire_step(state=state, model=lj_model, dt_max=0.3) + energies.append(state.energy.item()) + steps_taken += 1 + + final_com = get_centers_of_mass( + positions=state.positions, + masses=state.masses, + system_idx=state.system_idx, + n_systems=state.n_systems, + ) + + assert torch.allclose(final_com, initial_com, atol=1e-4) + + +def test_fix_atoms_validation() -> None: + """Test FixAtoms construction and validation.""" + # Boolean mask conversion + mask = torch.zeros(10, dtype=torch.bool) + mask[:3] = True + assert torch.all(FixAtoms(atom_mask=mask).atom_idx == torch.tensor([0, 1, 2])) + + # Invalid indices + with pytest.raises(ValueError, match="Indices must be integers"): + FixAtoms(atom_idx=torch.tensor([0.5, 1.5])) + with pytest.raises(ValueError, match="Duplicate"): + FixAtoms(atom_idx=torch.tensor([0, 1, 1])) + with pytest.raises(ValueError, match="wrong number of dimensions"): + FixAtoms(atom_idx=torch.tensor([[0, 1]])) + + +def test_constraint_validation_warnings(ar_double_sim_state: ts.SimState) -> None: + """Test validation warnings for constraint conflicts.""" + with pytest.warns(UserWarning, match="Multiple constraints.*same atoms"): + validate_constraints( + [FixAtoms(atom_idx=[0, 1, 2]), FixAtoms(atom_idx=[2, 3, 4])], + ar_double_sim_state, + ) + with pytest.warns(UserWarning, match="FixCom together with other constraints"): + validate_constraints( + [FixCom([0]), FixAtoms(atom_idx=[0, 1])], ar_double_sim_state + ) + + +def test_constraint_validation_errors( + cu_sim_state: ts.SimState, + ar_supercell_sim_state: ts.SimState, +) -> None: + """Test validation errors for invalid constraints.""" + # Out of bounds + with pytest.raises(ValueError, match=r"has indices up to.*only has.*atoms"): + cu_sim_state.constraints = [FixAtoms(atom_idx=[0, 1, 100])] + + # Validation in __post_init__ + with pytest.raises(ValueError, match="Duplicate"): + ts.SimState( + positions=ar_supercell_sim_state.positions.clone(), + masses=ar_supercell_sim_state.masses, + cell=ar_supercell_sim_state.cell, + pbc=ar_supercell_sim_state.pbc, + atomic_numbers=ar_supercell_sim_state.atomic_numbers, + system_idx=ar_supercell_sim_state.system_idx, + _constraints=[FixAtoms(atom_idx=[0, 0, 1])], + ) + + +@pytest.mark.parametrize( + ("integrator", "constraint", "n_steps"), + [ + ("nve", FixAtoms(atom_idx=[0, 1]), 100), + ("nvt_nose_hoover", FixCom([0]), 200), + ("npt_langevin", FixAtoms(atom_idx=[0, 3]), 200), + ("npt_nose_hoover", FixCom([0]), 200), + ], +) +def test_integrators_with_constraints( + cu_sim_state: ts.SimState, + lj_model: LennardJonesModel, + integrator: str, + constraint: Constraint, + n_steps: int, +) -> None: + """Test all integrators respect constraints.""" + cu_sim_state.constraints = [constraint] + kT = torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature + dt = torch.tensor(0.001, dtype=DTYPE) + + # Store initial state + if isinstance(constraint, FixAtoms): + initial = cu_sim_state.positions[constraint.atom_idx].clone() + else: + initial = get_centers_of_mass( + cu_sim_state.positions, + cu_sim_state.masses, + cu_sim_state.system_idx, + cu_sim_state.n_systems, + ) + + # Run integration + if integrator == "nve": + state = ts.nve_init(cu_sim_state, lj_model, kT=kT, seed=42) + for _ in range(n_steps): + state = ts.nve_step(state, lj_model, dt=dt) + elif integrator == "nvt_nose_hoover": + state = ts.nvt_nose_hoover_init(cu_sim_state, lj_model, kT=kT, dt=dt) + for _ in range(n_steps): + state = ts.nvt_nose_hoover_step(state, lj_model, dt=dt, kT=kT) + elif integrator == "npt_langevin": + state = ts.npt_langevin_init(cu_sim_state, lj_model, kT=kT, seed=42, dt=dt) + for _ in range(n_steps): + state = ts.npt_langevin_step( + state, + lj_model, + dt=dt, + kT=kT, + external_pressure=torch.tensor(0.0, dtype=DTYPE), + ) + else: # npt_nose_hoover + state = ts.npt_nose_hoover_init(cu_sim_state, lj_model, kT=kT, dt=dt) + for _ in range(n_steps): + state = ts.npt_nose_hoover_step( + state, + lj_model, + dt=torch.tensor(0.001, dtype=DTYPE), + kT=kT, + external_pressure=torch.tensor(0.0, dtype=DTYPE), + ) + + # Verify constraint held + if isinstance(constraint, FixAtoms): + assert torch.allclose(state.positions[constraint.atom_idx], initial, atol=1e-6) + else: + final = get_centers_of_mass( + state.positions, state.masses, state.system_idx, state.n_systems + ) + assert torch.allclose(final, initial, atol=1e-5) + + +def test_multiple_constraints_and_dof( + cu_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + """Test multiple constraints together with correct DOF calculation.""" + # Test DOF calculation + n = cu_sim_state.n_atoms + assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n) + cu_sim_state.constraints = [FixAtoms(atom_idx=[0])] + assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n - 3) + cu_sim_state.constraints = [FixCom([0]), FixAtoms(atom_idx=[0])] + assert torch.all(cu_sim_state.get_number_of_degrees_of_freedom() == 3 * n - 6) + + # Verify both constraints hold during dynamics + initial_pos = cu_sim_state.positions[0].clone() + initial_com = get_centers_of_mass( + cu_sim_state.positions, + cu_sim_state.masses, + cu_sim_state.system_idx, + cu_sim_state.n_systems, + ) + state = ts.nvt_langevin_init( + cu_sim_state, + lj_model, + kT=torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature, + seed=42, + ) + for _ in range(200): + state = ts.nvt_langevin_step( + state, + lj_model, + dt=torch.tensor(0.001, dtype=DTYPE), + kT=torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature, + ) + assert torch.allclose(state.positions[0], initial_pos, atol=1e-6) + final_com = get_centers_of_mass( + state.positions, state.masses, state.system_idx, state.n_systems + ) + assert torch.allclose(final_com, initial_com, atol=1e-5) + + +@pytest.mark.parametrize( + ("cell_filter", "fire_flavor"), + [ + (ts.CellFilter.unit, "ase_fire"), + (ts.CellFilter.frechet, "ase_fire"), + (ts.CellFilter.frechet, "vv_fire"), + ], +) +def test_cell_optimization_with_constraints( + ar_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, + cell_filter: str, + fire_flavor: FireFlavor, +) -> None: + """Test cell filters work with constraints.""" + ar_supercell_sim_state.positions += ( + torch.randn_like(ar_supercell_sim_state.positions) * 0.05 + ) + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=[0, 1])] + state = ts.fire_init( + ar_supercell_sim_state, + lj_model, + cell_filter=cell_filter, + fire_flavor=fire_flavor, + ) + for _ in range(50): + state = ts.fire_step(state, lj_model, dt_max=0.1) + if state.forces.abs().max() < 0.05: + break + assert len(state.constraints) > 0 + + +def test_batched_constraints(ar_double_sim_state: ts.SimState) -> None: + """Test system-specific constraints in batched states.""" + s1, s2 = ar_double_sim_state.split() + s1.constraints = [FixAtoms(atom_idx=[0, 1])] + s2.constraints = [FixCom([0])] + combined = ts.concatenate_states([s1, s2]) + assert len(combined.constraints) == 2 + assert isinstance(combined.constraints[0], FixAtoms) + assert torch.all(combined.constraints[0].atom_idx == torch.tensor([0, 1])) + assert isinstance(combined.constraints[1], FixCom) + assert torch.all(combined.constraints[1].system_idx == torch.tensor([1])) + + +def test_constraints_with_non_pbc(lj_model: LennardJonesModel) -> None: + """Test constraints work with non-periodic boundaries.""" + state = ts.SimState( + positions=torch.tensor( + [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]], + dtype=DTYPE, + ), + masses=torch.ones(4, dtype=DTYPE) * 39.948, + cell=torch.eye(3, dtype=DTYPE).unsqueeze(0) * 10.0, + pbc=False, + atomic_numbers=torch.full((4,), 18, dtype=torch.long), + system_idx=torch.zeros(4, dtype=torch.long), + ) + state.constraints = [FixCom([0])] + initial = get_centers_of_mass( + state.positions, state.masses, state.system_idx, state.n_systems + ) + md_state = ts.nve_init(state, lj_model, kT=torch.tensor(100.0, dtype=DTYPE), seed=42) + for _ in range(100): + md_state = ts.nve_step(md_state, lj_model, dt=torch.tensor(0.001, dtype=DTYPE)) + final = get_centers_of_mass( + md_state.positions, md_state.masses, md_state.system_idx, md_state.n_systems + ) + assert torch.allclose(final, initial, atol=1e-5) + + +def test_high_level_api_with_constraints( + cu_sim_state: ts.SimState, + ar_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test high-level integrate() and optimize() APIs with constraints.""" + # Test integrate() + cu_sim_state.constraints = [FixCom([0])] + initial_com = get_centers_of_mass( + cu_sim_state.positions, + cu_sim_state.masses, + cu_sim_state.system_idx, + cu_sim_state.n_systems, + ) + final = ts.integrate( + cu_sim_state, + lj_model, + integrator=ts.Integrator.nvt_langevin, + n_steps=50, + temperature=300.0, + timestep=0.001, + ) + final_com = get_centers_of_mass( + final.positions, final.masses, final.system_idx, final.n_systems + ) + assert torch.allclose(final_com, initial_com, atol=1e-5) + + # Test optimize() + ar_supercell_sim_state.positions += ( + torch.randn_like(ar_supercell_sim_state.positions) * 0.1 + ) + ar_supercell_sim_state.constraints = [FixAtoms(atom_idx=[0, 1, 2])] + initial_pos = ar_supercell_sim_state.positions[[0, 1, 2]].clone() + final = ts.optimize( + ar_supercell_sim_state, lj_model, optimizer=ts.Optimizer.fire, max_steps=500 + ) + assert torch.allclose(final.positions[[0, 1, 2]], initial_pos, atol=1e-5) + + +def test_temperature_with_constrained_dof( + cu_sim_state: ts.SimState, lj_model: LennardJonesModel +) -> None: + """Test temperature calculation uses constrained DOF.""" + target = 300.0 + cu_sim_state.constraints = [FixAtoms(atom_idx=[0, 1, 2])] + state = ts.nvt_langevin_init( + cu_sim_state, + lj_model, + kT=torch.tensor(target, dtype=DTYPE) * MetalUnits.temperature, + seed=42, + ) + temps = [] + for _ in range(4000): + state = ts.nvt_langevin_step( + state, + lj_model, + dt=torch.tensor(0.001, dtype=DTYPE), + kT=torch.tensor(target, dtype=DTYPE) * MetalUnits.temperature, + ) + temp = state.calc_kT() + temps.append(temp / MetalUnits.temperature) + avg = torch.mean(torch.stack(temps)[500:]) + assert abs(avg - target) / target < 0.30 + + +def test_system_constraint_update_and_select() -> None: + """Test select_constraint and select_sub_constraint for SystemConstraint.""" + # Create a FixCom constraint for systems 0, 1, 2 + constraint = FixCom([0, 1, 2]) + + # Test select_constraint with system_mask + # Keep systems 0 and 2 (drop system 1) + atom_mask = torch.ones(10, dtype=torch.bool) + system_mask = torch.tensor([True, False, True], dtype=torch.bool) + updated_constraint = constraint.select_constraint(atom_mask, system_mask) + + # System indices should be renumbered: [0, 2] -> [0, 1] + assert torch.all(updated_constraint.system_idx == torch.tensor([0, 1])) + + # Test select_sub_constraint + # Select system 1 from the original constraint + constraint = FixCom([0, 1, 2]) + atom_idx = torch.arange(5, 10) # Atoms for a specific system + sys_idx = 1 + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx) + + # Should return a constraint with system_idx = [0] (renumbered from 1) + assert sub_constraint is not None + assert torch.all(sub_constraint.system_idx == torch.tensor([0])) + + # Test when system is not in constraint + constraint = FixCom([0, 2]) + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx=1) + assert sub_constraint is None + + +def test_atom_indexed_constraint_update_and_select() -> None: + """Test select_constraint and select_sub_constraint for AtomConstraint.""" + # Create a FixAtoms constraint for atoms 0, 1, 5, 8 + constraint = FixAtoms(atom_idx=[0, 1, 5, 8]) + + # Test select_constraint with atom_mask + # Keep atoms 0, 1, 2, 3, 5, 6, 7, 8 (drop atoms 4) + atom_mask = torch.tensor( + [True, True, True, True, False, True, True, True, True], dtype=torch.bool + ) + system_mask = torch.ones(2, dtype=torch.bool) + updated_constraint = constraint.select_constraint(atom_mask, system_mask) + + # Atom indices should be renumbered: + # Original: [0, 1, 5, 8] + # After dropping atom 4: [0, 1, 4, 7] (indices shift down by 1 after index 4) + assert torch.all(updated_constraint.atom_idx == torch.tensor([0, 1, 4, 7])) + + # Test select_sub_constraint + # Select atoms that belong to a specific system + constraint = FixAtoms(atom_idx=[0, 1, 5, 8]) + atom_idx = torch.tensor([0, 1, 2, 3, 4]) # Atoms for first system + sys_idx = 0 + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx) + + # Should return a constraint with only atoms 0, 1 (within atom_idx range) + # Renumbered to start from 0 + assert sub_constraint is not None + assert torch.all(sub_constraint.atom_idx == torch.tensor([0, 1])) + + # Test with different atom range + constraint = FixAtoms(atom_idx=[0, 1, 5, 8]) + atom_idx = torch.tensor([5, 6, 7, 8, 9]) # Atoms for second system + sys_idx = 1 + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx) + + # Should return a constraint with atoms 5, 8 renumbered to [0, 3] + assert sub_constraint is not None + assert torch.all(sub_constraint.atom_idx == torch.tensor([0, 3])) + + # Test when no atoms in range + constraint = FixAtoms(atom_idx=[0, 1]) + atom_idx = torch.tensor([5, 6, 7, 8]) + sub_constraint = constraint.select_sub_constraint(atom_idx, sys_idx=1) + assert sub_constraint is None + + +def test_merge_constraints(ar_double_sim_state: ts.SimState) -> None: + """Test merge_constraints combines constraints from multiple systems.""" + # Split the double system state + s1, s2 = ar_double_sim_state.split() + n_atoms_s1 = s1.n_atoms + n_atoms_s2 = s2.n_atoms + + # Create constraints for each system + # System 1: Fix atoms 0, 1 and fix COM for system 0 + s1_constraints = [ + FixAtoms(atom_idx=[0, 1]), + FixCom([0]), + ] + + # System 2: Fix atoms 2, 3 and fix COM for system 0 + s2_constraints = [ + FixAtoms(atom_idx=[2, 3]), + FixCom([0]), + ] + + # Merge constraints + constraint_lists = [s1_constraints, s2_constraints] + num_atoms_per_state = torch.tensor([n_atoms_s1, n_atoms_s2]) + merged_constraints = merge_constraints(constraint_lists, num_atoms_per_state) + + # Should have 2 constraints: one FixAtoms and one FixCom + assert len(merged_constraints) == 2 + + # Find FixAtoms and FixCom in merged list + fix_atoms = None + fix_com = None + for constraint in merged_constraints: + if isinstance(constraint, FixAtoms): + fix_atoms = constraint + elif isinstance(constraint, FixCom): + fix_com = constraint + + assert fix_atoms is not None + assert fix_com is not None + + # FixAtoms should have indices [0, 1] from s1 and [2+n_atoms_s1, 3+n_atoms_s1] from s2 + expected_atom_indices = torch.tensor([0, 1, 2 + n_atoms_s1, 3 + n_atoms_s1]) + assert torch.all(fix_atoms.atom_idx == expected_atom_indices) + + # FixCom should have system_idx [0, 1] (one for each original system) + expected_system_indices = torch.tensor([0, 1]) + assert torch.all(fix_com.system_idx == expected_system_indices) + + # Test with three systems + s3 = s1.clone() + s3_constraints = [FixAtoms(atom_idx=[0])] + constraint_lists = [s1_constraints, s2_constraints, s3_constraints] + num_atoms_per_state = torch.tensor([n_atoms_s1, n_atoms_s2, s3.n_atoms]) + merged_constraints = merge_constraints(constraint_lists, num_atoms_per_state) + + # Find FixAtoms + fix_atoms = None + for constraint in merged_constraints: + if isinstance(constraint, FixAtoms): + fix_atoms = constraint + break + + assert fix_atoms is not None + # Should include atoms from all three systems with proper offsets + expected_atom_indices = torch.tensor( + [0, 1, 2 + n_atoms_s1, 3 + n_atoms_s1, 0 + n_atoms_s1 + n_atoms_s2] + ) + assert torch.all(fix_atoms.atom_idx == expected_atom_indices) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 16565b73..fe9e8c3f 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -9,6 +9,8 @@ import torch_sim as ts import torch_sim.transforms as ft from tests.conftest import DEVICE, DTYPE +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.units import MetalUnits def test_inverse_box_scalar() -> None: @@ -1301,3 +1303,65 @@ def test_build_linked_cell_neighborhood_basic() -> None: # Verify that there are neighbors from both batches assert torch.any(system_mapping == 0) assert torch.any(system_mapping == 1) + + +def test_unwrap_positions(ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel): + n_steps = 50 + dt = torch.tensor(0.001, dtype=DTYPE) + kT = torch.tensor(300, dtype=DTYPE) * MetalUnits.temperature + + # Same cell + state = ts.nvt_langevin_init( + state=ar_double_sim_state, model=lj_model, kT=kT, seed=42 + ) + state.positions = ft.pbc_wrap_batched(state.positions, state.cell, state.system_idx) + positions = [state.positions.detach().clone()] + for _step in range(n_steps): + state = ts.nvt_langevin_step(model=lj_model, state=state, dt=dt, kT=kT) + positions.append(state.positions.detach().clone()) + + positions = torch.stack(positions) + wrapped_positions = torch.stack( + [ + ft.pbc_wrap_batched(positions, state.cell, state.system_idx) + for positions in positions + ] + ) + unwrapped_positions = ft.unwrap_positions( + wrapped_positions, + state.cell, + state.system_idx, + ) + assert torch.allclose(unwrapped_positions, positions, atol=1e-4) + + # Different cell + state = ts.npt_langevin_init( + state=ar_double_sim_state, model=lj_model, kT=kT, seed=42, dt=dt + ) + state.positions = ft.pbc_wrap_batched(state.positions, state.cell, state.system_idx) + positions = [state.positions.detach().clone()] + cells = [state.cell.detach().clone()] + for _step in range(n_steps): + state = ts.npt_langevin_step( + model=lj_model, + state=state, + dt=dt, + kT=kT, + external_pressure=torch.tensor(0.0, dtype=DTYPE, device=DEVICE), + ) + positions.append(state.positions.detach().clone()) + cells.append(state.cell.detach().clone()) + + positions = torch.stack(positions) + wrapped_positions = torch.stack( + [ + ft.pbc_wrap_batched(positions, cell, state.system_idx) + for positions, cell in zip(positions, cells, strict=True) + ] + ) + unwrapped_positions = ft.unwrap_positions( + wrapped_positions, + state.cell, + state.system_idx, + ) + assert torch.allclose(unwrapped_positions, positions, atol=1e-4) diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index 3bc1bc92..d25b32cc 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -8,6 +8,7 @@ import torch_sim as ts from torch_sim import ( autobatching, + constraints, elastic, io, math, diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py new file mode 100644 index 00000000..d352d539 --- /dev/null +++ b/torch_sim/constraints.py @@ -0,0 +1,610 @@ +"""Constraints for molecular dynamics simulations. + +This module implements constraints inspired by ASE's constraint system, +adapted for the torch-sim framework with support for batched operations +and PyTorch tensors. + +The constraints affect degrees of freedom counting and modify forces, momenta, +and positions during MD simulations. +""" + +from __future__ import annotations + +import warnings +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Self + +import torch + + +if TYPE_CHECKING: + from torch_sim.state import SimState + + +class Constraint(ABC): + """Base class for all constraints in torch-sim. + + This is the abstract base class that all constraints must inherit from. + It defines the interface that constraints must implement to work with + the torch-sim MD system. + """ + + @abstractmethod + def get_removed_dof(self, state: SimState) -> torch.Tensor: + """Get the number of degrees of freedom removed by this constraint. + + Args: + state: The simulation state + + Returns: + Number of degrees of freedom removed by this constraint + """ + + @abstractmethod + def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: + """Adjust positions to satisfy the constraint. + + This method should modify new_positions in-place to ensure the + constraint is satisfied. + + Args: + state: Current simulation state + new_positions: Proposed new positions to be adjusted + """ + + def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: + """Adjust momenta to satisfy the constraint. + + This method should modify momenta in-place to ensure the constraint + is satisfied. By default, it calls adjust_forces with the momenta. + + Args: + state: Current simulation state + momenta: Momenta to be adjusted + """ + # Default implementation: treat momenta like forces + self.adjust_forces(state, momenta) + + @abstractmethod + def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: + """Adjust forces to satisfy the constraint. + + This method should modify forces in-place to ensure the constraint + is satisfied. + + Args: + state: Current simulation state + forces: Forces to be adjusted + """ + + @abstractmethod + def select_constraint( + self, atom_mask: torch.Tensor, system_mask: torch.Tensor + ) -> None | Self: + """Update the constraint to account for atom and system masks. + + Args: + atom_mask: Boolean mask for atoms to keep + system_mask: Boolean mask for systems to keep + """ + + @abstractmethod + def select_sub_constraint(self, atom_idx: torch.Tensor, sys_idx: int) -> None | Self: + """Select a constraint for a given atom and system index. + + Args: + atom_idx: Atom indices for a single system + sys_idx: System index for a single system + + Returns: + Constraint for the given atom and system index + """ + + +def _mask_constraint_indices(idx: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + cumsum_atom_mask = torch.cumsum(~mask, dim=0) + new_indices = idx - cumsum_atom_mask[idx] + mask_indices = torch.where(mask)[0] + drop_indices = ~torch.isin(idx, mask_indices) + return new_indices[~drop_indices] + + +class AtomConstraint(Constraint): + """Base class for constraints that act on specific atom indices. + + This class provides common functionality for constraints that operate + on a subset of atoms, identified by their indices. + """ + + def __init__( + self, + atom_idx: torch.Tensor | list[int] | None = None, + atom_mask: torch.Tensor | list[int] | None = None, + ) -> None: + """Initialize indexed constraint. + + Args: + atom_idx: Indices of atoms to constrain. Can be a tensor or list of integers. + atom_mask: Boolean mask for atoms to constrain. + + Raises: + ValueError: If both indices and mask are provided, or if indices have + wrong shape/type + """ + if atom_idx is not None and atom_mask is not None: + raise ValueError("Provide either atom_idx or atom_mask, not both.") + if atom_mask is not None: + atom_mask = torch.as_tensor(atom_mask) + atom_idx = torch.where(atom_mask)[0] + + # Convert to tensor if needed + atom_idx = torch.as_tensor(atom_idx) + + # Ensure we have the right shape and type + atom_idx = torch.atleast_1d(atom_idx) + if atom_idx.ndim != 1: + raise ValueError( + "atom_idx has wrong number of dimensions. " + f"Got {atom_idx.ndim}, expected ndim <= 1" + ) + + if torch.is_floating_point(atom_idx): + raise ValueError( + f"Indices must be integers or boolean mask, not dtype={atom_idx.dtype}" + ) + + self.atom_idx = atom_idx.long() + + def get_indices(self) -> torch.Tensor: + """Get the constrained atom indices. + + Returns: + Tensor of atom indices affected by this constraint + """ + return self.atom_idx.clone() + + def select_constraint( + self, + atom_mask: torch.Tensor, + system_mask: torch.Tensor, # noqa: ARG002 + ) -> None | Self: + """Update the constraint to account for atom and system masks. + + Args: + atom_mask: Boolean mask for atoms to keep + system_mask: Boolean mask for systems to keep + """ + indices = self.atom_idx.clone() + indices = _mask_constraint_indices(indices, atom_mask) + if len(indices) == 0: + return None + return type(self)(indices) + + def select_sub_constraint( + self, + atom_idx: torch.Tensor, + sys_idx: int, # noqa: ARG002 + ) -> None | Self: + """Select a constraint for a given atom and system index. + + Args: + atom_idx: Atom indices for a single system + sys_idx: System index for a single system + """ + mask = torch.isin(self.atom_idx, atom_idx) + masked_indices = self.atom_idx[mask] + new_atom_idx = masked_indices - atom_idx.min() + if len(new_atom_idx) == 0: + return None + return type(self)(new_atom_idx) + + +class SystemConstraint(Constraint): + """Base class for constraints that act on specific system indices. + + This class provides common functionality for constraints that operate + on a subset of systems, identified by their indices. + """ + + def __init__( + self, + system_idx: torch.Tensor | list[int] | None = None, + system_mask: torch.Tensor | list[int] | None = None, + ) -> None: + """Initialize indexed constraint. + + Args: + system_idx: Indices of systems to constrain. + Can be a tensor or list of integers. + system_mask: Boolean mask for systems to constrain. + + Raises: + ValueError: If both indices and mask are provided, or if indices have + wrong shape/type + """ + if system_idx is not None and system_mask is not None: + raise ValueError("Provide either system_idx or system_mask, not both.") + if system_mask is not None: + system_idx = torch.as_tensor(system_idx) + system_idx = torch.where(system_mask)[0] + + # Convert to tensor if needed + system_idx = torch.as_tensor(system_idx) + + # Ensure we have the right shape and type + system_idx = torch.atleast_1d(system_idx) + if system_idx.ndim != 1: + raise ValueError( + "system_idx has wrong number of dimensions. " + f"Got {system_idx.ndim}, expected ndim <= 1" + ) + + # Check for duplicates + if len(system_idx) != len(torch.unique(system_idx)): + raise ValueError("Duplicate system indices found in SystemConstraint.") + + if torch.is_floating_point(system_idx): + raise ValueError( + f"Indices must be integers or boolean mask, not dtype={system_idx.dtype}" + ) + + self.system_idx = system_idx.long() + + def select_constraint( + self, + atom_mask: torch.Tensor, # noqa: ARG002 + system_mask: torch.Tensor, + ) -> None | Self: + """Update the constraint to account for atom and system masks. + + Args: + atom_mask: Boolean mask for atoms to keep + system_mask: Boolean mask for systems to keep + """ + system_idx = self.system_idx.clone() + system_idx = _mask_constraint_indices(system_idx, system_mask) + if len(system_idx) == 0: + return None + return type(self)(system_idx) + + def select_sub_constraint( + self, + atom_idx: torch.Tensor, # noqa: ARG002 + sys_idx: int, + ) -> None | Self: + """Select a constraint for a given atom and system index. + + Args: + atom_idx: Atom indices for a single system + sys_idx: System index for a single system + """ + return type(self)(torch.tensor([0])) if sys_idx in self.system_idx else None + + +def merge_constraints( + constraint_lists: list[list[AtomConstraint | SystemConstraint]], + num_atoms_per_state: torch.Tensor, +) -> list[Constraint]: + """Merge constraints from multiple systems into a single list of constraints. + + Args: + constraint_lists: List of lists of constraints + num_atoms_per_state: Number of atoms per system + + Returns: + List of merged constraints + """ + from collections import defaultdict + + cumsum_atoms = torch.cumsum(num_atoms_per_state, dim=0) - num_atoms_per_state[0] + + # aggregate updated constraint indices by constraint type + constraint_indices: dict[type[Constraint], list[torch.Tensor]] = defaultdict(list) + for i, constraint_list in enumerate(constraint_lists): + for constraint in constraint_list: + if isinstance(constraint, AtomConstraint): + idxs = constraint.atom_idx + offset = cumsum_atoms[i] + elif isinstance(constraint, SystemConstraint): + idxs = constraint.system_idx + offset = i + else: + raise NotImplementedError( + f"Constraint type {type(constraint)} is not implemented" + ) + constraint_indices[type(constraint)].append(idxs + offset) + + return [ + constraint_type(torch.cat(idxs)) + for constraint_type, idxs in constraint_indices.items() + ] + + +class FixAtoms(AtomConstraint): + """Constraint that fixes specified atoms in place. + + This constraint prevents the specified atoms from moving by: + - Resetting their positions to original values + - Setting their forces to zero + - Removing 3 degrees of freedom per fixed atom + + Examples: + Fix atoms with indices [0, 1, 2]: + >>> constraint = FixAtoms(atom_idx=[0, 1, 2]) + + Fix atoms using a boolean mask: + >>> mask = torch.tensor([True, True, True, False, False]) + >>> constraint = FixAtoms(mask=mask) + """ + + def __init__( + self, + atom_idx: torch.Tensor | list[int] | None = None, + atom_mask: torch.Tensor | list[int] | None = None, + ) -> None: + """Initialize FixAtoms constraint and check for duplicate indices.""" + super().__init__(atom_idx=atom_idx, atom_mask=atom_mask) + # Check duplicates + if len(self.atom_idx) != len(torch.unique(self.atom_idx)): + raise ValueError("Duplicate atom indices found in FixAtoms constraint.") + + def get_removed_dof(self, state: SimState) -> torch.Tensor: + """Get number of removed degrees of freedom. + + Each fixed atom removes 3 degrees of freedom (x, y, z motion). + + Args: + state: Simulation state + + Returns: + Number of degrees of freedom removed (3 * number of fixed atoms) + """ + fixed_atoms_system_idx = torch.bincount( + state.system_idx[self.atom_idx], minlength=state.n_systems + ) + return 3 * fixed_atoms_system_idx + + def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: + """Reset positions of fixed atoms to their current values. + + Args: + state: Current simulation state + new_positions: Proposed positions to be adjusted in-place + """ + new_positions[self.atom_idx] = state.positions[self.atom_idx] + + def adjust_forces( + self, + state: SimState, # noqa: ARG002 + forces: torch.Tensor, + ) -> None: + """Set forces on fixed atoms to zero. + + Args: + state: Current simulation state + forces: Forces to be adjusted in-place + """ + forces[self.atom_idx] = 0.0 + + def __repr__(self) -> str: + """String representation of the constraint.""" + if len(self.atom_idx) <= 10: + indices_str = self.atom_idx.tolist() + else: + indices_str = f"{self.atom_idx[:5].tolist()}...{self.atom_idx[-5:].tolist()}" + return f"FixAtoms(indices={indices_str})" + + +class FixCom(SystemConstraint): + """Constraint that fixes the center of mass of all atoms per system. + + This constraint prevents the center of mass from moving by: + - Adjusting positions to maintain center of mass position + - Removing center of mass velocity from momenta + - Adjusting forces to remove net force + - Removing 3 degrees of freedom (center of mass translation) + + The constraint is applied to all atoms in the system. + """ + + coms: torch.Tensor | None = None + + def get_removed_dof(self, state: SimState) -> torch.Tensor: + """Get number of removed degrees of freedom. + + Fixing center of mass removes 3 degrees of freedom (x, y, z translation). + + Args: + state: Simulation state + + Returns: + Always returns 3 (center of mass translation degrees of freedom) + """ + affected_systems = torch.zeros(state.n_systems, dtype=torch.long) + affected_systems[self.system_idx] = 1 + return 3 * affected_systems + + def adjust_positions(self, state: SimState, new_positions: torch.Tensor) -> None: + """Adjust positions to maintain center of mass position. + + Args: + state: Current simulation state + new_positions: Proposed positions to be adjusted in-place + """ + dtype = state.positions.dtype + system_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( + 0, state.system_idx, state.masses + ) + if self.coms is None: + self.coms = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 3), + state.masses.unsqueeze(-1) * state.positions, + ) + self.coms /= system_mass.unsqueeze(-1) + + new_com = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 3), + state.masses.unsqueeze(-1) * new_positions, + ) + new_com /= system_mass.unsqueeze(-1) + displacement = torch.zeros(state.n_systems, 3, dtype=dtype) + displacement[self.system_idx] = ( + -new_com[self.system_idx] + self.coms[self.system_idx] + ) + new_positions += displacement[state.system_idx] + + def adjust_momenta(self, state: SimState, momenta: torch.Tensor) -> None: + """Remove center of mass velocity from momenta. + + Args: + state: Current simulation state + momenta: Momenta to be adjusted in-place + """ + # Compute center of mass momenta + dtype = momenta.dtype + com_momenta = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 3), + momenta, + ) + system_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( + 0, state.system_idx, state.masses + ) + velocity_com = com_momenta / system_mass.unsqueeze(-1) + velocity_change = torch.zeros(state.n_systems, 3, dtype=dtype) + velocity_change[self.system_idx] = velocity_com[self.system_idx] + momenta -= velocity_change[state.system_idx] * state.masses.unsqueeze(-1) + + def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: + """Remove net force to prevent center of mass acceleration. + + This implements the constraint from Eq. (3) and (7) in + https://doi.org/10.1021/jp9722824 + + Args: + state: Current simulation state + forces: Forces to be adjusted in-place + """ + dtype = state.positions.dtype + system_square_mass = torch.zeros(state.n_systems, dtype=dtype).scatter_add_( + 0, + state.system_idx, + torch.square(state.masses), + ) + lmd = torch.zeros((state.n_systems, 3), dtype=dtype).scatter_add_( + 0, + state.system_idx.unsqueeze(-1).expand(-1, 3), + forces * state.masses.unsqueeze(-1), + ) + lmd /= system_square_mass.unsqueeze(-1) + forces_change = torch.zeros(state.n_systems, 3, dtype=dtype) + forces_change[self.system_idx] = lmd[self.system_idx] + forces -= forces_change[state.system_idx] * state.masses.unsqueeze(-1) + + def __repr__(self) -> str: + """String representation of the constraint.""" + return f"FixCom(system_idx={self.system_idx})" + + +def count_degrees_of_freedom( + state: SimState, constraints: list[Constraint] | None = None +) -> int: + """Count the total degrees of freedom in a system with constraints. + + This function calculates the total number of degrees of freedom by starting + with the unconstrained count (n_atoms * 3) and subtracting the degrees of + freedom removed by each constraint. + + Args: + state: Simulation state + constraints: List of active constraints (optional) + + Returns: + Total number of degrees of freedom + """ + # Start with unconstrained DOF + total_dof = state.n_atoms * 3 + + # Subtract DOF removed by constraints + if constraints is not None: + for constraint in constraints: + total_dof -= constraint.get_removed_dof(state) + + return max(0, total_dof) # Ensure non-negative + + +def check_no_index_out_of_bounds( + indices: torch.Tensor, max_state_indices: int, constraint_name: str +) -> None: + """Check that constraint indices are within bounds of the state.""" + if (len(indices) > 0) and (indices.max() >= max_state_indices): + raise ValueError( + f"Constraint {constraint_name} has indices up to " + f"{indices.max()}, but state only has {max_state_indices} " + "atoms" + ) + + +def validate_constraints(constraints: list[Constraint], state: SimState) -> None: + """Validate constraints for potential issues and incompatibilities. + + This function checks for: + 1. Overlapping atom indices across multiple constraints + 2. AtomConstraints spanning multiple systems (requires state) + 3. Mixing FixCom with other constraints (warning only) + + Args: + constraints: List of constraints to validate + state: SimState to check against + + Raises: + ValueError: If constraints are invalid or span multiple systems + + Warns: + UserWarning: If constraints may lead to unexpected behavior + """ + if not constraints: + return + + indexed_constraints = [] + has_com_constraint = False + + for constraint in constraints: + if isinstance(constraint, AtomConstraint): + indexed_constraints.append(constraint) + + # Validate that atom indices exist in state if provided + check_no_index_out_of_bounds( + constraint.atom_idx, state.n_atoms, type(constraint).__name__ + ) + elif isinstance(constraint, SystemConstraint): + check_no_index_out_of_bounds( + constraint.system_idx, state.n_systems, type(constraint).__name__ + ) + + if isinstance(constraint, FixCom): + has_com_constraint = True + + # Check for overlapping atom indices + if len(indexed_constraints) > 1: + all_indices = torch.cat([c.atom_idx for c in indexed_constraints]) + unique_indices = torch.unique(all_indices) + if len(unique_indices) < len(all_indices): + warnings.warn( + "Multiple constraints are acting on the same atoms. " + "This may lead to unexpected behavior.", + UserWarning, + stacklevel=3, + ) + + # Warn about COM constraint with fixed atoms + if has_com_constraint and indexed_constraints: + warnings.warn( + "Using FixCom together with other constraints may lead to " + "unexpected behavior. The center of mass constraint is applied " + "to all atoms, including those that may be constrained by other means.", + UserWarning, + stacklevel=3, + ) diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index b03fb7a3..2de04ed5 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -5,9 +5,8 @@ import torch -from torch_sim import transforms from torch_sim.models.interface import ModelInterface -from torch_sim.quantities import calc_temperature +from torch_sim.quantities import calc_kT from torch_sim.state import SimState from torch_sim.units import MetalUnits @@ -57,6 +56,12 @@ def velocities(self) -> torch.Tensor: """ return self.momenta / self.masses.unsqueeze(-1) + def set_momenta(self, new_momenta: torch.Tensor) -> None: + """Set new momenta, applying any constraints as needed.""" + for constraint in self.constraints: + constraint.adjust_momenta(self, new_momenta) + self.momenta = new_momenta + def calc_temperature( self, units: MetalUnits = MetalUnits.temperature ) -> torch.Tensor: @@ -68,12 +73,19 @@ def calc_temperature( Returns: torch.Tensor: Calculated temperature """ - return calc_temperature( + return self.calc_kT() / units.temperature + + def calc_kT(self) -> torch.Tensor: # noqa: N802 + """Calculate kT from momenta, masses, and system indices. + + Returns: + torch.Tensor: Calculated kT in energy units + """ + return calc_kT( masses=self.masses, momenta=self.momenta, system_idx=self.system_idx, dof_per_system=self.get_number_of_degrees_of_freedom(), - units=units, ) @@ -154,7 +166,7 @@ def momentum_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """ new_momenta = state.momenta + state.forces * dt - state.momenta = new_momenta + state.set_momenta(new_momenta) return state @@ -174,17 +186,7 @@ def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """ new_positions = state.positions + state.velocities * dt - - if state.pbc.any(): - # Split positions and cells by system - new_positions = transforms.pbc_wrap_batched( - new_positions, - state.cell, - state.system_idx, - state.pbc, - ) - - state.positions = new_positions + state.set_positions(new_positions) return state diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index db6e2b15..388a4427 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -1,5 +1,6 @@ """Implementations of NPT integrators.""" +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -364,14 +365,7 @@ def _npt_langevin_position_step( ) # Update positions with all contributions - state.positions = c_1 + c_2.unsqueeze(-1) * c_3 - - # Apply periodic boundary conditions if needed - if state.pbc.any(): - state.positions = ts.transforms.pbc_wrap_batched( - state.positions, state.cell, state.system_idx, state.pbc - ) - + state.set_positions(c_1 + c_2.unsqueeze(-1) * c_3) return state @@ -435,7 +429,8 @@ def _npt_langevin_velocity_step( # Update momenta (velocities * masses) with all contributions new_velocities = c_1 + c_2 + c_3 - state.momenta = new_velocities * state.masses.unsqueeze(-1) + # Apply constraints. + state.set_momenta(new_velocities * state.masses.unsqueeze(-1)) return state @@ -565,6 +560,9 @@ def npt_langevin_init( kT = torch.as_tensor(kT, device=device, dtype=dtype) dt = torch.as_tensor(dt, device=device, dtype=dtype) + if not isinstance(state, SimState): + state = SimState(**state) + if alpha.ndim == 0: alpha = alpha.expand(state.n_systems) if cell_alpha.ndim == 0: @@ -572,9 +570,6 @@ def npt_langevin_init( if b_tau.ndim == 0: b_tau = b_tau.expand(state.n_systems) - if not isinstance(state, SimState): - state = SimState(**state) - # Get model output to initialize forces and stress model_output = model(state) @@ -606,6 +601,16 @@ def npt_langevin_init( ) cell_masses = (n_atoms_per_system + 1) * batch_kT * b_tau * b_tau + if state.constraints: + # warn if constraints are present + warnings.warn( + "Constraints are present in the system. " + "Make sure they are compatible with NPT Langevin dynamics." + "We recommend not using constraints with NPT dynamics for now.", + UserWarning, + stacklevel=3, + ) + # Create the initial state return NPTLangevinState( positions=state.positions, @@ -625,6 +630,7 @@ def npt_langevin_init( cell_velocities=cell_velocities, cell_masses=cell_masses, cell_alpha=cell_alpha, + _constraints=state.constraints, ) @@ -1027,14 +1033,7 @@ def _npt_nose_hoover_exp_iL1( # noqa: N802 state.positions * (torch.exp(x_expanded) - 1) + dt * velocities * torch.exp(x_2_expanded) * sinh_expanded ) - new_positions = state.positions + new_positions - - # Apply periodic boundary conditions if needed - if state.pbc.any(): - return ts.transforms.pbc_wrap_batched( - new_positions, state.current_cell, state.system_idx, pbc=state.pbc - ) - return new_positions + return state.positions + new_positions def _npt_nose_hoover_exp_iL2( # noqa: N802 @@ -1244,7 +1243,7 @@ def _npt_nose_hoover_inner_step( # Update particle positions and forces positions = _npt_nose_hoover_exp_iL1(state, state.velocities, cell_velocities, dt) - state.positions = positions + state.set_positions(positions) state.cell = cell model_output = model(state) @@ -1265,8 +1264,8 @@ def _npt_nose_hoover_inner_step( cell_momentum = cell_momentum + dt_2 * cell_force_val.unsqueeze(-1) # Return updated state - state.positions = positions - state.momenta = momenta + state.set_positions(positions) + state.set_momenta(momenta) state.forces = model_output["forces"] state.energy = model_output["energy"] state.cell_position = cell_position @@ -1411,6 +1410,16 @@ def npt_nose_hoover_init( forces = model_output["forces"] energy = model_output["energy"] + if state.constraints: + # warn if constraints are present + warnings.warn( + "Constraints are present in the system. " + "Make sure they are compatible with NPT Nosé Hoover dynamics." + "We recommend not using constraints with NPT dynamics for now.", + UserWarning, + stacklevel=3, + ) + # Create initial state return NPTNoseHooverState( positions=state.positions, @@ -1430,6 +1439,7 @@ def npt_nose_hoover_init( thermostat=thermostat_fns.initialize(dof_per_system, KE_thermostat, kT), barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, + _constraints=state.constraints, ) diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index d3773b3c..b4db4e6c 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -67,6 +67,7 @@ def nve_init( pbc=state.pbc, system_idx=state.system_idx, atomic_numbers=state.atomic_numbers, + _constraints=state.constraints, ) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index e04984b4..1bcfe2cb 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -72,7 +72,7 @@ def _ou_step( c1.unsqueeze(-1) * state.momenta + c2 * torch.sqrt(state.masses).unsqueeze(-1) * noise ) - state.momenta = new_momenta + state.set_momenta(new_momenta) return state @@ -118,7 +118,6 @@ def nvt_langevin_init( "momenta", calculate_momenta(state.positions, state.masses, state.system_idx, kT, seed), ) - return MDState( positions=state.positions, momenta=momenta, @@ -129,6 +128,7 @@ def nvt_langevin_init( pbc=state.pbc, system_idx=state.system_idx, atomic_numbers=state.atomic_numbers, + _constraints=state.constraints, ) @@ -328,6 +328,7 @@ def nvt_nose_hoover_init( system_idx=state.system_idx, chain=chain_fns.initialize(dof_per_system, KE, kT), _chain_fns=chain_fns, # Store the chain functions + _constraints=state.constraints, ) @@ -372,7 +373,7 @@ def nvt_nose_hoover_step( # First half-step of chain evolution momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx) - state.momenta = momenta + state.set_momenta(momenta) # Full velocity Verlet step state = velocity_verlet(state=state, dt=dt, model=model) @@ -385,7 +386,7 @@ def nvt_nose_hoover_step( # Second half-step of chain evolution momenta, chain = chain_fns.half_step(state.momenta, chain, kT, state.system_idx) - state.momenta = momenta + state.set_momenta(momenta) state.chain = chain return state diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index f75890a6..3e0ef47d 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -214,6 +214,7 @@ def swap_mc_init( system_idx=state.system_idx, energy=model_output["energy"], last_permutation=torch.arange(state.n_atoms, device=state.device), + _constraints=state.constraints, ) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 0a689432..9163fa71 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -80,6 +80,7 @@ def fire_init( "cell": state.cell.clone(), "atomic_numbers": state.atomic_numbers.clone(), "system_idx": state.system_idx.clone(), + "_constraints": state.constraints, "pbc": state.pbc, # Optimization state "forces": forces, @@ -211,13 +212,13 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) # Position update - state.positions = state.positions + atom_wise_dt * state.velocities + state.set_positions(state.positions + atom_wise_dt * state.velocities) # Cell position updates are handled in the velocity update step above # Get new forces and energy model_output = model(state) - state.forces = model_output["forces"] + state.set_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] @@ -419,7 +420,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 cur_deform_grad = cell_filters.deform_grad( state.reference_cell.mT, state.row_vector_cell ) - state.positions = ( + state.set_positions( torch.linalg.solve( cur_deform_grad[state.system_idx], state.positions.unsqueeze(-1) ).squeeze(-1) @@ -454,16 +455,18 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 new_deform_grad = cell_filters.deform_grad( state.reference_cell.mT, state.row_vector_cell ) - state.positions = torch.bmm( - state.positions.unsqueeze(1), - new_deform_grad[state.system_idx].transpose(-2, -1), - ).squeeze(1) + state.set_positions( + torch.bmm( + state.positions.unsqueeze(1), + new_deform_grad[state.system_idx].transpose(-2, -1), + ).squeeze(1) + ) else: - state.positions = state.positions + dr_atom + state.set_positions(state.positions + dr_atom) # Get new forces, energy, and stress model_output = model(state) - state.forces = model_output["forces"] + state.set_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index bfdfcf3f..4a563b86 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -61,6 +61,7 @@ def gradient_descent_init( "pbc": state.pbc, "atomic_numbers": state.atomic_numbers, "system_idx": state.system_idx, + "_constraints": state.constraints, } if cell_filter is not None: # Create cell optimization state @@ -107,7 +108,7 @@ def gradient_descent_step( atom_lr = pos_lr[state.system_idx].unsqueeze(-1) # Update atomic positions - state.positions = state.positions + atom_lr * state.forces + state.set_positions(state.positions + atom_lr * state.forces) # Update cell if using cell optimization if isinstance(state, CellOptimState): @@ -117,7 +118,7 @@ def gradient_descent_step( # Get updated forces, energy, and stress model_output = model(state) - state.forces = model_output["forces"] + state.set_forces(model_output["forces"]) state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index 2ab530db..bd652857 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -23,6 +23,16 @@ class OptimState(SimState): _atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001 _system_attributes = SimState._system_attributes | {"energy", "stress"} # noqa: SLF001 + def set_forces(self, new_forces: torch.Tensor) -> None: + """Set new forces in the optimization state.""" + for constraint in self._constraints: + constraint.adjust_forces(self, new_forces) + self.forces = new_forces + + def __post_init__(self) -> None: + """Post-initialization to ensure SimState setup.""" + self.set_forces(self.forces) + @dataclass(kw_only=True) class FireState(OptimState): diff --git a/torch_sim/quantities.py b/torch_sim/quantities.py index 35f7d6f0..c7b31e5b 100644 --- a/torch_sim/quantities.py +++ b/torch_sim/quantities.py @@ -57,7 +57,7 @@ def calc_kT( # noqa: N802 # Count degrees of freedom per system system_sizes = torch.bincount(system_idx) if dof_per_system is None: - dof_per_system = system_sizes * squared_term.shape[-1] + dof_per_system = system_sizes * squared_term.shape[-1] # multiply by n_dimensions # Calculate temperature per system system_sums = torch.segment_reduce( diff --git a/torch_sim/runners.py b/torch_sim/runners.py index b2059aac..43c33db3 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -165,7 +165,6 @@ def integrate[T: SimState]( # noqa: C901 f"integrator must be key from Integrator or a tuple of " f"(init_func, step_func), got {type(integrator)}" ) - # batch_iterator will be a list if autobatcher is False batch_iterator = _configure_batches_iterator( initial_state, model, autobatcher=autobatcher diff --git a/torch_sim/state.py b/torch_sim/state.py index 813354fe..1d14b1ee 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -23,6 +23,8 @@ from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure +from torch_sim.constraints import Constraint, merge_constraints, validate_constraints + @dataclass class SimState: @@ -53,6 +55,8 @@ class SimState: atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) system_idx (torch.Tensor): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. + constraints (list["Constraint"] | None): List of constraints applied to the + system. Constraints affect degrees of freedom and modify positions. Properties: wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary @@ -85,6 +89,7 @@ class SimState: pbc: torch.Tensor | list[bool] | bool atomic_numbers: torch.Tensor system_idx: torch.Tensor | None = field(default=None) + _constraints: list["Constraint"] = field(default_factory=lambda: []) # noqa: PIE807 if TYPE_CHECKING: @@ -107,7 +112,7 @@ def pbc(self) -> torch.Tensor: _system_attributes: ClassVar[set[str]] = {"cell"} _global_attributes: ClassVar[set[str]] = {"pbc"} - def __post_init__(self) -> None: + def __post_init__(self) -> None: # noqa: C901 """Initialize the SimState and validate the arguments.""" # Check that positions, masses and atomic numbers have compatible shapes shapes = [ @@ -136,6 +141,9 @@ def __post_init__(self) -> None: if not torch.all(counts == torch.bincount(initial_system_idx)): raise ValueError("System indices must be unique consecutive integers") + if self.constraints: + validate_constraints(self.constraints, state=self) + if self.cell.ndim != 3 and initial_system_idx is None: self.cell = self.cell.unsqueeze(0) @@ -208,6 +216,7 @@ def attributes(self) -> dict[str, torch.Tensor]: for attr in self._atom_attributes | self._system_attributes | self._global_attributes + | {"_constraints"} } @property @@ -238,6 +247,46 @@ def row_vector_cell(self, value: torch.Tensor) -> None: """ self.cell = value.mT + def set_positions(self, new_positions: torch.Tensor) -> None: + """Set the positions and apply constraints if they exist. + + Args: + new_positions: New positions tensor with shape (n_atoms, 3) + """ + # Apply constraints if they exist + for constraint in self.constraints: + constraint.adjust_positions(self, new_positions) + self.positions = new_positions + + @property + def constraints(self) -> list[Constraint]: + """Get the constraints for the SimState. + + Returns: + list["Constraint"]: List of constraints applied to the system. + """ + return self._constraints + + @constraints.setter + def constraints(self, constraints: list[Constraint] | Constraint) -> None: + """Set the constraints for the SimState. + + Args: + constraints (list["Constraint"] | None): List of constraints to apply. + If None, no constraints are applied. + + Raises: + ValueError: If constraints are invalid or span multiple systems + """ + # check it is a list + if isinstance(constraints, Constraint): + constraints = [constraints] + + # Validate new constraints before adding + validate_constraints(constraints, state=self) + + self._constraints = constraints + def get_number_of_degrees_of_freedom(self) -> torch.Tensor: """Calculate degrees of freedom accounting for constraints. @@ -247,7 +296,18 @@ def get_number_of_degrees_of_freedom(self) -> torch.Tensor: of freedom, minus any degrees removed by constraints. """ # Start with unconstrained DOF: 3 degrees per atom - return 3 * self.n_atoms_per_system + dof_per_system = 3 * self.n_atoms_per_system + + # Subtract DOF removed by constraints + if self.constraints is not None: + for constraint in self.constraints: + removed_dof = constraint.get_removed_dof(self) + dof_per_system -= removed_dof + + # Ensure non-negative DOF + if (dof_per_system <= 0).any(): + raise ValueError("Degrees of freedom cannot be zero or negative") + return dof_per_system def clone(self) -> Self: """Create a deep copy of the SimState. @@ -613,7 +673,7 @@ def _state_to_device[T: SimState]( attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) - return type(state)(**attrs) # type: ignore[invalid-return-type] + return type(state)(**attrs) def get_attrs_for_scope( @@ -664,6 +724,18 @@ def _filter_attrs_by_mask( # Copy global attributes directly filtered_attrs = dict(get_attrs_for_scope(state, "global")) + # take into account constraints that are AtomConstraint + filtered_attrs["_constraints"] = [ + constraint.select_constraint(atom_mask, system_mask) + for constraint in copy.deepcopy(state.constraints) + ] + # Remove any None constraints resulting from selection + filtered_attrs["_constraints"] = [ + constraint + for constraint in filtered_attrs["_constraints"] + if constraint is not None + ] + # Filter per-atom attributes for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name == "system_idx": @@ -685,6 +757,7 @@ def _filter_attrs_by_mask( dtype=attr_value.dtype, ) filtered_attrs[attr_name] = new_system_idxs + else: filtered_attrs[attr_name] = attr_value[atom_mask] @@ -711,7 +784,7 @@ def _split_state[T: SimState](state: T) -> list[T]: list[SimState]: A list of SimState objects, each containing a single system """ - system_sizes = torch.bincount(state.system_idx).tolist() + system_sizes = state.n_atoms_per_system.tolist() split_per_atom = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): @@ -730,6 +803,8 @@ def _split_state[T: SimState](state: T) -> list[T]: # Create a state for each system states: list[T] = [] n_systems = len(system_sizes) + zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64) + cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0))) for sys_idx in range(n_systems): system_attrs = { # Create a system tensor with all zeros for this system @@ -749,6 +824,15 @@ def _split_state[T: SimState](state: T) -> list[T]: # Add the global attributes **global_attrs, } + + atom_idx = torch.arange(cumsum_atoms[sys_idx], cumsum_atoms[sys_idx + 1]) + new_constraints = [ + new_constraint + for constraint in state.constraints + if (new_constraint := constraint.select_sub_constraint(atom_idx, sys_idx)) + ] + + system_attrs["_constraints"] = new_constraints states.append(type(state)(**system_attrs)) # type: ignore[invalid-argument-type] return states @@ -881,6 +965,7 @@ def concatenate_states[T: SimState]( # noqa: C901 per_system_tensors = defaultdict(list) new_system_indices = [] system_offset = 0 + num_atoms_per_state = [] # Process all states in a single pass for state in states: @@ -903,6 +988,8 @@ def concatenate_states[T: SimState]( # noqa: C901 num_systems = state.n_systems new_indices = state.system_idx + system_offset new_system_indices.append(new_indices) + num_atoms_per_state.append(state.n_atoms) + system_offset += num_systems # Concatenate collected tensors @@ -920,8 +1007,14 @@ def concatenate_states[T: SimState]( # noqa: C901 # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) + # Merge constraints + constraint_lists = [state.constraints for state in states] + constraints = merge_constraints( + constraint_lists, torch.tensor(num_atoms_per_state, device=target_device) + ) + # Create a new instance of the same class - return state_class(**concatenated) + return state_class(**concatenated, _constraints=constraints) def initialize_state( diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 2ab4ab2e..28acb977 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -1175,3 +1175,97 @@ def safe_mask( """ masked = torch.where(mask, operand, torch.zeros_like(operand)) return torch.where(mask, fn(masked), torch.full_like(operand, placeholder)) + + +def unwrap_positions( + positions: torch.Tensor, cells: torch.Tensor, system_idx: torch.Tensor +) -> torch.Tensor: + """Vectorized unwrapping for multiple systems without explicit loops. + + Parameters + ---------- + positions : (T, N_tot, 3) + Wrapped cartesian positions for all systems concatenated. + cells : (n_systems, 3, 3) or (T, n_systems, 3, 3) + Box matrices, constant or time-dependent. + system_idx : (N_tot,) + For each atom, which system it belongs to (0..n_systems-1). + + Returns: + ------- + unwrapped_pos : (T, N_tot, 3) + Unwrapped cartesian positions. + """ + # -- Constant boxes per system + if cells.ndim == 3: + inv_box = torch.inverse(cells) # (n_systems, 3, 3) + + # Map each atom to its system's box + inv_box_atoms = inv_box[system_idx] # (N, 3, 3) + box_atoms = cells[system_idx] # (N, 3, 3) + + # Compute fractional coordinates + frac = torch.einsum("tni,nij->tnj", positions, inv_box_atoms) + + # Fractional displacements and unwrap + dfrac = frac[1:] - frac[:-1] + dfrac -= torch.round(dfrac) + + # Back to Cartesian + dcart = torch.einsum("tni,nij->tnj", dfrac, box_atoms) + + # -- Time-dependent boxes per system + elif cells.ndim == 4: + inv_box = torch.inverse(cells) # (T, n_systems, 3, 3) + + # Gather each atom's box per frame efficiently + inv_box_atoms = inv_box[:, system_idx] # (T, N, 3, 3) + box_atoms = cells[:, system_idx] # (T, N, 3, 3) + + # Compute fractional coordinates per frame + frac = torch.einsum("tni,tnij->tnj", positions, inv_box_atoms) + + dfrac = frac[1:] - frac[:-1] + dfrac -= torch.round(dfrac) + + dcart = torch.einsum("tni,tnij->tnj", dfrac, box_atoms[:-1]) + + else: + raise ValueError("box must have shape (n_systems,3,3) or (T,n_systems,3,3)") + + # Cumulative reconstruction + unwrapped = torch.empty_like(positions) + unwrapped[0] = positions[0] + unwrapped[1:] = torch.cumsum(dcart, dim=0) + unwrapped[0] + + return unwrapped + + +def get_centers_of_mass( + positions: torch.Tensor, + masses: torch.Tensor, + system_idx: torch.Tensor, + n_systems: int, +) -> torch.Tensor: + """Compute the centers of mass for each structure in the simulation state.s. + + Args: + positions (torch.Tensor): Atomic positions of shape (N, 3). + masses (torch.Tensor): Atomic masses of shape (N,). + system_idx (torch.Tensor): System indices for each atom of shape (N,). + n_systems (int): Total number of systems. + + Returns: + torch.Tensor: A tensor of shape (n_structures, 3) containing + the center of mass coordinates for each structure. + """ + coms = torch.zeros((n_systems, 3), dtype=positions.dtype).scatter_add_( + 0, + system_idx.unsqueeze(-1).expand(-1, 3), + masses.unsqueeze(-1) * positions, + ) + system_masses = torch.zeros((n_systems,), dtype=positions.dtype).scatter_add_( + 0, system_idx, masses + ) + coms /= system_masses.unsqueeze(-1) + return coms