Skip to content

Commit e4d7f46

Browse files
committed
update trtllm-gen to fix several issues
Signed-off-by: Perkz Zheng <[email protected]>
1 parent f566d49 commit e4d7f46

File tree

3 files changed

+46
-23
lines changed

3 files changed

+46
-23
lines changed

flashinfer/artifacts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class ArtifactPath:
8787
When compiling new cubins for backend directories, update the corresponding path.
8888
"""
8989

90-
TRTLLM_GEN_FMHA: str = "463def7494c9fc6792b5aa5b5beef34025e247ac/fmha/trtllm-gen/"
90+
TRTLLM_GEN_FMHA: str = "b793e1b2cf7c419f070372ba55bbe53ca6fb9016/fmha/trtllm-gen/"
9191
TRTLLM_GEN_BMM: str = (
9292
"23daeee32b60bde7947ce1ee7a58d4ab701f134b/batched_gemm-0d28130-add42d1"
9393
)
@@ -102,7 +102,7 @@ class ArtifactPath:
102102
class MetaInfoHash:
103103
DEEPGEMM: str = "f161e031826adb8c4f0d31ddbd2ed77e4909e4e43cdfc9728918162a62fcccfb"
104104
TRTLLM_GEN_FMHA: str = (
105-
"2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a"
105+
"bf45e2c21de9fbf5209bec3975b5ffe24b1d7a2e00aa40c548c992281864009f"
106106
)
107107
TRTLLM_GEN_BMM: str = (
108108
"6cfade1395f9648aba5dcf2c329114619e175c0f238882555178f98c8f5c1968"
@@ -123,7 +123,7 @@ class CheckSumHash:
123123
"639c534614e9fdf5a9cfa91f7ea8f53989613019c0e1f8b755f461e1fcc7546f"
124124
)
125125
TRTLLM_GEN_BMM: str = (
126-
"46ccf0492e3ed10135c2861a4f4ef9bb45846610f9a9d2ccaf2d5bf01d2006fd"
126+
"1ebace613389a4f2e10b14315da5d522642c5dcaae23f01213d56c59068f148b"
127127
)
128128
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
129129
TRTLLM_GEN_GEMM: str = (

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,15 @@ class TllmGenFmhaKernel {
9696
inline uint64_t hashID(int qkvLayout, int maskType, int kernelType, int scheduler,
9797
int multiCtasKvMode, int headDimPerCtaV, int headDimQk, int headDimV,
9898
int tileSizeKv, int numTokensPerPage, int maxNumHeadsQPerKvInCta,
99-
bool reuseSmemKForV, bool uses2CtaMma) const {
99+
bool reuseSmemKForV, bool uses2CtaMma, bool sparseMla) const {
100100
FLASHINFER_CHECK((headDimPerCtaV >= 32) && (headDimQk >= 32) && (headDimV >= 32) &&
101-
(headDimPerCtaV <= 2048) && (headDimQk <= 2048) && (headDimV <= 2048) &&
102-
(numTokensPerPage <= 128),
103-
"Expect (32 <= headDim <= 2048) && (numTokensPerPage <= 128), "
104-
"got headDimPerCtaV=%d, headDimQk=%d, "
105-
"headDimV=%d, numTokensPerPage=%d",
106-
headDimPerCtaV, headDimQk, headDimV, numTokensPerPage);
101+
(headDimPerCtaV <= 1024) && (headDimQk <= 1024) && (headDimV <= 1024),
102+
"Expect (32 <= headDim <= 1024), got headDimPerCtaV=%d, headDimQk=%d, "
103+
"headDimV=%d",
104+
headDimPerCtaV, headDimQk, headDimV);
105+
// The numTokensPerPage must be power of 2.
106+
FLASHINFER_CHECK((numTokensPerPage & (numTokensPerPage - 1)) == 0,
107+
"The numTokensPerPage must be power of 2.");
107108
FLASHINFER_CHECK(maxNumHeadsQPerKvInCta <= 128,
108109
"The maxNumHeadsQPerKvInCta <= 128 is required.");
109110
FLASHINFER_CHECK(tileSizeKv == 64 || tileSizeKv == 128, "The tileSizeKv must be 64 or 128.");
@@ -113,25 +114,26 @@ class TllmGenFmhaKernel {
113114
// Bit 8 - 11: kernelType.
114115
// Bit 12 - 15: tileScheduler.
115116
// Bit 16 - 17: multiCtasKvMode.
116-
// Bit 18 - 24: (headDimPerCtaV >> 5).
117-
// Bit 25 - 31: (headDimQk >> 5).
118-
// Bit 32 - 38: (headDimV >> 5).
119-
// Bit 39 - 40: (tileSizeKv >> 6).
120-
// Bit 41 - 48: numTokensPerPage.
117+
// Bit 18 - 25: (headDimPerCtaV >> 3).
118+
// Bit 26 - 33: (headDimQk >> 3).
119+
// Bit 34 - 41: (headDimV >> 3).
120+
// Bit 42 - 43: (tileSizeKv >> 6).
121+
// Bit 44 - 48: (log2(numTokensPerPage)).
121122
// Bit 49 - 56: maxNumHeadsQPerKvInCta.
122123
// Bit 57 - 57: reuseSmemKForV.
123124
// Bit 58 - 58: uses2CtaMma.
125+
// Bit 59 - 59: sparseMla.
124126
return (static_cast<uint64_t>(qkvLayout) << 0) | (static_cast<uint64_t>(maskType) << 4) |
125127
(static_cast<uint64_t>(kernelType) << 8) | (static_cast<uint64_t>(scheduler) << 12) |
126128
(static_cast<uint64_t>(multiCtasKvMode) << 16) |
127-
(static_cast<uint64_t>(headDimPerCtaV >> 5) << 18) |
128-
(static_cast<uint64_t>(headDimQk >> 5) << 25) |
129-
(static_cast<uint64_t>(headDimV >> 5) << 32) |
130-
(static_cast<uint64_t>(tileSizeKv >> 6) << 39) |
131-
(static_cast<uint64_t>(numTokensPerPage) << 41) |
129+
(static_cast<uint64_t>(headDimPerCtaV >> 3) << 18) |
130+
(static_cast<uint64_t>(headDimQk >> 3) << 26) |
131+
(static_cast<uint64_t>(headDimV >> 3) << 34) |
132+
(static_cast<uint64_t>(tileSizeKv >> 6) << 42) |
133+
(static_cast<uint64_t>(log2(numTokensPerPage)) << 44) |
132134
(static_cast<uint64_t>(maxNumHeadsQPerKvInCta) << 49) |
133135
(static_cast<uint64_t>(reuseSmemKForV) << 57) |
134-
(static_cast<uint64_t>(uses2CtaMma) << 58);
136+
(static_cast<uint64_t>(uses2CtaMma) << 58) | (static_cast<uint64_t>(sparseMla) << 59);
135137
}
136138

137139
uint64_t hashID(KernelMeta const& kernelMeta) const {
@@ -140,7 +142,7 @@ class TllmGenFmhaKernel {
140142
kernelMeta.mHeadDimPerCtaV, kernelMeta.mHeadDimQk, kernelMeta.mHeadDimV,
141143
kernelMeta.mTileSizeKv, kernelMeta.mNumTokensPerPage,
142144
kernelMeta.mMaxNumHeadsQPerKvInCta, kernelMeta.mReuseSmemKForV,
143-
kernelMeta.m2CtaMma);
145+
kernelMeta.m2CtaMma, kernelMeta.mSparseMla);
144146
}
145147

146148
std::pair<bool, std::string> checkIfKernelExist(RunnerParams const& params) const {
@@ -552,7 +554,8 @@ class TllmGenFmhaKernel {
552554
static_cast<int>(selectKernelParams.mMultiCtasKvMode),
553555
selectKernelParams.mHeadDimPerCtaV, params.mHeadDimQk, params.mHeadDimV,
554556
selectKernelParams.mTileSizeKv, numTokensPerPage, maxNumHeadsQPerKvInCta,
555-
selectKernelParams.mReuseSmemKForV, selectKernelParams.mUses2CtaMma),
557+
selectKernelParams.mReuseSmemKForV, selectKernelParams.mUses2CtaMma,
558+
/* sparseMla */ false),
556559
info);
557560
}
558561

include/flashinfer/trtllm/fmha/kernelParams.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ struct KernelParams {
104104
// The sequence lengths for K/V. Required by pagedKv kernels to avoid unnecessary computation
105105
// based on (ptrCumSeqLensKv[batchIdx + 1] - ptrCumSeqLensKv[batchIdx]).
106106
int32_t const* ptrSeqLensKv;
107+
// The reserved memory buffer.
108+
int32_t* ptrReservedMem;
107109
// The softmax stats buffer.
108110
float2* ptrSoftmaxStats;
109111

@@ -139,6 +141,8 @@ struct KernelParams {
139141
int64_t mNumHiddenEltsO;
140142
// The total number of pages in the paged-kv memory pool.
141143
int32_t mNumPagesInMemPool;
144+
// The number of tokens per page (used if dynamic numTokensPerPage is enabled).
145+
int32_t mNumTokensPerPageLog2;
142146
// The output scale for FP8 quantization.
143147
float mOutputScale;
144148
// The scaling factor for softmax (multiplied by log2 to use faster exp2).
@@ -147,11 +151,15 @@ struct KernelParams {
147151
float mScaleSfKv;
148152
// The SF scale for O.
149153
float mScaleSfO;
154+
// The reserved parameter.
155+
float mReservedParam;
150156
// The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase
151157
// kernel when inflight batching is enabled in TRT-LLM.
152158
int32_t mStartTokenIdxSfO;
153159
// The sum of sequence lengths for Q and K/V.
154160
int32_t mSumOfSeqLensQ, mSumOfSeqLensKv;
161+
// The sparseMla topK value.
162+
int32_t mSparseMlaTopK;
155163
// The flag to use block sparse attention.
156164
bool mUseBlockSparseAttention;
157165

@@ -537,6 +545,8 @@ struct KernelParams {
537545
int32_t maxNumCtasQ, int32_t maxNumCtasKv) {
538546
// Create the return struct.
539547
KernelParams params;
548+
// Memset the kernel parameters to 0.
549+
memset(&params, 0, sizeof(KernelParams));
540550

541551
// Get the device pointers for TMA descriptors.
542552
auto [qPtr, kPtr, vPtr] = getDevicePtrs(options, get_size_in_bytes(kernelMeta.mDataTypeKv));
@@ -681,6 +691,16 @@ struct KernelParams {
681691
// Default 0 means that chunked attention is disabled.
682692
params.mChunkedAttentionSizeLog2 = 0;
683693
}
694+
695+
// Compute the log of numTokensPerPage
696+
int32_t numTokensPerPageLog2{-1};
697+
if (isPagedKv(options.mQkvLayout)) {
698+
FLASHINFER_CHECK((options.mNumTokensPerPage & (options.mNumTokensPerPage - 1)) == 0,
699+
"NumTokensPerPage must be power of 2");
700+
numTokensPerPageLog2 = (int)log2f((float)options.mNumTokensPerPage);
701+
}
702+
params.mNumTokensPerPageLog2 = numTokensPerPageLog2;
703+
684704
params.mMaxSeqLenQ = options.mMaxSeqLenQ;
685705
params.mMaxSeqLenKv = options.mMaxSeqLenKv;
686706
params.mMaxNumCtasQ = maxNumCtasQ;

0 commit comments

Comments
 (0)