diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index cfe2339e2b56..c742927646b6 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -18,10 +18,10 @@ import torch from diffusers import DiffusionPipeline -from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device +from diffusers.utils.testing_utils import backend_empty_cache, require_torch_accelerator, slow, torch_device -@require_torch_gpu +@require_torch_accelerator @slow class QuantCompileTests: @property @@ -51,7 +51,7 @@ def _init_pipeline(self, quantization_config, torch_dtype): return pipe def _test_torch_compile(self, torch_dtype=torch.bfloat16): - pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda") + pipe = self._init_pipeline(self.quantization_config, torch_dtype).to(torch_device) # `fullgraph=True` ensures no graph breaks pipe.transformer.compile(fullgraph=True) @@ -71,7 +71,7 @@ def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16 pipe = self._init_pipeline(self.quantization_config, torch_dtype) group_offload_kwargs = { - "onload_device": torch.device("cuda"), + "onload_device": torch.device(torch_device), "offload_device": torch.device("cpu"), "offload_type": "leaf_level", "use_stream": use_stream, @@ -81,7 +81,7 @@ def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16 for name, component in pipe.components.items(): if name != "transformer" and isinstance(component, torch.nn.Module): if torch.device(component.device).type == "cpu": - component.to("cuda") + component.to(torch_device) # small resolutions to ensure speedy execution. pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 9d09fd2f1bab..5dcc207e655b 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -236,7 +236,7 @@ def test_quantization(self): ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), ] - if TorchAoConfig._is_cuda_capability_atleast_8_9(): + if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES_TO_TEST.extend([ ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), @@ -753,7 +753,7 @@ def test_quantization(self): ("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])), ] - if TorchAoConfig._is_cuda_capability_atleast_8_9(): + if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES_TO_TEST.extend([ ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),