Skip to content

Commit 5afd925

Browse files
committed
Tests
1 parent ff13e83 commit 5afd925

File tree

2 files changed

+121
-5
lines changed

2 files changed

+121
-5
lines changed

pyadjoint/adjfloat.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@ def register(cls):
2020
return register
2121

2222

23+
@register_function(np.power)
24+
class _pyadjoint_power(sp.Function):
25+
def fdiff(self, argindex=1):
26+
if argindex == 1:
27+
return sp.Piecewise(
28+
# Let SymPy decide how to handle indeterminate form
29+
((self.args[0] ** self.args[1]).diff(self.args[0]), sp.And(self.args[0] == 0, self.args[1] == 0)),
30+
# Otherwise simplify
31+
(sp.S.Zero, self.args[1] == 0),
32+
(_pyadjoint_power(self.args[0], self.args[1] - 1) * self.args[1], True))
33+
elif argindex == 2:
34+
return (self.args[0] ** self.args[1]).diff(self.args[1])
35+
36+
2337
@register_function(np.hypot)
2438
class _pyadjoint_hypot(sp.Function):
2539
def fdiff(self, argindex=1):
@@ -228,11 +242,11 @@ def __sub__(self, other):
228242
def __rsub__(self, other):
229243
return super().__rsub__(other)
230244

231-
@annotate_operator(Operator(operator.pow, 2), operator.pow)
245+
@annotate_operator(Operator(_pyadjoint_power, 2), operator.pow)
232246
def __pow__(self, other):
233247
return super().__pow__(other)
234248

235-
@annotate_operator(Operator(roperator(operator.pow), 2), roperator(operator.pow))
249+
@annotate_operator(Operator(roperator(_pyadjoint_power), 2), roperator(operator.pow))
236250
def __rpow__(self, other):
237251
return super().__rpow__(other)
238252

@@ -243,7 +257,7 @@ def __rpow__(self, other):
243257
divide = register_operator(np.divide, operator.truediv, 2)
244258
add = register_operator(np.add, operator.add, 2)
245259
subtract = register_operator(np.subtract, operator.sub, 2)
246-
power = register_operator(np.power, operator.pow, 2)
260+
power = register_operator(np.power, _pyadjoint_power, 2)
247261
minimum = register_operator(
248262
np.minimum,
249263
lambda self, other: sp.Piecewise((self, self <= other),

tests/pyadjoint/test_floats.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import math
33
import numpy as np
44
import operator
5+
import sympy as sp
56
from numpy.testing import assert_approx_equal
67
from numpy.random import rand
78
from pyadjoint import *
@@ -165,6 +166,8 @@ def test_float_abs(v, abs):
165166
assert_approx_equal(b, abs(v))
166167

167168
rf = ReducedFunctional(b, Control(a))
169+
assert_approx_equal(rf(v), abs(v))
170+
assert_approx_equal(rf.tlm(1.0), 1.0 if v >= 0 else -1.0)
168171
assert_approx_equal(rf.derivative(), 1.0 if v >= 0 else -1.0)
169172
assert_approx_equal(rf.hessian(1.0), 0.0)
170173

@@ -177,6 +180,8 @@ def test_float_exp(v, exp):
177180
assert_approx_equal(b, math.exp(v))
178181

179182
rf = ReducedFunctional(b, Control(a))
183+
assert_approx_equal(rf(v), math.exp(v))
184+
assert_approx_equal(rf.tlm(1.0), math.exp(v))
180185
assert_approx_equal(rf.derivative(), math.exp(v))
181186
assert_approx_equal(rf.hessian(1.0), math.exp(v))
182187

@@ -189,10 +194,109 @@ def test_float_loglog(log):
189194
assert_approx_equal(b, math.log(math.log(v)))
190195

191196
rf = ReducedFunctional(b, Control(a))
197+
assert_approx_equal(rf(v), math.log(math.log(v)))
198+
assert_approx_equal(rf.tlm(1.0), 1.0 / (v * math.log(v)))
192199
assert_approx_equal(rf.derivative(), 1.0 / (v * math.log(v)))
193200
assert_approx_equal(rf.hessian(1.0), -(1.0 + math.log(v)) / ((v * math.log(v)) ** 2))
194201

195202

203+
def compose(*args):
204+
def fn(x):
205+
for arg in reversed(args):
206+
x = arg(x)
207+
return x
208+
return fn
209+
210+
211+
def sq(x):
212+
return x ** 2
213+
214+
215+
def hypotsq(x):
216+
return 1 + x ** 2
217+
218+
219+
@pytest.mark.parametrize("np_operator, sp_operator",
220+
(
221+
(operator.abs, lambda x: sp.Piecewise((x, x >= 0), (-x, True))),
222+
(operator.pos, operator.pos),
223+
(operator.neg, operator.neg),
224+
(lambda x: 1 + x, lambda x: 1 + x),
225+
(lambda x: x + 1, lambda x: 1 + x),
226+
(lambda x: 1 - x, lambda x: 1 - x),
227+
(lambda x: x - 1, lambda x: x - 1),
228+
(lambda x: 2 * x, lambda x: 2 * x),
229+
(lambda x: x * 2, lambda x: 2 * x),
230+
(compose(lambda x: 2 / x, hypotsq), compose(lambda x: 2 / x, hypotsq)),
231+
(lambda x: x / 2, lambda x: x / 2),
232+
(lambda x: 2 ** x, lambda x: 2 ** x),
233+
(lambda x: x ** 2, sq),
234+
(np.absolute, lambda x: sp.Piecewise((x, x >= 0), (-x, True))),
235+
(np.positive, operator.pos),
236+
(np.negative, operator.neg),
237+
(lambda x: np.add(x, 1), lambda x: 1 + x),
238+
(lambda x: np.add(1, x), lambda x: 1 + x),
239+
(lambda x: np.subtract(x, 1), lambda x: x - 1),
240+
(lambda x: np.subtract(1, x), lambda x: 1 - x),
241+
(lambda x: np.multiply(x, 2), lambda x: 2 * x),
242+
(lambda x: np.multiply(2, x), lambda x: 2 * x),
243+
(lambda x: np.divide(x, 2), lambda x: x / 2),
244+
(compose(lambda x: np.divide(2, x), hypotsq), compose(lambda x: 2 / x, hypotsq)),
245+
(lambda x: np.power(x, 2), sq),
246+
(lambda x: np.power(2, x), lambda x: 2 ** x),
247+
(np.sin, sp.sin),
248+
(np.cos, sp.cos),
249+
(np.tan, sp.tan),
250+
(compose(np.arcsin, np.tanh), compose(sp.asin, sp.tanh)),
251+
(compose(np.arccos, np.tanh), compose(sp.acos, sp.tanh)),
252+
(np.arctan, sp.atan),
253+
(lambda x: np.arctan2(x, 1), sp.atan),
254+
(lambda x: np.arctan2(1, x), lambda x: sp.atan2(1, x)),
255+
(compose(sq, lambda x: np.hypot(1, x)), hypotsq),
256+
(compose(sq, lambda x: np.hypot(x, 1)), hypotsq),
257+
(np.sinh, sp.sinh),
258+
(np.cosh, sp.cosh),
259+
(np.tanh, sp.tanh),
260+
(np.arcsinh, sp.asinh),
261+
(compose(np.arccosh, hypotsq, hypotsq), (compose(sp.acosh, hypotsq, hypotsq))),
262+
(compose(np.arctanh, np.sin), compose(sp.atanh, sp.sin)),
263+
(np.exp, sp.exp),
264+
(np.exp2, lambda x: 2 ** x),
265+
(np.expm1, lambda x: sp.exp(x) - 1),
266+
(compose(np.log, hypotsq), compose(sp.log, hypotsq)),
267+
(compose(np.log2, hypotsq), compose(lambda x: sp.log(x, 2), hypotsq)),
268+
(compose(np.log10, hypotsq), compose(lambda x: sp.log(x, 10), hypotsq)),
269+
(compose(np.log1p, sq), compose(sp.log, hypotsq)),
270+
(compose(np.sqrt, hypotsq), compose(sp.sqrt, hypotsq)),
271+
(np.square, sq),
272+
(compose(np.cbrt, hypotsq), compose(lambda x: x ** sp.Rational(1, 3), hypotsq)),
273+
(compose(np.reciprocal, hypotsq), compose(lambda x: 1 / x, hypotsq)),
274+
(lambda x: np.minimum(x, 1), lambda x: sp.Piecewise((x, x <= 1), (sp.S.One, True))),
275+
(lambda x: np.minimum(1, x), lambda x: sp.Piecewise((x, x <= 1), (sp.S.One, True))),
276+
(lambda x: np.maximum(x, 1), lambda x: sp.Piecewise((x, x >= 1), (sp.S.One, True))),
277+
(lambda x: np.maximum(1, x), lambda x: sp.Piecewise((x, x >= 1), (sp.S.One, True))),
278+
))
279+
@pytest.mark.parametrize("v", (-np.sqrt(np.pi), 0, np.exp(0.5)))
280+
def test_float_operators(np_operator, sp_operator, v):
281+
np_operator = compose(np.exp, np_operator)
282+
sp_operator = compose(sp.exp, sp_operator)
283+
284+
a = AdjFloat(v)
285+
b = np_operator(a)
286+
287+
x = sp.Symbol("x", real=True)
288+
op_ref = sp.lambdify((x,), sp_operator(x), modules=["numpy"])(v)
289+
dop_ref = sp.lambdify((x,), sp_operator(x).diff(x), modules=["numpy"])(v)
290+
ddop_ref = sp.lambdify((x,), sp_operator(x).diff(x).diff(x), modules=["numpy"])(v)
291+
292+
assert_approx_equal(b, op_ref)
293+
rf = ReducedFunctional(b, Control(a))
294+
assert_approx_equal(rf(v), op_ref)
295+
assert_approx_equal(rf.tlm(1.0), dop_ref)
296+
assert_approx_equal(rf.derivative(), dop_ref)
297+
assert_approx_equal(rf.hessian(1.0), ddop_ref)
298+
299+
196300
def test_float_logexp():
197301
a = AdjFloat(3.0)
198302
b = exp(a)
@@ -240,8 +344,6 @@ def test_float_exponentiation():
240344
# d(a**a)/da = dexp(a log(a))/da = a**a * (log(a) + 1)
241345
assert_approx_equal(rf.derivative(), 4.0 * (math.log(2.0)+1.0))
242346

243-
# TODO: __rpow__ is not yet implemented
244-
245347

246348
@pytest.mark.parametrize("B", [3,4])
247349
@pytest.mark.parametrize("E", [6,5])

0 commit comments

Comments
 (0)