diff --git a/src/transformers/generation/continuous_batching/requests.py b/src/transformers/generation/continuous_batching/requests.py index 28407c550e87..b1c65e8ba6b0 100644 --- a/src/transformers/generation/continuous_batching/requests.py +++ b/src/transformers/generation/continuous_batching/requests.py @@ -19,6 +19,7 @@ import torch +from ...utils import is_torch_xpu_available from ...utils.logging import logging from ...utils.metrics import traced @@ -35,6 +36,13 @@ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]: total_memory = torch.cuda.get_device_properties(device).total_memory reserved_memory = torch.cuda.memory_reserved(device) allocated_memory = torch.cuda.memory_allocated(device) + elif is_torch_xpu_available(): + device = torch.device("xpu") + torch.xpu.empty_cache() + torch.xpu.synchronize() + total_memory = torch.xpu.get_device_properties(device).total_memory + reserved_memory = torch.xpu.memory_reserved(device) + allocated_memory = torch.xpu.memory_allocated(device) elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): device = torch.device("mps") # MPS memory reporting (PyTorch 2.0+) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 6a6ce1db17e7..6552e068aaaf 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import is_accelerate_available, is_torch_available, logging +from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging if is_torch_available(): @@ -114,6 +114,9 @@ def convert_moe_packed_tensors( if not blocks.is_cuda and torch.cuda.is_available(): blocks = blocks.cuda() scales = scales.cuda() + elif (blocks.device.type != "xpu") and is_torch_xpu_available(): + blocks = blocks.to("xpu") + scales = scales.to("xpu") scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7 @@ -351,6 +354,8 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) if target_device == "cpu" and torch.cuda.is_available(): torch.cuda.empty_cache() + elif target_device == "cpu" and is_torch_xpu_available(): + torch.xpu.empty_cache() setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device))) delattr(module, blocks_attr) delattr(module, scales_attr) @@ -395,7 +400,7 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, trito else: blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2) if getattr(target_device, "type", target_device) == "cpu": - target_device = "cuda" + target_device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" blocks = blocks.to(target_device).contiguous() scales = scales.to(target_device).contiguous() with on_device(target_device): diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 80da7886dccf..e10d714a1df1 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -20,7 +20,14 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList from transformers.generation.continuous_batching.cache import group_layers_by_attn_type from transformers.generation.continuous_batching.continuous_api import build_attention_mask -from transformers.testing_utils import Expectations, require_kernels, require_torch_gpu, slow +from transformers.testing_utils import ( + Expectations, + require_kernels, + require_torch_accelerator, + require_torch_gpu, + slow, + torch_device, +) ALLOW_EXPECTED_OUTPUTS = True # this is a debug flag when you want to measure deviation between CB and non-CB gen @@ -148,7 +155,7 @@ def _continuous_batching_parity( # Generation with continuous batching model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation, dtype="auto") - model = model.cuda().eval() + model = model.to(torch_device).eval() model.generation_config.max_new_tokens = 40 model.generation_config.do_sample = False model.generation_config.use_cuda_graph = False @@ -156,9 +163,9 @@ def _continuous_batching_parity( cb_outputs = model.generate_batch(inputs=batched_inputs, generation_config=model.generation_config) # Generation without continuous batching - if attn_implementation == "sdpa_paged": + if attn_implementation == "paged|sdpa": non_cb_attn_implementation = "sdpa" - elif attn_implementation == "eager_paged": + elif attn_implementation == "paged|eager": non_cb_attn_implementation = "eager" elif attn_implementation == "paged_attention|kernels-community/flash-attn": non_cb_attn_implementation = "eager" @@ -169,14 +176,14 @@ def _continuous_batching_parity( model = AutoModelForCausalLM.from_pretrained( model_id, attn_implementation=non_cb_attn_implementation, dtype="auto" ) - model = model.cuda().eval() + model = model.to(torch_device).eval() model.generation_config.max_new_tokens = 40 model.generation_config.do_sample = False model.generation_config.use_cuda_graph = False for request_id, request in cb_outputs.items(): # Generate without continuous batching - input_ids = torch.tensor([request.prompt_ids]).cuda() + input_ids = torch.tensor([request.prompt_ids]).to(torch_device) attention_mask = torch.ones_like(input_ids) outputs = model.generate( input_ids, attention_mask=attention_mask, generation_config=model.generation_config @@ -208,7 +215,7 @@ def _continuous_batching_parity( ) # Eager tests - @require_torch_gpu + @require_torch_accelerator @slow def test_continuous_batching_parity_llama_eager(self) -> None: expected_outputs = Expectations({ @@ -218,11 +225,15 @@ def test_continuous_batching_parity_llama_eager(self) -> None: ("cuda", (9, 0)): { "req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5. The total number of bolts is 4.5. The total", "req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the" - } + }, + ("xpu", None): { + "req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The answer is not 3.5 bolts of blue fiber and 1.5 bolts of white fiber. The answer'", + "req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the" + }, }).get_expectation() # fmt: skip - self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "eager_paged", expected_outputs) + self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|eager", expected_outputs) - @require_torch_gpu + @require_torch_accelerator @slow def test_continuous_batching_parity_gemma_eager(self) -> None: expected_outputs = Expectations({ @@ -232,53 +243,68 @@ def test_continuous_batching_parity_gemma_eager(self) -> None: ("cuda", (9, 0)): { "req_0": "\n\n**$12**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 13", "req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n " - } + }, + ("xpu", None): { + "req_0": "\n\n**$12**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 13", + "req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n ", + "req_2": "\n\n**$100,000**\n\n**Explanation:**\n\nHere's how to calculate the profit:\n\n1. **Calculate the total cost:** $80,00", + }, }).get_expectation() # fmt: skip - self._continuous_batching_parity("google/gemma-2-2b-it", "eager_paged", expected_outputs) + self._continuous_batching_parity("google/gemma-2-2b-it", "paged|eager", expected_outputs) - @require_torch_gpu + @require_torch_accelerator @slow def test_continuous_batching_parity_qwen_eager(self) -> None: expected_outputs = {} - self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "eager_paged", expected_outputs) + self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "paged|eager", expected_outputs) - @require_torch_gpu + @require_torch_accelerator @slow def test_continuous_batching_parity_gpt_oss_eager(self) -> None: expected_outputs = Expectations({ ("cuda", (9, 0)): { "req_1": " 2.5 bolts. The question: \"What is the name of the puzzle that involves a robe taking 2 bolts of blue fiber and half that much white fiber?\" The answer: \"The", "req_2": " 50%.\"\n\nWe need to parse: He buys a house for $80,000. He puts in $50,000 in repairs. This increased the value of the house by 150%." - } + }, + ("xpu", None): { + "req_1": " 2.5 bolts. The question: \"What is the name of the puzzle that involves a robe taking 2 bolts of blue fiber and half that much white fiber?\" The answer: \"The", + "req_2": " 50%.\"\n\nWe need to parse: He buys a house for $80,000. He puts in $50,000 in repairs. This increased the value of the house by 150%." + }, }).get_expectation() # fmt: skip - self._continuous_batching_parity("openai/gpt-oss-20b", "eager_paged", expected_outputs) + self._continuous_batching_parity("openai/gpt-oss-20b", "paged|eager", expected_outputs) # SDPA tests - @require_torch_gpu + @require_torch_accelerator @slow def test_continuous_batching_parity_llama_sdpa(self) -> None: expected_outputs = Expectations({ ("rocm", (9, 4)): { "req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the" - } + }, + ("xpu", None): { + "req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the" + }, }).get_expectation() # fmt: skip - self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "sdpa_paged", expected_outputs) + self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|sdpa", expected_outputs) - @require_torch_gpu + @require_torch_accelerator @slow def test_continuous_batching_parity_gemma_sdpa(self) -> None: expected_outputs = Expectations({ ("cuda", (9, 0)): { "req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =", - } + }, + ("xpu", None): { + "req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =", + }, }).get_expectation() # fmt: skip - self._continuous_batching_parity("google/gemma-2-2b-it", "sdpa_paged", expected_outputs) + self._continuous_batching_parity("google/gemma-2-2b-it", "paged|sdpa", expected_outputs) - @require_torch_gpu + @require_torch_accelerator @slow def test_continuous_batching_parity_qwen_sdpa(self) -> None: expected_outputs = {} - self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "sdpa_paged", expected_outputs) + self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "paged|sdpa", expected_outputs) # GPT-OSS is not compatible with SDPA because it has an attention sink. TODO: is this fixable? @@ -336,7 +362,7 @@ def test_attn_implementation(self) -> None: manager = model.init_continuous_batching() assert "paged|eager" == manager.model.config._attn_implementation - @require_torch_gpu + @require_torch_accelerator def test_streaming_request(self) -> None: model_id = "Qwen/Qwen2.5-0.5B-Instruct" max_new_tokens = 3 @@ -368,7 +394,7 @@ def test_streaming_request(self) -> None: manager.stop(block=True) - @require_torch_gpu + @require_torch_accelerator def test_non_streaming_request(self) -> None: model_id = "Qwen/Qwen2.5-0.5B-Instruct" max_new_tokens = 3 @@ -395,7 +421,7 @@ def test_non_streaming_request(self) -> None: manager.stop(block=True) - @require_torch_gpu + @require_torch_accelerator def test_streaming_and_non_streaming_requests_can_alternate(self) -> None: model_id = "Qwen/Qwen2.5-0.5B-Instruct" max_new_tokens = 3