Skip to content

Commit e5649e7

Browse files
author
OptaxDev
committed
Add explanation to Newton Schulz step
PiperOrigin-RevId: 799665576
1 parent ac4d41d commit e5649e7

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

optax/contrib/_muon.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ def _shape_factor(x: jax.Array, dim_nums: MuonDimensionNumbers) -> float:
120120

121121

122122
def _newton_schulz_iterator(x: jax.Array, coeffs: jax.Array) -> jax.Array:
123+
# Implements Newton-Schulz step f(X) = c_0 X + c_1 (XX^T)X + c_2 (XX^T)^2X,
124+
# with quintic form f(X) = c_0 X + (c_1 A + c_2 AA)X, where A = XX^T.
125+
# The NS step has the property f(X) = f(X^T)^T. That is, we can get equivalent
126+
# result by tranposing input and output. In particular, we may tranpose X
127+
# when rows > cols for effciency.
123128
a = x @ x.T
124129
b = coeffs[1] * a + coeffs[2] * a @ a
125130
return coeffs[0] * x + b @ x

0 commit comments

Comments
 (0)