2727import jax .numpy as jnp
2828from optax import contrib
2929from optax ._src import alias
30+ from optax ._src import base
3031from optax ._src import combine
3132from optax ._src import numerics
3233from optax ._src import update
3334from optax .schedules import _inject
3435from optax .transforms import _accumulation
3536from optax .tree_utils import _state_utils
37+ from optax .tree_utils import _tree_math
3638
3739# Testing contributions coded as GradientTransformations
3840_MAIN_OPTIMIZERS_UNDER_TEST = [
5355 opt_name = 'schedule_free_adamw' ,
5456 opt_kwargs = dict (learning_rate = 1e-2 , warmup_steps = 5000 ),
5557 ),
58+ dict (
59+ opt_name = 'sophia' ,
60+ opt_kwargs = dict (learning_rate = 1e-2 ),
61+ ),
5662]
5763for optimizer in _MAIN_OPTIMIZERS_UNDER_TEST :
5864 optimizer ['wrapper_name' ] = None
@@ -144,11 +150,10 @@ def _setup_parabola(dtype):
144150 initial_params = jnp .array ([- 1.0 , 10.0 , 1.0 ], dtype = dtype )
145151 final_params = jnp .array ([1.0 , - 1.0 , 1.0 ], dtype = dtype )
146152
147- @jax .value_and_grad
148- def get_updates (params ):
153+ def obj_fn (params ):
149154 return jnp .sum (numerics .abs_sq (params - final_params ))
150155
151- return initial_params , final_params , get_updates
156+ return initial_params , final_params , obj_fn
152157
153158
154159def _setup_rosenbrock (dtype ):
@@ -159,13 +164,12 @@ def _setup_rosenbrock(dtype):
159164 initial_params = jnp .array ([0.0 , 0.0 ], dtype = dtype )
160165 final_params = jnp .array ([a , a ** 2 ], dtype = dtype )
161166
162- @jax .value_and_grad
163- def get_updates (params ):
167+ def obj_fn (params ):
164168 return numerics .abs_sq (a - params [0 ]) + b * numerics .abs_sq (
165169 params [1 ] - params [0 ] ** 2
166170 )
167171
168- return initial_params , final_params , get_updates
172+ return initial_params , final_params , obj_fn
169173
170174
171175class ContribTest (chex .TestCase ):
@@ -188,16 +192,18 @@ def test_optimizers(
188192 opt = _get_opt_factory (opt_name )(** opt_kwargs )
189193 if wrapper_name is not None :
190194 opt = _wrap_opt (opt , wrapper_name , wrapper_kwargs )
191- initial_params , final_params , get_updates = target (dtype )
195+ initial_params , final_params , obj_fn = target (dtype )
192196
193197 @jax .jit
194198 def step (params , state ):
195- value , updates = get_updates (params )
199+ value , updates = jax . value_and_grad ( obj_fn ) (params )
196200 if (
197201 opt_name in ['momo' , 'momo_adam' ]
198202 or wrapper_name == 'reduce_on_plateau'
199203 ):
200204 update_kwargs = {'value' : value }
205+ elif opt_name == 'sophia' :
206+ update_kwargs = {'obj_fn' : obj_fn }
201207 else :
202208 update_kwargs = {}
203209 updates , state = opt .update (updates , state , params , ** update_kwargs )
@@ -266,14 +272,21 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
266272 update_kwargs = {'value' : jnp .array (1.0 )}
267273 else :
268274 update_kwargs = {}
275+ if opt_name == 'sophia' :
276+ obj_fn = lambda x : _tree_math .tree_l2_norm (x , squared = True )
277+ update_fn = functools .partial (opt .update , obj_fn = obj_fn )
278+ inject_update_fn = functools .partial (opt_inject .update , obj_fn = obj_fn )
279+ else :
280+ update_fn = opt .update
281+ inject_update_fn = opt_inject .update
269282
270283 state = self .variant (opt .init )(params )
271- updates , new_state = self .variant (opt . update )(
284+ updates , new_state = self .variant (update_fn )(
272285 grads , state , params , ** update_kwargs
273286 )
274287
275288 state_inject = self .variant (opt_inject .init )(params )
276- updates_inject , new_state_inject = self .variant (opt_inject . update )(
289+ updates_inject , new_state_inject = self .variant (inject_update_fn )(
277290 grads , state_inject , params , ** update_kwargs
278291 )
279292
@@ -320,7 +333,11 @@ def test_preserve_dtype(
320333 update_kwargs = {'value' : value }
321334 else :
322335 update_kwargs = {}
323- updates , _ = self .variant (opt .update )(grads , state , params , ** update_kwargs )
336+ if opt_name == 'sophia' :
337+ update_fn = functools .partial (opt .update , obj_fn = fun )
338+ else :
339+ update_fn = opt .update
340+ updates , _ = self .variant (update_fn )(grads , state , params , ** update_kwargs )
324341 self .assertEqual (updates .dtype , params .dtype )
325342
326343 @chex .variants (
@@ -339,10 +356,16 @@ def test_gradient_accumulation(
339356 opt = _get_opt_factory (opt_name )(** opt_kwargs )
340357 if wrapper_name is not None :
341358 opt = _wrap_opt (opt , wrapper_name , wrapper_kwargs )
342- opt = _accumulation .MultiSteps (opt , every_k_schedule = 4 )
343359
344360 fun = lambda x : jnp .sum (x ** 2 )
345361
362+ if opt_name == 'sophia' :
363+ update_fn = functools .partial (opt .update , obj_fn = fun )
364+ else :
365+ update_fn = opt .update
366+ opt = base .GradientTransformationExtraArgs (opt .init , update_fn )
367+ opt = _accumulation .MultiSteps (opt , every_k_schedule = 4 )
368+
346369 params = jnp .array ([1.0 , 2.0 ], dtype = dtype )
347370 value , grads = jax .value_and_grad (fun )(params )
348371 state = self .variant (opt .init )(params )
0 commit comments