Skip to content

Commit 74281ed

Browse files
authored
[feat] Refactor trtllmgen MOE and add Bf16 trtllmgen moe (#2014)
<!-- .github/pull_request_template.md --> ## 📌 Description - Refactor `trtllm_fused_moe_kernel_launcher.cu` to use class structure for code cleanliness and readability - Add BF16 MOE, initial PR (#1859) from @aleozlx and @nekorobov - Add BF16 MOE autotune <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * BF16 Mixture-of-Experts (MoE) pathway added with autotuning and public API access. * **Improvements** * Unified BF16/FP8/FP4/FP16 pathways with clearer dtype compatibility checks and corrected operator return semantics. * Routing selection now respects token-size and input packing, and diagnostics produce more descriptive error messages. * **Tests** * Expanded BF16 test coverage across routing modes, weight layouts, and token sizes. * **Chores** * Updated artifact metadata and checksums. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: jiahanc <[email protected]>
1 parent ba011d1 commit 74281ed

File tree

7 files changed

+1932
-1143
lines changed

7 files changed

+1932
-1143
lines changed

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,16 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
116116
}
117117
}
118118

119-
FLASHINFER_CHECK(
120-
!mPassingConfigIndices.empty(),
121-
"No kernel found for the given options: mDtypeA: %s, mDtypeB: %s, mDtypeC: %s, "
122-
"mUseDeepSeekFp8: %d, "
123-
"mTransposeMmaOutput: %d, mRouteAct: %d, mFusedAct: %d, mIsStaticBatch: %d, mTileSize: %d",
124-
tg::dtypeToString(mOptions.dtypeA).c_str(), tg::dtypeToString(mOptions.dtypeB).c_str(),
125-
tg::dtypeToString(mOptions.dtypeC).c_str(), mOptions.deepSeekFp8, mOptions.transposeMmaOutput,
126-
mOptions.routeAct, mOptions.fusedAct, mOptions.staticBatch, mOptions.tileSize);
119+
std::ostringstream error_msg;
120+
error_msg << "No kernel found for the given options: "
121+
<< "mDtypeA: " << tg::dtypeToString(mOptions.dtypeA)
122+
<< ", mDtypeB: " << tg::dtypeToString(mOptions.dtypeB)
123+
<< ", mDtypeC: " << tg::dtypeToString(mOptions.dtypeC)
124+
<< ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8
125+
<< ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput
126+
<< ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct
127+
<< ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize;
128+
FLASHINFER_CHECK(!mPassingConfigIndices.empty(), error_msg.str());
127129
}
128130

129131
size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 1418 additions & 1106 deletions
Large diffs are not rendered by default.

csrc/trtllm_fused_moe_routing_renormalize.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,8 @@ void run(Data const& data, void* stream) {
435435
<< "Routing kernel expects #experts " << data.mNumExperts << " to be a multiple of 4.";
436436

437437
// FIXME: routingIndicesBlockKernel breaks the vllm + gpt-oss DeepEP
438-
// bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens;
439-
bool const useSingleBlock = false;
438+
bool const useSingleBlock =
439+
data.mNumTokens <= BlockKernelMaxNumTokens && data.mPtrTopKPacked == nullptr;
440440

441441
bool const useSingleCluster =
442442
data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr)

flashinfer/artifacts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class ArtifactPath:
8989

9090
TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/"
9191
TRTLLM_GEN_BMM: str = (
92-
"23daeee32b60bde7947ce1ee7a58d4ab701f134b/batched_gemm-0d28130-add42d1"
92+
"c108f5cc46420e11805467898186533fb48d6a6f/batched_gemm-0d28130-7b26988"
9393
)
9494
TRTLLM_GEN_GEMM: str = (
9595
"1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3"
@@ -105,7 +105,7 @@ class MetaInfoHash:
105105
"2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a"
106106
)
107107
TRTLLM_GEN_BMM: str = (
108-
"6cfade1395f9648aba5dcf2c329114619e175c0f238882555178f98c8f5c1968"
108+
"26c51b75921be90235d193675facdea5d8341c4c52c73bd0a7c8e787c0388beb"
109109
)
110110
TRTLLM_GEN_GEMM: str = (
111111
"bd5c3227bec4f8d7a7d3a27fd7628e010d99a5c42651d0a6b97e146803e63340"
@@ -123,7 +123,7 @@ class CheckSumHash:
123123
"639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f"
124124
)
125125
TRTLLM_GEN_BMM: str = (
126-
"46ccf0492e3ed10135c2861a4f4ef9bb45846610f9a9d2ccaf2d5bf01d2006fd"
126+
"85a4516b7ab25b1a6495398ae934a00e30ccd6662b9ec27be1330d7bba5e1ddf"
127127
)
128128
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
129129
TRTLLM_GEN_GEMM: str = (

flashinfer/fused_moe/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
trtllm_fp4_block_scale_routed_moe,
3030
trtllm_fp8_block_scale_moe,
3131
trtllm_fp8_per_tensor_scale_moe,
32+
trtllm_bf16_moe,
3233
)
3334

3435
__all__ = [
@@ -40,8 +41,11 @@
4041
"gen_cutlass_fused_moe_sm120_module",
4142
"gen_cutlass_fused_moe_sm100_module",
4243
"gen_cutlass_fused_moe_sm90_module",
44+
"gen_trtllm_gen_fused_moe_sm100_module",
4345
"reorder_rows_for_gated_act_gemm",
46+
"trtllm_bf16_moe",
4447
"trtllm_fp4_block_scale_moe",
48+
"trtllm_fp4_block_scale_routed_moe",
4549
"trtllm_fp8_block_scale_moe",
4650
"trtllm_fp8_per_tensor_scale_moe",
4751
]

0 commit comments

Comments
 (0)