Skip to content

Conversation

@thomasloux
Copy link
Collaborator

Summary

First draft for constraints. Inspiration from ASE.
This is a rather big feature, I prefer to post a first mvp so that we can discuss some implementation and make sure constraints are compatible with Integrators and Optimizers.

Codes implementing constraints

  • JAX-MD: no
  • ASE: yes
  • LAMMPS: yes, but difficult to read (C++) and copy.
  • OpenMM: yes, but only bonds lengths or angles as far as I understand
    For the rest I did not check this much.

Implementation:

  • Constraints: FixAtoms, FixCom
    Change in the code:
  • Add an optional constraints variable in SimState
  • Turn every subclass dataclass(kw_only=True) to prevent errors (positional argument before non pos argument)
  • Adapt MD only, via setter (set_momenta, set_positions)

What's working perfectly

  • Contraints tests implemented
  • All tests pass
  • FixAtoms really fixes the atoms concerned

Future work

  • Adapt optimizer
  • Make sure that FixCom is working well. The Com drift is small, but I would expect an even smaller value. ASE seems to have a smaller value from first tests. I did not yet reproduce their tests as it is based on an optimization task.
  • Implement other constraints
  • make sure everything is compatible with Integrators and Optimizers. Essentially we should apply RATTLE algorithm.
  • Adapt the calculation of temperature to take into account the reduction of degrees of freedom.
  • For FixCom, one wants the center of mass to be fixed in unwrap coordinate. So of course, we need to check after unwrapping. And as a result, I now affect set_momenta, then apply the wrapping on the systems. Actually ASE does not wrap during the simulations, but the wrapping is applied from each calculator forward pass. I think that it is the case for most of TorchSim model implementations. So we may want to stop wrapping during the simulations. Of course if always possible to wrap afterwards.

Implementation discussions

  • I wanted first to have an hidden variable _positions and have a implicit setter and getter. But this is very painful to adapt with _atom_attributes and the copy of state sometimes performed directly accessing the variables using vars(state). At least it is the case if I set the init function to accept an positions argument and not a _positions argument. Actually I think that the setter is clearer to indicate that the constraints are imposed.
  • I decided to add constraints as a global attribute, but prevent list of constraint and looping over the constraints.
  • The disavantage: It's rather difficult to write the constraints and define the constraints over the batch system. For FixAtoms you need to provide the index of the atoms in the batch system. Probably we want to define actually for each system and then batch the constraint. For FixCom, we need to compute the COM in an efficient, which is not super easy to read. It will be even worse if one wants to fix the com of a subgroup in a batch systems.

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
  • Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

We highly recommended installing the prek hooks running in CI locally to speedup the development process. Simply run pip install prek && prek install to install the hooks which will check your code before each commit.

Copy link
Collaborator

@curtischong curtischong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good!

I like how the abstract methods like adjust_positions are flexible enough for constraining specific indices/atoms and for constraining specific axes of movement (which'll probably be added later)



@dataclass
@dataclass(kw_only=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to tag this PR with breaking bc of this. I think this is a reasonable change though

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm fine making state initialization kw only, you'd be crazy to do it with args

Copy link
Collaborator

@orionarcher orionarcher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice job with this @thomasloux! I think this is most of the way there.

I have a few scattered comments but mostly I have thoughts on the API.

So we may want to stop wrapping during the simulations.

Agreed, this is a longstanding issue noted in #17.

I decided to add constraints as a global attribute, but prevent list of constraint and looping over the constraints.

I think a global attribute is the right approach. It doesn't make sense to concatenate or stack the constraints so atom and system attributes don't make sense. That said I think we'll need to add in some special handling for concatenate and split operations to make sure this doesn't break state handling.

The disavantage: It's rather difficult to write the constraints and define the constraints over the batch system. For FixAtoms you need to provide the index of the atoms in the batch system.

I think the usage pattern should be to define constraints for all the singular states and then concantenate them together, letting TorchSim handle making sure the indices are properly updated. This would also mean only allowing one constraint of each type on a given SimState. We could also imagine letting folks set constraints in ASE and then automatically porting them over.

For FixCom, we need to compute the COM in an efficient, which is not super easy to read. It will be even worse if one wants to fix the com of a subgroup in a batch systems.

Yeah this is a bit of a headache. I'll think on this too.

Broadly, I think the design decisions here are solid and just need a bit more buildout to get to a good API. In particular, I think we'll need a few more additions to this PR:

  1. Constraints update when state is modified. _split_state, _pop_states, _slice_state, and concatenate_states will need to be modified to correctly adjust the indices of the constraints when the state is mutated.
  2. Autovalidation of constraints. I think there should be a validate_constraints method that makes sure there are no overlapping constraints and that all contraints operate within a single system idx (not across multiple batches). This can be called both in the post_init method of SimState and in a set_constraint method if we add one.

I'm currently traveling but am happy to pair on this PR when I am back (Monday) to help with both of the above. This has been on my todo list for a while.

from torch_sim.state import SimState


class FixConstraint(ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is going to be the base class for all constraints I'd favor naming it Constraint



@dataclass
@dataclass(kw_only=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm fine making state initialization kw only, you'd be crazy to do it with args

Comment on lines +1162 to +1164
def unwrap_positions(
pos: torch.Tensor, box: torch.Tensor, system_idx: torch.Tensor
) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just used in the tests so I'd suggest moving it into the testing file and making it a private method. Unwrapping coordinates is quite tricky to do right, and I don't want to imply we've done it perfectly by adding it to the public API.

Copy link
Collaborator Author

@thomasloux thomasloux Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now I can do that, but I consider that this implement should be right, at least assuming that the displacement at each step is small enough

@orionarcher orionarcher added breaking Breaking changes feature Entirely new features, not improvements to existing ones labels Oct 21, 2025
@thomasloux
Copy link
Collaborator Author

thomasloux commented Oct 21, 2025

@orionarcher thanks for the feedback. I'm starting to adapt for _split_state, _pop_states, _slice_state, and concatenate_states. _filter_attrs_by_mask seems to be slightly too general. Actually as it is accepting both an atom_maskand system_mask, although actually you probably expect only to depend on system_mask. I would suggest to determine the atom_mask in the function. Even if it is slightly less efficient.

EDIT: or find a way to be sure that atom_mask always correspond to the system_mask

@thomasloux
Copy link
Collaborator Author

Hey, I've changed quite a lot the API of the constraints. Actually I've almost reproduced the equivalent of atom_attribute or system_attribute at the level of constraints. The code is still a draft, they are probably better way reorganize the operation, especially for manipulation of states (pop, split, concatenante) which are probably difficult to read and too long.

  • Now support states manipulations
  • I removed wrap
  • Add a fix reference in FixCom (set at the first call)
  • Test for constraints for optimizers
  • Modify calc_kt to support new degrees of freedom

Regarding SystemConstraint, I need to implement a default slice(None) so that the constraint can function in the case the final user directly sets constraint (state.constraint = [FixCom()]). We could accept not to support this.

Also I realize the general_attribute is only supported for now for pbc, which should change soon. We should have constraint in its own category as it's not even propagate in similar ways as the rest of the attributes. This would also allow to have a private _constraint variable, and support the previous statement not to set directly constraint.

More than open for suggestion how to best add constraints.

By the way, in the meantime, we may want to have a dedicate branch so that people can try (a lot of demand for optimization #114). It's probably better to do before having extensive tests and proofs that our optimizers and integrators are compatible with the constraints.

Regarding calc_kt, I suggest to make it a MDState method (like in ASE). So that the user does not forget to add degrees of freedom.

I did not add for now a check of incompatible constraints for now, not sure it is relevant. For instance you may want in the future to fix the length of all atoms in a chain (think of a polymer). Then the index of an atom will appear multiple times.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

breaking Breaking changes feature Entirely new features, not improvements to existing ones

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants