Skip to content

Commit 9c35706

Browse files
CUDA: fix MMQ nwarps for AMD with warp_size==32 (#15014)
1 parent c76b420 commit 9c35706

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -251,25 +251,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
251251
#endif // AMD_MFMA_AVAILABLE
252252

253253
#if defined(GGML_USE_HIP)
254-
static int mmq_get_nwarps_host(const int cc) {
255-
return amd_mfma_available(cc) ? 8 : 4;
254+
static int mmq_get_nwarps_host(const int cc, const int warp_size) {
255+
return amd_mfma_available(cc) ? 8 : 256/warp_size;
256256
}
257257
#else
258-
static int mmq_get_nwarps_host(const int /*cc*/) {
259-
return 8;
258+
static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
259+
return 256/warp_size;
260260
}
261261
#endif // (GGML_USE_HIP)
262262

263263
static constexpr __device__ int mmq_get_nwarps_device() {
264-
#if defined(GGML_USE_HIP)
265264
#if defined(AMD_MFMA_AVAILABLE)
266265
return 8;
267266
#else
268-
return 4;
267+
return 256/ggml_cuda_get_physical_warp_size();
269268
#endif // AMD_MFMA_AVAILABLE
270-
#else
271-
return 8;
272-
#endif // defined(GGML_USE_HIP)
273269
}
274270

275271
// ------------------------------------------------------------
@@ -3472,7 +3468,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
34723468
const int cc = ggml_cuda_info().devices[id].cc;
34733469
const int nsm = ggml_cuda_info().devices[id].nsm;
34743470
const int warp_size = ggml_cuda_info().devices[id].warp_size;
3475-
const int nwarps = mmq_get_nwarps_host(cc);
3471+
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
34763472
const int mmq_y = get_mmq_y_host(cc);
34773473

34783474
const dim3 block_dims(warp_size, nwarps, 1);
@@ -3559,7 +3555,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
35593555
const int cc = ggml_cuda_info().devices[id].cc;
35603556
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
35613557
const int warp_size = ggml_cuda_info().devices[id].warp_size;
3562-
const int nwarps = mmq_get_nwarps_host(cc);
3558+
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
35633559

35643560
const int mmq_x_max = get_mmq_x_max_host(cc);
35653561
const int mmq_y = get_mmq_y_host(cc);

0 commit comments

Comments
 (0)