99
1010import thunder
1111from thunder .executors .cutlass_dsl_ex import cutlass_dsl_ex , is_device_quack_compat
12- from thunder .tests .framework import requiresCUDA
1312
1413if TYPE_CHECKING :
1514 from typing import Any
1817
1918_quack_available = find_spec ("quack" ) is not None
2019quack_available = pytest .mark .skipif (
21- not is_device_quack_compat () or not _quack_available ,
20+ not torch . cuda . is_available () or not _quack_available or not is_device_quack_compat () ,
2221 reason = "quack requires SM9.0/10.0" ,
2322)
2423_DTYPES = (torch .float16 , torch .bfloat16 , torch .float32 )
@@ -44,7 +43,6 @@ def jit_with_cutlass_dsl_ex(fn: Callable[[Any], Any]) -> Callable[[Any], Any]:
4443 return thunder .jit (fn , executors = [cutlass_dsl_ex ], disable_torch_autograd = True )
4544
4645
47- @requiresCUDA
4846@quack_available
4947@pytest .mark .parametrize ("dtype" , _DTYPES , ids = _DTYPE_IDS )
5048def test_quack_cross_entropy (dtype : torch .dtype ):
@@ -61,7 +59,6 @@ def test_quack_cross_entropy(dtype: torch.dtype):
6159 torch .testing .assert_close (expected , actual )
6260
6361
64- @requiresCUDA
6562@quack_available
6663@pytest .mark .parametrize ("shape" , _SHAPES , ids = _SHAPE_IDS )
6764@pytest .mark .parametrize ("dtype" , _DTYPES , ids = _DTYPE_IDS )
@@ -76,7 +73,6 @@ def test_quack_softmax(dtype: torch.dtype, shape: tuple[int, ...]):
7673 torch .testing .assert_close (expected , actual )
7774
7875
79- @requiresCUDA
8076@quack_available
8177@pytest .mark .parametrize ("shape" , _SHAPES , ids = _SHAPE_IDS )
8278@pytest .mark .parametrize ("dtype" , _DTYPES , ids = _DTYPE_IDS )
@@ -92,7 +88,6 @@ def test_quack_layernorm(dtype: torch.dtype, shape: tuple[int, ...]):
9288 torch .testing .assert_close (expected , actual )
9389
9490
95- @requiresCUDA
9691@quack_available
9792@pytest .mark .parametrize ("shape" , _SHAPES , ids = _SHAPE_IDS )
9893@pytest .mark .parametrize ("dtype" , _DTYPES , ids = _DTYPE_IDS )
0 commit comments