Skip to content

Commit 3d8c391

Browse files
author
OptaxDev
committed
Merge pull request #979 from evanatyourservice:sophia_h
PiperOrigin-RevId: 703157245
2 parents 3f0a64b + 341c6c2 commit 3d8c391

File tree

4 files changed

+361
-12
lines changed

4 files changed

+361
-12
lines changed

docs/api/contrib.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ Experimental features and algorithms that don't meet the
3434
schedule_free_eval_params
3535
schedule_free_sgd
3636
ScheduleFreeState
37+
sophia
38+
SophiaState
3739
split_real_and_imaginary
3840
SplitRealAndImaginaryState
3941

@@ -99,3 +101,10 @@ Sharpness aware minimization
99101
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
100102
.. autofunction:: sam
101103
.. autoclass:: SAMState
104+
105+
Sophia
106+
~~~~~~
107+
.. autofunction:: hutchinson_estimator_diag_hessian
108+
.. autoclass:: HutchinsonState
109+
.. autofunction:: sophia
110+
.. autoclass:: SophiaState

optax/contrib/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,7 @@
5151
from optax.contrib._schedule_free import schedule_free_eval_params
5252
from optax.contrib._schedule_free import schedule_free_sgd
5353
from optax.contrib._schedule_free import ScheduleFreeState
54+
from optax.contrib._sophia import hutchinson_estimator_diag_hessian
55+
from optax.contrib._sophia import HutchinsonState
56+
from optax.contrib._sophia import sophia
57+
from optax.contrib._sophia import SophiaState

optax/contrib/_common_test.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@
2727
import jax.numpy as jnp
2828
from optax import contrib
2929
from optax._src import alias
30+
from optax._src import base
3031
from optax._src import combine
3132
from optax._src import numerics
3233
from optax._src import update
3334
from optax.schedules import _inject
3435
from optax.transforms import _accumulation
3536
from optax.tree_utils import _state_utils
37+
from optax.tree_utils import _tree_math
3638

3739
# Testing contributions coded as GradientTransformations
3840
_MAIN_OPTIMIZERS_UNDER_TEST = [
@@ -53,6 +55,10 @@
5355
opt_name='schedule_free_adamw',
5456
opt_kwargs=dict(learning_rate=1e-2, warmup_steps=5000),
5557
),
58+
dict(
59+
opt_name='sophia',
60+
opt_kwargs=dict(learning_rate=1e-2),
61+
),
5662
]
5763
for optimizer in _MAIN_OPTIMIZERS_UNDER_TEST:
5864
optimizer['wrapper_name'] = None
@@ -144,11 +150,10 @@ def _setup_parabola(dtype):
144150
initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype)
145151
final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype)
146152

147-
@jax.value_and_grad
148-
def get_updates(params):
153+
def obj_fn(params):
149154
return jnp.sum(numerics.abs_sq(params - final_params))
150155

151-
return initial_params, final_params, get_updates
156+
return initial_params, final_params, obj_fn
152157

153158

154159
def _setup_rosenbrock(dtype):
@@ -159,13 +164,12 @@ def _setup_rosenbrock(dtype):
159164
initial_params = jnp.array([0.0, 0.0], dtype=dtype)
160165
final_params = jnp.array([a, a**2], dtype=dtype)
161166

162-
@jax.value_and_grad
163-
def get_updates(params):
167+
def obj_fn(params):
164168
return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq(
165169
params[1] - params[0] ** 2
166170
)
167171

168-
return initial_params, final_params, get_updates
172+
return initial_params, final_params, obj_fn
169173

170174

171175
class ContribTest(chex.TestCase):
@@ -188,16 +192,18 @@ def test_optimizers(
188192
opt = _get_opt_factory(opt_name)(**opt_kwargs)
189193
if wrapper_name is not None:
190194
opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs)
191-
initial_params, final_params, get_updates = target(dtype)
195+
initial_params, final_params, obj_fn = target(dtype)
192196

193197
@jax.jit
194198
def step(params, state):
195-
value, updates = get_updates(params)
199+
value, updates = jax.value_and_grad(obj_fn)(params)
196200
if (
197201
opt_name in ['momo', 'momo_adam']
198202
or wrapper_name == 'reduce_on_plateau'
199203
):
200204
update_kwargs = {'value': value}
205+
elif opt_name == 'sophia':
206+
update_kwargs = {'obj_fn': obj_fn}
201207
else:
202208
update_kwargs = {}
203209
updates, state = opt.update(updates, state, params, **update_kwargs)
@@ -266,14 +272,21 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
266272
update_kwargs = {'value': jnp.array(1.0)}
267273
else:
268274
update_kwargs = {}
275+
if opt_name == 'sophia':
276+
obj_fn = lambda x: _tree_math.tree_l2_norm(x, squared=True)
277+
update_fn = functools.partial(opt.update, obj_fn=obj_fn)
278+
inject_update_fn = functools.partial(opt_inject.update, obj_fn=obj_fn)
279+
else:
280+
update_fn = opt.update
281+
inject_update_fn = opt_inject.update
269282

270283
state = self.variant(opt.init)(params)
271-
updates, new_state = self.variant(opt.update)(
284+
updates, new_state = self.variant(update_fn)(
272285
grads, state, params, **update_kwargs
273286
)
274287

275288
state_inject = self.variant(opt_inject.init)(params)
276-
updates_inject, new_state_inject = self.variant(opt_inject.update)(
289+
updates_inject, new_state_inject = self.variant(inject_update_fn)(
277290
grads, state_inject, params, **update_kwargs
278291
)
279292

@@ -320,7 +333,11 @@ def test_preserve_dtype(
320333
update_kwargs = {'value': value}
321334
else:
322335
update_kwargs = {}
323-
updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs)
336+
if opt_name == 'sophia':
337+
update_fn = functools.partial(opt.update, obj_fn=fun)
338+
else:
339+
update_fn = opt.update
340+
updates, _ = self.variant(update_fn)(grads, state, params, **update_kwargs)
324341
self.assertEqual(updates.dtype, params.dtype)
325342

326343
@chex.variants(
@@ -339,10 +356,16 @@ def test_gradient_accumulation(
339356
opt = _get_opt_factory(opt_name)(**opt_kwargs)
340357
if wrapper_name is not None:
341358
opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs)
342-
opt = _accumulation.MultiSteps(opt, every_k_schedule=4)
343359

344360
fun = lambda x: jnp.sum(x**2)
345361

362+
if opt_name == 'sophia':
363+
update_fn = functools.partial(opt.update, obj_fn=fun)
364+
else:
365+
update_fn = opt.update
366+
opt = base.GradientTransformationExtraArgs(opt.init, update_fn)
367+
opt = _accumulation.MultiSteps(opt, every_k_schedule=4)
368+
346369
params = jnp.array([1.0, 2.0], dtype=dtype)
347370
value, grads = jax.value_and_grad(fun)(params)
348371
state = self.variant(opt.init)(params)

0 commit comments

Comments
 (0)