Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class CommonAttentionMetadata:
dcp_local_seq_lens: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""

enable_kv_scales_calculation: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make sure this is removed from backend-specific attention metadata classes, if anywhere?



def slice_query_start_locs(
query_start_loc: torch.Tensor,
Expand Down
14 changes: 14 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def __init__(
# This will be overridden in load_model()
self.is_multimodal_pruning_enabled = False
self.max_model_len = model_config.max_model_len

self.kv_scales_calculated = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.kv_scales_calculated = False
# Always set to false after the first forward pass
self.calculate_kv_scales = self.cache_config.calculate_kv_scales

self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
Expand Down Expand Up @@ -1328,6 +1330,12 @@ def _prepare_inputs(
kv_cache_group_id
]

# Determine if we need to calculate KV scales on this forward pass.
# Only True on the first pass when calculate_kv_scales is enabled.
enable_kv_scales_calculation = (
self.cache_config.calculate_kv_scales
and not self.kv_scales_calculated)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Determine if we need to calculate KV scales on this forward pass.
# Only True on the first pass when calculate_kv_scales is enabled.
enable_kv_scales_calculation = (
self.cache_config.calculate_kv_scales
and not self.kv_scales_calculated)


common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
Expand All @@ -1347,6 +1355,7 @@ def _prepare_inputs(
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
if self.dcp_world_size > 1
else None,
enable_kv_scales_calculation=enable_kv_scales_calculation,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
enable_kv_scales_calculation=enable_kv_scales_calculation,
enable_kv_scales_calculation=self.calculate_kv_scales,

)
Comment on lines 1355 to 1358

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Propagate KV scale flag to per-layer metadata

The new enable_kv_scales_calculation flag is set only on CommonAttentionMetadata here, but the per-layer metadata objects that Attention.forward actually receives are produced later by builders (split_attn_metadata and the various *AttentionMetadataBuilder.build) without copying this attribute. As a result getattr(attn_metadata, "enable_kv_scales_calculation", False) in vllm/attention/layer.py remains False and the KV scale calculation path never runs, so the bug this change is meant to fix still occurs whenever KV scales are enabled. The flag needs to be attached to the per-layer metadata returned by the builders (and preserved when splitting) so layers can see it.

Useful? React with 👍 / 👎.


if self.speculative_config and spec_decode_common_attn_metadata is None:
Expand Down Expand Up @@ -2525,6 +2534,11 @@ def execute_model(
**model_kwargs,
)

# Mark KV scales as calculated after the first forward pass
if (self.cache_config.calculate_kv_scales
and not self.kv_scales_calculated):
self.kv_scales_calculated = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To improve maintainability and avoid re-evaluating the same condition, you can pass the enable_kv_scales_calculation flag from _prepare_inputs to execute_model. This makes the logic clearer and reduces redundancy.

Here's how you can do it:

  1. Update _prepare_inputs to return enable_kv_scales_calculation:

    # In _prepare_inputs function signature
    ) -> tuple[
        ...,
        bool,  # use_cascade_attn
        bool,  # enable_kv_scales_calculation
    ]:
    
    # In _prepare_inputs return statement
    return (
        ...,
        use_cascade_attn,
        enable_kv_scales_calculation,
    )
  2. Update the call to _prepare_inputs in execute_model:

    # In execute_model
    (
        ...,
        use_cascade_attn,
        enable_kv_scales_calculation,
    ) = self._prepare_inputs(scheduler_output)
  3. Then, you can simplify the logic for updating self.kv_scales_calculated as suggested below.

# Mark KV scales as calculated if they were computed in this pass.
if enable_kv_scales_calculation:
    self.kv_scales_calculated = True

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even easier:

Suggested change
if (self.cache_config.calculate_kv_scales
and not self.kv_scales_calculated):
self.kv_scales_calculated = True
self.calculate_kv_scales = False


with record_function_or_nullcontext("Postprocess"):
if self.use_aux_hidden_state_outputs:
# True when EAGLE 3 is used.
Expand Down
Loading