Skip to content

Commit 434438f

Browse files
committed
Add GGML_MUSA in Makefile
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent a59f8fd commit 434438f

File tree

1 file changed

+43
-9
lines changed

1 file changed

+43
-9
lines changed

Makefile

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,14 @@ ifndef GGML_NO_ACCELERATE
526526
endif
527527
endif # GGML_NO_ACCELERATE
528528

529+
ifdef GGML_MUSA
530+
CC := clang
531+
CXX := clang++
532+
GGML_CUDA := 1
533+
GGML_NO_OPENMP := 1
534+
MK_CPPFLAGS += -DGGML_USE_MUSA
535+
endif
536+
529537
ifndef GGML_NO_OPENMP
530538
MK_CPPFLAGS += -DGGML_USE_OPENMP
531539
MK_CFLAGS += -fopenmp
@@ -574,15 +582,27 @@ else
574582
endif # GGML_CUDA_FA_ALL_QUANTS
575583

576584
ifdef GGML_CUDA
577-
ifneq ('', '$(wildcard /opt/cuda)')
578-
CUDA_PATH ?= /opt/cuda
585+
ifdef GGML_MUSA
586+
ifneq ('', '$(wildcard /opt/musa)')
587+
CUDA_PATH ?= /opt/musa
588+
else
589+
CUDA_PATH ?= /usr/local/musa
590+
endif
591+
592+
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
593+
MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
594+
MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
579595
else
580-
CUDA_PATH ?= /usr/local/cuda
581-
endif
596+
ifneq ('', '$(wildcard /opt/cuda)')
597+
CUDA_PATH ?= /opt/cuda
598+
else
599+
CUDA_PATH ?= /usr/local/cuda
600+
endif
582601

583-
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
584-
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
585-
MK_NVCCFLAGS += -use_fast_math
602+
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
603+
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
604+
MK_NVCCFLAGS += -use_fast_math
605+
endif # GGML_MUSA
586606

587607
OBJ_GGML += ggml/src/ggml-cuda.o
588608
OBJ_GGML += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
@@ -592,9 +612,11 @@ ifdef LLAMA_FATAL_WARNINGS
592612
MK_NVCCFLAGS += -Werror all-warnings
593613
endif # LLAMA_FATAL_WARNINGS
594614

615+
ifndef GGML_MUSA
595616
ifndef JETSON_EOL_MODULE_DETECT
596617
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
597618
endif # JETSON_EOL_MODULE_DETECT
619+
endif # GGML_MUSA
598620

599621
ifdef LLAMA_DEBUG
600622
MK_NVCCFLAGS += -lineinfo
@@ -607,8 +629,12 @@ endif # GGML_CUDA_DEBUG
607629
ifdef GGML_CUDA_NVCC
608630
NVCC = $(CCACHE) $(GGML_CUDA_NVCC)
609631
else
610-
NVCC = $(CCACHE) nvcc
611-
endif #GGML_CUDA_NVCC
632+
ifdef GGML_MUSA
633+
NVCC = $(CCACHE) mcc
634+
else
635+
NVCC = $(CCACHE) nvcc
636+
endif # GGML_MUSA
637+
endif # GGML_CUDA_NVCC
612638

613639
ifdef CUDA_DOCKER_ARCH
614640
MK_NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
@@ -679,9 +705,15 @@ define NVCC_COMPILE
679705
$(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
680706
endef # NVCC_COMPILE
681707
else
708+
ifdef GGML_MUSA
709+
define NVCC_COMPILE
710+
$(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -c $< -o $@
711+
endef # NVCC_COMPILE
712+
else
682713
define NVCC_COMPILE
683714
$(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
684715
endef # NVCC_COMPILE
716+
endif # GGML_MUSA
685717
endif # JETSON_EOL_MODULE_DETECT
686718

687719
ggml/src/ggml-cuda/%.o: \
@@ -906,6 +938,7 @@ $(info I CXX: $(shell $(CXX) --version | head -n 1))
906938
ifdef GGML_CUDA
907939
$(info I NVCC: $(shell $(NVCC) --version | tail -n 1))
908940
CUDA_VERSION := $(shell $(NVCC) --version | grep -oP 'release (\K[0-9]+\.[0-9])')
941+
ifndef GGML_MUSA
909942
ifeq ($(shell awk -v "v=$(CUDA_VERSION)" 'BEGIN { print (v < 11.7) }'),1)
910943

911944
ifndef CUDA_DOCKER_ARCH
@@ -915,6 +948,7 @@ endif # CUDA_POWER_ARCH
915948
endif # CUDA_DOCKER_ARCH
916949

917950
endif # eq ($(shell echo "$(CUDA_VERSION) < 11.7" | bc),1)
951+
endif # GGML_MUSA
918952
endif # GGML_CUDA
919953
$(info )
920954

0 commit comments

Comments
 (0)