Skip to content

Commit c24b72b

Browse files
committed
feat: add flashinfer as kernel backend for cuda device.
1 parent ab09a45 commit c24b72b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1354
-138
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@
2727
url = https://gitcode.com/xLLM-AI/spdlog.git
2828
[submodule "third_party/Mooncake"]
2929
path = third_party/Mooncake
30-
url = https://gitcode.com/xLLM-AI/Mooncake.git
30+
url = https://gitcode.com/xLLM-AI/Mooncake.git

CMakeLists.txt

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ set_property(GLOBAL PROPERTY USE_FOLDERS ON)
33

44
option(USE_NPU "Enable NPU support" OFF)
55
option(USE_MLU "Enable MLU support" OFF)
6+
option(USE_CUDA "Enable CUDA support" OFF)
67

78
if(DEVICE_ARCH STREQUAL "ARM")
89
set(CMAKE_SYSTEM_PROCESSOR aarch64)
@@ -101,7 +102,7 @@ set(CMAKE_CXX_STANDARD 20)
101102
set(CMAKE_CXX_STANDARD_REQUIRED ON)
102103
set(CMAKE_CXX_EXTENSIONS ON)
103104

104-
if(USE_NPU)
105+
if(USE_NPU OR USE_CUDA)
105106
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
106107
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
107108
elseif(USE_MLU)
@@ -178,6 +179,32 @@ if (DEFINED ENV{DEPENDENCES_ROOT})
178179
message(STATUS "Using DEPENDENCES_ROOT: $ENV{DEPENDENCES_ROOT}")
179180
endif()
180181

182+
183+
# Build TORCH_CUDA_ARCH_LIST
184+
if(USE_CUDA)
185+
# set architecture for CUDA
186+
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
187+
set(CMAKE_CUDA_ARCHITECTURES 80)
188+
endif()
189+
# Build TORCH_CUDA_ARCH_LIST
190+
set(TORCH_CUDA_ARCH_LIST "")
191+
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
192+
if(CUDA_ARCH MATCHES "^([0-9])([0-9])a$")
193+
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}a")
194+
elseif(CUDA_ARCH MATCHES "^([0-9])([0-9])*$")
195+
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}")
196+
elseif(CUDA_ARCH STREQUAL "native")
197+
set(TORCH_ARCH "Auto")
198+
else()
199+
message(FATAL_ERROR "${CUDA_ARCH} is not supported")
200+
endif()
201+
list(APPEND TORCH_CUDA_ARCH_LIST ${TORCH_ARCH})
202+
endforeach()
203+
204+
message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}")
205+
message(STATUS "TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}")
206+
endif()
207+
181208
# configure vcpkg
182209
# have to set CMAKE_TOOLCHAIN_FILE before first project call.
183210
# if (DEFINED ENV{VCPKG_ROOT} AND NOT DEFINED CMAKE_TOOLCHAIN_FILE)
@@ -217,7 +244,12 @@ endif()
217244
set(CPPREST_EXCLUDE_WEBSOCKETS ON CACHE BOOL "Exclude websockets functionality." FORCE)
218245
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format-truncation")
219246

220-
project("xllm" LANGUAGES C CXX)
247+
if(USE_CUDA)
248+
project("xllm" LANGUAGES C CXX CUDA)
249+
find_package(CUDAToolkit REQUIRED)
250+
else()
251+
project("xllm" LANGUAGES C CXX)
252+
endif()
221253

222254
# find_package(CUDAToolkit REQUIRED)
223255

@@ -352,6 +384,43 @@ if(USE_MLU)
352384
)
353385
endif()
354386

387+
if(USE_CUDA)
388+
add_definitions(-DUSE_CUDA)
389+
add_compile_definitions(TORCH_CUDA=1)
390+
set(CMAKE_VERBOSE_MAKEFILE ON)
391+
include_directories(
392+
$ENV{PYTHON_INCLUDE_PATH}
393+
$ENV{PYTORCH_INSTALL_PATH}/include
394+
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include
395+
)
396+
397+
link_directories(
398+
$ENV{PYTHON_LIB_PATH}
399+
$ENV{PYTORCH_INSTALL_PATH}/lib
400+
$ENV{CUDA_TOOLKIT_ROOT_DIR}/lib64
401+
)
402+
403+
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -O3)
404+
# The following definitions must be undefined since half-precision operation is required.
405+
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS}
406+
-U__CUDA_NO_HALF_OPERATORS__
407+
-U__CUDA_NO_HALF_CONVERSIONS__
408+
-U__CUDA_NO_HALF2_OPERATORS__
409+
-U__CUDA_NO_BFLOAT16_CONVERSIONS__)
410+
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} --use_fast_math -Xfatbin -compress-all)
411+
message(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}")
412+
413+
# find_package(NCCL REQUIRED)
414+
415+
# find cudnn
416+
execute_process(COMMAND python -c "import nvidia.cudnn; print(nvidia.cudnn.__file__)" OUTPUT_VARIABLE CUDNN_PYTHON_PATH)
417+
get_filename_component(CUDNN_ROOT_DIR "${CUDNN_PYTHON_PATH}" DIRECTORY)
418+
link_directories(
419+
${CUDNN_ROOT_DIR}/lib64
420+
${CUDNN_ROOT_DIR}/lib
421+
)
422+
endif()
423+
355424
# check if USE_CXX11_ABI is set correctly
356425
# if (DEFINED USE_CXX11_ABI)
357426
# parse_make_options(${TORCH_CXX_FLAGS} "TORCH_CXX_FLAGS")

setup.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def get_python_include_path():
9898
return None
9999

100100

101-
# PYTORCH_INSTALL_PATH and LIBTORCH_ROOT
102101
def get_torch_root_path():
103102
try:
104103
import torch
@@ -115,6 +114,12 @@ def get_torch_mlu_root_path():
115114
except ImportError:
116115
return None
117116

117+
def get_nccl_root_path():
118+
try:
119+
from nvidia import nccl
120+
return str(Path(nccl.__file__).parent)
121+
except ImportError:
122+
return None
118123

119124
def set_npu_envs():
120125
PYTORCH_NPU_INSTALL_PATH = os.getenv("PYTORCH_NPU_INSTALL_PATH")
@@ -212,7 +217,16 @@ def set_mlu_envs():
212217
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
213218
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
214219
os.environ["PYTORCH_MLU_INSTALL_PATH"] = get_torch_mlu_root_path()
215-
220+
221+
def set_cuda_envs():
222+
os.environ["PYTHON_INCLUDE_PATH"] = get_python_include_path()
223+
os.environ["PYTHON_LIB_PATH"] = get_torch_root_path()
224+
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
225+
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
226+
os.environ["CUDA_TOOLKIT_ROOT_DIR"] = "/usr/local/cuda"
227+
os.environ["NCCL_ROOT"] = get_nccl_root_path()
228+
os.environ["NCCL_VERSION"] = "2"
229+
216230
class CMakeExtension(Extension):
217231
def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
218232
super().__init__(name, sources=[])
@@ -223,7 +237,7 @@ def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
223237
class ExtBuild(build_ext):
224238
user_options = build_ext.user_options + [
225239
("base-dir=", None, "base directory of xLLM project"),
226-
("device=", None, "target device type (a3 or a2 or mlu)"),
240+
("device=", None, "target device type (a3 or a2 or mlu or cuda)"),
227241
("arch=", None, "target arch type (x86 or arm)"),
228242
("install-xllm-kernels=", None, "install xllm_kernels RPM package (true/false)"),
229243
]
@@ -302,8 +316,14 @@ def build_extension(self, ext: CMakeExtension):
302316
cmake_args += ["-DUSE_MLU=ON"]
303317
# set mlu environment variables
304318
set_mlu_envs()
319+
elif self.device == "cuda":
320+
cuda_architectures = "80;89;90"
321+
cmake_args += ["-DUSE_CUDA=ON",
322+
f"-DCMAKE_CUDA_ARCHITECTURES={cuda_architectures}"]
323+
# set cuda environment variables
324+
set_cuda_envs()
305325
else:
306-
raise ValueError("Please set --device to a2 or a3 or mlu.")
326+
raise ValueError("Please set --device to a2 or a3 or mlu or cuda.")
307327

308328

309329
# Adding CMake arguments set as environment variable
@@ -353,7 +373,7 @@ def build_extension(self, ext: CMakeExtension):
353373

354374
class BuildDistWheel(bdist_wheel):
355375
user_options = bdist_wheel.user_options + [
356-
("device=", None, "target device type (a3 or a2 or mlu)"),
376+
("device=", None, "target device type (a3 or a2 or mlu or cuda)"),
357377
("arch=", None, "target arch type (x86 or arm)"),
358378
]
359379

@@ -530,7 +550,7 @@ def apply_patch():
530550
idx = sys.argv.index('--device')
531551
if idx + 1 < len(sys.argv):
532552
device = sys.argv[idx+1].lower()
533-
if device not in ('a2', 'a3', 'mlu'):
553+
if device not in ('a2', 'a3', 'mlu', 'cuda'):
534554
print("Error: --device must be a2 or a3 or mlu (case-insensitive)")
535555
sys.exit(1)
536556
# Remove the arguments so setup() doesn't see them

xllm/core/common/global_flags.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ DEFINE_int32(transfer_listen_port, 26000, "The KVCacheTranfer listen port.");
276276
DEFINE_bool(enable_shm,
277277
false,
278278
"Whether to enable shared memory for executing model.");
279+
279280
// --- function call config ---
280281

281282
DEFINE_string(tool_call_parser,
@@ -353,6 +354,7 @@ DEFINE_int32(micro_batch_num,
353354
"Default use two micro batches for multi-stream parallel.");
354355

355356
// --- dit config ---
357+
356358
DEFINE_int32(max_requests_per_batch, 1, "Max number of request per batch.");
357359

358360
// --- continuous kv cache config ---
@@ -377,22 +379,34 @@ DEFINE_int64(buffer_size_per_seq,
377379
"Buffer size per sequence in bytes, default 0.");
378380

379381
// --- beam search config ---
382+
380383
DEFINE_bool(enable_beam_search_kernel,
381384
false,
382385
"Whether to enable beam search kernel.");
383386

384387
// --- reasoning parser config ---
388+
385389
DEFINE_string(reasoning_parser,
386390
"",
387391
"Specify the reasoning parser for handling reasoning "
388392
"interactions(e.g. glm45, qwen3, deepseek-r1).");
389393

390394
// --- qwen3 reranker config ---
395+
391396
DEFINE_bool(enable_qwen3_reranker, false, "Whether to enable qwen3 reranker.");
392397

398+
// --- flashinfer config ---
399+
400+
DEFINE_int32(flashinfer_workspace_buffer_size,
401+
128 * 1024 * 1024,
402+
"The user reserved workspace buffer used to store intermediate "
403+
"attention results in split-k algorithm for flashinfer.");
404+
405+
// --- prefetch weight config ---
406+
393407
DEFINE_bool(
394408
enable_prefetch_weight,
395409
false,
396410
"Whether to enable prefetch weight,only applicable to Qwen3-dense model."
397411
"The default prefetching ratio for gateup weight is 40%."
398-
"If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5");
412+
"If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5");

xllm/core/common/global_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,5 @@ DECLARE_string(reasoning_parser);
204204
DECLARE_bool(enable_shm);
205205

206206
DECLARE_bool(enable_prefetch_weight);
207+
208+
DECLARE_int32(flashinfer_workspace_buffer_size);

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
216216
state_.q_seq_lens.insert(state_.q_seq_lens.end(),
217217
state.q_seq_lens.begin(),
218218
state.q_seq_lens.end());
219-
#elif defined(USE_MLU)
219+
#elif defined(USE_MLU) || defined(USE_CUDA)
220220
int32_t seq_len_offset = state_.seq_lens.back();
221221
// skip the first element which is 0
222222
for (size_t i = 1; i < state.seq_lens.size(); ++i) {
@@ -248,6 +248,16 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
248248
state.kv_cache_start_offsets.begin(),
249249
state.kv_cache_start_offsets.end());
250250
}
251+
// for flashinfer
252+
state_.paged_kv_indptr.insert(state_.paged_kv_indptr.end(),
253+
state.paged_kv_indptr.begin(),
254+
state.paged_kv_indptr.end());
255+
state_.paged_kv_indices.insert(state_.paged_kv_indices.end(),
256+
state.paged_kv_indices.begin(),
257+
state.paged_kv_indices.end());
258+
state_.paged_kv_last_page_len.insert(state_.paged_kv_last_page_len.end(),
259+
state.paged_kv_last_page_len.begin(),
260+
state.paged_kv_last_page_len.end());
251261
}
252262
for (const auto& write_block_ids : thread_write_block_ids) {
253263
write_block_ids_.insert(write_block_ids.begin(), write_block_ids.end());
@@ -288,7 +298,7 @@ void BatchInputBuilder::process_single_sequence(
288298
#if defined(USE_NPU)
289299
state.seq_lens.push_back(seq_len);
290300
state.q_seq_lens.push_back(q_seq_len);
291-
#elif defined(USE_MLU)
301+
#elif defined(USE_MLU) || defined(USE_CUDA)
292302
state.seq_lens.push_back(state.seq_lens.back() + seq_len);
293303
state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len);
294304
#endif
@@ -448,7 +458,12 @@ void BatchInputBuilder::setup_kv_cache_info(
448458
block_size = block.size();
449459
block_ids.push_back(block.id());
450460
u_block_ids.emplace_back(block.id());
461+
state.paged_kv_indices.push_back(block.id());
451462
}
463+
state.paged_kv_indptr.push_back(state.paged_kv_indptr.back() + blocks.size());
464+
int32_t last_page_len =
465+
(seq_len % block_size == 0) ? block_size : seq_len % block_size;
466+
state.paged_kv_last_page_len.push_back(last_page_len);
452467

453468
int32_t kv_cache_block_idx = n_kv_cache_tokens / block_size;
454469
for (auto iter = block_ids.begin() + kv_cache_block_idx;
@@ -517,12 +532,15 @@ void BatchInputBuilder::padding_decode_batch_size(
517532
#if defined(USE_NPU)
518533
state_.seq_lens.push_back(num_decoding_tokens);
519534
state_.q_seq_lens.push_back(num_decoding_tokens);
520-
#elif defined(USE_MLU)
535+
#elif defined(USE_MLU) || defined(USE_CUDA)
521536
state_.seq_lens.push_back(state_.seq_lens.back() + num_decoding_tokens);
522537
state_.q_seq_lens.push_back(state_.q_seq_lens.back() +
523538
num_decoding_tokens);
524539
#endif
525540
state_.block_tables_vec.emplace_back();
541+
state_.paged_kv_indices.push_back(0);
542+
state_.paged_kv_indptr.push_back(state_.paged_kv_indptr.back() + 1);
543+
state_.paged_kv_last_page_len.push_back(1);
526544
}
527545
}
528546
}
@@ -560,6 +578,14 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
560578
input_params.decode_seq_range =
561579
util::find_ones_indices(input_params.q_seq_lens_vec);
562580

581+
// for flashinfer
582+
input_params.paged_kv_indptr =
583+
torch::tensor(state_.paged_kv_indptr, torch::kInt);
584+
input_params.paged_kv_indices =
585+
torch::tensor(state_.paged_kv_indices, torch::kInt);
586+
input_params.paged_kv_last_page_len =
587+
torch::tensor(state_.paged_kv_last_page_len, torch::kInt);
588+
563589
// Setup multimodal data
564590
input_params.mm_data = MMData::batch(mm_data_vec_);
565591

@@ -634,6 +660,12 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
634660
raw_forward_input.transfer_kv_infos = std::move(state_.transfer_kv_infos);
635661
raw_forward_input.prefill_seq_len = state_.prefill_seq_len;
636662

663+
// for flashinfer
664+
raw_forward_input.paged_kv_indptr = std::move(state_.paged_kv_indptr);
665+
raw_forward_input.paged_kv_indices = std::move(state_.paged_kv_indices);
666+
raw_forward_input.paged_kv_last_page_len =
667+
std::move(state_.paged_kv_last_page_len);
668+
637669
raw_forward_input.embedding_ids = std::move(state_.embedding_ids);
638670
raw_forward_input.extra_token_ids = std::move(state_.extra_token_ids);
639671
// beam search kernel input

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class BatchInputBuilder {
8686
#if defined(USE_NPU)
8787
std::vector<int32_t> seq_lens;
8888
std::vector<int32_t> q_seq_lens;
89-
#elif defined(USE_MLU)
89+
#elif defined(USE_MLU) || defined(USE_CUDA)
9090
std::vector<int32_t> seq_lens = {0}; // cu_seq_lens
9191
std::vector<int32_t> q_seq_lens = {0}; // q_cu_seq_len
9292
#endif
@@ -107,6 +107,11 @@ class BatchInputBuilder {
107107
// for continuous kvcache
108108
std::vector<int64_t> new_cache_slot_offsets; //[n_tokens]
109109
std::vector<int64_t> kv_cache_start_offsets; //[n_seq]
110+
111+
// for flashinfer
112+
std::vector<int32_t> paged_kv_indptr = {0};
113+
std::vector<int32_t> paged_kv_indices;
114+
std::vector<int32_t> paged_kv_last_page_len;
110115
};
111116

112117
// Helper methods for sequence processing

xllm/core/framework/batch/batch_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,15 @@ TEST(BatchTest, Basic) {
152152

153153
#if defined(USE_NPU)
154154
const std::vector<int32_t> q_seq_lens = {9, 1, 1, 4};
155-
#elif defined(USE_MLU)
155+
#else
156156
const std::vector<int32_t> q_seq_lens = {0, 9, 10, 11, 15};
157157
#endif
158158
EXPECT_TRUE(equal(input_params.q_seq_lens, q_seq_lens));
159159

160160
// seq4's kv_seq_len = q_len + num_cached_tokens (q_len<=max_allowed_tokens)
161161
#if defined(USE_NPU)
162162
const std::vector<int32_t> kv_seq_lens = {9, 8, 16, 8};
163-
#elif defined(USE_MLU)
163+
#else
164164
const std::vector<int32_t> kv_seq_lens = {0, 9, 17, 33, 41};
165165
#endif
166166
EXPECT_TRUE(equal(input_params.kv_seq_lens, kv_seq_lens));

0 commit comments

Comments
 (0)