@@ -96,14 +96,15 @@ class TllmGenFmhaKernel {
9696 inline uint64_t hashID (int qkvLayout, int maskType, int kernelType, int scheduler,
9797 int multiCtasKvMode, int headDimPerCtaV, int headDimQk, int headDimV,
9898 int tileSizeKv, int numTokensPerPage, int maxNumHeadsQPerKvInCta,
99- bool reuseSmemKForV, bool uses2CtaMma) const {
99+ bool reuseSmemKForV, bool uses2CtaMma, bool sparseMla ) const {
100100 FLASHINFER_CHECK ((headDimPerCtaV >= 32 ) && (headDimQk >= 32 ) && (headDimV >= 32 ) &&
101- (headDimPerCtaV <= 2048 ) && (headDimQk <= 2048 ) && (headDimV <= 2048 ) &&
102- (numTokensPerPage <= 128 ),
103- " Expect (32 <= headDim <= 2048) && (numTokensPerPage <= 128), "
104- " got headDimPerCtaV=%d, headDimQk=%d, "
105- " headDimV=%d, numTokensPerPage=%d" ,
106- headDimPerCtaV, headDimQk, headDimV, numTokensPerPage);
101+ (headDimPerCtaV <= 1024 ) && (headDimQk <= 1024 ) && (headDimV <= 1024 ),
102+ " Expect (32 <= headDim <= 1024), got headDimPerCtaV=%d, headDimQk=%d, "
103+ " headDimV=%d" ,
104+ headDimPerCtaV, headDimQk, headDimV);
105+ // The numTokensPerPage must be power of 2.
106+ FLASHINFER_CHECK ((numTokensPerPage & (numTokensPerPage - 1 )) == 0 ,
107+ " The numTokensPerPage must be power of 2." );
107108 FLASHINFER_CHECK (maxNumHeadsQPerKvInCta <= 128 ,
108109 " The maxNumHeadsQPerKvInCta <= 128 is required." );
109110 FLASHINFER_CHECK (tileSizeKv == 64 || tileSizeKv == 128 , " The tileSizeKv must be 64 or 128." );
@@ -113,25 +114,26 @@ class TllmGenFmhaKernel {
113114 // Bit 8 - 11: kernelType.
114115 // Bit 12 - 15: tileScheduler.
115116 // Bit 16 - 17: multiCtasKvMode.
116- // Bit 18 - 24 : (headDimPerCtaV >> 5 ).
117- // Bit 25 - 31 : (headDimQk >> 5 ).
118- // Bit 32 - 38 : (headDimV >> 5 ).
119- // Bit 39 - 40 : (tileSizeKv >> 6).
120- // Bit 41 - 48: numTokensPerPage.
117+ // Bit 18 - 25 : (headDimPerCtaV >> 3 ).
118+ // Bit 26 - 33 : (headDimQk >> 3 ).
119+ // Bit 34 - 41 : (headDimV >> 3 ).
120+ // Bit 42 - 43 : (tileSizeKv >> 6).
121+ // Bit 44 - 48: (log2( numTokensPerPage)) .
121122 // Bit 49 - 56: maxNumHeadsQPerKvInCta.
122123 // Bit 57 - 57: reuseSmemKForV.
123124 // Bit 58 - 58: uses2CtaMma.
125+ // Bit 59 - 59: sparseMla.
124126 return (static_cast <uint64_t >(qkvLayout) << 0 ) | (static_cast <uint64_t >(maskType) << 4 ) |
125127 (static_cast <uint64_t >(kernelType) << 8 ) | (static_cast <uint64_t >(scheduler) << 12 ) |
126128 (static_cast <uint64_t >(multiCtasKvMode) << 16 ) |
127- (static_cast <uint64_t >(headDimPerCtaV >> 5 ) << 18 ) |
128- (static_cast <uint64_t >(headDimQk >> 5 ) << 25 ) |
129- (static_cast <uint64_t >(headDimV >> 5 ) << 32 ) |
130- (static_cast <uint64_t >(tileSizeKv >> 6 ) << 39 ) |
131- (static_cast <uint64_t >(numTokensPerPage) << 41 ) |
129+ (static_cast <uint64_t >(headDimPerCtaV >> 3 ) << 18 ) |
130+ (static_cast <uint64_t >(headDimQk >> 3 ) << 26 ) |
131+ (static_cast <uint64_t >(headDimV >> 3 ) << 34 ) |
132+ (static_cast <uint64_t >(tileSizeKv >> 6 ) << 42 ) |
133+ (static_cast <uint64_t >(log2 ( numTokensPerPage)) << 44 ) |
132134 (static_cast <uint64_t >(maxNumHeadsQPerKvInCta) << 49 ) |
133135 (static_cast <uint64_t >(reuseSmemKForV) << 57 ) |
134- (static_cast <uint64_t >(uses2CtaMma) << 58 );
136+ (static_cast <uint64_t >(uses2CtaMma) << 58 ) | ( static_cast < uint64_t >(sparseMla) << 59 ) ;
135137 }
136138
137139 uint64_t hashID (KernelMeta const & kernelMeta) const {
@@ -140,7 +142,7 @@ class TllmGenFmhaKernel {
140142 kernelMeta.mHeadDimPerCtaV , kernelMeta.mHeadDimQk , kernelMeta.mHeadDimV ,
141143 kernelMeta.mTileSizeKv , kernelMeta.mNumTokensPerPage ,
142144 kernelMeta.mMaxNumHeadsQPerKvInCta , kernelMeta.mReuseSmemKForV ,
143- kernelMeta.m2CtaMma );
145+ kernelMeta.m2CtaMma , kernelMeta. mSparseMla );
144146 }
145147
146148 std::pair<bool , std::string> checkIfKernelExist (RunnerParams const & params) const {
@@ -552,7 +554,8 @@ class TllmGenFmhaKernel {
552554 static_cast <int >(selectKernelParams.mMultiCtasKvMode ),
553555 selectKernelParams.mHeadDimPerCtaV , params.mHeadDimQk , params.mHeadDimV ,
554556 selectKernelParams.mTileSizeKv , numTokensPerPage, maxNumHeadsQPerKvInCta,
555- selectKernelParams.mReuseSmemKForV , selectKernelParams.mUses2CtaMma ),
557+ selectKernelParams.mReuseSmemKForV , selectKernelParams.mUses2CtaMma ,
558+ /* sparseMla */ false ),
556559 info);
557560 }
558561
0 commit comments