diff --git a/docs/releasehistory.md b/docs/releasehistory.md index e93f42d87..8bbb9f3cb 100644 --- a/docs/releasehistory.md +++ b/docs/releasehistory.md @@ -6,7 +6,7 @@ Releases follow the `major.minor.micro` scheme recommended by [PEP440](https://w * `minor` increments add features but do not break API compatibility * `micro` increments represent bugfix releases or improvements in documentation -## Current development +## 0.17.0 ### API-breaking changes @@ -15,6 +15,8 @@ Releases follow the `major.minor.micro` scheme recommended by [PEP440](https://w ### Bugfixes ### New features +- [PR #2048](https://github.com/openforcefield/openff-toolkit/pull/2048): Adds NAGLChargesHandler. See [SMIRNOFF EP 11](https://github.com/openforcefield/standards/pull/71) for the new SMIRNOFF specification section and discussion. + ### Improved documentation and warnings diff --git a/docs/typing.rst b/docs/typing.rst index 7157e4cb5..8c07100df 100644 --- a/docs/typing.rst +++ b/docs/typing.rst @@ -75,6 +75,7 @@ During ``System`` creation, each ``ParameterHandler`` registered to a ``ForceFie ElectrostaticsHandler LibraryChargeHandler ToolkitAM1BCCHandler + NAGLChargesHandler GBSAHandler ChargeIncrementModelHandler VirtualSiteHandler diff --git a/openff/toolkit/__init__.py b/openff/toolkit/__init__.py index 69a3ccf76..cf62c730a 100644 --- a/openff/toolkit/__init__.py +++ b/openff/toolkit/__init__.py @@ -21,6 +21,7 @@ GLOBAL_TOOLKIT_REGISTRY, AmberToolsToolkitWrapper, BuiltInToolkitWrapper, + NAGLToolkitWrapper, OpenEyeToolkitWrapper, RDKitToolkitWrapper, ToolkitRegistry, @@ -53,6 +54,7 @@ "GLOBAL_TOOLKIT_REGISTRY": "openff.toolkit.utils.toolkits", "AmberToolsToolkitWrapper": "openff.toolkit.utils.toolkits", "BuiltInToolkitWrapper": "openff.toolkit.utils.toolkits", + "NAGLToolkitWrapper": "openff.toolkit.utils.toolkits", "OpenEyeToolkitWrapper": "openff.toolkit.utils.toolkits", "RDKitToolkitWrapper": "openff.toolkit.utils.toolkits", "ToolkitRegistry": "openff.toolkit.utils.toolkits", diff --git a/openff/toolkit/_tests/test_nagl.py b/openff/toolkit/_tests/test_nagl.py index 9d0bc9a24..3a6779fb9 100644 --- a/openff/toolkit/_tests/test_nagl.py +++ b/openff/toolkit/_tests/test_nagl.py @@ -5,6 +5,7 @@ import pytest from openff.utilities import has_package, skip_if_missing +from openff.nagl_models._dynamic_fetch import BadFileSuffixError from openff.toolkit import Molecule, unit from openff.toolkit._tests.create_molecules import ( create_acetaldehyde, @@ -14,8 +15,8 @@ create_reversed_ethanol, ) from openff.toolkit._tests.utils import requires_openeye +from openff.toolkit.utils import GLOBAL_TOOLKIT_REGISTRY from openff.toolkit.utils.exceptions import ( - ChargeMethodUnavailableError, ToolkitUnavailableException, ) from openff.toolkit.utils.nagl_wrapper import NAGLToolkitWrapper @@ -38,6 +39,9 @@ def test_version(self): assert parsed_version == NAGLToolkitWrapper()._toolkit_version + def test_nagl_in_global_toolkit_registry(self): + assert NAGLToolkitWrapper in {type(tk) for tk in GLOBAL_TOOLKIT_REGISTRY.registered_toolkits} + @requires_openeye @pytest.mark.parametrize( "molecule_function", @@ -123,8 +127,8 @@ def test_conformer_argument(self): def test_unsupported_charge_method(self): with pytest.raises( - ChargeMethodUnavailableError, - match="Charge model hartree_fock not supported", + BadFileSuffixError, + match="Found an unrecognized file path extension on filename='hartree_fock'", ): create_ethanol().assign_partial_charges( partial_charge_method="hartree_fock", diff --git a/openff/toolkit/_tests/test_parameters.py b/openff/toolkit/_tests/test_parameters.py index bd9ae0836..2a667104e 100644 --- a/openff/toolkit/_tests/test_parameters.py +++ b/openff/toolkit/_tests/test_parameters.py @@ -14,6 +14,7 @@ from openff.toolkit import ForceField, Molecule, Quantity, Topology, unit from openff.toolkit._tests.mocking import VirtualSiteMocking from openff.toolkit._tests.utils import does_not_raise +from openff.toolkit.typing.engines.smirnoff import NAGLChargesHandler from openff.toolkit.typing.engines.smirnoff.parameters import ( AngleHandler, BondHandler, @@ -2722,6 +2723,201 @@ def test_charge_increment_one_ci_missing(self): ], ) +class TestNAGLChargesHandler: + def test_nagl_charges_handler_serialization(self): + handler = NAGLChargesHandler(model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", skip_version_check=True) + assert handler.model_file == "openff-gnn-am1bcc-0.1.0-rc.3.pt" + handler_dict = handler.to_dict() + assert handler_dict["model_file"] == "openff-gnn-am1bcc-0.1.0-rc.3.pt" + + def test_nagl_charges_handler_with_optional_fields(self): + # Test with model_file_hash + handler = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + model_file_hash="144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0", + skip_version_check=True + ) + assert handler.model_file == "openff-gnn-am1bcc-0.1.0-rc.3.pt" + assert handler.model_file_hash == "144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0" + assert handler.digital_object_identifier is None + + # Test with digital_object_identifier + handler = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + digital_object_identifier="10.5072/zenodo.203601", + skip_version_check=True + ) + assert handler.model_file == "openff-gnn-am1bcc-0.1.0-rc.3.pt" + assert handler.model_file_hash is None + assert handler.digital_object_identifier == "10.5072/zenodo.203601" + + # Test with both optional fields + handler = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + model_file_hash="144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0", + digital_object_identifier="10.5072/zenodo.203601", + skip_version_check=True + ) + assert handler.model_file == "openff-gnn-am1bcc-0.1.0-rc.3.pt" + assert handler.model_file_hash == "144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0" + assert handler.digital_object_identifier == "10.5072/zenodo.203601" + + def test_nagl_charges_handler_serialization_with_optional_fields(self): + # Test serialization with all fields + handler = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + model_file_hash="144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0", + digital_object_identifier="10.5072/zenodo.203601", + skip_version_check=True + ) + handler_dict = handler.to_dict() + assert handler_dict["model_file"] == "openff-gnn-am1bcc-0.1.0-rc.3.pt" + assert handler_dict["model_file_hash"] == "144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0" + assert handler_dict["digital_object_identifier"] == "10.5072/zenodo.203601" + + # Test deserialization via constructor + handler_from_dict = NAGLChargesHandler(**handler_dict) + assert handler_from_dict.model_file == "openff-gnn-am1bcc-0.1.0-rc.3.pt" + assert handler_from_dict.model_file_hash == "144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0" + assert handler_from_dict.digital_object_identifier == "10.5072/zenodo.203601" + + def test_nagl_charges_handler_compatibility(self): + # Test compatible handlers (same model_file) + handler1 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + skip_version_check=True + ) + handler2 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + model_file_hash="144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0", + skip_version_check=True + ) + # Should not raise exception + handler1.check_handler_compatibility(handler2) + + # Test incompatible handlers (different model_file) + handler3 = NAGLChargesHandler( + model_file="different-model-file.pt", + skip_version_check=True + ) + with pytest.raises(IncompatibleParameterError, match="different model_files"): + handler1.check_handler_compatibility(handler3) + + def test_nagl_charges_handler_defaults(self): + # Test that optional fields default to None + handler = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + skip_version_check=True + ) + assert handler.model_file_hash is None + assert handler.digital_object_identifier is None + + def test_nagl_charges_handler_hash_compatibility(self): + """Test compatibility checks for model_file_hash""" + # Test compatible handlers with same hash + handler1 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + model_file_hash="144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0", + skip_version_check=True + ) + handler2 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + model_file_hash="144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0", + skip_version_check=True + ) + # Should not raise exception + handler1.check_handler_compatibility(handler2) + + # Test incompatible handlers with different hashes + handler3 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + model_file_hash="different_hash_value", + skip_version_check=True + ) + with pytest.raises(IncompatibleParameterError, match="different model_file_hash values"): + handler1.check_handler_compatibility(handler3) + + # Test compatibility when only one handler has hash (should be compatible) + handler4 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + skip_version_check=True + ) + # Should not raise exception + handler1.check_handler_compatibility(handler4) + handler4.check_handler_compatibility(handler1) + + def test_nagl_charges_handler_doi_compatibility(self): + """Test compatibility checks for digital_object_identifier""" + # Test compatible handlers with same DOI + handler1 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + digital_object_identifier="10.5072/zenodo.203601", + skip_version_check=True + ) + handler2 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + digital_object_identifier="10.5072/zenodo.203601", + skip_version_check=True + ) + # Should not raise exception + handler1.check_handler_compatibility(handler2) + + # Test incompatible handlers with different DOIs + handler3 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + digital_object_identifier="10.5072/zenodo.999999", + skip_version_check=True + ) + with pytest.raises(IncompatibleParameterError, match="different digital_object_identifier values"): + handler1.check_handler_compatibility(handler3) + + # Test compatibility when only one handler has DOI (should be compatible) + handler4 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + skip_version_check=True + ) + # Should not raise exception + handler1.check_handler_compatibility(handler4) + handler4.check_handler_compatibility(handler1) + + def test_nagl_charges_handler_combined_compatibility(self): + """Test compatibility checks with both hash and DOI""" + # Test compatible handlers with same hash and DOI + handler1 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + model_file_hash="144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0", + digital_object_identifier="10.5072/zenodo.203601", + skip_version_check=True + ) + handler2 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + model_file_hash="144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0", + digital_object_identifier="10.5072/zenodo.203601", + skip_version_check=True + ) + # Should not raise exception + handler1.check_handler_compatibility(handler2) + + # Test incompatible with same hash but different DOI + handler3 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + model_file_hash="144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0", + digital_object_identifier="10.5072/zenodo.999999", + skip_version_check=True + ) + with pytest.raises(IncompatibleParameterError, match="different digital_object_identifier values"): + handler1.check_handler_compatibility(handler3) + + # Test incompatible with different hash but same DOI + handler4 = NAGLChargesHandler( + model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + model_file_hash="different_hash_value", + digital_object_identifier="10.5072/zenodo.203601", + skip_version_check=True + ) + with pytest.raises(IncompatibleParameterError, match="different model_file_hash values"): + handler1.check_handler_compatibility(handler4) + class TestGBSAHandler: def test_create_default_gbsahandler(self): diff --git a/openff/toolkit/typing/engines/smirnoff/__init__.py b/openff/toolkit/typing/engines/smirnoff/__init__.py index ae14b1671..d881c950d 100644 --- a/openff/toolkit/typing/engines/smirnoff/__init__.py +++ b/openff/toolkit/typing/engines/smirnoff/__init__.py @@ -18,6 +18,7 @@ IndexedParameterAttribute, LibraryChargeHandler, MappedParameterAttribute, + NAGLChargesHandler, ParameterAttribute, ParameterHandler, ParameterList, diff --git a/openff/toolkit/typing/engines/smirnoff/parameters.py b/openff/toolkit/typing/engines/smirnoff/parameters.py index b42fcba8b..be5f7a7b3 100644 --- a/openff/toolkit/typing/engines/smirnoff/parameters.py +++ b/openff/toolkit/typing/engines/smirnoff/parameters.py @@ -29,6 +29,7 @@ "LibraryChargeHandler", "LibraryChargeType", "MappedParameterAttribute", + "NAGLChargesHandler", "NotEnoughPointsForInterpolationError", "ParameterAttribute", "ParameterHandler", @@ -3253,6 +3254,115 @@ def find_matches(self, entity, unique=False): unique=unique, ) +class NAGLChargesHandler(_NonbondedHandler): + """ParameterHandler for applying partial charges from a pretrained NAGL model. + + This handler processes the NAGLCharges section of SMIRNOFF force fields, which + specifies a pre-trained NAGL model for computing + partial charges on molecules. + + Parameters + ---------- + model_file : str + Path to the PyTorch model file (e.g., "openff-gnn-am1bcc-0.1.0-rc.3.pt"). + This is the model that will be used for charge assignment. + model_file_hash : str, optional + SHA-256 hash of the model file for integrity verification. When provided, + the hash will be validated against the actual model file. + digital_object_identifier : str, optional + Zenodo DOI that can be used to retrieve the model file if it's not found + locally. Must point to a Zenodo record with an attached file matching + the model_file name. + version : str, optional + The version of the NAGLCharges section specification. + skip_version_check : bool, optional, default=False + If True, skips validation of the version parameter and sets it to the highest + supported version. + allow_cosmetic_attributes : bool, optional, default=False + If True, allows non-specification attributes to be present. + + Examples + -------- + Create a handler with just the model file: + + >>> handler = NAGLChargesHandler( + ... model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + ... skip_version_check=True + ... ) + + Create a handler with hash verification: + + >>> handler = NAGLChargesHandler( + ... model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + ... model_file_hash="144ed56e46c5b3ad80157b342c8c0f8f7340e4d382a678e30dd300c811646bd0", + ... skip_version_check=True + ... ) + + Create a handler with DOI for model retrieval: + + >>> handler = NAGLChargesHandler( + ... model_file="openff-gnn-am1bcc-0.1.0-rc.3.pt", + ... digital_object_identifier="10.5072/zenodo.203601", + ... skip_version_check=True + ... ) + + Notes + ----- + NAGLChargesHandler compatibility is determined solely by the model_file parameter. Two + handlers are compatible if and only if they specify the same model_file, + regardless of the values of model_file_hash or digital_object_identifier. + + The actual model loading, hash verification, and DOI-based retrieval are + handled by the openff-nagl-models package, not by this handler directly. + """ + + _TAGNAME = "NAGLCharges" + _DEPENDENCIES = [vdWHandler, ElectrostaticsHandler, LibraryChargeHandler] + _INFOTYPE = None # No separate parameter types; just a model path + _MAX_SUPPORTED_SECTION_VERSION = Version("0.3") + model_file = ParameterAttribute(converter=str) + model_file_hash = ParameterAttribute(default=None, converter=str) + digital_object_identifier = ParameterAttribute(default=None, converter=str) + + def check_handler_compatibility( + self, + other_handler: "NAGLChargesHandler", + assume_missing_is_default: bool = True, + ): + """ + Checks whether this ParameterHandler encodes compatible physics as another ParameterHandler. This is + called if a second handler is attempted to be initialized for the same tag. + + Parameters + ---------- + other_handler + The handler to compare to. + assume_missing_is_default + + Raises + ------ + IncompatibleParameterError if handler_kwargs are incompatible with existing parameters. + """ + if self.model_file != other_handler.model_file: + raise IncompatibleParameterError("Attempted to initialize two NAGLCharges sections with different " + "model_files: " + f"{self.model_file=} is not identical to {other_handler.model_file=}") + + # If both handlers have model_file_hashes defined, ensure they're identical + if self.model_file_hash and other_handler.model_file_hash and \ + self.model_file_hash != other_handler.model_file_hash: + raise IncompatibleParameterError("Attempted to initialize two NAGLCharges sections with different " + "model_file_hash values: " + f"{self.model_file_hash=} is not identical to " + f"{other_handler.model_file_hash=}") + + # If both handlers have digital_object_identifiers defined, ensure they're identical + if self.digital_object_identifier and other_handler.digital_object_identifier and \ + self.digital_object_identifier != other_handler.digital_object_identifier: + raise IncompatibleParameterError("Attempted to initialize two NAGLCharges sections with different " + "digital_object_identifier values: " + f"{self.digital_object_identifier=} is not identical to " + f"{other_handler.digital_object_identifier=}") class ToolkitAM1BCCHandler(_NonbondedHandler): """Handle SMIRNOFF ```` tags @@ -3261,7 +3371,7 @@ class ToolkitAM1BCCHandler(_NonbondedHandler): """ _TAGNAME = "ToolkitAM1BCC" # SMIRNOFF tag name to process - _DEPENDENCIES = [vdWHandler, ElectrostaticsHandler, LibraryChargeHandler] + _DEPENDENCIES = [vdWHandler, ElectrostaticsHandler, LibraryChargeHandler, NAGLChargesHandler] _KWARGS = ["toolkit_registry"] # Kwargs to catch when create_force is called def check_handler_compatibility( @@ -3382,6 +3492,7 @@ def find_matches(self, entity, unique=False): return matches + class GBSAHandler(ParameterHandler): """Handle SMIRNOFF ```` tags diff --git a/openff/toolkit/utils/nagl_wrapper.py b/openff/toolkit/utils/nagl_wrapper.py index eaec04ebe..2fa08e03d 100644 --- a/openff/toolkit/utils/nagl_wrapper.py +++ b/openff/toolkit/utils/nagl_wrapper.py @@ -6,7 +6,6 @@ from openff.toolkit import Quantity, unit from openff.toolkit.utils.base_wrapper import ToolkitWrapper from openff.toolkit.utils.exceptions import ( - ChargeMethodUnavailableError, ToolkitUnavailableException, ) @@ -68,6 +67,8 @@ def assign_partial_charges( use_conformers: Optional[list["Quantity"]] = None, strict_n_conformers: bool = False, normalize_partial_charges: bool = True, + doi: Optional[str] = None, + file_hash: Optional[str] = None, _cls: Optional[type["FrozenMolecule"]] = None, ): """ @@ -93,6 +94,14 @@ def assign_partial_charges( formal charge of the molecule. This is used to prevent accumulation of rounding errors when the partial charge generation method has low precision. + doi + Zenodo DOI to check if NAGL model file needs to be fetched. Passed + directly to openff.nagl_models._dynamic_fetch.get_model, see docs + on that method for more details. + file_hash + sha256 hash to check against NAGL model file. Passed + directly to openff.nagl_models._dynamic_fetch.get_model, see docs + on that method for more details. _cls : class Molecule constructor @@ -105,7 +114,12 @@ def assign_partial_charges( if the charge method is supported by this toolkit, but fails """ from openff.nagl import GNNModel - from openff.nagl_models import validate_nagl_model_path + from openff.nagl_models._dynamic_fetch import get_model + + if partial_charge_method == "" or partial_charge_method == "None": + raise FileNotFoundError("NAGLToolkitWrapper.assign_partial_charges can not accept " + "a blank model file name. There is no default model, one must be " + "explicitly defined when being called.") if _cls is None: from openff.toolkit.topology.molecule import Molecule @@ -130,13 +144,9 @@ def assign_partial_charges( stacklevel=2, ) - try: - model_path = validate_nagl_model_path(model=partial_charge_method) - except FileNotFoundError as error: - raise ChargeMethodUnavailableError( - f"Charge model {partial_charge_method} not supported by " - f"{self.__class__.__name__}." - ) from error + model_path = get_model(filename=partial_charge_method, + doi=doi, + file_hash=file_hash) model = GNNModel.load(model_path, eval_mode=True) charges = model.compute_property( diff --git a/openff/toolkit/utils/toolkits.py b/openff/toolkit/utils/toolkits.py index 97875076f..6955b234e 100644 --- a/openff/toolkit/utils/toolkits.py +++ b/openff/toolkit/utils/toolkits.py @@ -101,6 +101,7 @@ # Create global toolkit registry, where all available toolkits are registered GLOBAL_TOOLKIT_REGISTRY = ToolkitRegistry( toolkit_precedence=[ + NAGLToolkitWrapper, OpenEyeToolkitWrapper, RDKitToolkitWrapper, AmberToolsToolkitWrapper,