From 99933ab6f69cd18efb6e12c2215bea3e23f321f5 Mon Sep 17 00:00:00 2001 From: CalCraven Date: Fri, 12 Sep 2025 14:33:18 -0400 Subject: [PATCH 01/11] Add bond/angle wildcard handling --- gmso/core/forcefield.py | 54 +++++++++++++++++++ gmso/tests/files/alkanes_wildcards.xml | 54 +++++-------------- ...opls_gmso.py => test_parameterizations.py} | 22 ++++++++ 3 files changed, 89 insertions(+), 41 deletions(-) rename gmso/tests/parameterization/{test_opls_gmso.py => test_parameterizations.py} (65%) diff --git a/gmso/core/forcefield.py b/gmso/core/forcefield.py index 9316abe24..f7a7bf048 100644 --- a/gmso/core/forcefield.py +++ b/gmso/core/forcefield.py @@ -351,6 +351,33 @@ def _get_bond_type(self, atom_types, return_match_order=False, warn=False): if reverse in self.bond_types: match = self.bond_types[reverse], (1, 0) + if match: + if return_match_order: + return match + else: + return match[0] + + for i in range(1, 3): + forward_patterns = mask_with(atom_types, i) + reverse_patterns = mask_with(reversed(atom_types), i) + + for forward_pattern, reverse_pattern in zip( + forward_patterns, reverse_patterns + ): + forward_match_key = FF_TOKENS_SEPARATOR.join(forward_pattern) + reverse_match_key = FF_TOKENS_SEPARATOR.join(reverse_pattern) + + if forward_match_key in self.bond_types: + match = self.bond_types[forward_match_key], (0, 1) + break + + if reverse_match_key in self.bond_types: + match = self.bond_types[reverse_match_key], (2, 1) + break + + if match: + break + msg = ( f"BondType between atoms {atom_types[0]} and {atom_types[1]} " f"is missing from the ForceField" @@ -382,6 +409,33 @@ def _get_angle_type(self, atom_types, return_match_order=False, warn=False): if reverse in self.angle_types: match = self.angle_types[reverse], (2, 1, 0) + if match: + if return_match_order: + return match + else: + return match[0] + + for i in range(1, 4): + forward_patterns = mask_with(atom_types, i) + reverse_patterns = mask_with(reversed(atom_types), i) + + for forward_pattern, reverse_pattern in zip( + forward_patterns, reverse_patterns + ): + forward_match_key = FF_TOKENS_SEPARATOR.join(forward_pattern) + reverse_match_key = FF_TOKENS_SEPARATOR.join(reverse_pattern) + + if forward_match_key in self.angle_types: + match = self.angle_types[forward_match_key], (0, 1, 2) + break + + if reverse_match_key in self.angle_types: + match = self.angle_types[reverse_match_key], (2, 1, 0) + break + + if match: + break + msg = ( f"AngleType between atoms {atom_types[0]}, {atom_types[1]} " f"and {atom_types[2]} is missing from the ForceField" diff --git a/gmso/tests/files/alkanes_wildcards.xml b/gmso/tests/files/alkanes_wildcards.xml index 7cd7f3f2b..c7fed57e5 100644 --- a/gmso/tests/files/alkanes_wildcards.xml +++ b/gmso/tests/files/alkanes_wildcards.xml @@ -46,38 +46,20 @@ - + - - - - - - - - + + - - - - - - - + - - - - - - - - + + @@ -88,24 +70,14 @@ - - - - - - - - - - - + - - - - - - + + + + + + diff --git a/gmso/tests/parameterization/test_opls_gmso.py b/gmso/tests/parameterization/test_parameterizations.py similarity index 65% rename from gmso/tests/parameterization/test_opls_gmso.py rename to gmso/tests/parameterization/test_parameterizations.py index 085ac2a28..3e3728c05 100644 --- a/gmso/tests/parameterization/test_opls_gmso.py +++ b/gmso/tests/parameterization/test_parameterizations.py @@ -4,11 +4,13 @@ import parmed as pmd import pytest +from gmso import ForceField from gmso.external.convert_parmed import from_parmed from gmso.parameterization.parameterize import apply from gmso.tests.parameterization.parameterization_base_test import ( ParameterizationBaseTest, ) +from gmso.tests.utils import get_path def get_foyer_opls_test_dirs(): @@ -56,3 +58,23 @@ def test_foyer_oplsaa_files( assert_same_connection_params(gmso_top, gmso_top_from_pmd) assert_same_connection_params(gmso_top, gmso_top_from_pmd, "angles") assert_same_connection_params(gmso_top, gmso_top_from_pmd, "dihedrals") + + +class TestGeneralParameterizations(ParameterizationBaseTest): + def test_wildcards(self, ethane_methane_top): + from gmso.core.views import PotentialFilters + + ff = ForceField(get_path("alkanes_wildcards.xml")) + ptop = apply(ethane_methane_top, ff, identify_connections=True) + assert ptop.is_fully_typed() + assert len(ptop.bond_types) == 11 + assert len(ptop.bond_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 1 + assert ptop.bonds[0].bond_type.member_types == ("*", "*") + + assert len(ptop.angle_types) == 12 + 6 # ethane + methane + assert len(ptop.angle_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 1 + assert ptop.angles[0].angle_type.member_types == ("*", "*", "*") + + assert len(ptop.dihedral_types) == 9 + assert len(ptop.dihedral_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 1 + assert ptop.dihedrals[0].dihedral_type.member_types == ("*", "*", "*", "*") From 7128362231df5141ecbcf91bbb109ca457d7a48b Mon Sep 17 00:00:00 2001 From: CalCraven Date: Tue, 16 Sep 2025 11:07:12 -0500 Subject: [PATCH 02/11] change dihedral sample value so that c5 is 0 and rb torsion conversion is allowed --- gmso/tests/files/alkanes_wildcards.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gmso/tests/files/alkanes_wildcards.xml b/gmso/tests/files/alkanes_wildcards.xml index c7fed57e5..a22651478 100644 --- a/gmso/tests/files/alkanes_wildcards.xml +++ b/gmso/tests/files/alkanes_wildcards.xml @@ -77,7 +77,7 @@ - + From d8da81e985bc2d85c2c34ad7a483de721db6d330 Mon Sep 17 00:00:00 2001 From: CalCraven Date: Tue, 16 Sep 2025 13:51:52 -0500 Subject: [PATCH 03/11] Add tests for reversed version of wildcard ordering --- gmso/core/forcefield.py | 2 +- gmso/tests/files/alkanes_wildcards.xml | 16 ++++++++++++++-- .../parameterization/test_parameterizations.py | 8 +++++--- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/gmso/core/forcefield.py b/gmso/core/forcefield.py index f7a7bf048..fc5069755 100644 --- a/gmso/core/forcefield.py +++ b/gmso/core/forcefield.py @@ -372,7 +372,7 @@ def _get_bond_type(self, atom_types, return_match_order=False, warn=False): break if reverse_match_key in self.bond_types: - match = self.bond_types[reverse_match_key], (2, 1) + match = self.bond_types[reverse_match_key], (1, 0) break if match: diff --git a/gmso/tests/files/alkanes_wildcards.xml b/gmso/tests/files/alkanes_wildcards.xml index a22651478..393d7e873 100644 --- a/gmso/tests/files/alkanes_wildcards.xml +++ b/gmso/tests/files/alkanes_wildcards.xml @@ -46,17 +46,29 @@ - + + + + + + + - + + + + + + + diff --git a/gmso/tests/parameterization/test_parameterizations.py b/gmso/tests/parameterization/test_parameterizations.py index 3e3728c05..f160346a6 100644 --- a/gmso/tests/parameterization/test_parameterizations.py +++ b/gmso/tests/parameterization/test_parameterizations.py @@ -68,12 +68,14 @@ def test_wildcards(self, ethane_methane_top): ptop = apply(ethane_methane_top, ff, identify_connections=True) assert ptop.is_fully_typed() assert len(ptop.bond_types) == 11 - assert len(ptop.bond_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 1 - assert ptop.bonds[0].bond_type.member_types == ("*", "*") + assert len(ptop.bond_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 2 + assert ptop.bonds[0].bond_type.member_types == ("*", "opls_135") # ethane + assert ptop.bonds[8].bond_type.member_types == ("*", "*") # methane assert len(ptop.angle_types) == 12 + 6 # ethane + methane - assert len(ptop.angle_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 1 + assert len(ptop.angle_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 2 assert ptop.angles[0].angle_type.member_types == ("*", "*", "*") + assert ptop.angles[2].angle_type.member_types == ("*", "*", "opls_135") assert len(ptop.dihedral_types) == 9 assert len(ptop.dihedral_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 1 From 272fb395d36622d21651e290b162365d15679a9e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Sep 2025 18:52:05 +0000 Subject: [PATCH 04/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- gmso/tests/parameterization/test_parameterizations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gmso/tests/parameterization/test_parameterizations.py b/gmso/tests/parameterization/test_parameterizations.py index f160346a6..8986a439a 100644 --- a/gmso/tests/parameterization/test_parameterizations.py +++ b/gmso/tests/parameterization/test_parameterizations.py @@ -69,8 +69,8 @@ def test_wildcards(self, ethane_methane_top): assert ptop.is_fully_typed() assert len(ptop.bond_types) == 11 assert len(ptop.bond_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 2 - assert ptop.bonds[0].bond_type.member_types == ("*", "opls_135") # ethane - assert ptop.bonds[8].bond_type.member_types == ("*", "*") # methane + assert ptop.bonds[0].bond_type.member_types == ("*", "opls_135") # ethane + assert ptop.bonds[8].bond_type.member_types == ("*", "*") # methane assert len(ptop.angle_types) == 12 + 6 # ethane + methane assert len(ptop.angle_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 2 From e5bb131d267d7ed09b5b18a66fa080f6c0a05f59 Mon Sep 17 00:00:00 2001 From: CalCraven Date: Wed, 24 Sep 2025 11:22:38 -0500 Subject: [PATCH 05/11] Change get_forcefield methods and parameterizations to apply wildcard parameters --- gmso/__init__.py | 4 +- gmso/abc/abstract_connection.py | 18 + gmso/core/angle.py | 45 ++- gmso/core/bond.py | 15 + gmso/core/dihedral.py | 49 ++- gmso/core/forcefield.py | 363 ++++++++---------- gmso/core/improper.py | 57 ++- gmso/core/views.py | 7 + gmso/external/convert_mbuild.py | 4 +- gmso/external/convert_parmed.py | 1 + gmso/formats/top.py | 2 +- .../topology_parameterizer.py | 37 +- gmso/tests/files/bond-order.xml | 156 ++++++++ gmso/tests/files/restrained_benzene_ua.top | 12 +- .../test_parameterization_options.py | 21 + gmso/tests/test_forcefield.py | 91 ++--- gmso/tests/test_top.py | 2 +- gmso/utils/connectivity.py | 21 +- gmso/utils/misc.py | 16 + 19 files changed, 623 insertions(+), 298 deletions(-) create mode 100644 gmso/tests/files/bond-order.xml diff --git a/gmso/__init__.py b/gmso/__init__.py index 814fd6875..8a982eeb4 100644 --- a/gmso/__init__.py +++ b/gmso/__init__.py @@ -137,5 +137,5 @@ def print_level(self, level: str): # Example usage in __init__.py -gmso_logger = GMSOLogger() -gmso_logger.library_logger.setLevel(logging.INFO) +# gmso_logger = GMSOLogger() +# gmso_logger.library_logger.setLevel(logging.WARNING) diff --git a/gmso/abc/abstract_connection.py b/gmso/abc/abstract_connection.py index 516ed5ce6..242d837f9 100644 --- a/gmso/abc/abstract_connection.py +++ b/gmso/abc/abstract_connection.py @@ -1,3 +1,4 @@ +import itertools from typing import Optional, Sequence from pydantic import ConfigDict, Field, model_validator @@ -111,3 +112,20 @@ def __repr__(self): def __str__(self): return f"<{self.__class__.__name__} {self.name}, id: {id(self)}> " + + def get_connection_identifiers(self): + from gmso.core.bond import Bond + + borderDict = {1: "-", 2: "=", 3: "#", 0: "~", None: "~", 1.5: ":"} + site_identifiers = [ + (site.atom_type.atomclass, site.atom_type.name) + for site in self.connection_members + ] + if isinstance(self, Bond): + bond_identifiers = [borderDict[self.bond_order]] + else: + bond_identifiers = [borderDict[b.bond_order] for b in self.bonds] + choices = [(aclass, atype, "*") for aclass, atype in site_identifiers] + choices += [(val, "~") for val in bond_identifiers] + all_combinations = itertools.product(*choices) + return all_combinations diff --git a/gmso/core/angle.py b/gmso/core/angle.py index 14b1ad439..923b273c4 100644 --- a/gmso/core/angle.py +++ b/gmso/core/angle.py @@ -2,11 +2,12 @@ from typing import Callable, ClassVar, Optional, Tuple -from pydantic import ConfigDict, Field +from pydantic import ConfigDict, Field, model_validator from gmso.abc.abstract_connection import Connection from gmso.core.angle_type import AngleType from gmso.core.atom import Atom +from gmso.core.bond import Bond class Angle(Connection): @@ -26,12 +27,24 @@ class Angle(Connection): __members_creator__: ClassVar[Callable] = Atom.model_validate + connectivity: ClassVar[Tuple[Tuple[int]]] = ((0, 1), (1, 2)) + connection_members_: Tuple[Atom, Atom, Atom] = Field( ..., description="The 3 atoms involved in the angle.", alias="connection_members", ) + bonds_: Tuple[Bond, Bond] = Field( + default=None, + description=""" + List of connection bonds. + Ordered to align with connection_members, such that self.bonds_[0] is + the bond between (self.connection_members[0], self.connection_members[1]). + """, + alias="bonds", + ) + angle_type_: Optional[AngleType] = Field( default=None, description="AngleType of this angle.", @@ -48,6 +61,7 @@ class Angle(Connection): """, alias="restraint", ) + model_config = ConfigDict( alias_to_fields=dict( **Connection.model_config["alias_to_fields"], @@ -73,6 +87,24 @@ def restraint(self): """Return the restraint of this angle.""" return self.__dict__.get("restraint_") + @property + def bonds(self): + """Return the bond_order symbol of this bond.""" + return self.__dict__.get("bonds_") + + @bonds.setter + def bonds(self, bonds): + """Return the bonds that makeup this Improper. + + Connectivity is ((0,1), (0,2), (0,3)) + """ + self._bonds = bonds + + @property + def bonds_orders(self): + """Return the bond_order strings of this angle.""" + return "".join([str(b.bond_order) for b in self.bonds]) + def equivalent_members(self): """Return a set of the equivalent connection member tuples. @@ -99,3 +131,14 @@ def __setattr__(self, key, value): super(Angle, self).__setattr__("angle_type", value) else: super(Angle, self).__setattr__(key, value) + + @model_validator(mode="before") + @classmethod + def set_dependent_value_default(cls, data): + if "bonds" not in data and "connection_members" in data: + atoms = data["connection_members"] + data["bonds"] = ( + Bond(connection_members=(atoms[0], atoms[1])), + Bond(connection_members=(atoms[1], atoms[2])), + ) + return data diff --git a/gmso/core/bond.py b/gmso/core/bond.py index cca75ea3e..c9c80682e 100644 --- a/gmso/core/bond.py +++ b/gmso/core/bond.py @@ -26,6 +26,8 @@ class Bond(Connection): __members_creator__: ClassVar[Callable] = Atom.model_validate + connectivity: ClassVar[Tuple[Tuple[int]]] = ((0, 1),) + connection_members_: Tuple[Atom, Atom] = Field( ..., description="The 2 atoms involved in the bond.", @@ -46,12 +48,20 @@ class Bond(Connection): """, alias="restraint", ) + + bond_order_: Optional[float] = Field( + default=None, + description="Bond order of this bond.", + alias="bond_order", + ) + model_config = ConfigDict( alias_to_fields=dict( **Connection.model_config["alias_to_fields"], **{ "bond_type": "bond_type_", "restraint": "restraint_", + "bond_order": "bond_order_", }, ) ) @@ -71,6 +81,11 @@ def restraint(self): """Return the restraint of this bond.""" return self.__dict__.get("restraint_") + @property + def bond_order(self): + """Return the bond_order of this bond.""" + return self.__dict__.get("bond_order_") + def equivalent_members(self): """Get a set of the equivalent connection member tuples. diff --git a/gmso/core/dihedral.py b/gmso/core/dihedral.py index d8891fac0..b3825aaad 100644 --- a/gmso/core/dihedral.py +++ b/gmso/core/dihedral.py @@ -1,9 +1,10 @@ from typing import Callable, ClassVar, Optional, Tuple -from pydantic import ConfigDict, Field +from pydantic import ConfigDict, Field, model_validator from gmso.abc.abstract_connection import Connection from gmso.core.atom import Atom +from gmso.core.bond import Bond from gmso.core.dihedral_type import DihedralType @@ -28,12 +29,24 @@ class Dihedral(Connection): __members_creator__: ClassVar[Callable] = Atom.model_validate + connectivity: ClassVar[Tuple[Tuple[int]]] = ((0, 1), (1, 2), (2, 3)) + connection_members_: Tuple[Atom, Atom, Atom, Atom] = Field( ..., description="The 4 atoms involved in the dihedral.", alias="connection_members", ) + bonds_: Tuple[Bond, Bond, Bond] = Field( + default=None, + description=""" + List of connection bonds. + Ordered to align with connection_members, such that self.bonds_[0] is + the bond between (self.connection_members[0], self.connection_members[1]). + """, + alias="bonds", + ) + dihedral_type_: Optional[DihedralType] = Field( default=None, description="DihedralType of this dihedral.", @@ -50,6 +63,7 @@ class Dihedral(Connection): """, alias="restraint", ) + model_config = ConfigDict( alias_to_fields=dict( **Connection.model_config["alias_to_fields"], @@ -74,6 +88,27 @@ def restraint(self): """Return the restraint of this dihedral.""" return self.__dict__.get("restraint_") + @property + def bonds(self): + """Return the bonds that makeup this dihedral. + + Connectivity is ((0,1), (1,2), (2,3)) + """ + return self.__dict__.get("bonds_") + + @bonds.setter + def bonds(self, bonds): + """Return the bonds that makeup this Improper. + + Connectivity is ((0,1), (0,2), (0,3)) + """ + self._bonds = bonds + + @property + def bonds_orders(self): + """Return the bond_order strings of this dihedral.""" + return "".join([b.bond_order for b in self.bonds]) + def equivalent_members(self): """Get a set of the equivalent connection member tuples @@ -99,3 +134,15 @@ def __setattr__(self, key, value): super(Dihedral, self).__setattr__("dihedral_type_", value) else: super(Dihedral, self).__setattr__(key, value) + + @model_validator(mode="before") + @classmethod + def set_dependent_value_default(cls, data): + if "bonds" not in data and "connection_members" in data: + atoms = data["connection_members"] + data["bonds"] = ( + Bond(connection_members=(atoms[0], atoms[1])), + Bond(connection_members=(atoms[1], atoms[2])), + Bond(connection_members=(atoms[2], atoms[3])), + ) + return data diff --git a/gmso/core/forcefield.py b/gmso/core/forcefield.py index fc5069755..b78c3b16c 100644 --- a/gmso/core/forcefield.py +++ b/gmso/core/forcefield.py @@ -11,8 +11,9 @@ from pydantic import ValidationError from gmso.core.element import element_by_symbol -from gmso.exceptions import GMSOError, MissingPotentialError +from gmso.exceptions import GMSOError from gmso.utils._constants import FF_TOKENS_SEPARATOR +from gmso.utils.connectivity import create_pattern from gmso.utils.decorators import deprecate_kwargs from gmso.utils.ff_utils import ( parse_ff_atomtypes, @@ -22,7 +23,7 @@ parse_ff_virtual_types, validate, ) -from gmso.utils.misc import mask_with, validate_type +from gmso.utils.misc import mask_with, reverse_identifier, validate_type logger = logging.getLogger(__name__) @@ -260,7 +261,7 @@ def group_pairpotential_types_by_expression(self): """ return _group_by_expression(self.pairpotential_types) - def get_potential(self, group, key, return_match_order=False, warn=False): + def get_potential(self, group, key, exact_match=False): """Return a specific potential by key in this ForceField. Parameters @@ -269,10 +270,8 @@ def get_potential(self, group, key, return_match_order=False, warn=False): The potential group to perform this search on key: str (for atom type) or list of str (for connection types) The key to lookup for this potential group - return_match_order : bool, default=False - If true, return the order of connection member types/classes that got matched - warn: bool, default=False - If true, raise a warning instead of Error if no match found + exact_match : bool, default=False + If False, use wildcard matching to check for valid matches in the forcefield. Returns ------- @@ -304,11 +303,14 @@ def get_potential(self, group, key, return_match_order=False, warn=False): str, ) + if group == "atom_type": + return potential_extractors[group](key) return potential_extractors[group]( - key, return_match_order=return_match_order, warn=warn + key, + exact_match=exact_match, ) - def get_parameters(self, group, key, warn=False, copy=False): + def get_parameters(self, group, key, exact_match=False, copy=False): """Return parameters for a specific potential by key in this ForceField. This function uses the `get_potential` function to get Parameters @@ -318,48 +320,51 @@ def get_parameters(self, group, key, warn=False, copy=False): gmso.ForceField.get_potential Get specific potential/parameters from a forcefield potential group by key """ - potential = self.get_potential(group, key, warn=warn) - return potential.get_parameters(copy=copy) + potential = self.get_potential(group, key, exact_match=exact_match) + if potential is None: + return None + if group == "atom_type": + return potential.get_parameters(copy=copy) + return potential[0].get_parameters(copy=copy) - def _get_atom_type(self, atom_type, return_match_order=False, warn=False): + def _get_atom_type(self, atom_type): """Get a particular atom_type with given `atom_type` from this ForceField.""" if isinstance(atom_type, list): atom_type = atom_type[0] - if not self.atom_types.get(atom_type): - msg = f"AtomType {atom_type} is not present in the ForceField" - if warn: - logger.warning(msg) - else: - raise MissingPotentialError(msg) - return self.atom_types.get(atom_type) - def _get_bond_type(self, atom_types, return_match_order=False, warn=False): + def _get_bond_type(self, identifier, exact_match=False): """Get a particular bond_type between `atom_types` from this ForceField.""" - if len(atom_types) != 2: - raise ValueError( - f"BondType potential can only " - f"be extracted for two atoms. Provided {len(atom_types)}" - ) - - forward = FF_TOKENS_SEPARATOR.join(atom_types) - reverse = FF_TOKENS_SEPARATOR.join(reversed(atom_types)) - match = None + if isinstance(identifier, str): + forward = identifier + reverse = reverse_identifier(forward) + else: + if len(identifier) == 2: # add wildcard bond + identifier.append("~") + elif len(identifier) != 3: + raise ValueError( + f"BondType potential can only " + f"be extracted for two atoms. Provided {len(identifier)}" + ) + forward = create_pattern(identifier) + reverse = create_pattern(list(reversed(identifier[:2])) + [identifier[2]]) if forward in self.bond_types: - match = self.bond_types[forward], (0, 1) + return self.bond_types[forward], (0, 1) if reverse in self.bond_types: - match = self.bond_types[reverse], (1, 0) + return self.bond_types[reverse], (1, 0) - if match: - if return_match_order: - return match - else: - return match[0] + if exact_match: + return None for i in range(1, 3): - forward_patterns = mask_with(atom_types, i) - reverse_patterns = mask_with(reversed(atom_types), i) + forward_patterns = mask_with(identifier[:2], i) + reverse_patterns = mask_with( + reversed( + identifier[:2], + ), + i, + ) for forward_pattern, reverse_pattern in zip( forward_patterns, reverse_patterns @@ -368,56 +373,43 @@ def _get_bond_type(self, atom_types, return_match_order=False, warn=False): reverse_match_key = FF_TOKENS_SEPARATOR.join(reverse_pattern) if forward_match_key in self.bond_types: - match = self.bond_types[forward_match_key], (0, 1) - break + return self.bond_types[forward_match_key], (0, 1) if reverse_match_key in self.bond_types: - match = self.bond_types[reverse_match_key], (1, 0) - break - - if match: - break + return self.bond_types[reverse_match_key], (1, 0) + return None - msg = ( - f"BondType between atoms {atom_types[0]} and {atom_types[1]} " - f"is missing from the ForceField" - ) - if match: - if return_match_order: - return match - else: - return match[0] - elif warn: - logger.warning(msg) - return None - else: - raise MissingPotentialError(msg) - - def _get_angle_type(self, atom_types, return_match_order=False, warn=False): + def _get_angle_type(self, identifier, exact_match=False): """Get a particular angle_type between `atom_types` from this ForceField.""" - if len(atom_types) != 3: - raise ValueError( - f"AngleType potential can only " - f"be extracted for three atoms. Provided {len(atom_types)}" + if isinstance(identifier, str): + forward = identifier + reverse = reverse_identifier(forward) + else: + if len(identifier) == 3: + identifier.append("~") + identifier.append("~") + elif len(identifier) != 5: + raise ValueError( + f"AngleType potential can only " + f"be extracted for three atoms. Provided {len(identifier)}" + ) + forward = create_pattern(identifier) + reverse = create_pattern( + list(reversed(identifier[:3])) + list(reversed(identifier[3:])) ) - forward = FF_TOKENS_SEPARATOR.join(atom_types) - reverse = FF_TOKENS_SEPARATOR.join(reversed(atom_types)) - match = None if forward in self.angle_types: - match = self.angle_types[forward], (0, 1, 2) + return self.angle_types[forward], (0, 1, 2) if reverse in self.angle_types: - match = self.angle_types[reverse], (2, 1, 0) - - if match: - if return_match_order: - return match - else: - return match[0] + return self.angle_types[reverse], (2, 1, 0) + if exact_match: + return None for i in range(1, 4): - forward_patterns = mask_with(atom_types, i) - reverse_patterns = mask_with(reversed(atom_types), i) + forward_patterns = mask_with(identifier, i) + reverse_patterns = mask_with( + list(reversed(identifier[:3])) + list(reversed(identifier[3:])), i + ) for forward_pattern, reverse_pattern in zip( forward_patterns, reverse_patterns @@ -426,57 +418,43 @@ def _get_angle_type(self, atom_types, return_match_order=False, warn=False): reverse_match_key = FF_TOKENS_SEPARATOR.join(reverse_pattern) if forward_match_key in self.angle_types: - match = self.angle_types[forward_match_key], (0, 1, 2) - break + return self.angle_types[forward_match_key], (0, 1, 2) if reverse_match_key in self.angle_types: - match = self.angle_types[reverse_match_key], (2, 1, 0) - break - - if match: - break + return self.angle_types[reverse_match_key], (2, 1, 0) - msg = ( - f"AngleType between atoms {atom_types[0]}, {atom_types[1]} " - f"and {atom_types[2]} is missing from the ForceField" - ) - if match: - if return_match_order: - return match - else: - return match[0] - elif warn: - logger.warning(msg) - return None - else: - raise MissingPotentialError(msg) + return None - def _get_dihedral_type(self, atom_types, return_match_order=False, warn=False): + def _get_dihedral_type(self, identifier, exact_match=False): """Get a particular dihedral_type between `atom_types` from this ForceField.""" - if len(atom_types) != 4: - raise ValueError( - f"DihedralType potential can only " - f"be extracted for four atoms. Provided {len(atom_types)}" + if isinstance(identifier, str): + forward = identifier + reverse = reverse_identifier(forward) + else: + if len(identifier) == 4: + identifier.append("~") + identifier.append("~") + identifier.append("~") + elif len(identifier) != 7: + raise ValueError( + f"DihedralType potential can only " + f"be extracted for four atoms and three bonds. Provided {len(identifier)}" + ) + forward = create_pattern(identifier) + reverse = create_pattern( + list(reversed(identifier[:4])) + list(reversed(identifier[4:])) ) - forward = FF_TOKENS_SEPARATOR.join(atom_types) - reverse = FF_TOKENS_SEPARATOR.join(reversed(atom_types)) - - match = None if forward in self.dihedral_types: - match = self.dihedral_types[forward], (0, 1, 2, 3) + return self.dihedral_types[forward], (0, 1, 2, 3) if reverse in self.dihedral_types: - match = self.dihedral_types[reverse], (3, 2, 1, 0) - - if match: - if return_match_order: - return match - else: - return match[0] + return self.dihedral_types[reverse], (3, 2, 1, 0) + if exact_match: + return None for i in range(1, 5): - forward_patterns = mask_with(atom_types, i) - reverse_patterns = mask_with(reversed(atom_types), i) + forward_patterns = mask_with(identifier, i) + reverse_patterns = mask_with(reversed(identifier[:4] + identifier[4:]), i) for forward_pattern, reverse_pattern in zip( forward_patterns, reverse_patterns @@ -485,100 +463,82 @@ def _get_dihedral_type(self, atom_types, return_match_order=False, warn=False): reverse_match_key = FF_TOKENS_SEPARATOR.join(reverse_pattern) if forward_match_key in self.dihedral_types: - match = self.dihedral_types[forward_match_key], (0, 1, 2, 3) - break + return self.dihedral_types[forward_match_key], (0, 1, 2, 3) if reverse_match_key in self.dihedral_types: - match = self.dihedral_types[reverse_match_key], (3, 2, 1, 0) - break + return self.dihedral_types[reverse_match_key], (3, 2, 1, 0) - if match: - break + return None - msg = ( - f"DihedralType between atoms {atom_types[0]}, {atom_types[1]}, " - f"{atom_types[2]} and {atom_types[3]} is missing from the ForceField." - ) - if match: - if return_match_order: - return match - else: - return match[0] - elif warn: - logger.warning(msg) - return None + def _get_improper_type(self, identifier, exact_match=False): + """Get a particular improper_type between `identifier` from this ForceField.""" + if isinstance(identifier, str): + forward = identifier + equivalent = None # TODO + # equivalent = equivalent_identifier else: - raise MissingPotentialError(msg) - - def _get_improper_type(self, atom_types, return_match_order=False, warn=False): - """Get a particular improper_type between `atom_types` from this ForceField.""" - if len(atom_types) != 4: - raise ValueError( - f"ImproperType potential can only " - f"be extracted for four atoms. Provided {len(atom_types)}" - ) + if len(identifier) == 4: # add wildcard bonds + identifier.append("~") + identifier.append("~") + identifier.append("~") + elif len(identifier) != 7: + raise ValueError( + f"ImproperType potential can only " + f"be extracted for four atoms and three bonds. Provided {len(identifier)}" + ) - forward = FF_TOKENS_SEPARATOR.join(atom_types) + forward = create_pattern(identifier) + equiv_idx = [ + (0, i, j, k) for (i, j, k) in itertools.permutations((1, 2, 3), 3) + ] + equivalent = [ + [ + identifier[m], + identifier[n], + identifier[o], + identifier[p], + identifier[n + 3], + identifier[o + 3], + identifier[p + 3], + ] + for (m, n, o, p) in equiv_idx + ] if forward in self.improper_types: - if return_match_order: - return self.improper_types[forward], (0, 1, 2, 3) - else: - return self.improper_types[forward] + return self.improper_types[forward], (0, 1, 2, 3) - equiv_idx = [(0, i, j, k) for (i, j, k) in itertools.permutations((1, 2, 3), 3)] - equivalent = [ - [atom_types[m], atom_types[n], atom_types[o], atom_types[p]] - for (m, n, o, p) in equiv_idx - ] + if not equivalent: + return None for eq, order in zip(equivalent, equiv_idx): - eq_key = FF_TOKENS_SEPARATOR.join(eq) + eq_key = create_pattern(eq) if eq_key in self.improper_types: - if return_match_order: - return self.improper_types[eq_key], order - else: - return self.improper_types[eq_key] + return self.improper_types[eq_key], order + + if exact_match: + return None - match = None for i in range(1, 5): - forward_patterns = mask_with(atom_types, i) + forward_patterns = mask_with(identifier, i) for forward_pattern in forward_patterns: forward_match_key = FF_TOKENS_SEPARATOR.join(forward_pattern) if forward_match_key in self.improper_types: - match = self.improper_types[forward_match_key], (0, 1, 2, 3) - break - if match: - break - if not match: - for i in range(1, 5): - for eq, order in zip(equivalent, equiv_idx): - equiv_patterns = mask_with(eq, i) - for equiv_pattern in equiv_patterns: - equiv_pattern_key = FF_TOKENS_SEPARATOR.join(equiv_pattern) - if equiv_pattern_key in self.improper_types: - match = ( - self.improper_types[equiv_pattern_key], - order, - ) - break - if match: - break - if match: - break - - msg = ( - f"ImproperType between atoms {atom_types[0]}, {atom_types[1]}, " - f"{atom_types[2]} and {atom_types[3]} is missing from the ForceField." - ) - if match: - return match - elif warn: - logger.warning(msg) - return None - else: - raise MissingPotentialError(msg) + return self.improper_types[forward_match_key], (0, 1, 2, 3) + for i in range(1, 5): + for eq, order in zip(equivalent, equiv_idx): + equiv_patterns = mask_with(eq, i) + for equiv_pattern in equiv_patterns: + equiv_pattern_key = FF_TOKENS_SEPARATOR.join(equiv_pattern) + if equiv_pattern_key in self.improper_types: + return ( + self.improper_types[equiv_pattern_key], + order, + ) + + return None - def _get_virtual_type(self, atom_types, return_match_order=False, warn=False): + def _get_virtual_type( + self, atom_types, return_match_order=False, exact_match=False + ): """Get a particular virtual_type between `atom_types` from this ForceField.""" forward = FF_TOKENS_SEPARATOR.join(atom_types) @@ -586,18 +546,7 @@ def _get_virtual_type(self, atom_types, return_match_order=False, warn=False): match = None if forward in self.virtual_types: match = self.virtual_types[forward], tuple(range(n_elements)) - - msg = f"VirtualType between atoms {tuple(atype for atype in atom_types)} is missing from the ForceField" - if match: - if return_match_order: - return match - else: - return match[0] # only return the atoms, not their order - elif warn: - logger.warning(msg) - return None - else: - raise MissingPotentialError(msg) + return match def __repr__(self): """Return a formatted representation of the Forcefield.""" diff --git a/gmso/core/improper.py b/gmso/core/improper.py index 58f6c414d..0b223beba 100644 --- a/gmso/core/improper.py +++ b/gmso/core/improper.py @@ -2,10 +2,11 @@ from typing import Callable, ClassVar, Optional, Tuple -from pydantic import ConfigDict, Field +from pydantic import ConfigDict, Field, model_validator from gmso.abc.abstract_connection import Connection from gmso.core.atom import Atom +from gmso.core.bond import Bond from gmso.core.improper_type import ImproperType @@ -31,6 +32,8 @@ class Improper(Connection): __members_creator__: ClassVar[Callable] = Atom.model_validate + connectivity: ClassVar[Tuple[Tuple[int]]] = ((0, 1), (0, 2), (0, 3)) + connection_members_: Tuple[Atom, Atom, Atom, Atom] = Field( ..., description="The 4 atoms of this improper. Central atom first, " @@ -38,16 +41,35 @@ class Improper(Connection): alias="connection_members", ) + bonds_: Tuple[Bond, Bond, Bond] = Field( + default=None, + description=""" + List of connection bonds. + Ordered to align with connection_members, such that self.bonds_[0] is + the bond between (self.connection_members[0], self.connection_members[1]). + """, + alias="bonds", + ) + improper_type_: Optional[ImproperType] = Field( default=None, description="ImproperType of this improper.", alias="improper_type", ) + + bond_orders_: Optional[Tuple[str, str]] = Field( + default=None, + description=""" + List of connection members bond orders. + """, + alias="bond_orders", + ) model_config = ConfigDict( alias_to_fields=dict( **Connection.model_config["alias_to_fields"], **{ "improper_type": "improper_type_", + "bond_orders": "bond_orders_", }, ) ) @@ -63,6 +85,27 @@ def connection_type(self): # ToDo: Deprecate this? return self.__dict__.get("improper_type_") + @property + def bonds(self): + """Return the bonds that makeup this Improper. + + Connectivity is ((0,1), (0,2), (0,3)) + """ + return self.__dict__.get("bonds_") + + @bonds.setter + def bonds(self, new_bonds): + """Return the bonds that makeup this Improper. + + Connectivity is ((0,1), (0,2), (0,3)) + """ + self.bonds_ = new_bonds + + @property + def bonds_orders(self): + """Return the bond_order strings of this improper.""" + return [b.bond_order for b in self.bonds] + def equivalent_members(self): """Get a set of the equivalent connection member tuples. @@ -94,3 +137,15 @@ def __setattr__(self, key, value): super(Improper, self).__setattr__("improper_type_", value) else: super(Improper, self).__setattr__(key, value) + + @model_validator(mode="before") + @classmethod + def set_dependent_value_default(cls, data): + if "bonds" not in data and "connection_members" in data: + atoms = data["connection_members"] + data["bonds"] = ( + Bond(connection_members=(atoms[0], atoms[1])), + Bond(connection_members=(atoms[0], atoms[2])), + Bond(connection_members=(atoms[0], atoms[3])), + ) + return data diff --git a/gmso/core/views.py b/gmso/core/views.py index 45d4b66ff..845725cb9 100644 --- a/gmso/core/views.py +++ b/gmso/core/views.py @@ -44,6 +44,11 @@ def get_parameters(potential): ) +def get_name(potential): + """Return the string name of the object""" + return potential.name + + def filtered_potentials(potential_types, identifier): """Filter and return unique potentials based on pre-defined identifier function.""" visited = defaultdict(set) @@ -63,6 +68,7 @@ class PotentialFilters: UNIQUE_PARAMETERS = "unique_parameters" UNIQUE_ID = "unique_id" REPEAT_DUPLICATES = "repeat_duplicates" + NAME = "name" @staticmethod def all(): @@ -80,6 +86,7 @@ def all(): PotentialFilters.UNIQUE_PARAMETERS: get_parameters, PotentialFilters.UNIQUE_ID: lambda p: id(p), PotentialFilters.REPEAT_DUPLICATES: lambda _: str(uuid.uuid4()), + PotentialFilters.NAME: get_name, } diff --git a/gmso/external/convert_mbuild.py b/gmso/external/convert_mbuild.py index 10f8f2719..1983e829a 100644 --- a/gmso/external/convert_mbuild.py +++ b/gmso/external/convert_mbuild.py @@ -21,6 +21,7 @@ import logging logger = logging.getLogger(__name__) +borderDict = {"single": 1, "double": 2, "triple": 3, "default": "0", "aromatic": 1.5} def from_mbuild( @@ -113,10 +114,11 @@ def from_mbuild( site = _parse_site(site_map, part, search_method, infer_element=infer_elements) top.add_site(site) - for b1, b2 in compound.bonds(): + for b1, b2, border in compound.bonds(return_bond_order=True): assert site_map[b1]["site"].molecule == site_map[b2]["site"].molecule new_bond = Bond( connection_members=[site_map[b1]["site"], site_map[b2]["site"]], + bond_order=borderDict[border["bond_order"]], ) top.add_connection(new_bond, update_types=False) diff --git a/gmso/external/convert_parmed.py b/gmso/external/convert_parmed.py index bdf07a0f4..458c00d2a 100644 --- a/gmso/external/convert_parmed.py +++ b/gmso/external/convert_parmed.py @@ -123,6 +123,7 @@ def from_parmed(structure, refer_type=True): for angle in structure.angles: # Generate angles and harmonic parameters # If typed, assumed to be harmonic angles + # import pdb; pdb.set_trace() top_connection = gmso.Angle( connection_members=_sort_angle_members( top, site_map, *attrgetter("atom1", "atom2", "atom3")(angle) diff --git a/gmso/formats/top.py b/gmso/formats/top.py index e318b022a..85a90e0e5 100644 --- a/gmso/formats/top.py +++ b/gmso/formats/top.py @@ -564,8 +564,8 @@ def _position_restraints_writer(top, site, shifted_idx_map): def _bond_restraint_writer(top, bond, shifted_idx_map): """Write bond restraint information.""" line = "{0:8s}{1:8s}{2:4s}{3:15.5f}{4:15.5f}\n".format( - str(shifted_idx_map[top.get_index(bond.connection_members[1])] + 1), str(shifted_idx_map[top.get_index(bond.connection_members[0])] + 1), + str(shifted_idx_map[top.get_index(bond.connection_members[1])] + 1), "6", bond.restraint["r_eq"].in_units(u.nm).value, bond.restraint["k"].in_units(u.Unit("kJ/(mol * nm**2)")).value, diff --git a/gmso/parameterization/topology_parameterizer.py b/gmso/parameterization/topology_parameterizer.py index 3620e13fa..92c5d7b1a 100644 --- a/gmso/parameterization/topology_parameterizer.py +++ b/gmso/parameterization/topology_parameterizer.py @@ -194,9 +194,13 @@ def _parameterize_virtual_sites(self, top, sites, bonds, ff): def _apply_connection_parameters(self, connections, ff, error_on_missing=True): """Find and assign potentials from the forcefield for the provided connections.""" visited = dict() - for connection in connections: group, connection_identifiers = self.connection_identifier(connection) + # TODO: sort connection_identifiers in connection_identifiers + connection_identifiers = sorted( + list(connection_identifiers), + key=lambda item: 100 * item.count("*") + item.count("~"), + ) match = None for identifier_key in connection_identifiers: if tuple(identifier_key) in visited: @@ -206,8 +210,7 @@ def _apply_connection_parameters(self, connections, ff, error_on_missing=True): match = ff.get_potential( group=group, key=identifier_key, - return_match_order=True, - warn=True, + exact_match=True, ) if match: visited[tuple(identifier_key)] = match @@ -222,6 +225,20 @@ def _apply_connection_parameters(self, connections, ff, error_on_missing=True): setattr(connection, group, match[0].clone(self.config.fast_copy)) matched_order = [connection.connection_members[i] for i in match[1]] connection.connection_members = matched_order + if group == "angle_type": # reverse angle.bonds + if match[1][0] == 2: + connection.bonds = [connection.bonds[1], connection.bonds[0]] + elif group == "dihedral_type": # reorder dihedral.bonds + if match[1][0] == 3: + connection.bonds = [ + connection.bonds[2], + connection.bonds[1], + connection.bonds[0], + ] + elif group == "improper_type": # reorder improper.bonds + improper_bonds = [connection.bonds[i - 1] for i in match[1][1:]] + connection.bonds = improper_bonds + if not match[0].member_types: connection.connection_type.member_types = tuple( member.atom_type.name for member in matched_order @@ -246,8 +263,6 @@ def _apply_virtual_site_parameters(self, virtual_sites, ff, error_on_missing=Tru match = ff.get_potential( group=group, key=identifier_key, - return_match_order=True, - warn=True, ) if match: visited[tuple(identifier_key)] = match @@ -442,12 +457,12 @@ def connection_identifier( ): # This can extended to incorporate a pluggable object from the forcefield. """Return the group and list of identifiers for a connection to query the forcefield for its potential.""" group = POTENTIAL_GROUPS[type(connection)] - return group, [ - list(member.atom_type.name for member in connection.connection_members), - list( - member.atom_type.atomclass for member in connection.connection_members - ), - ] + return ( + group, + [ + *connection.get_connection_identifiers(), # the viable keys made up of the bond orders + ], + ) @staticmethod def virtual_site_identifier( diff --git a/gmso/tests/files/bond-order.xml b/gmso/tests/files/bond-order.xml new file mode 100644 index 000000000..e2b592cf6 --- /dev/null +++ b/gmso/tests/files/bond-order.xml @@ -0,0 +1,156 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/gmso/tests/files/restrained_benzene_ua.top b/gmso/tests/files/restrained_benzene_ua.top index 0a70d14dc..3b083eea8 100644 --- a/gmso/tests/files/restrained_benzene_ua.top +++ b/gmso/tests/files/restrained_benzene_ua.top @@ -56,12 +56,12 @@ BenzeneUA 3 [ angle_restraints ] ; ai aj ai ak funct theta_eq k multiplicity -1 6 1 2 1 120.00000 1000.00000 1 -2 3 2 1 1 120.00000 1000.00000 1 -3 4 3 2 1 120.00000 1000.00000 1 -4 5 4 3 1 120.00000 1000.00000 1 -5 6 5 4 1 120.00000 1000.00000 1 -6 5 6 1 1 120.00000 1000.00000 1 +1 2 1 6 1 120.00000 1000.00000 1 +2 1 2 3 1 120.00000 1000.00000 1 +3 2 3 4 1 120.00000 1000.00000 1 +4 3 4 5 1 120.00000 1000.00000 1 +5 4 5 6 1 120.00000 1000.00000 1 +6 1 6 5 1 120.00000 1000.00000 1 [ dihedrals ] ; ai aj ak al funct c0 c1 c2 c3 c4 c5 diff --git a/gmso/tests/parameterization/test_parameterization_options.py b/gmso/tests/parameterization/test_parameterization_options.py index 5e02f71bf..a7efe93c8 100644 --- a/gmso/tests/parameterization/test_parameterization_options.py +++ b/gmso/tests/parameterization/test_parameterization_options.py @@ -270,3 +270,24 @@ def test_hierarchical_mol_structure( speedup_by_moltag=speedup_by_moltag, match_ff_by=match_ff_by, ) + + def test_bond_order(self): + from gmso.core.views import PotentialFilters + + cpd = mb.load("C=CCC#CC(=O)O", smiles=True) + top = from_mbuild(cpd) + ff = ForceField(get_path("bond-order.xml")) + ptop = apply(top, ff, identify_connections=True) + assert len(ptop.atom_types) == 14 + assert len(ptop.bond_types) == 13 + # assert len(ptop.bond_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 5 + assert len(ptop.bond_types(PotentialFilters.NAME)) == 5 + assert len(ptop.angle_types) == 18 + # assert len(ptop.angle_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 5 + assert len(ptop.angle_types(PotentialFilters.NAME)) == 5 + assert len(ptop.dihedral_types) == 18 + # assert len(ptop.dihedral_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 3 + assert len(ptop.dihedral_types(PotentialFilters.NAME)) == 3 + assert len(ptop.improper_types) == 7 + # assert len(ptop.improper_types(PotentialFilters.UNIQUE_NAME_CLASS)) == 3 + assert len(ptop.improper_types(PotentialFilters.NAME)) == 3 diff --git a/gmso/tests/test_forcefield.py b/gmso/tests/test_forcefield.py index 085a0c714..4035d5b4a 100644 --- a/gmso/tests/test_forcefield.py +++ b/gmso/tests/test_forcefield.py @@ -1,5 +1,3 @@ -import logging - import lxml import pytest import unyt as u @@ -13,7 +11,6 @@ ForceFieldParseError, GMSOError, MissingAtomTypesError, - MissingPotentialError, ) from gmso.tests.base_test import BaseTest from gmso.tests.utils import allclose_units_mixed, get_path @@ -424,7 +421,9 @@ def test_forcefield_get_parameters_atom_type_copy(self, opls_ethane_foyer): assert allclose_units_mixed(params.values(), params_copy.values()) def test_forcefield_get_potential_bond_type(self, opls_ethane_foyer): - bt = opls_ethane_foyer.get_potential("bond_type", key=["opls_135", "opls_140"]) + bt, _ = opls_ethane_foyer.get_potential( + "bond_type", key=["opls_135", "opls_140"] + ) assert bt.name == "BondType-Harmonic-2" params = bt.parameters assert "k" in params @@ -437,9 +436,10 @@ def test_forcefield_get_potential_bond_type(self, opls_ethane_foyer): ) def test_forcefield_get_potential_bond_type_reversed(self, opls_ethane_foyer): - assert opls_ethane_foyer.get_potential( - "bond_type", ["opls_135", "opls_140"] - ) == opls_ethane_foyer.get_potential("bond_type", ["opls_140", "opls_135"]) + bt1, b1 = opls_ethane_foyer.get_potential("bond_type", ["opls_135", "opls_140"]) + bt2, b2 = opls_ethane_foyer.get_potential("bond_type", ["opls_140", "opls_135"]) + assert bt1 == bt2 + assert tuple(reversed(b1)) == b2 def test_forcefield_get_parameters_bond_type(self, opls_ethane_foyer): params = opls_ethane_foyer.get_parameters( @@ -451,7 +451,7 @@ def test_forcefield_get_parameters_bond_type(self, opls_ethane_foyer): ) def test_forcefield_get_potential_angle_type(self, opls_ethane_foyer): - at = opls_ethane_foyer.get_potential( + at, _ = opls_ethane_foyer.get_potential( "angle_type", key=["opls_135", "opls_135", "opls_140"] ) assert at.name == "AngleType-Harmonic-1" @@ -467,11 +467,14 @@ def test_forcefield_get_potential_angle_type(self, opls_ethane_foyer): ) def test_forcefield_get_potential_angle_type_reversed(self, opls_ethane_foyer): - assert opls_ethane_foyer.get_potential( + a1, members1 = opls_ethane_foyer.get_potential( "angle_type", ["opls_135", "opls_135", "opls_140"] - ) == opls_ethane_foyer.get_potential( + ) + a2, members2 = opls_ethane_foyer.get_potential( "angle_type", ["opls_140", "opls_135", "opls_135"] ) + assert a1 == a2 + assert tuple(reversed(members1)) == members2 def test_forcefield_get_parameters_angle_type(self, opls_ethane_foyer): params = opls_ethane_foyer.get_parameters( @@ -484,7 +487,7 @@ def test_forcefield_get_parameters_angle_type(self, opls_ethane_foyer): ) def test_forcefield_get_potential_dihedral_type(self, opls_ethane_foyer): - dt = opls_ethane_foyer.get_potential( + dt, _ = opls_ethane_foyer.get_potential( "dihedral_type", key=["opls_140", "opls_135", "opls_135", "opls_140"], ) @@ -538,76 +541,38 @@ def test_forcefield_get_potential_non_string_key(self, opls_ethane_foyer): opls_ethane_foyer.get_potential("atom_type", key=[111]) def test_get_atom_type_missing(self, opls_ethane_foyer, caplog): - with pytest.raises(MissingPotentialError): - opls_ethane_foyer._get_atom_type("opls_359", warn=False) - - with caplog.at_level(logging.WARNING, logger="gmso"): - opls_ethane_foyer._get_atom_type("opls_359", warn=True) - assert "AtomType opls_359 is not present in the ForceField" in caplog.text + assert opls_ethane_foyer._get_atom_type("opls_359") is None def test_get_bond_type_missing(self, opls_ethane_foyer, caplog): - with pytest.raises(MissingPotentialError): - opls_ethane_foyer._get_bond_type(["opls_359", "opls_600"], warn=False) - - with caplog.at_level(logging.WARNING, logger="gmso"): - opls_ethane_foyer._get_bond_type(["opls_359", "opls_600"], warn=True) - assert "BondType between atoms opls_359 and opls_600 is missing" in caplog.text + assert opls_ethane_foyer._get_bond_type(["opls_359", "opls_600"]) is None def test_get_angle_type_missing(self, opls_ethane_foyer, caplog): - with pytest.raises(MissingPotentialError): - opls_ethane_foyer._get_angle_type( - ["opls_359", "opls_600", "opls_700"], warn=False - ) - - with caplog.at_level(logging.WARNING, logger="gmso"): - opls_ethane_foyer._get_angle_type( - ["opls_359", "opls_600", "opls_700"], warn=True - ) assert ( - "AngleType between atoms opls_359, opls_600 and opls_700 is missing" - in caplog.text + opls_ethane_foyer._get_angle_type(["opls_359", "opls_600", "opls_700"]) + is None ) def test_get_dihedral_type_missing(self, opls_ethane_foyer, caplog): - with pytest.raises(MissingPotentialError): - opls_ethane_foyer._get_dihedral_type( - ["opls_359", "opls_600", "opls_700", "opls_800"], warn=False - ) - - with caplog.at_level(logging.WARNING, logger="gmso"): + assert ( opls_ethane_foyer._get_dihedral_type( - ["opls_359", "opls_600", "opls_700", "opls_800"], warn=True + ["opls_359", "opls_600", "opls_700", "opls_800"] ) - assert ( - "DihedralType between atoms opls_359, opls_600, opls_700 and opls_800" - in caplog.text + is None ) def test_get_improper_type_missing(self, opls_ethane_foyer, caplog): - with pytest.raises(MissingPotentialError): - opls_ethane_foyer._get_improper_type( - ["opls_359", "opls_600", "opls_700", "opls_800"], warn=False - ) - - with caplog.at_level(logging.WARNING, logger="gmso"): + assert ( opls_ethane_foyer._get_improper_type( - ["opls_359", "opls_600", "opls_700", "opls_800"], warn=True + ["opls_359", "opls_600", "opls_700", "opls_800"] ) - assert ( - "ImproperType between atoms opls_359, opls_600, opls_700 and opls_800" - in caplog.text + is None ) def test_get_virtual_type_missing(self, caplog): ff = ForceField(get_path("ff-example0.xml"), backend="gmso") - with pytest.raises(MissingPotentialError): - ff._get_virtual_type(["Missing"], warn=False) - - with caplog.at_level(logging.WARNING, logger="gmso"): - ff._get_virtual_type(["Missing", "Missing"], warn=True) - assert "VirtualType between atoms ('Missing', 'Missing')" in caplog.text + assert ff._get_virtual_type(["Missing"]) is None - match = ff._get_virtual_type(["Xe"], warn=False) + match = ff._get_virtual_type(["Xe"]) assert match def test_non_element_types(self, non_element_ff, opls_ethane_foyer): @@ -639,10 +604,10 @@ def test_forcefield_get_impropers_combinations(self): "CT~CT~HC~HC": ImproperType(name="imp1"), "CT~HC~HC~HC": ImproperType(name="imp2"), } - imp1 = ff_with_impropers.get_potential( + imp1, _ = ff_with_impropers.get_potential( "improper_type", ["CT", "HC", "HC", "CT"] ) - imp2 = ff_with_impropers.get_potential( + imp2, _ = ff_with_impropers.get_potential( "improper_type", ["CT", "HC", "CT", "HC"] ) assert imp1.name == imp2.name diff --git a/gmso/tests/test_top.py b/gmso/tests/test_top.py index 54d3dd8b4..885fe9822 100644 --- a/gmso/tests/test_top.py +++ b/gmso/tests/test_top.py @@ -235,7 +235,7 @@ def test_benzene_restraints(self, typed_benzene_ua_system): for section, ref_section in zip(sections, ref_sections): assert section == ref_section if "dihedral" in section: - # Need to deal with these separatelt due to member's order issue + # Need to deal with these separately due to member's order issue # Each dict will have the keys be members and values be their parameters members = dict() ref_members = dict() diff --git a/gmso/utils/connectivity.py b/gmso/utils/connectivity.py index ae099d627..8cb87808b 100644 --- a/gmso/utils/connectivity.py +++ b/gmso/utils/connectivity.py @@ -78,9 +78,14 @@ def identify_connections(top, index_only=False): def _add_connections(top, matches, conn_type): """Add connections to the topology.""" for sorted_conn in matches: - to_add_conn = CONNS[conn_type]( - connection_members=[top.sites[idx] for idx in sorted_conn] - ) + cmembers = [top.sites[idx] for idx in sorted_conn] + bonds = list() + for i, j in CONNS[conn_type].connectivity: + bond = (cmembers[i], cmembers[j]) + key = frozenset([bond, tuple(reversed(bond))]) + bonds.append(top._unique_connections[key]) + to_add_conn = CONNS[conn_type](connection_members=cmembers, bonds=tuple(bonds)) + # import pdb; pdb.set_trace() top.add_connection(to_add_conn, update_types=False) @@ -452,3 +457,13 @@ def _graph_from_vtype(vtype): virtual_type_graph.add_edge(i, i + 1) return virtual_type_graph + + +def create_pattern(combination): + bonds_cutoff = len(combination) // 2 + sites = combination[: bonds_cutoff + 1] + bonds = combination[bonds_cutoff + 1 :] + pattern = sites[0] + for b, sit in zip(bonds, sites[1:]): + pattern += b + sit + return pattern diff --git a/gmso/utils/misc.py b/gmso/utils/misc.py index ad0a799af..afe49da5a 100644 --- a/gmso/utils/misc.py +++ b/gmso/utils/misc.py @@ -135,3 +135,19 @@ def get_xml_representation(value): return ",".join(value) else: return str(value) + + +def reverse_identifier(identifier: str): + bond_tokens = ["~", "-", "=", "#"] + outStr = "" + currentNode = "" + for letter in identifier[::-1]: + if letter in bond_tokens: + outStr += currentNode[::-1] + currentNode = "" + outStr += letter # should be a bond + else: + currentNode += letter + if currentNode: + outStr += currentNode[::-1] + return outStr From 63c09e9b153062657f9d7c31490bca1ee1629764 Mon Sep 17 00:00:00 2001 From: CalCraven Date: Wed, 22 Oct 2025 14:49:47 -0500 Subject: [PATCH 06/11] Add classes and types to gmso xsd schema, proprely get mbuild bond_orders --- gmso/external/convert_mbuild.py | 3 +-- gmso/parameterization/foyer_utils.py | 2 +- gmso/utils/schema/ff-gmso.xsd | 29 ++++++++++++++++------------ 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/gmso/external/convert_mbuild.py b/gmso/external/convert_mbuild.py index 1983e829a..eb02da542 100644 --- a/gmso/external/convert_mbuild.py +++ b/gmso/external/convert_mbuild.py @@ -21,7 +21,6 @@ import logging logger = logging.getLogger(__name__) -borderDict = {"single": 1, "double": 2, "triple": 3, "default": "0", "aromatic": 1.5} def from_mbuild( @@ -118,7 +117,7 @@ def from_mbuild( assert site_map[b1]["site"].molecule == site_map[b2]["site"].molecule new_bond = Bond( connection_members=[site_map[b1]["site"], site_map[b2]["site"]], - bond_order=borderDict[border["bond_order"]], + bond_order=border["bond_order"], ) top.add_connection(new_bond, update_types=False) diff --git a/gmso/parameterization/foyer_utils.py b/gmso/parameterization/foyer_utils.py index 01477c4fe..63a334c22 100644 --- a/gmso/parameterization/foyer_utils.py +++ b/gmso/parameterization/foyer_utils.py @@ -84,7 +84,7 @@ def get_topology_graph( atoms_indices = [ atom_index_map[id(atom)] for atom in top_bond.connection_members ] - top_graph.add_bond(atoms_indices[0], atoms_indices[1]) + top_graph.add_bond(atoms_indices[0], atoms_indices[1], getattr(top_bond, "bond_order", 0.0)) return top_graph diff --git a/gmso/utils/schema/ff-gmso.xsd b/gmso/utils/schema/ff-gmso.xsd index ee28825f8..2d93a4ab6 100644 --- a/gmso/utils/schema/ff-gmso.xsd +++ b/gmso/utils/schema/ff-gmso.xsd @@ -48,6 +48,8 @@ + + @@ -58,8 +60,8 @@ - - + + @@ -68,9 +70,8 @@ - - - + + @@ -86,6 +87,8 @@ + + @@ -102,9 +105,9 @@ - - - + + + @@ -159,6 +162,8 @@ + + @@ -177,10 +182,10 @@ - - - - + + + + From b6ba7bd167568a72c809b309fbe2ae5abe7edac8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Oct 2025 19:50:54 +0000 Subject: [PATCH 07/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- gmso/parameterization/foyer_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gmso/parameterization/foyer_utils.py b/gmso/parameterization/foyer_utils.py index 63a334c22..80b56db99 100644 --- a/gmso/parameterization/foyer_utils.py +++ b/gmso/parameterization/foyer_utils.py @@ -84,7 +84,9 @@ def get_topology_graph( atoms_indices = [ atom_index_map[id(atom)] for atom in top_bond.connection_members ] - top_graph.add_bond(atoms_indices[0], atoms_indices[1], getattr(top_bond, "bond_order", 0.0)) + top_graph.add_bond( + atoms_indices[0], atoms_indices[1], getattr(top_bond, "bond_order", 0.0) + ) return top_graph From 4e9836b1c0bcfe6c2de2c98d10b7e7a2e64ae730 Mon Sep 17 00:00:00 2001 From: CalCraven Date: Tue, 28 Oct 2025 18:35:12 -0500 Subject: [PATCH 08/11] Add doc strings for connections bonds functions, remove bonds setter, and dynamically set bonds based on connection_members when initialized with set_dependent_value_default --- gmso/__init__.py | 4 ++-- gmso/abc/abstract_connection.py | 4 +--- gmso/core/angle.py | 13 ++++--------- gmso/core/bond.py | 5 +++++ gmso/core/dihedral.py | 9 +-------- gmso/core/forcefield.py | 2 -- gmso/core/improper.py | 9 +-------- gmso/parameterization/topology_parameterizer.py | 13 ------------- gmso/utils/connectivity.py | 17 +++++++++++++++++ 9 files changed, 31 insertions(+), 45 deletions(-) diff --git a/gmso/__init__.py b/gmso/__init__.py index 8a982eeb4..68e7e660f 100644 --- a/gmso/__init__.py +++ b/gmso/__init__.py @@ -137,5 +137,5 @@ def print_level(self, level: str): # Example usage in __init__.py -# gmso_logger = GMSOLogger() -# gmso_logger.library_logger.setLevel(logging.WARNING) +gmso_logger = GMSOLogger() +gmso_logger.library_logger.setLevel(logging.WARNING) diff --git a/gmso/abc/abstract_connection.py b/gmso/abc/abstract_connection.py index 242d837f9..99f745bf3 100644 --- a/gmso/abc/abstract_connection.py +++ b/gmso/abc/abstract_connection.py @@ -114,14 +114,12 @@ def __str__(self): return f"<{self.__class__.__name__} {self.name}, id: {id(self)}> " def get_connection_identifiers(self): - from gmso.core.bond import Bond - borderDict = {1: "-", 2: "=", 3: "#", 0: "~", None: "~", 1.5: ":"} site_identifiers = [ (site.atom_type.atomclass, site.atom_type.name) for site in self.connection_members ] - if isinstance(self, Bond): + if not getattr(self, "bonds", None): bond_identifiers = [borderDict[self.bond_order]] else: bond_identifiers = [borderDict[b.bond_order] for b in self.bonds] diff --git a/gmso/core/angle.py b/gmso/core/angle.py index 923b273c4..77b8c80b4 100644 --- a/gmso/core/angle.py +++ b/gmso/core/angle.py @@ -23,6 +23,8 @@ class Angle(Connection): __eq__, __repr__, _validate methods Additional _validate methods are presented. + + self.connectivity is the associated indices that defines the way connection_members are bonded to match self.bonds. """ __members_creator__: ClassVar[Callable] = Atom.model_validate @@ -89,17 +91,9 @@ def restraint(self): @property def bonds(self): - """Return the bond_order symbol of this bond.""" + """Return a tuple of gmso.core.Bond objects that correspond to this angle.""" return self.__dict__.get("bonds_") - @bonds.setter - def bonds(self, bonds): - """Return the bonds that makeup this Improper. - - Connectivity is ((0,1), (0,2), (0,3)) - """ - self._bonds = bonds - @property def bonds_orders(self): """Return the bond_order strings of this angle.""" @@ -135,6 +129,7 @@ def __setattr__(self, key, value): @model_validator(mode="before") @classmethod def set_dependent_value_default(cls, data): + """Automatically set bonds for this angle if connection_members is defined.""" if "bonds" not in data and "connection_members" in data: atoms = data["connection_members"] data["bonds"] = ( diff --git a/gmso/core/bond.py b/gmso/core/bond.py index c9c80682e..e5a41055d 100644 --- a/gmso/core/bond.py +++ b/gmso/core/bond.py @@ -86,6 +86,11 @@ def bond_order(self): """Return the bond_order of this bond.""" return self.__dict__.get("bond_order_") + @bond_order.setter + def bond_order(self, order): + """Set the bond_order of this bond.""" + self._bond_order = order + def equivalent_members(self): """Get a set of the equivalent connection member tuples. diff --git a/gmso/core/dihedral.py b/gmso/core/dihedral.py index b3825aaad..69ffae193 100644 --- a/gmso/core/dihedral.py +++ b/gmso/core/dihedral.py @@ -96,14 +96,6 @@ def bonds(self): """ return self.__dict__.get("bonds_") - @bonds.setter - def bonds(self, bonds): - """Return the bonds that makeup this Improper. - - Connectivity is ((0,1), (0,2), (0,3)) - """ - self._bonds = bonds - @property def bonds_orders(self): """Return the bond_order strings of this dihedral.""" @@ -138,6 +130,7 @@ def __setattr__(self, key, value): @model_validator(mode="before") @classmethod def set_dependent_value_default(cls, data): + """Automatically set bonds for this dihedral if connection_members is defined.""" if "bonds" not in data and "connection_members" in data: atoms = data["connection_members"] data["bonds"] = ( diff --git a/gmso/core/forcefield.py b/gmso/core/forcefield.py index b78c3b16c..097701767 100644 --- a/gmso/core/forcefield.py +++ b/gmso/core/forcefield.py @@ -534,8 +534,6 @@ def _get_improper_type(self, identifier, exact_match=False): order, ) - return None - def _get_virtual_type( self, atom_types, return_match_order=False, exact_match=False ): diff --git a/gmso/core/improper.py b/gmso/core/improper.py index 0b223beba..d0b26a7e1 100644 --- a/gmso/core/improper.py +++ b/gmso/core/improper.py @@ -93,14 +93,6 @@ def bonds(self): """ return self.__dict__.get("bonds_") - @bonds.setter - def bonds(self, new_bonds): - """Return the bonds that makeup this Improper. - - Connectivity is ((0,1), (0,2), (0,3)) - """ - self.bonds_ = new_bonds - @property def bonds_orders(self): """Return the bond_order strings of this improper.""" @@ -141,6 +133,7 @@ def __setattr__(self, key, value): @model_validator(mode="before") @classmethod def set_dependent_value_default(cls, data): + """Automatically set bonds for this improper if connection_members is defined.""" if "bonds" not in data and "connection_members" in data: atoms = data["connection_members"] data["bonds"] = ( diff --git a/gmso/parameterization/topology_parameterizer.py b/gmso/parameterization/topology_parameterizer.py index 92c5d7b1a..19788a831 100644 --- a/gmso/parameterization/topology_parameterizer.py +++ b/gmso/parameterization/topology_parameterizer.py @@ -225,19 +225,6 @@ def _apply_connection_parameters(self, connections, ff, error_on_missing=True): setattr(connection, group, match[0].clone(self.config.fast_copy)) matched_order = [connection.connection_members[i] for i in match[1]] connection.connection_members = matched_order - if group == "angle_type": # reverse angle.bonds - if match[1][0] == 2: - connection.bonds = [connection.bonds[1], connection.bonds[0]] - elif group == "dihedral_type": # reorder dihedral.bonds - if match[1][0] == 3: - connection.bonds = [ - connection.bonds[2], - connection.bonds[1], - connection.bonds[0], - ] - elif group == "improper_type": # reorder improper.bonds - improper_bonds = [connection.bonds[i - 1] for i in match[1][1:]] - connection.bonds = improper_bonds if not match[0].member_types: connection.connection_type.member_types = tuple( diff --git a/gmso/utils/connectivity.py b/gmso/utils/connectivity.py index 8cb87808b..b660928e2 100644 --- a/gmso/utils/connectivity.py +++ b/gmso/utils/connectivity.py @@ -460,6 +460,23 @@ def _graph_from_vtype(vtype): def create_pattern(combination): + """Take a list of [site1, site2, bond1] and reorder into a string identifier. + + Parameters + ---------- + combination : tuple, list + The identifier for a given connection with a list of sites and bonds. + For example, a dihedral would look like: + combination = dihedral.connection_members + dihedral.bonds + + Returns + ------- + pattern : str + The identifying pattern for the list of sites. An improper might look like: + `central_atom-atom2-atom3=atom4` + where the combination was: + ["central_atom", "atom2", "atom3", "atom4", "-", "-", "="] + """ bonds_cutoff = len(combination) // 2 sites = combination[: bonds_cutoff + 1] bonds = combination[bonds_cutoff + 1 :] From 5c6f7bc2377ac0bcf3507db5352f7991c353f77b Mon Sep 17 00:00:00 2001 From: CalCraven Date: Wed, 29 Oct 2025 23:30:41 -0500 Subject: [PATCH 09/11] rename some of the utility functions used in identifying connections via masking, remove unused functions --- gmso/abc/abstract_connection.py | 8 +- gmso/core/forcefield.py | 165 ++++++++---------- .../topology_parameterizer.py | 5 - gmso/utils/connectivity.py | 27 ++- gmso/utils/misc.py | 70 +------- 5 files changed, 105 insertions(+), 170 deletions(-) diff --git a/gmso/abc/abstract_connection.py b/gmso/abc/abstract_connection.py index 99f745bf3..01072ddea 100644 --- a/gmso/abc/abstract_connection.py +++ b/gmso/abc/abstract_connection.py @@ -115,15 +115,13 @@ def __str__(self): def get_connection_identifiers(self): borderDict = {1: "-", 2: "=", 3: "#", 0: "~", None: "~", 1.5: ":"} - site_identifiers = [ - (site.atom_type.atomclass, site.atom_type.name) + choices = [ + (site.atom_type.name, site.atom_type.atomclass, "*") for site in self.connection_members ] if not getattr(self, "bonds", None): bond_identifiers = [borderDict[self.bond_order]] else: bond_identifiers = [borderDict[b.bond_order] for b in self.bonds] - choices = [(aclass, atype, "*") for aclass, atype in site_identifiers] choices += [(val, "~") for val in bond_identifiers] - all_combinations = itertools.product(*choices) - return all_combinations + return itertools.product(*choices) diff --git a/gmso/core/forcefield.py b/gmso/core/forcefield.py index 097701767..51e281226 100644 --- a/gmso/core/forcefield.py +++ b/gmso/core/forcefield.py @@ -13,7 +13,10 @@ from gmso.core.element import element_by_symbol from gmso.exceptions import GMSOError from gmso.utils._constants import FF_TOKENS_SEPARATOR -from gmso.utils.connectivity import create_pattern +from gmso.utils.connectivity import ( + connection_identifier_to_string, + yield_connection_identifiers, +) from gmso.utils.decorators import deprecate_kwargs from gmso.utils.ff_utils import ( parse_ff_atomtypes, @@ -23,7 +26,7 @@ parse_ff_virtual_types, validate, ) -from gmso.utils.misc import mask_with, reverse_identifier, validate_type +from gmso.utils.misc import reverse_string_identifier, validate_type logger = logging.getLogger(__name__) @@ -338,7 +341,7 @@ def _get_bond_type(self, identifier, exact_match=False): """Get a particular bond_type between `atom_types` from this ForceField.""" if isinstance(identifier, str): forward = identifier - reverse = reverse_identifier(forward) + reverse = reverse_string_identifier(forward) else: if len(identifier) == 2: # add wildcard bond identifier.append("~") @@ -347,8 +350,10 @@ def _get_bond_type(self, identifier, exact_match=False): f"BondType potential can only " f"be extracted for two atoms. Provided {len(identifier)}" ) - forward = create_pattern(identifier) - reverse = create_pattern(list(reversed(identifier[:2])) + [identifier[2]]) + forward = connection_identifier_to_string(identifier) + reverse = connection_identifier_to_string( + list(reversed(identifier[:2])) + [identifier[2]] + ) if forward in self.bond_types: return self.bond_types[forward], (0, 1) if reverse in self.bond_types: @@ -357,35 +362,24 @@ def _get_bond_type(self, identifier, exact_match=False): if exact_match: return None - for i in range(1, 3): - forward_patterns = mask_with(identifier[:2], i) - reverse_patterns = mask_with( - reversed( - identifier[:2], - ), - i, - ) - - for forward_pattern, reverse_pattern in zip( - forward_patterns, reverse_patterns - ): - forward_match_key = FF_TOKENS_SEPARATOR.join(forward_pattern) - reverse_match_key = FF_TOKENS_SEPARATOR.join(reverse_pattern) - - if forward_match_key in self.bond_types: - return self.bond_types[forward_match_key], (0, 1) - - if reverse_match_key in self.bond_types: - return self.bond_types[reverse_match_key], (1, 0) - return None + # try matching if exact_match is False + masked_forward = yield_connection_identifiers(forward) + masked_reverse = yield_connection_identifiers(reverse) + for mask_for, mask_rev in zip(masked_forward, masked_reverse): + if mask_for in self.bond_types: + return self.bond_types[mask_for], (0, 1) + elif mask_rev in self.bond_types: + return self.bond_types[mask_rev], (1, 0) + return None # return no match def _get_angle_type(self, identifier, exact_match=False): """Get a particular angle_type between `atom_types` from this ForceField.""" if isinstance(identifier, str): forward = identifier - reverse = reverse_identifier(forward) + reverse = reverse_string_identifier(forward) else: if len(identifier) == 3: + identifier = list(identifier) identifier.append("~") identifier.append("~") elif len(identifier) != 5: @@ -393,8 +387,8 @@ def _get_angle_type(self, identifier, exact_match=False): f"AngleType potential can only " f"be extracted for three atoms. Provided {len(identifier)}" ) - forward = create_pattern(identifier) - reverse = create_pattern( + forward = connection_identifier_to_string(identifier) + reverse = connection_identifier_to_string( list(reversed(identifier[:3])) + list(reversed(identifier[3:])) ) @@ -405,31 +399,21 @@ def _get_angle_type(self, identifier, exact_match=False): if exact_match: return None - for i in range(1, 4): - forward_patterns = mask_with(identifier, i) - reverse_patterns = mask_with( - list(reversed(identifier[:3])) + list(reversed(identifier[3:])), i - ) - - for forward_pattern, reverse_pattern in zip( - forward_patterns, reverse_patterns - ): - forward_match_key = FF_TOKENS_SEPARATOR.join(forward_pattern) - reverse_match_key = FF_TOKENS_SEPARATOR.join(reverse_pattern) - - if forward_match_key in self.angle_types: - return self.angle_types[forward_match_key], (0, 1, 2) - - if reverse_match_key in self.angle_types: - return self.angle_types[reverse_match_key], (2, 1, 0) - - return None + # try matching if exact_match is False + masked_forward = yield_connection_identifiers(forward) + masked_reverse = yield_connection_identifiers(reverse) + for mask_for, mask_rev in zip(masked_forward, masked_reverse): + if mask_for in self.angle_types: + return self.angle_types[mask_for], (0, 1, 2) + elif mask_rev in self.angle_types: + return self.angle_types[mask_rev], (2, 1, 0) + return None # return no match def _get_dihedral_type(self, identifier, exact_match=False): """Get a particular dihedral_type between `atom_types` from this ForceField.""" if isinstance(identifier, str): forward = identifier - reverse = reverse_identifier(forward) + reverse = reverse_string_identifier(forward) else: if len(identifier) == 4: identifier.append("~") @@ -440,8 +424,8 @@ def _get_dihedral_type(self, identifier, exact_match=False): f"DihedralType potential can only " f"be extracted for four atoms and three bonds. Provided {len(identifier)}" ) - forward = create_pattern(identifier) - reverse = create_pattern( + forward = connection_identifier_to_string(identifier) + reverse = connection_identifier_to_string( list(reversed(identifier[:4])) + list(reversed(identifier[4:])) ) @@ -452,30 +436,21 @@ def _get_dihedral_type(self, identifier, exact_match=False): if exact_match: return None - for i in range(1, 5): - forward_patterns = mask_with(identifier, i) - reverse_patterns = mask_with(reversed(identifier[:4] + identifier[4:]), i) - - for forward_pattern, reverse_pattern in zip( - forward_patterns, reverse_patterns - ): - forward_match_key = FF_TOKENS_SEPARATOR.join(forward_pattern) - reverse_match_key = FF_TOKENS_SEPARATOR.join(reverse_pattern) - - if forward_match_key in self.dihedral_types: - return self.dihedral_types[forward_match_key], (0, 1, 2, 3) - - if reverse_match_key in self.dihedral_types: - return self.dihedral_types[reverse_match_key], (3, 2, 1, 0) - - return None + # try matching if exact_match is False + masked_forward = yield_connection_identifiers(forward) + masked_reverse = yield_connection_identifiers(reverse) + for mask_for, mask_rev in zip(masked_forward, masked_reverse): + if mask_for in self.dihedral_types: + return self.dihedral_types[mask_for], (0, 1, 2, 3) + elif mask_rev in self.dihedral_types: + return self.dihedral_types[mask_rev], (3, 2, 1, 0) + return None # return no match def _get_improper_type(self, identifier, exact_match=False): """Get a particular improper_type between `identifier` from this ForceField.""" if isinstance(identifier, str): forward = identifier - equivalent = None # TODO - # equivalent = equivalent_identifier + reverse = reverse_string_identifier(forward, is_improper=True) else: if len(identifier) == 4: # add wildcard bonds identifier.append("~") @@ -487,7 +462,7 @@ def _get_improper_type(self, identifier, exact_match=False): f"be extracted for four atoms and three bonds. Provided {len(identifier)}" ) - forward = create_pattern(identifier) + forward = connection_identifier_to_string(identifier) equiv_idx = [ (0, i, j, k) for (i, j, k) in itertools.permutations((1, 2, 3), 3) ] @@ -507,32 +482,33 @@ def _get_improper_type(self, identifier, exact_match=False): if forward in self.improper_types: return self.improper_types[forward], (0, 1, 2, 3) - if not equivalent: - return None - for eq, order in zip(equivalent, equiv_idx): - eq_key = create_pattern(eq) - if eq_key in self.improper_types: - return self.improper_types[eq_key], order + if equivalent: + for eq, order in zip(equivalent, equiv_idx): + eq_key = connection_identifier_to_string(eq) + if eq_key in self.improper_types: + return self.improper_types[eq_key], order if exact_match: return None - - for i in range(1, 5): - forward_patterns = mask_with(identifier, i) - for forward_pattern in forward_patterns: - forward_match_key = FF_TOKENS_SEPARATOR.join(forward_pattern) - if forward_match_key in self.improper_types: - return self.improper_types[forward_match_key], (0, 1, 2, 3) - for i in range(1, 5): - for eq, order in zip(equivalent, equiv_idx): - equiv_patterns = mask_with(eq, i) - for equiv_pattern in equiv_patterns: - equiv_pattern_key = FF_TOKENS_SEPARATOR.join(equiv_pattern) - if equiv_pattern_key in self.improper_types: - return ( - self.improper_types[equiv_pattern_key], - order, - ) + # try matching if exact_match is False + masked_forward = yield_connection_identifiers(forward) + if equivalent: + masked_equivalents = [ + yield_connection_identifiers(equiv) for equiv in equivalent + ] + for mask_for in masked_forward: + if mask_for in self.improper_types: + return self.improper_types[mask_for], (0, 1, 2, 3) + for check_equivalent, order in zip(masked_equivalents, equiv_idx): + # iterate through all equivalent impropers as iterators, in sync with forward mask + if next(check_equivalent) in self.improper_types: + return self.improper_types[check_equivalent], order + else: # no zip if no equivalents + for mask_for in masked_forward: + if mask_for in self.improper_types: + return self.improper_types[mask_for], (0, 1, 2, 3) + + return None # return no match def _get_virtual_type( self, atom_types, return_match_order=False, exact_match=False @@ -545,6 +521,7 @@ def _get_virtual_type( if forward in self.virtual_types: match = self.virtual_types[forward], tuple(range(n_elements)) return match + return None def __repr__(self): """Return a formatted representation of the Forcefield.""" diff --git a/gmso/parameterization/topology_parameterizer.py b/gmso/parameterization/topology_parameterizer.py index 19788a831..bb7523968 100644 --- a/gmso/parameterization/topology_parameterizer.py +++ b/gmso/parameterization/topology_parameterizer.py @@ -196,11 +196,6 @@ def _apply_connection_parameters(self, connections, ff, error_on_missing=True): visited = dict() for connection in connections: group, connection_identifiers = self.connection_identifier(connection) - # TODO: sort connection_identifiers in connection_identifiers - connection_identifiers = sorted( - list(connection_identifiers), - key=lambda item: 100 * item.count("*") + item.count("~"), - ) match = None for identifier_key in connection_identifiers: if tuple(identifier_key) in visited: diff --git a/gmso/utils/connectivity.py b/gmso/utils/connectivity.py index b660928e2..8eea9328a 100644 --- a/gmso/utils/connectivity.py +++ b/gmso/utils/connectivity.py @@ -1,5 +1,7 @@ """Module supporting various connectivity methods and operations.""" +import itertools +import re from typing import TYPE_CHECKING, List import networkx as nx @@ -459,12 +461,12 @@ def _graph_from_vtype(vtype): return virtual_type_graph -def create_pattern(combination): +def connection_identifier_to_string(identifier): """Take a list of [site1, site2, bond1] and reorder into a string identifier. Parameters ---------- - combination : tuple, list + identifier : tuple, list The identifier for a given connection with a list of sites and bonds. For example, a dihedral would look like: combination = dihedral.connection_members + dihedral.bonds @@ -477,10 +479,25 @@ def create_pattern(combination): where the combination was: ["central_atom", "atom2", "atom3", "atom4", "-", "-", "="] """ - bonds_cutoff = len(combination) // 2 - sites = combination[: bonds_cutoff + 1] - bonds = combination[bonds_cutoff + 1 :] + bonds_cutoff = len(identifier) // 2 + sites = identifier[: bonds_cutoff + 1] + bonds = identifier[bonds_cutoff + 1 :] pattern = sites[0] for b, sit in zip(bonds, sites[1:]): pattern += b + sit return pattern + + +def yield_connection_identifiers(identifier): + """Yield all possible bond identifiers from a tuple or string identifier.""" + n_sites = len(identifier) // 2 + 1 + # decide if identifier is string or tuple + if isinstance(identifier, str): + bond_tokens = r"([\=\~\-\#\:])" + identifier = re.split(bond_tokens, identifier) + identifier = identifier[::2] + identifier[1::2] + site_identifiers = identifier[:n_sites] + bond_identifiers = identifier[n_sites:] + choices = [(site_identifier, "*") for site_identifier in site_identifiers] + choices += [(val, "~") for val in bond_identifiers] + return itertools.product(*choices) diff --git a/gmso/utils/misc.py b/gmso/utils/misc.py index afe49da5a..116ac055d 100644 --- a/gmso/utils/misc.py +++ b/gmso/utils/misc.py @@ -1,5 +1,6 @@ """Miscellaneous helper methods for GMSO.""" +import re from functools import lru_cache import unyt as u @@ -80,53 +81,6 @@ def validate_type(iterator, type_): ) -def mask_with(iterable, window_size=1, mask="*"): - """Mask an iterable with the `mask` in a circular sliding window of size `window_size`. - - This method masks an iterable elements with a mask object in a circular sliding window - - Parameters - ---------- - iterable: Iterable - The iterable to mask with - window_size: int, default=1 - The window size for the mask to be applied - mask: Any, default='*' - The mask to apply - Examples - -------- - >>> from gmso.utils.misc import mask_with - >>> list(mask_with(['Ar', 'Ar'], 1)) - [['*', 'Ar'], ['Ar', '*']] - >>> for masked_list in mask_with(['Ar', 'Xe', 'Xm', 'CH'], 2, mask='_'): - ... print('~'.join(masked_list)) - _~_~Xm~CH - Ar~_~_~CH - Ar~Xe~_~_ - _~Xe~Xm~_ - - Yields - ------ - list - The masked iterable - """ - input_list = list(iterable) - idx = 0 - first = None - while idx < len(input_list): - mask_idxes = set((idx + j) % len(input_list) for j in range(window_size)) - to_yield = [ - mask if j in mask_idxes else input_list[j] for j in range(len(input_list)) - ] - if to_yield == first: - break - if idx == 0: - first = to_yield - - idx += 1 - yield to_yield - - def get_xml_representation(value): """Given a value, get its XML representation.""" if isinstance(value, u.unyt_quantity): @@ -137,17 +91,11 @@ def get_xml_representation(value): return str(value) -def reverse_identifier(identifier: str): - bond_tokens = ["~", "-", "=", "#"] - outStr = "" - currentNode = "" - for letter in identifier[::-1]: - if letter in bond_tokens: - outStr += currentNode[::-1] - currentNode = "" - outStr += letter # should be a bond - else: - currentNode += letter - if currentNode: - outStr += currentNode[::-1] - return outStr +def reverse_string_identifier(identifier: str, is_improper=False): + """Change string identifier for a forcefield key.""" + tokens = r"([\=\~\-\#\:])" + items = re.split(tokens, identifier) + if is_improper: # only reverse middle two tokens and keep bonds + return "".join((items[:1] + items[3:5] + items[1:3] + items[5:])) + else: # flip full + return "".join(items[::-1]) From a41cbe6aeac989f84e0295f12e7c244abbfee00d Mon Sep 17 00:00:00 2001 From: CalCraven Date: Thu, 30 Oct 2025 08:42:58 -0500 Subject: [PATCH 10/11] pin openmm to less than 8.4 --- environment-dev.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment-dev.yml b/environment-dev.yml index e1fba7b3e..fb530416e 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -18,7 +18,7 @@ dependencies: - pytest - garnett>=0.7.1 - openff-toolkit-base>0.16.7 - - openmm + - openmm<8.4.0 - gsd>=2.9 - freud>=3.2 - parmed>=3.4.3 From 956809c47ca30855731eb12224099e8d6d496ce9 Mon Sep 17 00:00:00 2001 From: CalCraven Date: Thu, 30 Oct 2025 15:21:24 -0500 Subject: [PATCH 11/11] Use equivalents for _get_improper_type strings --- gmso/core/forcefield.py | 12 +++++++----- gmso/utils/misc.py | 20 +++++++++++++++----- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/gmso/core/forcefield.py b/gmso/core/forcefield.py index 51e281226..1841e7f63 100644 --- a/gmso/core/forcefield.py +++ b/gmso/core/forcefield.py @@ -26,7 +26,11 @@ parse_ff_virtual_types, validate, ) -from gmso.utils.misc import reverse_string_identifier, validate_type +from gmso.utils.misc import ( + improper_equivalents_string_identifier, + reverse_string_identifier, + validate_type, +) logger = logging.getLogger(__name__) @@ -448,9 +452,10 @@ def _get_dihedral_type(self, identifier, exact_match=False): def _get_improper_type(self, identifier, exact_match=False): """Get a particular improper_type between `identifier` from this ForceField.""" + equiv_idx = [(0, i, j, k) for (i, j, k) in itertools.permutations((1, 2, 3), 3)] if isinstance(identifier, str): forward = identifier - reverse = reverse_string_identifier(forward, is_improper=True) + equivalent = improper_equivalents_string_identifier(identifier) else: if len(identifier) == 4: # add wildcard bonds identifier.append("~") @@ -463,9 +468,6 @@ def _get_improper_type(self, identifier, exact_match=False): ) forward = connection_identifier_to_string(identifier) - equiv_idx = [ - (0, i, j, k) for (i, j, k) in itertools.permutations((1, 2, 3), 3) - ] equivalent = [ [ identifier[m], diff --git a/gmso/utils/misc.py b/gmso/utils/misc.py index 116ac055d..7ab0ee89f 100644 --- a/gmso/utils/misc.py +++ b/gmso/utils/misc.py @@ -91,11 +91,21 @@ def get_xml_representation(value): return str(value) -def reverse_string_identifier(identifier: str, is_improper=False): +def reverse_string_identifier(identifier: str): """Change string identifier for a forcefield key.""" tokens = r"([\=\~\-\#\:])" items = re.split(tokens, identifier) - if is_improper: # only reverse middle two tokens and keep bonds - return "".join((items[:1] + items[3:5] + items[1:3] + items[5:])) - else: # flip full - return "".join(items[::-1]) + return "".join(items[::-1]) + + +def improper_equivalents_string_identifier(identifier: str): + # only reverse middle two tokens and keep bonds + tokens = r"([\=\~\-\#\:])" + items = re.split(tokens, identifier) + return [ + "".join((items[:1] + items[3:5] + items[1:3] + items[5:])), + "".join((items[:1] + items[3:5] + items[1:3] + items[5:])), + "".join((items[:1] + items[5:] + items[3:5] + items[1:3])), + "".join((items[:1] + items[5:] + items[1:3] + items[3:5])), + "".join((items[:1] + items[1:3] + items[5:] + items[3:5])), + ]