Skip to content

Commit 0ede12d

Browse files
fix handling of flash_attention
1 parent 976ede8 commit 0ede12d

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

keras/src/backend/openvino/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ def dot_product_attention(
698698
"`dot_product_attention` with `bias` is not supported "
699699
"with openvino backend"
700700
)
701-
if flash_attention is not None:
701+
if flash_attention:
702702
raise NotImplementedError(
703703
"`dot_product_attention` with `flash_attention` is not supported "
704704
"with openvino backend"

keras/src/ops/nn_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2458,7 +2458,7 @@ def test_dot_product_attention(
24582458
)
24592459

24602460
if flash_attention:
2461-
if backend.backend() in ("tensorflow", "numpy"):
2461+
if backend.backend() in ("tensorflow", "numpy", "openvino"):
24622462
self.skipTest(
24632463
"Flash attention is not supported in tensorflow and numpy "
24642464
"backends."

0 commit comments

Comments
 (0)