From 22b37a43a65efb3eaeee6508ad2a9311b52ef8f4 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Tue, 2 Dec 2025 12:35:36 +0800 Subject: [PATCH 1/3] modify muon. --- keras/src/optimizers/muon.py | 59 +++++++++++++++++++++++++------ keras/src/optimizers/muon_test.py | 40 ++++++++++++++++++--- 2 files changed, 84 insertions(+), 15 deletions(-) diff --git a/keras/src/optimizers/muon.py b/keras/src/optimizers/muon.py index 88d0dde3ee92..31ed311d90e6 100644 --- a/keras/src/optimizers/muon.py +++ b/keras/src/optimizers/muon.py @@ -20,7 +20,7 @@ class Muon(optimizer.Optimizer): The Muon optimizer can use both the Muon update step or the AdamW update step based on the following: - - For any variable that isn't 2D, 3D or 4D, the AdamW step + - For any variable that isn't 2D, the AdamW step will be used. This is not configurable. - If the argument `exclude_embeddings` (defaults to `True`) is set to `True`, the AdamW step will be used. @@ -46,10 +46,12 @@ class Muon(optimizer.Optimizer): that takes no arguments and returns the actual value to use. The exponential decay rate for the 1st moment estimates. Defaults to `0.9`. - adam_beta_2: A float value or a constant float tensor, ora callable + adam_beta_2: A float value or a constant float tensor, or a callable that takes no arguments and returns the actual value to use. The exponential decay rate for the 2nd moment estimates. Defaults to `0.999`. + adam_weight_decay: Float. If set, weight decay is applied when using + the Adam optimizer. epsilon: A small constant for numerical stability. This is "epsilon hat" in the Kingma and Ba paper (in the formula just before Section 2.1), @@ -67,11 +69,16 @@ class Muon(optimizer.Optimizer): It is recommended to use the default value adam_lr_ratio: Float, the ratio of the learning rate when using Adam to the main learning rate. - it is recommended to set it to 0.1 + it is recommended to set it to 1 momentum: Float, momentum used by internal SGD. ns_steps: Integer, number of Newton-Schulz iterations to run. nesterov: Boolean, whether to use Nesterov-style momentum {{base_optimizer_keyword_args}} + `rms_rate`: A trick from https://arxiv.org/abs/2502.16982. + This parameter can enhance the stability of Muon, + allowing it to use the same learning rate and weight decay as Adam. + It is default to set it to `0.2` + If you wish to disable it, it is set None. """ def __init__( @@ -79,8 +86,9 @@ def __init__( learning_rate=0.001, adam_beta_1=0.9, adam_beta_2=0.999, + adam_weight_decay=0.004, epsilon=1e-7, - weight_decay=0.1, + weight_decay=0.004, clipnorm=None, clipvalue=None, global_clipnorm=None, @@ -95,10 +103,11 @@ def __init__( muon_a=3.4445, muon_b=-4.7750, muon_c=2.0315, - adam_lr_ratio=0.1, + adam_lr_ratio=1, momentum=0.95, - ns_steps=6, + ns_steps=5, nesterov=True, + rms_rate=0.2, **kwargs, ): super().__init__( @@ -127,12 +136,14 @@ def __init__( self.nesterov = nesterov self.exclude_embeddings = exclude_embeddings self.exclude_layers = exclude_layers or [] + self.adam_weight_decay = adam_weight_decay + self.rms_rate = rms_rate def _should_use_adamw(self, variable): # To use it with 4D convolutional filters, # it works well to just flatten their last 3 dimensions. # any {0,1}-D parameters should all be optimized by adam - if not 1 < len(variable.shape) < 4: + if len(variable.shape) != 2: return True if self.exclude_embeddings and "embedding" in variable.path.lower(): return True @@ -185,17 +196,15 @@ def update_step(self, gradient, variable, learning_rate): def _muon_update_step(self, gradient, variable, lr): m = self.adam_momentums[variable.path] self.assign_add(m, ops.add(gradient, m * (self.momentum - 1))) - shape = variable.shape if self.nesterov: g = ops.add(gradient, self.momentum * m) else: g = m + update = self.zeropower_via_newtonschulz5(g, self.ns_steps) self.assign_sub( variable, - lr - * self.zeropower_via_newtonschulz5(g, self.ns_steps) - * max(1, shape[0] / shape[1]) ** 0.5, + self.lr_adjust(lr * update), ) def _adamw_update_step(self, gradient, variable, learning_rate): @@ -239,6 +248,18 @@ def transpose_last_axis(self, X): X = ops.transpose(X, temp_order) return X + def lr_adjust(self, x): + """ + You can check the details at https://arxiv.org/pdf/2502.16982. + For a 2D matrix of size m,the analytical solution provided in the paper + rate * x * sqrt(max(n,m)) + """ + if self.rms_rate is None: + return x + # moonlight version + # https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py + return x * ops.sqrt(ops.maximum(x.shape[0], x.shape[1])) * self.rms_rate + def zeropower_via_newtonschulz5(self, x, steps: int): """We apply the Newton-Schulz iteration to compute matrix G. @@ -268,6 +289,20 @@ def zeropower_via_newtonschulz5(self, x, steps: int): x = self.transpose_last_axis(x) return x + def _apply_weight_decay(self, variables): + for variable in variables: + if self._use_weight_decay(variable): + if self._should_use_adamw(variable): + if self.adam_weight_decay is None: + continue + wd = ops.cast(self.adam_weight_decay, variable.dtype) + else: + if self.weight_decay is None: + continue + wd = ops.cast(self.weight_decay, variable.dtype) + lr = ops.cast(self.learning_rate, variable.dtype) + variable.assign(variable - variable * wd * lr) + def get_config(self): config = super().get_config() config.update( @@ -284,6 +319,8 @@ def get_config(self): "ns_steps": self.ns_steps, "nesterov": self.nesterov, "exclude_embeddings": self.exclude_embeddings, + "adam_weight_decay": self.adam_weight_decay, + "rms_rate": self.rms_rate, } ) return config diff --git a/keras/src/optimizers/muon_test.py b/keras/src/optimizers/muon_test.py index f22423c34aae..09d57074dc20 100644 --- a/keras/src/optimizers/muon_test.py +++ b/keras/src/optimizers/muon_test.py @@ -38,11 +38,11 @@ def test_should_use_adamw(self): True, optimizer._should_use_adamw(vars), ) - embeding = Embedding(2, 2) - embeding.build() + embedding = Embedding(2, 2) + embedding.build() self.assertAllClose( True, - optimizer._should_use_adamw(embeding.weights[0]), + optimizer._should_use_adamw(embedding.weights[0]), ) vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) optimizer = Muon() @@ -67,7 +67,10 @@ def test_muon_single_step(self): optimizer.build([vars]) optimizer._muon_update_step(grads, vars, 0.5) self.assertAllClose( - vars, [[1.13, 1.51], [2.57, 4.06]], rtol=1e-2, atol=1e-2 + vars, + [[0.988775, 1.887053], [2.873428, 3.97035]], + rtol=1e-2, + atol=1e-2, ) def test_clip_norm(self): @@ -81,3 +84,32 @@ def test_clip_value(self): grad = [np.array([100.0, 100.0])] clipped_grad = optimizer._clip_gradients(grad) self.assertAllClose(clipped_grad[0], [1.0, 1.0]) + + def test_muon_weight_decay(self): + variable = backend.Variable([[1.0, 2.0], [3.0, 4.0]]) + weight_decay = 0.01 + expected_variable = variable - variable * weight_decay + optimizer = Muon(learning_rate=1.0, weight_decay=weight_decay) + optimizer._apply_weight_decay([variable]) + self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4) + + def test_adamw_weight_decay(self): + variable = backend.Variable(2.0) + weight_decay = 0.01 + expected_variable = variable - variable * weight_decay + optimizer = Muon(learning_rate=1.0, adam_weight_decay=weight_decay) + optimizer._apply_weight_decay([variable]) + + self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4) + + def test_lr_adjust_none(self): + opt = Muon(rms_rate=None) + x = ops.ones((4, 4)) + want = x + self.assertAllClose(opt.lr_adjust(x), want) + + def test_lr_adjust_2d(self): + opt = Muon(rms_rate=0.2) + x = ops.ones((4, 2)) + want = x * 0.2 * 2 + self.assertAllClose(opt.lr_adjust(x), want) From 0737fc48773d2684774d3ad8a25c9447c11537bc Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Tue, 2 Dec 2025 12:44:22 +0800 Subject: [PATCH 2/3] modify gemini review. --- keras/src/optimizers/muon.py | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/keras/src/optimizers/muon.py b/keras/src/optimizers/muon.py index 31ed311d90e6..eeeeb7d48b49 100644 --- a/keras/src/optimizers/muon.py +++ b/keras/src/optimizers/muon.py @@ -69,16 +69,15 @@ class Muon(optimizer.Optimizer): It is recommended to use the default value adam_lr_ratio: Float, the ratio of the learning rate when using Adam to the main learning rate. - it is recommended to set it to 1 + It is recommended to set it to 1 momentum: Float, momentum used by internal SGD. ns_steps: Integer, number of Newton-Schulz iterations to run. nesterov: Boolean, whether to use Nesterov-style momentum {{base_optimizer_keyword_args}} - `rms_rate`: A trick from https://arxiv.org/abs/2502.16982. - This parameter can enhance the stability of Muon, - allowing it to use the same learning rate and weight decay as Adam. - It is default to set it to `0.2` - If you wish to disable it, it is set None. + rms_rate: Float. A parameter from https://arxiv.org/abs/2502.16982 + that can enhance the stability of Muon, allowing it to use the + same learning rate and weight decay as Adam. Defaults to `0.2`. + Set to `None` to disable this feature. """ def __init__( @@ -140,7 +139,6 @@ def __init__( self.rms_rate = rms_rate def _should_use_adamw(self, variable): - # To use it with 4D convolutional filters, # it works well to just flatten their last 3 dimensions. # any {0,1}-D parameters should all be optimized by adam if len(variable.shape) != 2: @@ -249,10 +247,12 @@ def transpose_last_axis(self, X): return X def lr_adjust(self, x): - """ - You can check the details at https://arxiv.org/pdf/2502.16982. - For a 2D matrix of size m,the analytical solution provided in the paper - rate * x * sqrt(max(n,m)) + """Adjusts learning rate based on the Moonlight implementation. + This method enhances the stability of Muon, allowing it to use the same + learning rate and weight decay as Adam. For details, see + https://arxiv.org/abs/2502.16982. + For a 2D matrix, the update is scaled by `sqrt(max(n, m)) * rms_rate`, + where `n` and `m` are the dimensions of the matrix. """ if self.rms_rate is None: return x @@ -291,17 +291,17 @@ def zeropower_via_newtonschulz5(self, x, steps: int): def _apply_weight_decay(self, variables): for variable in variables: - if self._use_weight_decay(variable): - if self._should_use_adamw(variable): - if self.adam_weight_decay is None: - continue - wd = ops.cast(self.adam_weight_decay, variable.dtype) - else: - if self.weight_decay is None: - continue - wd = ops.cast(self.weight_decay, variable.dtype) - lr = ops.cast(self.learning_rate, variable.dtype) - variable.assign(variable - variable * wd * lr) + if not self._use_weight_decay(variable): + continue + if self._should_use_adamw(variable): + weight_decay_value = self.adam_weight_decay + else: + weight_decay_value = self.weight_decay + if weight_decay_value is None: + continue + wd = ops.cast(weight_decay_value, variable.dtype) + lr = ops.cast(self.learning_rate, variable.dtype) + variable.assign(variable - variable * wd * lr) def get_config(self): config = super().get_config() From 4e4f3752120b46547bcbe8b5a30157ecac5ece8b Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Wed, 3 Dec 2025 11:52:29 +0800 Subject: [PATCH 3/3] modify --- keras/src/optimizers/muon.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/keras/src/optimizers/muon.py b/keras/src/optimizers/muon.py index eeeeb7d48b49..b1ba1e923f21 100644 --- a/keras/src/optimizers/muon.py +++ b/keras/src/optimizers/muon.py @@ -200,10 +200,7 @@ def _muon_update_step(self, gradient, variable, lr): g = m update = self.zeropower_via_newtonschulz5(g, self.ns_steps) - self.assign_sub( - variable, - self.lr_adjust(lr * update), - ) + self.assign_sub(variable, self.lr_adjust(lr * update)) def _adamw_update_step(self, gradient, variable, learning_rate): """Update step given gradient and the associated model variable."""