Skip to content

Commit 0fb5e42

Browse files
committed
refactor: separate SM100 and legacy TRT-LLM comm modules
Restructure the compilation of the TensorRT-LLM communication module to improve hardware compatibility and portability. Previously, the module was compiled with SM100-specific flags only if a compatible GPU was detected during the build process. This made a single build non-portable across different GPU generations. This change introduces two distinct modules: - `trtllm_comm`: compiled with SM100 optimizations for Hopper+ GPUs. - `trtllm_comm_legacy`: a fallback version for older GPU architectures. At runtime, `get_trtllm_comm_module` now detects the GPU's compute capability and dynamically loads the appropriate module. This allows a single FlashInfer build to support a wider range of NVIDIA GPUs and gracefully handles CPU-only environments. Signed-off-by: Emilien Macchi <[email protected]>
1 parent 28741b7 commit 0fb5e42

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

flashinfer/aot.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from .activation import act_func_def_str, gen_act_and_mul_module
1212
from .cascade import gen_cascade_module
1313
from .comm import gen_trtllm_comm_module, gen_vllm_comm_module
14+
from .comm.nvshmem import gen_nvshmem_module
15+
from .comm.trtllm_ar import gen_trtllm_comm_legacy_module
1416
from .fp4_quantization import gen_fp4_quantization_sm100_module
1517
from .fused_moe import gen_fused_moe_sm100_module
1618
from .gemm import gen_gemm_module, gen_gemm_sm90_module, gen_gemm_sm100_module
@@ -325,7 +327,8 @@ def gen_all_modules(
325327
jit_specs.append(gen_fused_moe_sm100_module())
326328
jit_specs.append(gen_fp4_quantization_sm100_module())
327329
jit_specs.append(gen_gemm_sm100_module())
328-
jit_specs.append(gen_trtllm_comm_module())
330+
331+
jit_specs.append(gen_trtllm_comm_module(sm100=has_sm100))
329332

330333
jit_specs += [
331334
gen_cascade_module(),

flashinfer/comm/trtllm_ar.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,37 @@ class FP4QuantizationSFLayout:
9595
LINEAR = 1
9696

9797

98-
def gen_trtllm_comm_module() -> JitSpec:
99-
major, minor = torch.cuda.get_device_capability()
98+
def gen_trtllm_comm_module(sm100: bool = True) -> JitSpec:
99+
"""
100+
Generate TensorRT-LLM communication module.
101+
If sm100 is True, use SM100 flags and name 'trtllm_comm'.
102+
If sm100 is False, use no extra flags and name 'trtllm_comm_legacy'.
103+
"""
100104
return gen_jit_spec(
101-
"trtllm_comm",
105+
"trtllm_comm" if sm100 else "trtllm_comm_legacy",
102106
[
103107
jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce.cu",
104108
jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce_fusion.cu",
105109
jit_env.FLASHINFER_CSRC_DIR / "trtllm_moe_allreduce_fusion.cu",
106110
],
107-
extra_cuda_cflags=sm100a_nvcc_flags if major >= 10 and minor >= 0 else [],
111+
extra_cuda_cflags=sm100a_nvcc_flags if sm100 else [],
108112
)
109113

110114

111115
@functools.cache
112116
def get_trtllm_comm_module():
113-
module = gen_trtllm_comm_module().build_and_load()
117+
# Select the appropriate module based on device capability
118+
try:
119+
major, minor = torch.cuda.get_device_capability()
120+
use_sm100_module = major >= 10 and minor >= 0
121+
except RuntimeError:
122+
# If CUDA is not available (e.g., CPU-only mode), default to legacy module
123+
use_sm100_module = False
124+
125+
if use_sm100_module:
126+
module = gen_trtllm_comm_module().build_and_load()
127+
else:
128+
module = gen_trtllm_comm_legacy_module().build_and_load()
114129

115130
@register_custom_op(
116131
"flashinfer::trtllm_lamport_initialize", mutates_args=["buffer"]

0 commit comments

Comments
 (0)