@@ -227,8 +227,11 @@ void main() {
227227
228228 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
229229
230+ // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
231+ const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
232+
230233 L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
231- M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0 );
234+ M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2 );
232235
233236 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
234237
@@ -256,7 +259,7 @@ void main() {
256259 }
257260
258261 if (p.mask != 0) {
259- tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV > tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV );
262+ tensorLayoutNV<2, Clamp > tensorLayoutM = createTensorLayoutNV(2, Clamp );
260263 tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
261264 // When using grouped query attention, all rows use the same mask.
262265 if (p.gqa_ratio > 1) {
@@ -278,7 +281,7 @@ void main() {
278281 uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
279282 uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
280283
281- coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0 ), R, C);
284+ coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2 ), R, C);
282285 }
283286
284287 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
0 commit comments