Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,11 @@ def schedule(self) -> SchedulerOutput:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Keep track of number of tokens to load from remote
# for the request st we can compute actual throughput
request.num_external_computed_tokens = (
num_external_computed_tokens
)

# Total computed tokens (local + external).
num_computed_tokens = (
Expand Down Expand Up @@ -1042,6 +1047,7 @@ def update_from_output(
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
num_external_computed_tokens=request.num_external_computed_tokens,
)
)
else:
Expand Down Expand Up @@ -1487,6 +1493,8 @@ def _update_requests_with_invalid_blocks(
request.num_computed_tokens - request.num_cached_tokens
)
request.num_computed_tokens = request.num_cached_tokens
# Prefill is to be recomputed locally.
request.num_external_computed_tokens = 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sdavidbd can you please double check this, my understanding is that we have to re-compute the whole prefill now so we can track prompt throughput

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even from the docstring, it seems clear we don't always re-compute the whole prefill:

This method scans the given requests, detects those with invalid blocks and adjusts their num_computed_tokens to the longest valid prefix.

A few things:

  1. I think num_external_computed_tokens = 0 should only happen inside the not marked_invalid_block clause - here we're saying all externally computed blocks are invalid
  2. (Unrelated to this PR - an observation) Setting request.num_computed_tokens = request.num_cached_tokens on line 1489 doesn't make sense to me - since num_cached_tokens includes both local and external computed tokens?
  3. We should update num_external_computed_tokens at # Truncate the computed tokens at the first failed block - to something like request.num_computed_tokens - local_computed_tokens (but not obvious how we calculated local_computed_tokens)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche — as @markmc noted, we don’t recompute the entire prefill. Only the externally computed tokens starting from the first failed block are recomputed.

To correctly update num_external_computed_tokens, we should first determine how many externally computed tokens are affected. This can be derived from the delta between the original and truncated num_computed_tokens — the same tokens already aggregated in total_affected_tokens (lines 1473–1477):

# Truncate the computed tokens at the first failed block
request.num_computed_tokens = idx * self.block_size
num_affected_tokens = req_num_computed_tokens - request.num_computed_tokens
total_affected_tokens += num_affected_tokens
request.num_external_computed_tokens -= num_affected_tokens

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@markmc — regarding your points:

  1. The not marked_invalid_block condition covers the sync-loading edge case where a request is affected by externally computed tokens that failed to load but are shared with preceding requests that will handle their recomputation. In this situation, the affected request still treats those tokens as locally computed, so its num_external_computed_tokens remains unchanged.

For example, assuming block_size = 1 and the following prompts (with R1receding R2 in the batch):

R1: t1 t2 t3
R2: t1 t2 t4 t5

Suppose t1 is locally computed, t2 and t4 are externally computed, and t2 fails to load while t4 succeeds. Then:

Before failure
Request num_computed_tokens num_external_computed_tokens
R1 2 1
R2 3 1
After failure
Request num_computed_tokens num_external_computed_tokens
R1 1 0
R2 3 1

Both R1 and R2 are affected and will recompute t2, t3 and t5 respectively, but R2’s total number of computed tokens remains unchanged.

  1. Correct — num_cached_tokens represents the total number of computed tokens (both local and external). Setting num_computed_tokens = num_cached_tokens ensures that all new tokens are recomputed in the current iteration, since the previous num_computed_tokens value already included them.

  2. Agreed — see my suggested code changes above for how we update num_external_computed_tokens accordingly.


affected_req_ids.add(request.request_id)

Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class EngineCoreOutput(
trace_headers: Mapping[str, str] | None = None
# The number of tokens with prefix cache hits.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this comment looks incorrect ... assuming "prefix cache" refers to the local cache?

                    # Total computed tokens (local + external).                                                                                                                 
                    num_computed_tokens = (
                        num_new_local_computed_tokens + num_external_computed_tokens
                    )
                ...
                # Count the number of prefix cached tokens.                                                                                                                     
                if request.num_cached_tokens < 0:
                    request.num_cached_tokens = num_computed_tokens

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with it cc @chaunceyjiang

num_cached_tokens: int = 0
# The number of tokens that have been computed remotely.
num_external_computed_tokens: int = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be tempted to refactor these two into a PrefillStats object ... and only include that in the ECO when the prefill completes ... especially if we ever wanted to also send like num_locally_cached_tokens too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion on this tbh, we can probably wait to have a few more things to bundle before executing the suggestion


@property
def finished(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _reset(self, now):

def _track_iteration_stats(self, iteration_stats: IterationStats):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably you want to update the Prometheus metric too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@markmc which one? I intentionally left self.counter_prompt_tokens unchanged to avoid replacing the actual prompt count.
Should I just make a new one for local tokens?

# Save tracked stats for token counters.
self.num_prompt_tokens += iteration_stats.num_prompt_tokens
self.num_prompt_tokens += iteration_stats.num_local_prompt_tokens
self.num_generation_tokens += iteration_stats.num_generation_tokens

def _get_throughput(self, tracked_stats: int, now: float) -> float:
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def __init__(self):
self.num_generation_tokens = 0
self.num_prompt_tokens = 0
self.num_preempted_reqs = 0
# Num of prompt tokens that have been computed locally.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the naming here a big confusing? By "computed locally" here we mean both computed and locally cached?

If you just tracked num_external_computed_tokens and then subtracted it in _track_iteration_stats() would that be more clear?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By "computed locally" here we mean both computed and locally cached?

Yes the behavior is unchanged, cached ones would still result in higher throughput even in regular aggregated setup.

If you just tracked num_external_computed_tokens and then subtracted it in _track_iteration_stats() would that be more clear?

I think looking at the diff

self.num_prompt_tokens += iteration_stats.num_prompt_tokens
-->
self.num_prompt_tokens += iteration_stats.num_local_prompt_tokens

this is pretty clear that I just want to rule out the remote tokens ie I assume the semantic was the intended one from the beginning, it's just "local" used to be redundant

self.num_local_prompt_tokens = 0
self.finished_requests: list[FinishedRequestStats] = []
self.max_num_generation_tokens_iter: list[int] = []
self.n_params_iter: list[int] = []
Expand Down Expand Up @@ -251,6 +253,9 @@ def update_from_output(
self.num_generation_tokens += num_new_generation_tokens
if is_prefilling:
self.num_prompt_tokens += prompt_len
self.num_local_prompt_tokens += (
prompt_len - output.num_external_computed_tokens
)

first_token_latency = self._time_since(req_stats.arrival_time)
self.time_to_first_tokens_iter.append(first_token_latency)
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,12 @@ def __init__(
# indicates that the output is corrupted
self.num_nans_in_logits = 0

# The number of requests being preempted by the scheduler
# The number of requests being preempted by the scheduler.
self.num_preemptions = 0

# The number of tokens that have been computed remotely.
self.num_external_computed_tokens = 0

self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
if block_hasher is not None:
Expand Down