Skip to content

Commit 3f0a64b

Browse files
author
OptaxDev
committed
Merge pull request #1142 from gautierronan:complex-lbfgs
PiperOrigin-RevId: 702881265
2 parents ad71306 + db4ff3e commit 3f0a64b

File tree

7 files changed

+133
-13
lines changed

7 files changed

+133
-13
lines changed

optax/_src/alias.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2514,6 +2514,8 @@ def lbfgs(
25142514
constrain the trust-region of the first step to an Euclidean ball of radius
25152515
1 at the first iteration. The choice of :math:`\gamma_0` is not detailed in
25162516
the references above, so this is a heuristic choice.
2517+
2518+
.. note:: The algorithm can support complex inputs.
25172519
"""
25182520
if learning_rate is None:
25192521
base_scaling = transform.scale(-1.0)

optax/_src/alias_test.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import numpy as np
2828
from optax._src import alias
2929
from optax._src import base
30+
from optax._src import linesearch as _linesearch
3031
from optax._src import numerics
3132
from optax._src import transform
3233
from optax._src import update
@@ -333,6 +334,7 @@ def stopping_criterion(carry):
333334
def step(carry):
334335
params, state, count, _ = carry
335336
value, grad = value_and_grad_fun(params)
337+
grad = otu.tree_conj(grad)
336338
updates, state = opt.update(
337339
grad, state, params, value=value, grad=grad, value_fn=fun
338340
)
@@ -690,9 +692,7 @@ def test_against_plain_implementation(
690692
scale_init_precond=scale_init_precond,
691693
linesearch=None,
692694
)
693-
lbfgs_sol, _ = _run_opt(
694-
opt, fun, init_params, maxiter=maxiter, tol=tol
695-
)
695+
lbfgs_sol, _ = _run_opt(opt, fun, init_params, maxiter=maxiter, tol=tol)
696696
expected_lbfgs_sol = _plain_lbfgs(
697697
fun,
698698
init_params,
@@ -806,9 +806,7 @@ def test_against_scipy(self, problem_name: str):
806806
jnp_fun, np_fun = problem['fun'], problem['numpy_fun']
807807

808808
opt = alias.lbfgs()
809-
optax_sol, _ = _run_opt(
810-
opt, jnp_fun, init_params, maxiter=500, tol=tol
811-
)
809+
optax_sol, _ = _run_opt(opt, jnp_fun, init_params, maxiter=500, tol=tol)
812810
scipy_sol = scipy_optimize.minimize(np_fun, init_params, method='BFGS').x
813811

814812
# 1. Check minimizer obtained against known minimizer or scipy minimizer
@@ -865,6 +863,76 @@ def fun(x):
865863
sol, _ = _run_opt(opt, fun, init_params=jnp.ones(n), tol=tol)
866864
chex.assert_trees_all_close(sol, jnp.zeros(n), atol=tol, rtol=tol)
867865

866+
@parameterized.product(
867+
linesearch=[
868+
_linesearch.scale_by_backtracking_linesearch(
869+
max_backtracking_steps=20
870+
),
871+
_linesearch.scale_by_zoom_linesearch(
872+
max_linesearch_steps=20, initial_guess_strategy='one'
873+
),
874+
],
875+
)
876+
def test_lbfgs_complex(self, linesearch):
877+
# Test that optimization over complex variable matches equivalent real case
878+
879+
tol = 1e-5
880+
mat = jnp.array([[1, -2], [3, 4], [-4 + 2j, 5 - 3j], [-2 - 2j, 6]])
881+
882+
def to_real(z):
883+
return jnp.stack((z.real, z.imag))
884+
885+
def to_complex(x):
886+
return x[..., 0, :] + 1j * x[..., 1, :]
887+
888+
def f_complex(z):
889+
return jnp.sum(jnp.abs(mat @ z) ** 1.5)
890+
891+
def f_real(x):
892+
return f_complex(to_complex(x))
893+
894+
z0 = jnp.array([1 - 1j, 0 + 1j])
895+
x0 = to_real(z0)
896+
897+
opt_complex = alias.lbfgs(linesearch=linesearch)
898+
opt_real = alias.lbfgs(linesearch=linesearch)
899+
sol_complex, _ = _run_opt(opt_complex, f_complex, init_params=z0, tol=tol)
900+
sol_real, _ = _run_opt(opt_real, f_real, init_params=x0, tol=tol)
901+
902+
chex.assert_trees_all_close(
903+
sol_complex, to_complex(sol_real), atol=tol, rtol=tol
904+
)
905+
906+
@parameterized.product(
907+
linesearch=[
908+
_linesearch.scale_by_backtracking_linesearch(
909+
max_backtracking_steps=20
910+
),
911+
_linesearch.scale_by_zoom_linesearch(
912+
max_linesearch_steps=20, initial_guess_strategy='one'
913+
),
914+
],
915+
)
916+
def test_lbfgs_complex_rosenbrock(self, linesearch):
917+
# Taken from previous jax tests
918+
tol = 1e-5
919+
complex_dim = 5
920+
921+
fun_real = _get_problem('rosenbrock')['fun']
922+
init_real = jnp.zeros((2 * complex_dim,), dtype=complex)
923+
expected_real = jnp.ones((2 * complex_dim,), dtype=complex)
924+
925+
def fun(z):
926+
x_real = jnp.concatenate([jnp.real(z), jnp.imag(z)])
927+
return fun_real(x_real)
928+
929+
init = init_real[:complex_dim] + 1.0j * init_real[complex_dim:]
930+
expected = expected_real[:complex_dim] + 1.0j * expected_real[complex_dim:]
931+
932+
opt = alias.lbfgs(linesearch=linesearch)
933+
got, _ = _run_opt(opt, fun, init, maxiter=500, tol=tol)
934+
chex.assert_trees_all_close(got, expected, atol=tol, rtol=tol)
935+
868936

869937
if __name__ == '__main__':
870938
absltest.main()

optax/_src/linesearch.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,8 @@ def scale_by_backtracking_linesearch(
240240
after the backtracking line-search doesn't necessarily need to satisfy the
241241
descent direction property (one could for example use momentum).
242242
243+
.. note:: The algorithm can support complex inputs.
244+
243245
.. seealso:: :func:`optax.value_and_grad_from_state` to make this method
244246
more efficient for non-stochastic objectives.
245247
@@ -319,7 +321,7 @@ def update_fn(
319321
# Slope of lr -> value_fn(params + lr * updates) at lr = 0
320322
# Should be negative to ensure that there exists a lr (potentially
321323
# infinitesimal) that satisfies the criterion.
322-
slope = otu.tree_vdot(updates, grad)
324+
slope = otu.tree_real(otu.tree_vdot(updates, otu.tree_conj(grad)))
323325

324326
def cond_fn(
325327
search_state: BacktrackingLineSearchState,
@@ -698,7 +700,7 @@ def _value_and_slope_on_line(
698700
"""
699701
step = otu.tree_add_scalar_mul(params, stepsize, updates)
700702
value_step, grad_step = value_and_grad_fn(step, **fn_kwargs)
701-
slope_step = otu.tree_vdot(grad_step, updates)
703+
slope_step = otu.tree_real(otu.tree_vdot(otu.tree_conj(grad_step), updates))
702704
return step, value_step, grad_step, slope_step
703705

704706
def _compute_decrease_error(
@@ -1205,7 +1207,7 @@ def init_fn(
12051207
f"Unknown initial guess strategy: {initial_guess_strategy}"
12061208
)
12071209

1208-
slope = otu.tree_vdot(updates, grad)
1210+
slope = otu.tree_real(otu.tree_vdot(updates, grad))
12091211
return ZoomLinesearchState(
12101212
count=jnp.asarray(0, dtype=jnp.int32),
12111213
#
@@ -1511,6 +1513,8 @@ def scale_by_zoom_linesearch(
15111513
This can be sufficient in practice and avoids having the linesearch spend
15121514
many iterations trying to satisfy the small curvature criterion.
15131515
1516+
.. note:: The algorithm can support complex inputs.
1517+
15141518
.. seealso:: :func:`optax.value_and_grad_from_state` to make this method
15151519
more efficient for non-stochastic objectives.
15161520
"""

optax/_src/transform.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,7 +1521,7 @@ def right_product(vec, idx):
15211521
dwi, dui = jax.tree.map(
15221522
lambda x: x[idx], (diff_params_memory, diff_updates_memory)
15231523
)
1524-
alpha = rhos[idx] * otu.tree_vdot(dwi, vec)
1524+
alpha = rhos[idx] * otu.tree_real(otu.tree_vdot(dwi, vec))
15251525
vec = otu.tree_add_scalar_mul(vec, -alpha, dui)
15261526
return vec, alpha
15271527

@@ -1536,7 +1536,7 @@ def left_product(vec, idx_alpha):
15361536
dwi, dui = jax.tree.map(
15371537
lambda x: x[idx], (diff_params_memory, diff_updates_memory)
15381538
)
1539-
beta = rhos[idx] * otu.tree_vdot(dui, vec)
1539+
beta = rhos[idx] * otu.tree_real(otu.tree_vdot(dui, vec))
15401540
vec = otu.tree_add_scalar_mul(vec, alpha - beta, dwi)
15411541
return vec, beta
15421542

@@ -1666,7 +1666,9 @@ def update_fn(
16661666
# 1. Updates the memory buffers given fresh params and gradients/updates
16671667
diff_params = otu.tree_sub(params, state.params)
16681668
diff_updates = otu.tree_sub(updates, state.updates)
1669-
vdot_diff_params_updates = otu.tree_vdot(diff_updates, diff_params)
1669+
vdot_diff_params_updates = otu.tree_real(
1670+
otu.tree_vdot(diff_updates, diff_params)
1671+
)
16701672
weight = jnp.where(
16711673
vdot_diff_params_updates == 0.0, 0.0, 1.0 / vdot_diff_params_updates
16721674
)
@@ -1691,7 +1693,7 @@ def update_fn(
16911693
# used to initialize the approximation of the inverse through the memory
16921694
# buffer.
16931695
if scale_init_precond:
1694-
numerator = otu.tree_vdot(diff_updates, diff_params)
1696+
numerator = otu.tree_real(otu.tree_vdot(diff_updates, diff_params))
16951697
denominator = otu.tree_l2_norm(diff_updates, squared=True)
16961698
identity_scale = jnp.where(
16971699
denominator > 0.0, numerator / denominator, 1.0

optax/tree_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from optax.tree_utils._tree_math import tree_add_scalar_mul
3030
from optax.tree_utils._tree_math import tree_bias_correction
3131
from optax.tree_utils._tree_math import tree_clip
32+
from optax.tree_utils._tree_math import tree_conj
3233
from optax.tree_utils._tree_math import tree_div
3334
from optax.tree_utils._tree_math import tree_full_like
3435
from optax.tree_utils._tree_math import tree_l1_norm
@@ -37,6 +38,7 @@
3738
from optax.tree_utils._tree_math import tree_max
3839
from optax.tree_utils._tree_math import tree_mul
3940
from optax.tree_utils._tree_math import tree_ones_like
41+
from optax.tree_utils._tree_math import tree_real
4042
from optax.tree_utils._tree_math import tree_scalar_mul
4143
from optax.tree_utils._tree_math import tree_sub
4244
from optax.tree_utils._tree_math import tree_sum

optax/tree_utils/_tree_math.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,30 @@ def tree_max(tree: Any) -> chex.Numeric:
183183
return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf))
184184

185185

186+
def tree_conj(tree: Any) -> Any:
187+
"""Compute the conjugate of a pytree.
188+
189+
Args:
190+
tree: pytree.
191+
192+
Returns:
193+
a pytree with the same structure as ``tree``.
194+
"""
195+
return jax.tree.map(jnp.conj, tree)
196+
197+
198+
def tree_real(tree: Any) -> Any:
199+
"""Compute the real part of a pytree.
200+
201+
Args:
202+
tree: pytree.
203+
204+
Returns:
205+
a pytree with the same structure as ``tree``.
206+
"""
207+
return jax.tree.map(jnp.real, tree)
208+
209+
186210
def _square(leaf):
187211
return jnp.square(leaf.real) + jnp.square(leaf.imag)
188212

optax/tree_utils/_tree_math_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,24 @@ def test_tree_max(self, key):
152152
got = tu.tree_max(tree)
153153
np.testing.assert_allclose(expected, got)
154154

155+
def test_tree_conj(self):
156+
expected = jnp.conj(self.array_a)
157+
got = tu.tree_conj(self.array_a)
158+
np.testing.assert_array_almost_equal(expected, got)
159+
160+
expected = (jnp.conj(self.tree_a[0]), jnp.conj(self.tree_a[1]))
161+
got = tu.tree_conj(self.tree_a)
162+
chex.assert_trees_all_close(expected, got)
163+
164+
def test_tree_real(self):
165+
expected = jnp.real(self.array_a)
166+
got = tu.tree_real(self.array_a)
167+
np.testing.assert_array_almost_equal(expected, got)
168+
169+
expected = (jnp.real(self.tree_a[0]), jnp.real(self.tree_a[1]))
170+
got = tu.tree_real(self.tree_a)
171+
chex.assert_trees_all_close(expected, got)
172+
155173
def test_tree_l2_norm(self):
156174
expected = jnp.sqrt(jnp.vdot(self.array_a, self.array_a).real)
157175
got = tu.tree_l2_norm(self.array_a)

0 commit comments

Comments
 (0)