Skip to content

Commit 95ee61a

Browse files
committed
reallow collectives for pre-Turing
1 parent 136ecfb commit 95ee61a

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ enum vk_device_architecture {
222222
AMD_RDNA2,
223223
AMD_RDNA3,
224224
INTEL_XE2,
225+
NVIDIA_PRE_TURING,
225226
};
226227

227228
// HSK x HSV
@@ -315,6 +316,22 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
315316
// https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
316317
return vk_device_architecture::INTEL_XE2;
317318
}
319+
} else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
320+
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
321+
322+
bool cooperative_matrix = false;
323+
324+
// Detect "pre-turing" based on lack of coopmat support.
325+
for (const auto& properties : ext_props) {
326+
if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
327+
cooperative_matrix = true;
328+
break;
329+
}
330+
}
331+
332+
if (!cooperative_matrix) {
333+
return vk_device_architecture::NVIDIA_PRE_TURING;
334+
}
318335
}
319336
return vk_device_architecture::OTHER;
320337
}
@@ -3098,9 +3115,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
30983115
break;
30993116
}
31003117

3118+
// Use collectives on pre-Turing NVIDIA GPUs, which had slower integer math.
3119+
bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA ||
3120+
device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
3121+
31013122
if (device->subgroup_shuffle &&
31023123
device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316.
3103-
device->vendor_id != VK_VENDOR_ID_NVIDIA) { // Collectives no faster on NVIDIA.
3124+
allow_collectives_nv) {
31043125
use_collectives = 1;
31053126
conv2d_BS_CRS = std::min(
31063127
device->subgroup_size,

0 commit comments

Comments
 (0)