@@ -154,15 +154,31 @@ void main() {
154
154
}
155
155
156
156
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
157
- tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
158
- tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
159
- tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
157
+ bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
160
158
161
- coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
159
+ if (nem1_bounds_check) {
160
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
161
+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
162
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
162
163
163
- coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br , Br, j * Bc, Bc)) ;
164
+ coopmat<float16_t, gl_ScopeWorkgroup , Br, Bc, gl_MatrixUseAccumulator> mv ;
164
165
165
- S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
166
+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
167
+
168
+ S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
169
+ } else {
170
+ tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
171
+ // Don't clamp against nem1 when GQA is enabled
172
+ uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
173
+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
174
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
175
+
176
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
177
+
178
+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
179
+
180
+ S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
181
+ }
166
182
}
167
183
168
184
// Clear padding elements to -inf, so they don't contribute to rowmax
0 commit comments