Skip to content

Commit 919df17

Browse files
committed
Fix icpx failure on test_unit_flash_attention_prefill
1 parent 1f769d4 commit 919df17

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,17 @@ struct Shape_h192 {
125125
using GmemTiledCopyV = cute::XE_2D_U16x16x32_LD_V;
126126
using GmemTiledCopyO = cute::XE_2D_U16x8x16_ST_N;
127127
};
128+
129+
template <class, class> class convert_fp8_to_fp16_name;
130+
131+
template <typename SrcT, typename DstT>
132+
void convert_fp8_to_fp16(const SrcT* d_src, DstT* d_dst, size_t size) {
133+
cutlasscompat::get_default_queue().parallel_for<convert_fp8_to_fp16_name<SrcT, DstT>>(size, [=](auto indx) {
134+
d_dst[indx] = static_cast<DstT>(d_src[indx]);
135+
}).wait();
136+
}
137+
138+
128139
/////////////////////////////////////////////////////////////////////
129140

130141
template<typename ElementInputType, typename ElementAccumulatorType, typename ElementOutputType,
@@ -225,15 +236,6 @@ struct TestbedImpl {
225236
//
226237
// Methods
227238
//
228-
template <class, class> class convert_fp8_to_fp16_name;
229-
230-
template <typename SrcT, typename DstT>
231-
void convert_fp8_to_fp16(const SrcT* d_src, DstT* d_dst, size_t size) {
232-
cutlasscompat::get_default_queue().parallel_for<convert_fp8_to_fp16_name<SrcT, DstT>>(size, [=](auto indx) {
233-
d_dst[indx] = static_cast<DstT>(d_src[indx]);
234-
}).wait();
235-
}
236-
237239
template <typename T>
238240
static constexpr bool is_fp8_v = cute::is_any_of_v<T, cute::float_e5m2_t, cute::float_e4m3_t>;
239241

0 commit comments

Comments
 (0)