@@ -526,6 +526,14 @@ ifndef GGML_NO_ACCELERATE
526
526
endif
527
527
endif # GGML_NO_ACCELERATE
528
528
529
+ ifdef GGML_MUSA
530
+ CC := clang
531
+ CXX := clang++
532
+ GGML_CUDA := 1
533
+ GGML_NO_OPENMP := 1
534
+ MK_CPPFLAGS += -DGGML_USE_MUSA
535
+ endif
536
+
529
537
ifndef GGML_NO_OPENMP
530
538
MK_CPPFLAGS += -DGGML_USE_OPENMP
531
539
MK_CFLAGS += -fopenmp
@@ -574,15 +582,27 @@ else
574
582
endif # GGML_CUDA_FA_ALL_QUANTS
575
583
576
584
ifdef GGML_CUDA
577
- ifneq ('', '$(wildcard /opt/cuda)')
578
- CUDA_PATH ?= /opt/cuda
585
+ ifdef GGML_MUSA
586
+ ifneq ('', '$(wildcard /opt/musa)')
587
+ CUDA_PATH ?= /opt/musa
588
+ else
589
+ CUDA_PATH ?= /usr/local/musa
590
+ endif
591
+
592
+ MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
593
+ MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
594
+ MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
579
595
else
580
- CUDA_PATH ?= /usr/local/cuda
581
- endif
596
+ ifneq ('', '$(wildcard /opt/cuda)')
597
+ CUDA_PATH ?= /opt/cuda
598
+ else
599
+ CUDA_PATH ?= /usr/local/cuda
600
+ endif
582
601
583
- MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
584
- MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
585
- MK_NVCCFLAGS += -use_fast_math
602
+ MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
603
+ MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
604
+ MK_NVCCFLAGS += -use_fast_math
605
+ endif # GGML_MUSA
586
606
587
607
OBJ_GGML += ggml/src/ggml-cuda.o
588
608
OBJ_GGML += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
@@ -592,9 +612,11 @@ ifdef LLAMA_FATAL_WARNINGS
592
612
MK_NVCCFLAGS += -Werror all-warnings
593
613
endif # LLAMA_FATAL_WARNINGS
594
614
615
+ ifndef GGML_MUSA
595
616
ifndef JETSON_EOL_MODULE_DETECT
596
617
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
597
618
endif # JETSON_EOL_MODULE_DETECT
619
+ endif # GGML_MUSA
598
620
599
621
ifdef LLAMA_DEBUG
600
622
MK_NVCCFLAGS += -lineinfo
@@ -607,8 +629,12 @@ endif # GGML_CUDA_DEBUG
607
629
ifdef GGML_CUDA_NVCC
608
630
NVCC = $(CCACHE) $(GGML_CUDA_NVCC)
609
631
else
610
- NVCC = $(CCACHE) nvcc
611
- endif # GGML_CUDA_NVCC
632
+ ifdef GGML_MUSA
633
+ NVCC = $(CCACHE) mcc
634
+ else
635
+ NVCC = $(CCACHE) nvcc
636
+ endif # GGML_MUSA
637
+ endif # GGML_CUDA_NVCC
612
638
613
639
ifdef CUDA_DOCKER_ARCH
614
640
MK_NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
@@ -679,9 +705,15 @@ define NVCC_COMPILE
679
705
$(NVCC ) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS ) $(CPPFLAGS ) -Xcompiler "$(CUDA_CXXFLAGS ) " -c $< -o $@
680
706
endef # NVCC_COMPILE
681
707
else
708
+ ifdef GGML_MUSA
709
+ define NVCC_COMPILE
710
+ $(NVCC ) $(NVCCFLAGS ) $(CPPFLAGS ) -c $< -o $@
711
+ endef # NVCC_COMPILE
712
+ else
682
713
define NVCC_COMPILE
683
714
$(NVCC ) $(NVCCFLAGS ) $(CPPFLAGS ) -Xcompiler "$(CUDA_CXXFLAGS ) " -c $< -o $@
684
715
endef # NVCC_COMPILE
716
+ endif # GGML_MUSA
685
717
endif # JETSON_EOL_MODULE_DETECT
686
718
687
719
ggml/src/ggml-cuda/% .o : \
@@ -906,6 +938,7 @@ $(info I CXX: $(shell $(CXX) --version | head -n 1))
906
938
ifdef GGML_CUDA
907
939
$(info I NVCC : $(shell $(NVCC ) --version | tail -n 1) )
908
940
CUDA_VERSION := $(shell $(NVCC ) --version | grep -oP 'release (\K[0-9]+\.[0-9]) ')
941
+ ifndef GGML_MUSA
909
942
ifeq ($(shell awk -v "v=$(CUDA_VERSION ) " 'BEGIN { print (v < 11.7) }'),1)
910
943
911
944
ifndef CUDA_DOCKER_ARCH
@@ -915,6 +948,7 @@ endif # CUDA_POWER_ARCH
915
948
endif # CUDA_DOCKER_ARCH
916
949
917
950
endif # eq ($(shell echo "$(CUDA_VERSION) < 11.7" | bc),1)
951
+ endif # GGML_MUSA
918
952
endif # GGML_CUDA
919
953
$(info )
920
954
0 commit comments