Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,17 +836,13 @@ def maybe_calc_kv_scales(
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata

if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]

if attn_metadata is None or not getattr(
attn_metadata, "enable_kv_scales_calculation", False
):
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
if not self.calculate_kv_scales:
return

self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)


Expand Down
20 changes: 10 additions & 10 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def __init__(
# This will be overridden in load_model()
self.is_multimodal_pruning_enabled = False
self.max_model_len = model_config.max_model_len

# Always set to false after the first forward pass
self.calculate_kv_scales = self.cache_config.calculate_kv_scales
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
Expand Down Expand Up @@ -2491,16 +2494,10 @@ def execute_model(
)

# Set cudagraph mode to none if calc_kv_scales is true.
if attn_metadata is not None:
metadata_list = (
attn_metadata.values()
if isinstance(attn_metadata, dict)
else [attn_metadata]
)
if any(
getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list
):
cudagraph_runtime_mode = CUDAGraphMode.NONE
# KV scales calculation involves dynamic operations that are incompatible
# with CUDA graph capture.
if self.calculate_kv_scales:
cudagraph_runtime_mode = CUDAGraphMode.NONE

# Run the model.
# Use persistent buffers for CUDA graphs.
Expand All @@ -2525,6 +2522,9 @@ def execute_model(
**model_kwargs,
)

# Mark KV scales as calculated after the first forward pass
self.calculate_kv_scales = False

with record_function_or_nullcontext("Postprocess"):
if self.use_aux_hidden_state_outputs:
# True when EAGLE 3 is used.
Expand Down