Skip to content

CUDA: fix MMQ nwarps for AMD with warp_size==32 #15014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -251,25 +251,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
#endif // AMD_MFMA_AVAILABLE

#if defined(GGML_USE_HIP)
static int mmq_get_nwarps_host(const int cc) {
return amd_mfma_available(cc) ? 8 : 4;
static int mmq_get_nwarps_host(const int cc, const int warp_size) {
return amd_mfma_available(cc) ? 8 : 256/warp_size;
}
#else
static int mmq_get_nwarps_host(const int /*cc*/) {
return 8;
static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
return 256/warp_size;
}
#endif // (GGML_USE_HIP)

static constexpr __device__ int mmq_get_nwarps_device() {
#if defined(GGML_USE_HIP)
#if defined(AMD_MFMA_AVAILABLE)
return 8;
#else
return 4;
return 256/ggml_cuda_get_physical_warp_size();
#endif // AMD_MFMA_AVAILABLE
#else
return 8;
#endif // defined(GGML_USE_HIP)
}

// ------------------------------------------------------------
Expand Down Expand Up @@ -3472,7 +3468,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;
const int warp_size = ggml_cuda_info().devices[id].warp_size;
const int nwarps = mmq_get_nwarps_host(cc);
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
const int mmq_y = get_mmq_y_host(cc);

const dim3 block_dims(warp_size, nwarps, 1);
Expand Down Expand Up @@ -3559,7 +3555,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
const int cc = ggml_cuda_info().devices[id].cc;
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
const int warp_size = ggml_cuda_info().devices[id].warp_size;
const int nwarps = mmq_get_nwarps_host(cc);
const int nwarps = mmq_get_nwarps_host(cc, warp_size);

const int mmq_x_max = get_mmq_x_max_host(cc);
const int mmq_y = get_mmq_y_host(cc);
Expand Down
Loading