Skip to content

Commit 4a46ff5

Browse files
committed
musa: disable mudnnMemcpyAsync by default
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 448889f commit 4a46ff5

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

ggml/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental,
174174
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
175175
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
176176
option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
177-
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental" OFF)
177+
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
178+
option(GGML_MUSA_MUDNN_COPY "ggml: enable MUDNN for accelerated copy" OFF)
178179
option(GGML_VULKAN "ggml: use Vulkan" OFF)
179180
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
180181
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)

ggml/src/ggml-cuda/cpy.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include "cpy.cuh"
22
#include "dequantize.cuh"
33
#include "cpy-utils.cuh"
4-
#ifdef GGML_USE_MUSA
4+
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
55
#include "ggml-musa/mudnn.cuh"
6-
#endif // GGML_USE_MUSA
6+
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
77

88
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
99

@@ -363,7 +363,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
363363
#endif
364364
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
365365
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
366-
#ifdef GGML_USE_MUSA
366+
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
367367
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
368368
CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
369369
} else

ggml/src/ggml-musa/CMakeLists.txt

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@ if (MUSAToolkit_FOUND)
3434
list(APPEND GGML_SOURCES_MUSA ${SRCS})
3535
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
3636
list(APPEND GGML_SOURCES_MUSA ${SRCS})
37-
file(GLOB SRCS "../ggml-musa/*.cu")
38-
list(APPEND GGML_SOURCES_MUSA ${SRCS})
37+
38+
if (GGML_MUSA_MUDNN_COPY)
39+
file(GLOB SRCS "../ggml-musa/*.cu")
40+
list(APPEND GGML_SOURCES_MUSA ${SRCS})
41+
add_compile_definitions(GGML_MUSA_MUDNN_COPY)
42+
endif()
3943

4044
if (GGML_CUDA_FA_ALL_QUANTS)
4145
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
@@ -101,10 +105,16 @@ if (MUSAToolkit_FOUND)
101105
endif()
102106

103107
if (GGML_STATIC)
104-
# TODO: mudnn has not provided static libraries yet
105108
target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
109+
# TODO: mudnn has not provided static libraries yet
110+
# if (GGML_MUSA_MUDNN_COPY)
111+
# target_link_libraries(ggml-musa PRIVATE mudnn_static)
112+
# endif()
106113
else()
107-
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
114+
target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
115+
if (GGML_MUSA_MUDNN_COPY)
116+
target_link_libraries(ggml-musa PRIVATE mudnn)
117+
endif()
108118
endif()
109119

110120
if (GGML_CUDA_NO_VMM)

0 commit comments

Comments
 (0)