22import math
33import numpy as np
44import operator
5+ import sympy as sp
56from numpy .testing import assert_approx_equal
67from numpy .random import rand
78from pyadjoint import *
@@ -165,6 +166,8 @@ def test_float_abs(v, abs):
165166 assert_approx_equal (b , abs (v ))
166167
167168 rf = ReducedFunctional (b , Control (a ))
169+ assert_approx_equal (rf (v ), abs (v ))
170+ assert_approx_equal (rf .tlm (1.0 ), 1.0 if v >= 0 else - 1.0 )
168171 assert_approx_equal (rf .derivative (), 1.0 if v >= 0 else - 1.0 )
169172 assert_approx_equal (rf .hessian (1.0 ), 0.0 )
170173
@@ -177,6 +180,8 @@ def test_float_exp(v, exp):
177180 assert_approx_equal (b , math .exp (v ))
178181
179182 rf = ReducedFunctional (b , Control (a ))
183+ assert_approx_equal (rf (v ), math .exp (v ))
184+ assert_approx_equal (rf .tlm (1.0 ), math .exp (v ))
180185 assert_approx_equal (rf .derivative (), math .exp (v ))
181186 assert_approx_equal (rf .hessian (1.0 ), math .exp (v ))
182187
@@ -189,10 +194,109 @@ def test_float_loglog(log):
189194 assert_approx_equal (b , math .log (math .log (v )))
190195
191196 rf = ReducedFunctional (b , Control (a ))
197+ assert_approx_equal (rf (v ), math .log (math .log (v )))
198+ assert_approx_equal (rf .tlm (1.0 ), 1.0 / (v * math .log (v )))
192199 assert_approx_equal (rf .derivative (), 1.0 / (v * math .log (v )))
193200 assert_approx_equal (rf .hessian (1.0 ), - (1.0 + math .log (v )) / ((v * math .log (v )) ** 2 ))
194201
195202
203+ def compose (* args ):
204+ def fn (x ):
205+ for arg in reversed (args ):
206+ x = arg (x )
207+ return x
208+ return fn
209+
210+
211+ def sq (x ):
212+ return x ** 2
213+
214+
215+ def hypotsq (x ):
216+ return 1 + x ** 2
217+
218+
219+ @pytest .mark .parametrize ("np_operator, sp_operator" ,
220+ (
221+ (operator .abs , lambda x : sp .Piecewise ((x , x >= 0 ), (- x , True ))),
222+ (operator .pos , operator .pos ),
223+ (operator .neg , operator .neg ),
224+ (lambda x : 1 + x , lambda x : 1 + x ),
225+ (lambda x : x + 1 , lambda x : 1 + x ),
226+ (lambda x : 1 - x , lambda x : 1 - x ),
227+ (lambda x : x - 1 , lambda x : x - 1 ),
228+ (lambda x : 2 * x , lambda x : 2 * x ),
229+ (lambda x : x * 2 , lambda x : 2 * x ),
230+ (compose (lambda x : 2 / x , hypotsq ), compose (lambda x : 2 / x , hypotsq )),
231+ (lambda x : x / 2 , lambda x : x / 2 ),
232+ (lambda x : 2 ** x , lambda x : 2 ** x ),
233+ (lambda x : x ** 2 , sq ),
234+ (np .absolute , lambda x : sp .Piecewise ((x , x >= 0 ), (- x , True ))),
235+ (np .positive , operator .pos ),
236+ (np .negative , operator .neg ),
237+ (lambda x : np .add (x , 1 ), lambda x : 1 + x ),
238+ (lambda x : np .add (1 , x ), lambda x : 1 + x ),
239+ (lambda x : np .subtract (x , 1 ), lambda x : x - 1 ),
240+ (lambda x : np .subtract (1 , x ), lambda x : 1 - x ),
241+ (lambda x : np .multiply (x , 2 ), lambda x : 2 * x ),
242+ (lambda x : np .multiply (2 , x ), lambda x : 2 * x ),
243+ (lambda x : np .divide (x , 2 ), lambda x : x / 2 ),
244+ (compose (lambda x : np .divide (2 , x ), hypotsq ), compose (lambda x : 2 / x , hypotsq )),
245+ (lambda x : np .power (x , 2 ), sq ),
246+ (lambda x : np .power (2 , x ), lambda x : 2 ** x ),
247+ (np .sin , sp .sin ),
248+ (np .cos , sp .cos ),
249+ (np .tan , sp .tan ),
250+ (compose (np .arcsin , np .tanh ), compose (sp .asin , sp .tanh )),
251+ (compose (np .arccos , np .tanh ), compose (sp .acos , sp .tanh )),
252+ (np .arctan , sp .atan ),
253+ (lambda x : np .arctan2 (x , 1 ), sp .atan ),
254+ (lambda x : np .arctan2 (1 , x ), lambda x : sp .atan2 (1 , x )),
255+ (compose (sq , lambda x : np .hypot (1 , x )), hypotsq ),
256+ (compose (sq , lambda x : np .hypot (x , 1 )), hypotsq ),
257+ (np .sinh , sp .sinh ),
258+ (np .cosh , sp .cosh ),
259+ (np .tanh , sp .tanh ),
260+ (np .arcsinh , sp .asinh ),
261+ (compose (np .arccosh , hypotsq , hypotsq ), (compose (sp .acosh , hypotsq , hypotsq ))),
262+ (compose (np .arctanh , np .sin ), compose (sp .atanh , sp .sin )),
263+ (np .exp , sp .exp ),
264+ (np .exp2 , lambda x : 2 ** x ),
265+ (np .expm1 , lambda x : sp .exp (x ) - 1 ),
266+ (compose (np .log , hypotsq ), compose (sp .log , hypotsq )),
267+ (compose (np .log2 , hypotsq ), compose (lambda x : sp .log (x , 2 ), hypotsq )),
268+ (compose (np .log10 , hypotsq ), compose (lambda x : sp .log (x , 10 ), hypotsq )),
269+ (compose (np .log1p , sq ), compose (sp .log , hypotsq )),
270+ (compose (np .sqrt , hypotsq ), compose (sp .sqrt , hypotsq )),
271+ (np .square , sq ),
272+ (compose (np .cbrt , hypotsq ), compose (lambda x : x ** sp .Rational (1 , 3 ), hypotsq )),
273+ (compose (np .reciprocal , hypotsq ), compose (lambda x : 1 / x , hypotsq )),
274+ (lambda x : np .minimum (x , 1 ), lambda x : sp .Piecewise ((x , x <= 1 ), (sp .S .One , True ))),
275+ (lambda x : np .minimum (1 , x ), lambda x : sp .Piecewise ((x , x <= 1 ), (sp .S .One , True ))),
276+ (lambda x : np .maximum (x , 1 ), lambda x : sp .Piecewise ((x , x >= 1 ), (sp .S .One , True ))),
277+ (lambda x : np .maximum (1 , x ), lambda x : sp .Piecewise ((x , x >= 1 ), (sp .S .One , True ))),
278+ ))
279+ @pytest .mark .parametrize ("v" , (- np .sqrt (np .pi ), 0 , np .exp (0.5 )))
280+ def test_float_operators (np_operator , sp_operator , v ):
281+ np_operator = compose (np .exp , np_operator )
282+ sp_operator = compose (sp .exp , sp_operator )
283+
284+ a = AdjFloat (v )
285+ b = np_operator (a )
286+
287+ x = sp .Symbol ("x" , real = True )
288+ op_ref = sp .lambdify ((x ,), sp_operator (x ), modules = ["numpy" ])(v )
289+ dop_ref = sp .lambdify ((x ,), sp_operator (x ).diff (x ), modules = ["numpy" ])(v )
290+ ddop_ref = sp .lambdify ((x ,), sp_operator (x ).diff (x ).diff (x ), modules = ["numpy" ])(v )
291+
292+ assert_approx_equal (b , op_ref )
293+ rf = ReducedFunctional (b , Control (a ))
294+ assert_approx_equal (rf (v ), op_ref )
295+ assert_approx_equal (rf .tlm (1.0 ), dop_ref )
296+ assert_approx_equal (rf .derivative (), dop_ref )
297+ assert_approx_equal (rf .hessian (1.0 ), ddop_ref )
298+
299+
196300def test_float_logexp ():
197301 a = AdjFloat (3.0 )
198302 b = exp (a )
@@ -240,8 +344,6 @@ def test_float_exponentiation():
240344 # d(a**a)/da = dexp(a log(a))/da = a**a * (log(a) + 1)
241345 assert_approx_equal (rf .derivative (), 4.0 * (math .log (2.0 )+ 1.0 ))
242346
243- # TODO: __rpow__ is not yet implemented
244-
245347
246348@pytest .mark .parametrize ("B" , [3 ,4 ])
247349@pytest .mark .parametrize ("E" , [6 ,5 ])
0 commit comments