diff --git a/pysages/backends/__init__.py b/pysages/backends/__init__.py index da514558..81477bc1 100644 --- a/pysages/backends/__init__.py +++ b/pysages/backends/__init__.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: MIT # See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES -from .contexts import JaxMDContext, JaxMDContextState # noqa: E402, F401 +from .contexts import ( # noqa: E402, F401 + JaxMDContext, + JaxMDContextState, + QboxContextGenerator, +) from .core import SamplingContext, supported_backends # noqa: E402, F401 diff --git a/pysages/backends/contexts.py b/pysages/backends/contexts.py index bf2865b7..9fe0026b 100644 --- a/pysages/backends/contexts.py +++ b/pysages/backends/contexts.py @@ -6,9 +6,26 @@ class to hold the simulation data. """ -from pysages.typing import Any, Callable, JaxArray, NamedTuple, Optional +import weakref +from dataclasses import dataclass +from importlib import import_module +from pathlib import Path +from xml.etree import ElementTree as et + +from pysages.typing import ( + Any, + Callable, + Iterable, + JaxArray, + NamedTuple, + Optional, + Union, +) +from pysages.utils import dispatch, is_file, splitlines JaxMDState = Any +QboxInstance = Any +XMLElement = et.Element class JaxMDContextState(NamedTuple): @@ -58,3 +75,120 @@ class JaxMDContext(NamedTuple): step_fn: Callable[..., JaxMDContextState] box: JaxArray dt: float + + +@dataclass(frozen=True) +class QboxContextGenerator: + """ + Provides an interface for setting up Qbox-backed simulations. + + Arguments + --------- + + launch_command: str + Specifies the command that will be used to run Qbox in interactive mode, + e.g. `qb` or `mpirun -n 4 qb`. + + script: str + File or multile string with the Qbox input script. + + nitscf: Optional[int] + Same as Qbox's `run` command parameter. The maximum number of self-consistent + iterations. + + nite: Optional[int] + Same as Qbox's `run` command parameter. The number of electronic iterations + performed between updates of the charge density. + + logfile: Union[Path, str] + Name for the output file. It must not exist on the working directory. + Defaults to `qb.r`. + """ + + # NOTE: we leave `niter` as non-configurable for now. + # niter: int + # Same as Qbox's `run` command parameter. The number of steps during which atomic + # positions are updated. Defaults to 1. + + launch_command: str + script: str + nitscf: Optional[int] = None + nite: Optional[int] = None + logfile: Union[Path, str] = Path("qb.r") + + def __call__(self, **kwargs): + if is_file(self.logfile): + msg = f"Rename or delete {self.logfile}, or choose a different log file name" + raise FileExistsError(msg) + + return QboxContext( + self.launch_command, self.script, self.logfile, 1, self.nitscf, self.nite + ) + + +@dataclass(frozen=True) +class QboxContext: + instance: QboxInstance + niter: int + nitscf: Optional[int] + nite: Optional[int] + species_masses: dict + initial_state: XMLElement + state: XMLElement + + @dispatch + def __init__( + self, launch_command: str, script: str, logfile: Union[Path, str], niter, nitscf, nite + ): + pexpect = import_module("pexpect.popen_spawn") + + def finalize(qb): + if not qb.flag_eof: + qb.sendline("quit") + qb.expect(pexpect.EOF) + + qb = pexpect.PopenSpawn(launch_command) + weakref.finalize(qb, lambda: finalize(qb)) + qb.logfile_read = open(logfile, "wb") + i = qb.expect([r"\[qbox\] ", pexpect.EOF]) + + if i == 1: # EOF was written to the log file + preamble = ( + "The command:\n\n " + f"{launch_command}\n\n" + "for running Qbox failed, it returned the following:\n\n" + ) + raise ChildProcessError(preamble + qb.before.decode()) + + super().__setattr__("instance", qb) + super().__setattr__("niter", niter) + super().__setattr__("nitscf", "" if nitscf is None else nitscf) + super().__setattr__("nite", "" if nite is None else nite) + + initial_state = qb.before + state = self.process_input(script) # sets `self.state` + + if self.state.find("error") is not None: + try: + qb.expect(pexpect.EOF, timeout=3) + finally: + raise ChildProcessError("Qbox encountered the following error:\n" + state.decode()) + + initial_state += state + b"\n" + super().__setattr__("initial_state", et.fromstring(initial_state)) + + k = 1822.888486 # to convert amu to atomic units + species = self.initial_state.iter("species") + species_masses = {s.attrib["name"]: k * float(s.find("mass").text) for s in species} + super().__setattr__("species_masses", species_masses) + + def process_input(self, entries: Union[str, Iterable[str]], target=r"\[qbox\] ", timeout=None): + qb = self.instance + state = b"" + for entry in splitlines(entries): + qb.sendline(entry) + qb.expect(target, timeout=timeout) + state += qb.before + # We add tags to ensure that the state corresponds to a valid xml section + super().__setattr__("state", et.fromstring(b"\n" + state + b"\n")) + return state diff --git a/pysages/backends/core.py b/pysages/backends/core.py index ad80a776..109347ba 100644 --- a/pysages/backends/core.py +++ b/pysages/backends/core.py @@ -3,7 +3,7 @@ from importlib import import_module -from pysages.backends.contexts import JaxMDContext +from pysages.backends.contexts import JaxMDContext, QboxContext from pysages.typing import Callable, Optional @@ -38,6 +38,8 @@ def __init__( self._backend_name = "lammps" elif module_name.startswith("simtk.openmm") or module_name.startswith("openmm"): self._backend_name = "openmm" + elif isinstance(context, QboxContext): + self._backend_name = "qbox" if self._backend_name is None: backends = ", ".join(supported_backends()) @@ -74,4 +76,4 @@ def __exit__(self, exc_type, exc_value, exc_traceback): def supported_backends(): - return ("ase", "hoomd", "jax-md", "lammps", "openmm") + return ("ase", "hoomd", "jax-md", "lammps", "openmm", "qbox") diff --git a/pysages/backends/qbox.py b/pysages/backends/qbox.py new file mode 100644 index 00000000..3c42347c --- /dev/null +++ b/pysages/backends/qbox.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: MIT +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +""" +This module defines the Sampler class, which enables any PySAGES SamplingMethod to be +hooked to a Qbox simulation instance. +""" + +from jax import jit +from jax import numpy as np +from plum import Val, type_parameter + +from pysages.backends.core import SamplingContext +from pysages.backends.snapshot import ( + Box, + HelperMethods, + Snapshot, + SnapshotMethods, + build_data_querier, +) +from pysages.typing import Callable, Optional +from pysages.utils import dispatch, last, parse_array + + +class Sampler: + """ + Allows performing enhanced sampling simulations with Qbox as a backend. + + Parameters + ---------- + + context: QboxContext + Contains a running instance of a Qbox simulation to which the PySAGES sampling + machinery will be hooked. + + sampling_method: SamplingMethod + The sampling method to be used. + + callbacks: Optional[Callback] + Some methods define callbacks for logging, but it can also be user-defined. + """ + + def __init__(self, context, sampling_method, callback: Optional[Callable]): + self.context = context + self.callback = callback + + self.snapshot = self.take_snapshot() + helpers, bias, atom_names, cv_indices = build_helpers(context, sampling_method) + _, initialize, method_update = sampling_method.build(self.snapshot, helpers) + + # Initialize external forces for each atom + for i in cv_indices: + name = atom_names[i] + # Initialize with zero force + cmd = f"extforce define atomic {name} {name} 0.0 0.0 0.0" + context.process_input(cmd) + + self.state = initialize() + self._update_box = lambda: self.snapshot.box + self._method_update = method_update + self._bias = bias + + def _pack_snapshot(self, masses, ids, box, dt): + """Returns the dynamic properties of the system.""" + positions = atom_property(self.context, "position") + velocities = atom_property(self.context, "velocity") + forces = atom_property(self.context, "force") + return Snapshot(positions, (velocities, masses), forces, ids, box, dt) + + def _update_snapshot(self): + """Updates the snapshot with the latest properties from Qbox.""" + snapshot = self.snapshot + _, masses = snapshot.vel_mass + return self._pack_snapshot(masses, snapshot.ids, self._update_box(), snapshot.dt) + + def restore(self, prev_snapshot): + """Replaces this sampler's snapshot with `prev_snapshot`.""" + context = self.context + names = atom_property(context, "name") + positions = prev_snapshot.positions + velocities, _ = prev_snapshot.vel_mass + + for name, x, v in zip(names, positions, velocities): + cmd = f"move {name} to {x[0]} {x[1]} {x[2]} {v[0]} {v[1]} {v[2]}" + context.process_input(cmd) + + # Recompute ground-state energies and forces. + # NOTE: Check in the future how to use Qbox `load` and `save` commands to also + # include the electronic wave function data. + context.process_input(f"run 0 {context.nitscf} {context.nite}") + self.snapshot = self._update_snapshot() + + def take_snapshot(self): + """Returns a copy of the current snapshot of the system.""" + masses = atom_property(self.context, "mass") + ids = np.arange(len(masses)) + snapshot_box = Box(*box(self.context)) + dt = timestep(self.context) + return self._pack_snapshot(masses, ids, snapshot_box, dt) + + def update(self, timestep): + """Update the sampling method state and apply bias.""" + self.snapshot = self._update_snapshot() + self.state = self._method_update(self.snapshot, self.state) + self._bias(self.snapshot, self.state) + if self.callback: + self.callback(self.snapshot, self.state, timestep) + + def run(self, nsteps: int): + """Run the Qbox simulation for nsteps.""" + cmd = f"run {self.context.niter} {self.context.nitscf} {self.context.nite}" + for step in range(nsteps): + # Send run command to Qbox for a single step + self.context.process_input(cmd) + # Update sampling method state after each step + self.update(step) + + +def build_snapshot_methods(sampling_method): + """ + Builds methods for retrieving snapshot properties in a format useful for collective + variable calculations. + """ + + def positions(snapshot): + return snapshot.positions + + def indices(snapshot): + return snapshot.ids + + def momenta(snapshot): + V, M = snapshot.vel_mass + return (M * V).flatten() + + def masses(snapshot): + _, M = snapshot.vel_mass + return M + + return SnapshotMethods(positions, indices, jit(momenta), masses) + + +def build_helpers(context, sampling_method): + """ + Builds helper methods used for restoring snapshots and biasing a simulation. + """ + # Precompute atom names since they won't change + atom_names = atom_property(context, "name") + + cv_indices = set() + for cv in sampling_method.cvs: + cv_indices.update(n.item() for n in cv.indices) + + def extforce_cmd(name, force): + return f"extforce set {name} {force[0]} {force[1]} {force[2]}" + + def bias(snapshot, state): + """Adds the computed bias to the forces using Qbox's extforce command.""" + if state.bias is None: + return + # Generate and send all extforce commands + context.process_input(extforce_cmd(atom_names[i], state.bias[i]) for i in cv_indices) + + snapshot_methods = build_snapshot_methods(sampling_method) + flags = sampling_method.snapshot_flags + helpers = HelperMethods(build_data_querier(snapshot_methods, flags), lambda: 3) + + return helpers, bias, atom_names, cv_indices + + +@dispatch +def atom_property(context, prop: str): + return atom_property(context, *property_handler(context, Val(prop))) + + +@dispatch +def atom_property(context, xml_tag, extract, gather): + atomset = last(context.state.iter("atomset")) + if atomset is None: + context.process_input("run 0") + atomset = last(context.state.iter("atomset")) + return gather(extract(elem) for elem in atomset.iter(xml_tag)) + + +@dispatch +def property_handler(context, prop: Val["name"]): # noqa: F821 + return ( + "atom", # xml_tag + (lambda s: s.attrib["name"]), # extract + list, # gather + ) + + +@dispatch +def property_handler(context, prop: Val["mass"]): # noqa: F821 + return ( + "atom", # xml_tag + (lambda s: context.species_masses[s.attrib["species"]]), # extract + (lambda d: np.array(list(d)).reshape(-1, 1)), # gather + ) + + +@dispatch +def property_handler(context, prop: Val): + return ( + type_parameter(prop), # xml_tag + (lambda s: s.text), # extract + (lambda d: parse_array(" ".join(d))), # gather + ) + + +def box(context): + elem = last(context.state.iter("unit_cell")) + if elem is None: + context.process_input("print cell") + elem = context.state.find("unit_cell") + cell_vecs = " ".join(elem.attrib.values()) + H = parse_array(cell_vecs, transpose=True) + origin = np.array([0.0, 0.0, 0.0]) + return Box(H, origin) + + +def timestep(context): + context.process_input("print dt") + elem = context.state.find("cmd") + return float(elem.tail.strip("\ndt= ")) + + +def bind(sampling_context: SamplingContext, callback: Optional[Callable], **kwargs): + """ + Sets up and returns a Sampler which enables performing enhanced sampling simulations. + + This function takes a `sampling_context` that has its context attribute as an instance + of a `QboxContext,` and creates a `Sampler` object that connects the PySAGES + sampling method to the Qbox simulation. It also modifies the `sampling_context`'s + `view` and `run` attributes to call the Qbox `run` command. + """ + context = sampling_context.context + sampler = Sampler(context, sampling_context.method, callback) + sampling_context.run = sampler.run + + return sampler