Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/transformers/generation/continuous_batching/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch

from ...utils import is_torch_xpu_available
from ...utils.logging import logging
from ...utils.metrics import traced

Expand All @@ -35,6 +36,13 @@ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
total_memory = torch.cuda.get_device_properties(device).total_memory
reserved_memory = torch.cuda.memory_reserved(device)
allocated_memory = torch.cuda.memory_allocated(device)
elif is_torch_xpu_available():
device = torch.device("xpu")
torch.xpu.empty_cache()
torch.xpu.synchronize()
total_memory = torch.xpu.get_device_properties(device).total_memory
reserved_memory = torch.xpu.memory_reserved(device)
allocated_memory = torch.xpu.memory_allocated(device)
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
# MPS memory reporting (PyTorch 2.0+)
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/integrations/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging


if is_torch_available():
Expand Down Expand Up @@ -114,6 +114,9 @@ def convert_moe_packed_tensors(
if not blocks.is_cuda and torch.cuda.is_available():
blocks = blocks.cuda()
scales = scales.cuda()
elif (blocks.device.type != "xpu") and is_torch_xpu_available():
blocks = blocks.to("xpu")
scales = scales.to("xpu")

scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7

Expand Down Expand Up @@ -351,6 +354,8 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
if target_device == "cpu" and torch.cuda.is_available():
torch.cuda.empty_cache()
elif target_device == "cpu" and is_torch_xpu_available():
torch.xpu.empty_cache()
setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device)))
delattr(module, blocks_attr)
delattr(module, scales_attr)
Expand Down Expand Up @@ -395,7 +400,7 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, trito
else:
blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
if getattr(target_device, "type", target_device) == "cpu":
target_device = "cuda"
target_device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
blocks = blocks.to(target_device).contiguous()
scales = scales.to(target_device).contiguous()
with on_device(target_device):
Expand Down
82 changes: 54 additions & 28 deletions tests/generation/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList
from transformers.generation.continuous_batching.cache import group_layers_by_attn_type
from transformers.generation.continuous_batching.continuous_api import build_attention_mask
from transformers.testing_utils import Expectations, require_kernels, require_torch_gpu, slow
from transformers.testing_utils import (
Expectations,
require_kernels,
require_torch_accelerator,
require_torch_gpu,
slow,
torch_device,
)


ALLOW_EXPECTED_OUTPUTS = True # this is a debug flag when you want to measure deviation between CB and non-CB gen
Expand Down Expand Up @@ -148,17 +155,17 @@ def _continuous_batching_parity(

# Generation with continuous batching
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation, dtype="auto")
model = model.cuda().eval()
model = model.to(torch_device).eval()
model.generation_config.max_new_tokens = 40
model.generation_config.do_sample = False
model.generation_config.use_cuda_graph = False

cb_outputs = model.generate_batch(inputs=batched_inputs, generation_config=model.generation_config)

# Generation without continuous batching
if attn_implementation == "sdpa_paged":
if attn_implementation == "paged|sdpa":
non_cb_attn_implementation = "sdpa"
elif attn_implementation == "eager_paged":
elif attn_implementation == "paged|eager":
non_cb_attn_implementation = "eager"
elif attn_implementation == "paged_attention|kernels-community/flash-attn":
non_cb_attn_implementation = "eager"
Expand All @@ -169,14 +176,14 @@ def _continuous_batching_parity(
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=non_cb_attn_implementation, dtype="auto"
)
model = model.cuda().eval()
model = model.to(torch_device).eval()
model.generation_config.max_new_tokens = 40
model.generation_config.do_sample = False
model.generation_config.use_cuda_graph = False

for request_id, request in cb_outputs.items():
# Generate without continuous batching
input_ids = torch.tensor([request.prompt_ids]).cuda()
input_ids = torch.tensor([request.prompt_ids]).to(torch_device)
attention_mask = torch.ones_like(input_ids)
outputs = model.generate(
input_ids, attention_mask=attention_mask, generation_config=model.generation_config
Expand Down Expand Up @@ -208,7 +215,7 @@ def _continuous_batching_parity(
)

# Eager tests
@require_torch_gpu
@require_torch_accelerator
@slow
def test_continuous_batching_parity_llama_eager(self) -> None:
expected_outputs = Expectations({
Expand All @@ -218,11 +225,15 @@ def test_continuous_batching_parity_llama_eager(self) -> None:
("cuda", (9, 0)): {
"req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5. The total number of bolts is 4.5. The total",
"req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the"
}
},
("xpu", None): {
"req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The answer is not 3.5 bolts of blue fiber and 1.5 bolts of white fiber. The answer'",
"req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the"
},
}).get_expectation() # fmt: skip
self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "eager_paged", expected_outputs)
self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|eager", expected_outputs)

@require_torch_gpu
@require_torch_accelerator
@slow
def test_continuous_batching_parity_gemma_eager(self) -> None:
expected_outputs = Expectations({
Expand All @@ -232,53 +243,68 @@ def test_continuous_batching_parity_gemma_eager(self) -> None:
("cuda", (9, 0)): {
"req_0": "\n\n**$12**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 13",
"req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n "
}
},
("xpu", None): {
"req_0": "\n\n**$12**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 13",
"req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n ",
"req_2": "\n\n**$100,000**\n\n**Explanation:**\n\nHere's how to calculate the profit:\n\n1. **Calculate the total cost:** $80,00",
},
}).get_expectation() # fmt: skip
self._continuous_batching_parity("google/gemma-2-2b-it", "eager_paged", expected_outputs)
self._continuous_batching_parity("google/gemma-2-2b-it", "paged|eager", expected_outputs)

@require_torch_gpu
@require_torch_accelerator
@slow
def test_continuous_batching_parity_qwen_eager(self) -> None:
expected_outputs = {}
self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "eager_paged", expected_outputs)
self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "paged|eager", expected_outputs)

@require_torch_gpu
@require_torch_accelerator
@slow
def test_continuous_batching_parity_gpt_oss_eager(self) -> None:
expected_outputs = Expectations({
("cuda", (9, 0)): {
"req_1": " 2.5 bolts. The question: \"What is the name of the puzzle that involves a robe taking 2 bolts of blue fiber and half that much white fiber?\" The answer: \"The",
"req_2": " 50%.\"\n\nWe need to parse: He buys a house for $80,000. He puts in $50,000 in repairs. This increased the value of the house by 150%."
}
},
("xpu", None): {
"req_1": " 2.5 bolts. The question: \"What is the name of the puzzle that involves a robe taking 2 bolts of blue fiber and half that much white fiber?\" The answer: \"The",
"req_2": " 50%.\"\n\nWe need to parse: He buys a house for $80,000. He puts in $50,000 in repairs. This increased the value of the house by 150%."
},
}).get_expectation() # fmt: skip
self._continuous_batching_parity("openai/gpt-oss-20b", "eager_paged", expected_outputs)
self._continuous_batching_parity("openai/gpt-oss-20b", "paged|eager", expected_outputs)

# SDPA tests
@require_torch_gpu
@require_torch_accelerator
@slow
def test_continuous_batching_parity_llama_sdpa(self) -> None:
expected_outputs = Expectations({
("rocm", (9, 4)): {
"req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the"
}
},
("xpu", None): {
"req_2": " $50,000. This is because the value of the house increased by 150%, which means that the value of the house increased by $50,000. This is because the value of the"
},
}).get_expectation() # fmt: skip
self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "sdpa_paged", expected_outputs)
self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|sdpa", expected_outputs)

@require_torch_gpu
@require_torch_accelerator
@slow
def test_continuous_batching_parity_gemma_sdpa(self) -> None:
expected_outputs = Expectations({
("cuda", (9, 0)): {
"req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =",
}
},
("xpu", None): {
"req_1": " \n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =",
},
}).get_expectation() # fmt: skip
self._continuous_batching_parity("google/gemma-2-2b-it", "sdpa_paged", expected_outputs)
self._continuous_batching_parity("google/gemma-2-2b-it", "paged|sdpa", expected_outputs)

@require_torch_gpu
@require_torch_accelerator
@slow
def test_continuous_batching_parity_qwen_sdpa(self) -> None:
expected_outputs = {}
self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "sdpa_paged", expected_outputs)
self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "paged|sdpa", expected_outputs)

# GPT-OSS is not compatible with SDPA because it has an attention sink. TODO: is this fixable?

Expand Down Expand Up @@ -336,7 +362,7 @@ def test_attn_implementation(self) -> None:
manager = model.init_continuous_batching()
assert "paged|eager" == manager.model.config._attn_implementation

@require_torch_gpu
@require_torch_accelerator
def test_streaming_request(self) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
max_new_tokens = 3
Expand Down Expand Up @@ -368,7 +394,7 @@ def test_streaming_request(self) -> None:

manager.stop(block=True)

@require_torch_gpu
@require_torch_accelerator
def test_non_streaming_request(self) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
max_new_tokens = 3
Expand All @@ -395,7 +421,7 @@ def test_non_streaming_request(self) -> None:

manager.stop(block=True)

@require_torch_gpu
@require_torch_accelerator
def test_streaming_and_non_streaming_requests_can_alternate(self) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
max_new_tokens = 3
Expand Down