Skip to content

Commit e054116

Browse files
CalCravenpre-commit-ci[bot]chrisjonesBSU
authored
Fill out virtual_sites position functionality (#934)
* Fill out virtual_sites position functionality * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move expression evaluation for position to gmso expression.py, and add and test for different exceptions * Identify units from gmso site positions, instead of hard coded as nm * Remove unncessary variable --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Chris Jones <[email protected]>
1 parent 99fcbb5 commit e054116

File tree

4 files changed

+236
-12
lines changed

4 files changed

+236
-12
lines changed

gmso/core/virtual_site.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import string
12
from typing import Callable, List, Optional, Union
23

34
import unyt as u
45
from pydantic import ConfigDict, Field
6+
from sympy import Matrix, symbols
57

68
from gmso.abc.abstract_site import Site
79
from gmso.core.virtual_type import VirtualType
8-
from gmso.exceptions import MissingPotentialError, NotYetImplementedWarning
10+
from gmso.exceptions import MissingPotentialError
911

1012

1113
class VirtualSite(Site):
@@ -58,8 +60,8 @@ def parent_sites(self) -> List[Site]:
5860
"""Reminder that the order of sites is fixed, such that site index 1 corresponds to ri in the self.virtual_type.virtual_position expression."""
5961
return self.__dict__.get("parent_sites_", [])
6062

61-
def position(self) -> str:
62-
"""Not yet implemented function to get position from virtual_type.virtual_position and parent_sites."""
63+
def position(self) -> u.unyt_array:
64+
"""On the fly position evaluation from virtual_type.virtual_position and parent_sites."""
6365
if not self.virtual_type:
6466
raise MissingPotentialError(
6567
"No VirtualType associated with this VirtualSite."
@@ -68,10 +70,26 @@ def position(self) -> str:
6870
raise MissingPotentialError(
6971
"No VirtualPositionType associated with this VirtualType."
7072
)
71-
# TODO: validate parent atoms matches virtual_type.virtual_position in terms of independent variables ri, rj, etc.
72-
# TODO: Generate position from atoms of parent_atoms and self.virtual_type.virtual_position.expression.
73-
raise NotYetImplementedWarning(
74-
"Need a functional to call from self.virtual_type.virtual_position, and plug in ri, rj, rk etc."
73+
74+
independent_namespace = {}
75+
for _, symbol in zip(range(len(self.parent_sites)), string.ascii_lowercase[8:]):
76+
x, y, z = symbols(f"r{symbol}1 r{symbol}2 r{symbol}3")
77+
independent_namespace[f"r{symbol}"] = Matrix([x, y, z])
78+
79+
independent_parameters = {}
80+
for symbol, site in zip(string.ascii_lowercase[8:], self.parent_sites):
81+
for i, pos in enumerate(site.position):
82+
independent_parameters[f"r{symbol}{i + 1}"] = float(pos.value)
83+
84+
# get units from parent sites
85+
unitsUnyt = self.parent_sites[0].position.units
86+
87+
# perform expression evaluation
88+
return (
89+
self.virtual_type.virtual_position.potential_expression.evaluate(
90+
independent_namespace, independent_parameters
91+
)
92+
* unitsUnyt
7593
)
7694

7795
def __repr__(self):

gmso/tests/test_expression.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import numpy as np
12
import pytest
23
import sympy
34
import unyt as u
45

6+
from gmso.exceptions import GMSOError
57
from gmso.tests.base_test import BaseTest
68
from gmso.utils.expression import PotentialExpression, _are_equal_parameters
79

@@ -331,3 +333,68 @@ def test_are_equal_parameters(self):
331333
u1 = {"a": 2.0 * u.nm, "b": 3.5 * u.nm}
332334
u2 = {"c": 2.0 * u.nm, "d": 3.5 * u.nm}
333335
assert _are_equal_parameters(u1, u2) is False
336+
337+
def test_evaluate_expression(self):
338+
import sympy
339+
340+
from gmso.utils.expression import PotentialExpression
341+
342+
expression = PotentialExpression(
343+
expression="4*epsilon*((sigma/r)**12 - (sigma/r)**6)",
344+
parameters={"sigma": 1 * u.nm, "epsilon": -32.507936507936506 * u.kJ},
345+
independent_variables="r",
346+
)
347+
r = sympy.symbols("r")
348+
independent_namespace = {"r": r}
349+
independent_parameters = {"r": 2}
350+
assert expression.evaluate(independent_namespace, independent_parameters) == 2
351+
352+
# test use of norm in expression on vectors
353+
expression = PotentialExpression(
354+
expression="norm(a)*r",
355+
parameters={"a": [3, 4, 0] * u.nm}, # norm is 5
356+
independent_variables="r",
357+
)
358+
x, y, z = sympy.symbols("x y z")
359+
independent_namespace["r"] = sympy.Matrix([x, y, z])
360+
independent_parameters = {"x": 0 * u.nm, "y": 1 * u.nm, "z": 2 * u.nm}
361+
output = expression.evaluate(independent_namespace, independent_parameters)
362+
assert all(output == np.array([0, 1, 2]) * 5)
363+
364+
with pytest.raises(GMSOError): # fail with missing r parameters
365+
expression = PotentialExpression(
366+
expression="4*epsilon*((sigma/r)**12 - (sigma/r)**6)",
367+
parameters={"sigma": 1 * u.nm, "epsilon": -32.507936507936506 * u.kJ},
368+
independent_variables="r",
369+
)
370+
expression.evaluate()
371+
372+
with pytest.raises(GMSOError): # fail with divide by 0
373+
expression = PotentialExpression(
374+
expression="a/b",
375+
parameters={"a": 1 * u.nm, "b": 0 * u.kJ},
376+
independent_variables=[],
377+
)
378+
expression.evaluate()
379+
380+
with pytest.raises(GMSOError): # fail in vector divide by 0
381+
expression = PotentialExpression(
382+
expression="a/r",
383+
parameters={"a": 1 * u.nm},
384+
independent_variables="r",
385+
)
386+
x, y, z = sympy.symbols("x y z")
387+
independent_namespace["r"] = sympy.Matrix([x, y, z])
388+
independent_parameters = {"x": 1, "y": 0, "z": 0}
389+
expression.evaluate(independent_namespace, independent_parameters)
390+
391+
with pytest.raises(GMSOError): # fail with bad expression
392+
expression = PotentialExpression(
393+
expression="normal(a/r)",
394+
parameters={"a": 1 * u.nm},
395+
independent_variables="r",
396+
)
397+
x, y, z = sympy.symbols("x y z")
398+
independent_namespace["r"] = sympy.Matrix([x, y, z])
399+
independent_parameters = {"x": 1, "y": 0, "z": 0}
400+
expression.evaluate(independent_namespace, independent_parameters)

gmso/tests/test_virtual_site.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pytest
23
import unyt as u
34
from sympy import symbols, sympify
@@ -16,7 +17,7 @@
1617
class TestVirturalSite(BaseTest):
1718
@pytest.fixture(scope="session")
1819
def virtual_site(self):
19-
site = Atom()
20+
site = Atom(position=[1, 1, 1])
2021
return VirtualSite(parent_sites=[site])
2122

2223
@pytest.fixture(scope="session")
@@ -34,8 +35,30 @@ def test_new_site(self, water_system):
3435
for site in v_site.parent_sites:
3536
assert site in water_system.sites
3637

37-
def test_virtual_position(self):
38-
# TODO: Check position as a function of virtual_position_type
38+
def test_virtual_position(self, virtual_site):
39+
# Check position as a function of virtual_position_type
40+
# TODO: check for arrays for b, sin norm
41+
# TODO: checkk all gromacs potential forms
42+
43+
v_pot = VirtualPotentialType(
44+
expression="5*a*b",
45+
independent_variables={"a"},
46+
parameters={"b": 1 * u.kJ},
47+
)
48+
v_pos = VirtualPositionType(
49+
expression="ri*cos(b)",
50+
independent_variables=["ri"],
51+
parameters={"b": np.pi * u.radian},
52+
)
53+
assert v_pos
54+
v_type = VirtualType(virtual_potential=v_pot, virtual_position=v_pos)
55+
virtual_site.virtual_type = v_type # assign virtual type
56+
assert_allclose_units(virtual_site.position(), -1 * ([1, 1, 1] * u.nm))
57+
58+
def test_tip4p_water(self):
59+
pass
60+
61+
def test_tip5p_water(self):
3962
pass
4063

4164
def test_virtual_type(self):

gmso/utils/expression.py

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,32 @@
55
from functools import lru_cache
66
from typing import Dict
77

8+
import numpy as np
89
import sympy
910
import unyt as u
11+
from sympy import Add, Function, Mul, Symbol, lambdify, sympify
12+
13+
from gmso.exceptions import GMSOError
1014

1115
logger = logging.getLogger(__name__)
1216

1317
__all__ = ["PotentialExpression"]
1418

1519

20+
class norm(Function):
21+
"""Sympy functions for use in lambdify"""
22+
23+
@classmethod
24+
def eval(cls, arg):
25+
return None
26+
27+
28+
# Evaluate vector norm
29+
def norm_evaluation(matrix_arg):
30+
"""Evaluation method for norm of a sympy matrix"""
31+
return np.linalg.norm(matrix_arg)
32+
33+
1634
def _are_equal_parameters(u1, u2):
1735
"""Compare two parameters of unyt quantities/arrays.
1836
@@ -407,10 +425,8 @@ def from_non_parametric(
407425
----------
408426
non_parametric: PotentialExpression
409427
The non-parametric potential expression to create the parametric one from
410-
411428
parameters: dict
412429
The dictionary of parameters for the newly created parametric expression.
413-
414430
valid: bool, default=False
415431
Whether to validate expression/independent_variables and with the parameters.
416432
@@ -442,3 +458,103 @@ def from_non_parametric(
442458
independent_variables=deepcopy(non_parametric.independent_variables),
443459
verify_validity=not valid,
444460
)
461+
462+
def evaluate(
463+
self, independent_namespace: dict = None, independent_parameters: dict = None
464+
):
465+
"""Evaluate the sympy expression with the given parameters
466+
467+
Parameters
468+
----------
469+
independent_namespace: dict, default None
470+
Dictionary with keys that are the string of the symbol to evaluate and values are the sympy Symbol of Function object
471+
independent_parameters: dict, default None
472+
Keys are strings, and values are the unyt value that corresponds to that symbol in the expression.
473+
474+
Notes
475+
-----
476+
Evaluate the LJ expression as follows:
477+
```python
478+
from gmso.utils.expression import PotentialExpression
479+
import sympy
480+
expression = PotentialExpression(
481+
expression="4*epsilon*((sigma/r)**12 - (sigma/r)**6)",
482+
parameters={"sigma":1*u.nm, "epsilon":-32.507936507936506*u.kJ},
483+
independent_variables="r"
484+
)
485+
r = sympy.symbols(f"r")
486+
independent_namespace = {"r": r} # key to symbol
487+
independent_parameters = {"r": 2} # key to value
488+
expression.evaluate(independent_namespace, independent_parameters)
489+
```
490+
The input "norm" in the expression will be evaluated as the Matrix normal of the variable, i.e. the result from np.linalg.norm. See
491+
class norm(Function) in the module for the definition and evaluation procedure for this symbol, and possibility to add more shorthand
492+
function evaluations in the lambdify namespace.
493+
494+
Returns
495+
-------
496+
result: numpy.ndarray or float
497+
The expression evaluated as a function with parameters plugged in.
498+
"""
499+
# prep namespace for sympify
500+
if independent_namespace is None:
501+
namespace = {}
502+
args = []
503+
else: # grab symbols from namespace
504+
args = [Symbol(key) for key in independent_parameters]
505+
namespace = (
506+
independent_namespace.copy()
507+
) # make copies to not overwrite dictionary objects
508+
namespace.update(
509+
{sym: Symbol(sym) for sym in self.parameters}
510+
) # expression parameters
511+
namespace["norm"] = norm # handle Matrix/Vector normalization
512+
args.extend([namespace[sym] for sym in self.parameters]) # args to lambdify
513+
expr_string = str(self.expression)
514+
if independent_parameters is None:
515+
parameters = {}
516+
else:
517+
parameters = independent_parameters.copy()
518+
519+
# parse expression
520+
try:
521+
expr = sympify(expr_string, locals=namespace)
522+
except (ValueError, TypeError):
523+
raise GMSOError(
524+
f"Expression {expr_string=} was not viable in sympy for PotentialExpression object:{self}."
525+
)
526+
527+
f = lambdify(args, expr, modules=[{"norm": norm_evaluation}, "numpy"])
528+
529+
# evaluate
530+
parameters.update(
531+
{param: val.to_value() for param, val in self.parameters.items()}
532+
)
533+
# apply parameters here into lamdify'ed object
534+
try:
535+
result = f(**parameters)
536+
except ZeroDivisionError:
537+
raise GMSOError(
538+
f"PotentialExpression {self} is unabel to be evaluated since the result was a divide by 0 issue."
539+
)
540+
except NameError as e:
541+
raise GMSOError(f"PotentialExpresion {self} raise: " + str(e))
542+
543+
# error handling
544+
if isinstance(result, float):
545+
return result # TODO: Attach units here also processed via sympy expression
546+
elif isinstance(result, (Symbol, Add, Mul)) or not np.issubdtype(
547+
result.dtype, np.floating
548+
):
549+
raise GMSOError(
550+
f"{self=} expression was not able to be fully evaluated. Unknown parameters left defined in {expr_string=}. Position evaluation is: {result=}"
551+
)
552+
elif any(np.isnan(result)):
553+
raise GMSOError(
554+
f"Failed evaluation of {self=} with {expr_string=} and {parameters=}."
555+
)
556+
if len(result) > 1: # return a vector
557+
result = np.array(result).T[0]
558+
else:
559+
result = np.array(result)
560+
return result

0 commit comments

Comments
 (0)