@@ -275,19 +275,19 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
275
275
copy (params.gmem_tiled_copy_q , tQgQ (_,_,_,k_tile), tQrQ);
276
276
copy (gmem_tiled_copy_k, tKgK (_,_,_,k_tile), tKrK);
277
277
if constexpr (is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {
278
- auto tCrQ_ = make_fragment_like<half_t >(tCrQ);
279
- convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_ );
280
- auto tCrK_ = make_fragment_like<half_t >(tCrK);
281
- convert_FP8_to_FP16<ElementK>(tCrK, tCrK_ );
282
- cute::gemm (tiled_mma, accum, tCrQ_, tCrK_ , frag_src);
278
+ auto tCrQ_fp16 = make_fragment_like<half_t >(tCrQ);
279
+ convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_fp16 );
280
+ auto tCrK_fp16 = make_fragment_like<half_t >(tCrK);
281
+ convert_FP8_to_FP16<ElementK>(tCrK, tCrK_fp16 );
282
+ cute::gemm (tiled_mma, accum, tCrQ_fp16, tCrK_fp16 , frag_src);
283
283
} else if constexpr (is_fp8_v<ElementQ> && !is_fp8_v<ElementK>) {
284
- auto tCrQ_ = make_fragment_like<half_t >(tCrQ);
285
- convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_ );
286
- cute::gemm (tiled_mma, accum, tCrQ_ , tCrK, frag_src);
284
+ auto tCrQ_fp16 = make_fragment_like<half_t >(tCrQ);
285
+ convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_fp16 );
286
+ cute::gemm (tiled_mma, accum, tCrQ_fp16 , tCrK, frag_src);
287
287
} else if constexpr (!is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {
288
- auto tCrK_ = make_fragment_like<half_t >(tCrK);
289
- convert_FP8_to_FP16<ElementK>(tCrK, tCrK_ );
290
- cute::gemm (tiled_mma, accum, tCrQ , tCrK_ , frag_src);
288
+ auto tCrK_fp16 = make_fragment_like<half_t >(tCrK);
289
+ convert_FP8_to_FP16<ElementK>(tCrK, tCrK_fp16 );
290
+ cute::gemm (tiled_mma, accum, tCrQ , tCrK_fp16 , frag_src);
291
291
} else {
292
292
cute::gemm (tiled_mma, accum, tCrQ , tCrK, frag_src);
293
293
}
@@ -343,9 +343,9 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
343
343
for (int i = 0 ; i< tile_count; i++) {
344
344
copy (gmem_tiled_copy_v, tVgV (_,_,_,i), tVrV);
345
345
if constexpr (is_fp8_v<ElementV>) {
346
- auto tCrV_ = make_fragment_like<half_t >(tCrV);
347
- convert_FP8_to_FP16<ElementV>(tCrV, tCrV_ );
348
- cute::gemm (tiled_mma, accum (_,_,_,i), tPr, tCrV_ , frag_src (_,_,_,i));
346
+ auto tCrV_fp16 = make_fragment_like<half_t >(tCrV);
347
+ convert_FP8_to_FP16<ElementV>(tCrV, tCrV_fp16 );
348
+ cute::gemm (tiled_mma, accum (_,_,_,i), tPr, tCrV_fp16 , frag_src (_,_,_,i));
349
349
} else {
350
350
cute::gemm (tiled_mma, accum (_,_,_,i), tPr, tCrV, frag_src (_,_,_,i));
351
351
}
0 commit comments