Skip to content

Commit be6972b

Browse files
committed
fix FLAGS_enable_api_kernel_fallback does not take effect in XPU
1 parent 519aee5 commit be6972b

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

paddle/phi/core/kernel_factory.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
355355
!phi::backends::xpu::is_xpu_support_op(TransToFluidOpName(kernel_name),
356356
kernel_key.dtype()) &&
357357
!phi::backends::xpu::is_xpu_support_op(kernel_name, kernel_key.dtype());
358-
if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) ||
358+
if (FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end() &&
359359
is_xpu_unsupported
360360
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
361361
if (kernel_iter == iter->second.end() &&

test/xpu/test_fallback_flag.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle
18+
19+
20+
class TestFLAGSEnableApiKernelFallback(unittest.TestCase):
21+
def test_FLAGS_enable_api_kernel_fallback(self):
22+
FLAGS_enable_api_kernel_fallback_prev: bool = paddle.get_flags(
23+
["FLAGS_enable_api_kernel_fallback"]
24+
)["FLAGS_enable_api_kernel_fallback"]
25+
paddle.set_flags({"FLAGS_enable_api_kernel_fallback": False})
26+
x = paddle.to_tensor(1.0, dtype="float64")
27+
with self.assertRaisesRegex(RuntimeError, "not registered"):
28+
z = paddle.sqrt(x)
29+
paddle.set_flags(
30+
{
31+
'FLAGS_enable_api_kernel_fallback': FLAGS_enable_api_kernel_fallback_prev
32+
}
33+
)
34+
35+
36+
if __name__ == "__main__":
37+
unittest.main()

0 commit comments

Comments
 (0)