|
27 | 27 | import numpy as np |
28 | 28 | from optax._src import alias |
29 | 29 | from optax._src import base |
| 30 | +from optax._src import linesearch as _linesearch |
30 | 31 | from optax._src import numerics |
31 | 32 | from optax._src import transform |
32 | 33 | from optax._src import update |
@@ -333,6 +334,7 @@ def stopping_criterion(carry): |
333 | 334 | def step(carry): |
334 | 335 | params, state, count, _ = carry |
335 | 336 | value, grad = value_and_grad_fun(params) |
| 337 | + grad = otu.tree_conj(grad) |
336 | 338 | updates, state = opt.update( |
337 | 339 | grad, state, params, value=value, grad=grad, value_fn=fun |
338 | 340 | ) |
@@ -690,9 +692,7 @@ def test_against_plain_implementation( |
690 | 692 | scale_init_precond=scale_init_precond, |
691 | 693 | linesearch=None, |
692 | 694 | ) |
693 | | - lbfgs_sol, _ = _run_opt( |
694 | | - opt, fun, init_params, maxiter=maxiter, tol=tol |
695 | | - ) |
| 695 | + lbfgs_sol, _ = _run_opt(opt, fun, init_params, maxiter=maxiter, tol=tol) |
696 | 696 | expected_lbfgs_sol = _plain_lbfgs( |
697 | 697 | fun, |
698 | 698 | init_params, |
@@ -806,9 +806,7 @@ def test_against_scipy(self, problem_name: str): |
806 | 806 | jnp_fun, np_fun = problem['fun'], problem['numpy_fun'] |
807 | 807 |
|
808 | 808 | opt = alias.lbfgs() |
809 | | - optax_sol, _ = _run_opt( |
810 | | - opt, jnp_fun, init_params, maxiter=500, tol=tol |
811 | | - ) |
| 809 | + optax_sol, _ = _run_opt(opt, jnp_fun, init_params, maxiter=500, tol=tol) |
812 | 810 | scipy_sol = scipy_optimize.minimize(np_fun, init_params, method='BFGS').x |
813 | 811 |
|
814 | 812 | # 1. Check minimizer obtained against known minimizer or scipy minimizer |
@@ -865,6 +863,76 @@ def fun(x): |
865 | 863 | sol, _ = _run_opt(opt, fun, init_params=jnp.ones(n), tol=tol) |
866 | 864 | chex.assert_trees_all_close(sol, jnp.zeros(n), atol=tol, rtol=tol) |
867 | 865 |
|
| 866 | + @parameterized.product( |
| 867 | + linesearch=[ |
| 868 | + _linesearch.scale_by_backtracking_linesearch( |
| 869 | + max_backtracking_steps=20 |
| 870 | + ), |
| 871 | + _linesearch.scale_by_zoom_linesearch( |
| 872 | + max_linesearch_steps=20, initial_guess_strategy='one' |
| 873 | + ), |
| 874 | + ], |
| 875 | + ) |
| 876 | + def test_lbfgs_complex(self, linesearch): |
| 877 | + # Test that optimization over complex variable matches equivalent real case |
| 878 | + |
| 879 | + tol = 1e-5 |
| 880 | + mat = jnp.array([[1, -2], [3, 4], [-4 + 2j, 5 - 3j], [-2 - 2j, 6]]) |
| 881 | + |
| 882 | + def to_real(z): |
| 883 | + return jnp.stack((z.real, z.imag)) |
| 884 | + |
| 885 | + def to_complex(x): |
| 886 | + return x[..., 0, :] + 1j * x[..., 1, :] |
| 887 | + |
| 888 | + def f_complex(z): |
| 889 | + return jnp.sum(jnp.abs(mat @ z) ** 1.5) |
| 890 | + |
| 891 | + def f_real(x): |
| 892 | + return f_complex(to_complex(x)) |
| 893 | + |
| 894 | + z0 = jnp.array([1 - 1j, 0 + 1j]) |
| 895 | + x0 = to_real(z0) |
| 896 | + |
| 897 | + opt_complex = alias.lbfgs(linesearch=linesearch) |
| 898 | + opt_real = alias.lbfgs(linesearch=linesearch) |
| 899 | + sol_complex, _ = _run_opt(opt_complex, f_complex, init_params=z0, tol=tol) |
| 900 | + sol_real, _ = _run_opt(opt_real, f_real, init_params=x0, tol=tol) |
| 901 | + |
| 902 | + chex.assert_trees_all_close( |
| 903 | + sol_complex, to_complex(sol_real), atol=tol, rtol=tol |
| 904 | + ) |
| 905 | + |
| 906 | + @parameterized.product( |
| 907 | + linesearch=[ |
| 908 | + _linesearch.scale_by_backtracking_linesearch( |
| 909 | + max_backtracking_steps=20 |
| 910 | + ), |
| 911 | + _linesearch.scale_by_zoom_linesearch( |
| 912 | + max_linesearch_steps=20, initial_guess_strategy='one' |
| 913 | + ), |
| 914 | + ], |
| 915 | + ) |
| 916 | + def test_lbfgs_complex_rosenbrock(self, linesearch): |
| 917 | + # Taken from previous jax tests |
| 918 | + tol = 1e-5 |
| 919 | + complex_dim = 5 |
| 920 | + |
| 921 | + fun_real = _get_problem('rosenbrock')['fun'] |
| 922 | + init_real = jnp.zeros((2 * complex_dim,), dtype=complex) |
| 923 | + expected_real = jnp.ones((2 * complex_dim,), dtype=complex) |
| 924 | + |
| 925 | + def fun(z): |
| 926 | + x_real = jnp.concatenate([jnp.real(z), jnp.imag(z)]) |
| 927 | + return fun_real(x_real) |
| 928 | + |
| 929 | + init = init_real[:complex_dim] + 1.0j * init_real[complex_dim:] |
| 930 | + expected = expected_real[:complex_dim] + 1.0j * expected_real[complex_dim:] |
| 931 | + |
| 932 | + opt = alias.lbfgs(linesearch=linesearch) |
| 933 | + got, _ = _run_opt(opt, fun, init, maxiter=500, tol=tol) |
| 934 | + chex.assert_trees_all_close(got, expected, atol=tol, rtol=tol) |
| 935 | + |
868 | 936 |
|
869 | 937 | if __name__ == '__main__': |
870 | 938 | absltest.main() |
0 commit comments