Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/benchmarks/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,6 +1892,7 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
enable_multimodal_chat=args.enable_multimodal_chat,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
skip_chat_template=args.skip_chat_template,
),
"sharegpt": lambda: ShareGPTDataset(
random_seed=args.seed,
Expand Down
22 changes: 16 additions & 6 deletions vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):


class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH

reorder_batch_threshold: int = 1

Expand All @@ -212,8 +210,14 @@ def __init__(self, *args, **kwargs):
if self.vllm_config.speculative_config
else 0
)
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
if self.num_speculative_tokens > 1:
raise ValueError(
"Sparse MLA (DeepSeekV3.2 indexer) only supports "
"num_speculative_tokens <= 1 because the DeepGEMM "
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
f"Got num_speculative_tokens={self.num_speculative_tokens}."
)
self.reorder_batch_threshold += self.num_speculative_tokens

props = torch.cuda.get_device_properties(self.device)
sm_count = props.multi_processor_count
Expand Down Expand Up @@ -342,8 +346,14 @@ def build(
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
# Padded CUDA graph requests have block_table entries of -1.
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
# This is safe because padded requests have seq_lens=0, so the
# kernel produces no meaningful output for those rows.
block_table.clamp_(min=0)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
block_table=block_table,
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=requires_padding,
Expand Down
Loading