Skip to content
Merged
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def needs_capture(self, batch_size: int):
return (batch_size, self.draft_len) not in self.graph_outputs

def capture(self, batch_size: int, forward_fn: Callable,
initial_inputs: Dict[str, Any]):
postprocess_fn: Callable, initial_inputs: Dict[str, Any]):
"""Captures the forward pass for a given batch size."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

API change: make postprocess param explicit and align types with returned outputs.

  • The new parameter is good. Add a precise type for it. Also, capture now stores/returns dict outputs; update type hints accordingly and ensure replay returns the concrete object, not a wrapper.

Apply:

-    def capture(self, batch_size: int, forward_fn: Callable,
-                postprocess_fn: Callable, initial_inputs: Dict[str, Any]):
+    def capture(
+        self,
+        batch_size: int,
+        forward_fn: Callable[[Dict[str, Any]], Dict[str, Any]],
+        postprocess_fn: Callable[[Dict[str, Any]], None],
+        initial_inputs: Dict[str, Any],
+    ) -> None:
-        self.graph_outputs: Dict[Tuple[int, int],
-                                 Callable[[], Optional[torch.Tensor]]] = {}
+        self.graph_outputs: Dict[
+            Tuple[int, int], Callable[[], Optional[Dict[str, Any]]]
+        ] = {}
-    def replay(self, batch_size: int,
-               current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
+    def replay(
+        self, batch_size: int, current_inputs: Dict[str, Any]
+    ) -> Optional[Dict[str, Any]]:
-        output_ref = self.graph_outputs[key]
-
-        return output_ref
+        output_ref = self.graph_outputs[key]
+        return output_ref() if callable(output_ref) else output_ref

Run to confirm no other callsites of capture need changes:

Also applies to: 45-47, 193-201


🏁 Script executed:

#!/bin/bash
rg -nP -C2 '\bCUDAGraphRunner\b.*\bcapture\s*\('

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Search for any calls to .capture( across the repository
rg -nP -C2 '\.capture\s*\(' .

Length of output: 7724


Include explicit postprocess_fn in all CUDAGraphRunner.capture callsites
All existing calls (e.g. in tests/unittest/_torch/modeling/*.py and model_engine.py) must pass a postprocess_fn: Callable[[Dict[str, Any]], None] before initial_inputs (for tests, use lambda _: None). This ensures the API change aligns with updated type hints and prevents test failures.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py around lines 139 to 141,
the capture signature now requires an explicit postprocess_fn parameter but
existing callsites don’t pass it; update every CUDAGraphRunner.capture(...) call
(including tests in tests/unittest/_torch/modeling/*.py and model_engine.py) to
pass a postprocess_fn: Callable[[Dict[str, Any]], None] as the argument
immediately before initial_inputs — in tests use a no-op lambda like `lambda _:
None`; ensure all updated callsites match the new parameter order and type to
satisfy the updated type hints and prevent test failures.

engine = self._get_engine()
key = (batch_size, self.draft_len)
Expand Down Expand Up @@ -181,8 +181,10 @@ def capture(self, batch_size: int, forward_fn: Callable,
with with_multi_stream(True), piecewise_cuda_graph(False):
for _ in range(self.WARMUP_STEPS):
forward_fn(capture_inputs)
postprocess_fn(capture_inputs)
with torch.cuda.graph(graph, pool=self.memory_pool):
output = forward_fn(capture_inputs)
postprocess_fn(capture_inputs)

self.graphs[key] = graph
self.graph_outputs[key] = make_weak_ref(output)
Expand Down
27 changes: 26 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,26 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
self.previous_kv_lens_offsets_cuda[:num_gen_requests])
return inputs

def _postprocess_inputs(self, inputs: Dict[str, Any]):
"""
Postprocess to make sure model forward doesn't change the inputs.
It is only used in cuda graph capture, because other cases will prepare
new inputs before the model forward.
"""
if self.enable_spec_decode and not self._disable_overlap_scheduler:
if inputs['attn_metadata'].kv_cache_manager is not None:
num_seqs = inputs['attn_metadata'].num_seqs
num_ctx_requests = inputs['attn_metadata'].num_contexts
num_gen_requests = inputs['attn_metadata'].num_generations
num_ctx_tokens = inputs['attn_metadata'].num_ctx_tokens
previous_batch_tokens = inputs['input_ids'].shape[
0] - num_ctx_tokens
inputs['position_ids'][0, num_ctx_tokens:] -= (
self.previous_pos_id_offsets_cuda[:previous_batch_tokens])
inputs['attn_metadata'].kv_lens_cuda[
num_ctx_requests:num_seqs] -= (
self.previous_kv_lens_offsets_cuda[:num_gen_requests])

def _get_all_rank_num_tokens(self, attn_metadata: AttentionMetadata):
if self.enable_attention_dp:
return list(self.dist.tp_allgather(attn_metadata.num_tokens))
Expand Down Expand Up @@ -2206,8 +2226,13 @@ def capture_forward_fn(inputs: Dict[str, Any]):
gather_ids=gather_ids,
gather_context_logits=gather_context_logits)

def capture_postprocess_fn(inputs: Dict[str, Any]):
self._postprocess_inputs(inputs)

self.cuda_graph_runner.capture(batch_size,
capture_forward_fn, inputs)
capture_forward_fn,
capture_postprocess_fn,
inputs)

# here we don't need to use context since cuda graph capture didn't run kernel.
# maybe we need a cleaner way to do this.
Expand Down