@@ -251,25 +251,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
251
251
#endif // AMD_MFMA_AVAILABLE
252
252
253
253
#if defined(GGML_USE_HIP)
254
- static int mmq_get_nwarps_host (const int cc) {
255
- return amd_mfma_available (cc) ? 8 : 4 ;
254
+ static int mmq_get_nwarps_host (const int cc, const int warp_size ) {
255
+ return amd_mfma_available (cc) ? 8 : 256 /warp_size ;
256
256
}
257
257
#else
258
- static int mmq_get_nwarps_host (const int /* cc*/ ) {
259
- return 8 ;
258
+ static int mmq_get_nwarps_host (const int /* cc*/ , const int warp_size ) {
259
+ return 256 /warp_size ;
260
260
}
261
261
#endif // (GGML_USE_HIP)
262
262
263
263
static constexpr __device__ int mmq_get_nwarps_device () {
264
- #if defined(GGML_USE_HIP)
265
264
#if defined(AMD_MFMA_AVAILABLE)
266
265
return 8 ;
267
266
#else
268
- return 4 ;
267
+ return 256 / ggml_cuda_get_physical_warp_size () ;
269
268
#endif // AMD_MFMA_AVAILABLE
270
- #else
271
- return 8 ;
272
- #endif // defined(GGML_USE_HIP)
273
269
}
274
270
275
271
// ------------------------------------------------------------
@@ -3472,7 +3468,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
3472
3468
const int cc = ggml_cuda_info ().devices [id].cc ;
3473
3469
const int nsm = ggml_cuda_info ().devices [id].nsm ;
3474
3470
const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
3475
- const int nwarps = mmq_get_nwarps_host (cc);
3471
+ const int nwarps = mmq_get_nwarps_host (cc, warp_size );
3476
3472
const int mmq_y = get_mmq_y_host (cc);
3477
3473
3478
3474
const dim3 block_dims (warp_size, nwarps, 1 );
@@ -3559,7 +3555,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
3559
3555
const int cc = ggml_cuda_info ().devices [id].cc ;
3560
3556
const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
3561
3557
const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
3562
- const int nwarps = mmq_get_nwarps_host (cc);
3558
+ const int nwarps = mmq_get_nwarps_host (cc, warp_size );
3563
3559
3564
3560
const int mmq_x_max = get_mmq_x_max_host (cc);
3565
3561
const int mmq_y = get_mmq_y_host (cc);
0 commit comments