Skip to content

Commit 5c6d0a9

Browse files
committed
apply review comments
1 parent 32f49cf commit 5c6d0a9

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_cachedKV.hpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -275,19 +275,19 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
275275
copy(params.gmem_tiled_copy_q, tQgQ(_,_,_,k_tile), tQrQ);
276276
copy(gmem_tiled_copy_k, tKgK(_,_,_,k_tile), tKrK);
277277
if constexpr (is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {
278-
auto tCrQ_ = make_fragment_like<half_t>(tCrQ);
279-
convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_);
280-
auto tCrK_ = make_fragment_like<half_t>(tCrK);
281-
convert_FP8_to_FP16<ElementK>(tCrK, tCrK_);
282-
cute::gemm(tiled_mma, accum, tCrQ_, tCrK_, frag_src);
278+
auto tCrQ_fp16 = make_fragment_like<half_t>(tCrQ);
279+
convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_fp16);
280+
auto tCrK_fp16 = make_fragment_like<half_t>(tCrK);
281+
convert_FP8_to_FP16<ElementK>(tCrK, tCrK_fp16);
282+
cute::gemm(tiled_mma, accum, tCrQ_fp16, tCrK_fp16, frag_src);
283283
} else if constexpr (is_fp8_v<ElementQ> && !is_fp8_v<ElementK>) {
284-
auto tCrQ_ = make_fragment_like<half_t>(tCrQ);
285-
convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_);
286-
cute::gemm(tiled_mma, accum, tCrQ_ , tCrK, frag_src);
284+
auto tCrQ_fp16 = make_fragment_like<half_t>(tCrQ);
285+
convert_FP8_to_FP16<ElementQ>(tCrQ, tCrQ_fp16);
286+
cute::gemm(tiled_mma, accum, tCrQ_fp16 , tCrK, frag_src);
287287
} else if constexpr (!is_fp8_v<ElementQ> && is_fp8_v<ElementK>) {
288-
auto tCrK_ = make_fragment_like<half_t>(tCrK);
289-
convert_FP8_to_FP16<ElementK>(tCrK, tCrK_);
290-
cute::gemm(tiled_mma, accum, tCrQ , tCrK_, frag_src);
288+
auto tCrK_fp16 = make_fragment_like<half_t>(tCrK);
289+
convert_FP8_to_FP16<ElementK>(tCrK, tCrK_fp16);
290+
cute::gemm(tiled_mma, accum, tCrQ , tCrK_fp16, frag_src);
291291
} else {
292292
cute::gemm(tiled_mma, accum, tCrQ , tCrK, frag_src);
293293
}
@@ -343,9 +343,9 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
343343
for(int i = 0; i< tile_count; i++) {
344344
copy(gmem_tiled_copy_v, tVgV(_,_,_,i), tVrV);
345345
if constexpr (is_fp8_v<ElementV>) {
346-
auto tCrV_ = make_fragment_like<half_t>(tCrV);
347-
convert_FP8_to_FP16<ElementV>(tCrV, tCrV_);
348-
cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV_, frag_src(_,_,_,i));
346+
auto tCrV_fp16 = make_fragment_like<half_t>(tCrV);
347+
convert_FP8_to_FP16<ElementV>(tCrV, tCrV_fp16);
348+
cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV_fp16, frag_src(_,_,_,i));
349349
} else {
350350
cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV, frag_src(_,_,_,i));
351351
}

examples/06_bmg_flash_attention/06_bmg_prefill_attention_prefill_cachedKV_fp8.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/***************************************************************************************************
2-
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
2+
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
33
* SPDX-License-Identifier: BSD-3-Clause
44
*
55
* Redistribution and use in source and binary forms, with or without

0 commit comments

Comments
 (0)