Skip to content

Commit b4408d1

Browse files
committed
fix quack availability check
Signed-off-by: Masaki Kozuki <[email protected]>
1 parent 968420c commit b4408d1

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

thunder/tests/test_cutlass_dsl_ex.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import thunder
1111
from thunder.executors.cutlass_dsl_ex import cutlass_dsl_ex, is_device_quack_compat
12-
from thunder.tests.framework import requiresCUDA
1312

1413
if TYPE_CHECKING:
1514
from typing import Any
@@ -18,7 +17,7 @@
1817

1918
_quack_available = find_spec("quack") is not None
2019
quack_available = pytest.mark.skipif(
21-
not is_device_quack_compat() or not _quack_available,
20+
not torch.cuda.is_available() or not _quack_available or not is_device_quack_compat(),
2221
reason="quack requires SM9.0/10.0",
2322
)
2423
_DTYPES = (torch.float16, torch.bfloat16, torch.float32)
@@ -44,7 +43,6 @@ def jit_with_cutlass_dsl_ex(fn: Callable[[Any], Any]) -> Callable[[Any], Any]:
4443
return thunder.jit(fn, executors=[cutlass_dsl_ex], disable_torch_autograd=True)
4544

4645

47-
@requiresCUDA
4846
@quack_available
4947
@pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS)
5048
def test_quack_cross_entropy(dtype: torch.dtype):
@@ -61,7 +59,6 @@ def test_quack_cross_entropy(dtype: torch.dtype):
6159
torch.testing.assert_close(expected, actual)
6260

6361

64-
@requiresCUDA
6562
@quack_available
6663
@pytest.mark.parametrize("shape", _SHAPES, ids=_SHAPE_IDS)
6764
@pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS)
@@ -76,7 +73,6 @@ def test_quack_softmax(dtype: torch.dtype, shape: tuple[int, ...]):
7673
torch.testing.assert_close(expected, actual)
7774

7875

79-
@requiresCUDA
8076
@quack_available
8177
@pytest.mark.parametrize("shape", _SHAPES, ids=_SHAPE_IDS)
8278
@pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS)
@@ -92,7 +88,6 @@ def test_quack_layernorm(dtype: torch.dtype, shape: tuple[int, ...]):
9288
torch.testing.assert_close(expected, actual)
9389

9490

95-
@requiresCUDA
9691
@quack_available
9792
@pytest.mark.parametrize("shape", _SHAPES, ids=_SHAPE_IDS)
9893
@pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS)

0 commit comments

Comments
 (0)