Skip to content

Commit d38ddca

Browse files
committed
modify code .
1 parent 39020ea commit d38ddca

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

keras/src/backend/tensorflow/optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,9 @@ def weight_decay_fn(variable):
113113
def _backend_update_step(self, grads, trainable_variables, learning_rate):
114114
def _prepare_var(v):
115115
new_v = v.value if isinstance(v, backend.Variable) else v
116-
new_v._muon_use_adam_flag = v._muon_use_adam_flag
117-
new_v._muon_path_id = v._muon_path_id
116+
if hasattr(v, "_muon_use_adam_flag"):
117+
new_v._muon_use_adam_flag = v._muon_use_adam_flag
118+
new_v._muon_path_id = v._muon_path_id
118119
return new_v
119120

120121
trainable_variables = [_prepare_var(v) for v in trainable_variables]

keras/src/optimizers/muon_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def test_muon_weight_decay(self):
9292
weight_decay = 0.01
9393
expected_variable = variable - variable * weight_decay
9494
optimizer = Muon(learning_rate=1.0, weight_decay=weight_decay)
95+
optimizer.build([variable])
9596
optimizer._apply_weight_decay([variable])
9697
self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4)
9798

@@ -100,6 +101,7 @@ def test_adamw_weight_decay(self):
100101
weight_decay = 0.01
101102
expected_variable = variable - variable * weight_decay
102103
optimizer = Muon(learning_rate=1.0, adam_weight_decay=weight_decay)
104+
optimizer.build([variable])
103105
optimizer._apply_weight_decay([variable])
104106

105107
self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4)

0 commit comments

Comments
 (0)