@@ -31,17 +31,15 @@ template <typename InType, typename OutType>
3131struct CastDataTypeFunctor {
3232 HOSTDEVICE inline OutType operator ()(InType in) const {
3333#if defined(_MSC_VER)
34- // Avoid unsupported convert of float/bfloat8 /float16 -> complex
34+ // Avoid unsupported convert of float8/bfloat16 /float16 -> complex
3535 if constexpr (
36- (std::is_same_v<OutType, phi::dtype:: complex < float > > ||
36+ (std::is_same_v<OutType, phi::complex64 > ||
3737 std::is_same_v<
3838 OutType,
39- phi::dtype::complex <
40- double >>)&&(std::is_same_v<InType,
41- phi::dtype::float8_e4m3fn> ||
42- std::is_same_v<InType, phi::dtype::float8_e5m2> ||
43- std::is_same_v<InType, phi::dtype::bfloat16> ||
44- std::is_same_v<InType, phi::dtype::float16>)) {
39+ phi::complex128>)&&(std::is_same_v<InType, phi::float8_e4m3fn> ||
40+ std::is_same_v<InType, phi::float8_e5m2> ||
41+ std::is_same_v<InType, phi::bfloat16> ||
42+ std::is_same_v<InType, phi::float16>)) {
4543 // default value,only to avoid compile error
4644 return OutType (0 );
4745 } else {
@@ -181,7 +179,7 @@ void TransDataType(const phi::DenseTensor& in,
181179 if (phi::is_xpu_place (in.place ())) {
182180 switch (src_type) {
183181 case proto::VarType::FP16:
184- XPUTransDataType<phi::dtype:: float16>(in, out, dst_type, ctx);
182+ XPUTransDataType<phi::float16>(in, out, dst_type, ctx);
185183 break ;
186184 case proto::VarType::FP32:
187185 XPUTransDataType<float >(in, out, dst_type, ctx);
@@ -211,20 +209,18 @@ void TransDataType(const phi::DenseTensor& in,
211209#endif
212210 switch (src_type) {
213211 case proto::VarType::FP16:
214- phi::VisitDataType (dst_type,
215- CastDataType<phi::dtype::float16>(in, out, ctx));
212+ phi::VisitDataType (dst_type, CastDataType<phi::float16>(in, out, ctx));
216213 break ;
217214 case proto::VarType::BF16:
218- phi::VisitDataType (dst_type,
219- CastDataType<phi::dtype::bfloat16>(in, out, ctx));
215+ phi::VisitDataType (dst_type, CastDataType<phi::bfloat16>(in, out, ctx));
220216 break ;
221217 case proto::VarType::FP8_E4M3FN:
222- phi::VisitDataType (
223- dst_type, CastDataType<:: phi::dtype ::float8_e4m3fn>(in, out, ctx));
218+ phi::VisitDataType (dst_type,
219+ CastDataType<phi::float8_e4m3fn>(in, out, ctx));
224220 break ;
225221 case proto::VarType::FP8_E5M2:
226222 phi::VisitDataType (dst_type,
227- CastDataType<:: phi::dtype ::float8_e5m2>(in, out, ctx));
223+ CastDataType<phi::float8_e5m2>(in, out, ctx));
228224 break ;
229225 case proto::VarType::FP32:
230226 phi::VisitDataType (dst_type, CastDataType<float >(in, out, ctx));
0 commit comments