From 24246aea1057e541040a4c18a97031735ffe9138 Mon Sep 17 00:00:00 2001 From: ShanningZhuang Date: Tue, 25 Nov 2025 23:54:19 +0800 Subject: [PATCH] Fix #647 gradient_update_fn compatible with AdamW --- brax/training/gradients.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brax/training/gradients.py b/brax/training/gradients.py index 4c21d4c78..a616f0871 100644 --- a/brax/training/gradients.py +++ b/brax/training/gradients.py @@ -60,7 +60,7 @@ def gradient_update_fn( def f(*args, optimizer_state): value, grads = loss_and_pgrad_fn(*args) - params_update, optimizer_state = optimizer.update(grads, optimizer_state) + params_update, optimizer_state = optimizer.update(grads, optimizer_state, args[0]) params = optax.apply_updates(args[0], params_update) return value, params, optimizer_state