@@ -125,6 +125,17 @@ struct Shape_h192 {
125
125
using GmemTiledCopyV = cute::XE_2D_U16x16x32_LD_V;
126
126
using GmemTiledCopyO = cute::XE_2D_U16x8x16_ST_N;
127
127
};
128
+
129
+ template <class , class > class convert_fp8_to_fp16_name ;
130
+
131
+ template <typename SrcT, typename DstT>
132
+ void convert_fp8_to_fp16 (const SrcT* d_src, DstT* d_dst, size_t size) {
133
+ cutlasscompat::get_default_queue ().parallel_for <convert_fp8_to_fp16_name<SrcT, DstT>>(size, [=](auto indx) {
134
+ d_dst[indx] = static_cast <DstT>(d_src[indx]);
135
+ }).wait ();
136
+ }
137
+
138
+
128
139
// ///////////////////////////////////////////////////////////////////
129
140
130
141
template <typename ElementInputType, typename ElementAccumulatorType, typename ElementOutputType,
@@ -225,15 +236,6 @@ struct TestbedImpl {
225
236
//
226
237
// Methods
227
238
//
228
- template <class , class > class convert_fp8_to_fp16_name ;
229
-
230
- template <typename SrcT, typename DstT>
231
- void convert_fp8_to_fp16 (const SrcT* d_src, DstT* d_dst, size_t size) {
232
- cutlasscompat::get_default_queue ().parallel_for <convert_fp8_to_fp16_name<SrcT, DstT>>(size, [=](auto indx) {
233
- d_dst[indx] = static_cast <DstT>(d_src[indx]);
234
- }).wait ();
235
- }
236
-
237
239
template <typename T>
238
240
static constexpr bool is_fp8_v = cute::is_any_of_v<T, cute::float_e5m2_t , cute::float_e4m3_t >;
239
241
0 commit comments