We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ac4d41d commit e5649e7Copy full SHA for e5649e7
optax/contrib/_muon.py
@@ -120,6 +120,11 @@ def _shape_factor(x: jax.Array, dim_nums: MuonDimensionNumbers) -> float:
120
121
122
def _newton_schulz_iterator(x: jax.Array, coeffs: jax.Array) -> jax.Array:
123
+ # Implements Newton-Schulz step f(X) = c_0 X + c_1 (XX^T)X + c_2 (XX^T)^2X,
124
+ # with quintic form f(X) = c_0 X + (c_1 A + c_2 AA)X, where A = XX^T.
125
+ # The NS step has the property f(X) = f(X^T)^T. That is, we can get equivalent
126
+ # result by tranposing input and output. In particular, we may tranpose X
127
+ # when rows > cols for effciency.
128
a = x @ x.T
129
b = coeffs[1] * a + coeffs[2] * a @ a
130
return coeffs[0] * x + b @ x
0 commit comments