Skip to content

Commit 853faf9

Browse files
authored
fix phi::dtype in data_type_transform.cc (#76629)
1 parent 7c2c71f commit 853faf9

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

paddle/phi/core/framework/data_type_transform.cc

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,15 @@ template <typename InType, typename OutType>
3131
struct CastDataTypeFunctor {
3232
HOSTDEVICE inline OutType operator()(InType in) const {
3333
#if defined(_MSC_VER)
34-
// Avoid unsupported convert of float/bfloat8/float16 -> complex
34+
// Avoid unsupported convert of float8/bfloat16/float16 -> complex
3535
if constexpr (
36-
(std::is_same_v<OutType, phi::dtype::complex<float>> ||
36+
(std::is_same_v<OutType, phi::complex64> ||
3737
std::is_same_v<
3838
OutType,
39-
phi::dtype::complex<
40-
double>>)&&(std::is_same_v<InType,
41-
phi::dtype::float8_e4m3fn> ||
42-
std::is_same_v<InType, phi::dtype::float8_e5m2> ||
43-
std::is_same_v<InType, phi::dtype::bfloat16> ||
44-
std::is_same_v<InType, phi::dtype::float16>)) {
39+
phi::complex128>)&&(std::is_same_v<InType, phi::float8_e4m3fn> ||
40+
std::is_same_v<InType, phi::float8_e5m2> ||
41+
std::is_same_v<InType, phi::bfloat16> ||
42+
std::is_same_v<InType, phi::float16>)) {
4543
// default value,only to avoid compile error
4644
return OutType(0);
4745
} else {
@@ -181,7 +179,7 @@ void TransDataType(const phi::DenseTensor& in,
181179
if (phi::is_xpu_place(in.place())) {
182180
switch (src_type) {
183181
case proto::VarType::FP16:
184-
XPUTransDataType<phi::dtype::float16>(in, out, dst_type, ctx);
182+
XPUTransDataType<phi::float16>(in, out, dst_type, ctx);
185183
break;
186184
case proto::VarType::FP32:
187185
XPUTransDataType<float>(in, out, dst_type, ctx);
@@ -211,20 +209,18 @@ void TransDataType(const phi::DenseTensor& in,
211209
#endif
212210
switch (src_type) {
213211
case proto::VarType::FP16:
214-
phi::VisitDataType(dst_type,
215-
CastDataType<phi::dtype::float16>(in, out, ctx));
212+
phi::VisitDataType(dst_type, CastDataType<phi::float16>(in, out, ctx));
216213
break;
217214
case proto::VarType::BF16:
218-
phi::VisitDataType(dst_type,
219-
CastDataType<phi::dtype::bfloat16>(in, out, ctx));
215+
phi::VisitDataType(dst_type, CastDataType<phi::bfloat16>(in, out, ctx));
220216
break;
221217
case proto::VarType::FP8_E4M3FN:
222-
phi::VisitDataType(
223-
dst_type, CastDataType<::phi::dtype::float8_e4m3fn>(in, out, ctx));
218+
phi::VisitDataType(dst_type,
219+
CastDataType<phi::float8_e4m3fn>(in, out, ctx));
224220
break;
225221
case proto::VarType::FP8_E5M2:
226222
phi::VisitDataType(dst_type,
227-
CastDataType<::phi::dtype::float8_e5m2>(in, out, ctx));
223+
CastDataType<phi::float8_e5m2>(in, out, ctx));
228224
break;
229225
case proto::VarType::FP32:
230226
phi::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));

0 commit comments

Comments
 (0)