Skip to content
Merged
11 changes: 9 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,11 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests):
def needs_capture(self, batch_size: int):
return (batch_size, self.draft_len) not in self.graph_outputs

def capture(self, batch_size: int, forward_fn: Callable,
initial_inputs: Dict[str, Any]):
def capture(self,
batch_size: int,
forward_fn: Callable,
initial_inputs: Dict[str, Any],
postprocess_fn: Optional[Callable] = None):
"""Captures the forward pass for a given batch size."""
engine = self._get_engine()
key = (batch_size, self.draft_len)
Expand Down Expand Up @@ -181,8 +184,12 @@ def capture(self, batch_size: int, forward_fn: Callable,
with with_multi_stream(True), piecewise_cuda_graph(False):
for _ in range(self.WARMUP_STEPS):
forward_fn(capture_inputs)
if postprocess_fn is not None:
postprocess_fn(capture_inputs)
with torch.cuda.graph(graph, pool=self.memory_pool):
output = forward_fn(capture_inputs)
if postprocess_fn is not None:
postprocess_fn(capture_inputs)

self.graphs[key] = graph
self.graph_outputs[key] = make_weak_ref(output)
Expand Down
26 changes: 25 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,26 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
self.previous_kv_lens_offsets_cuda[:num_gen_requests])
return inputs

def _postprocess_inputs(self, inputs: Dict[str, Any]):
"""
Postprocess to make sure model forward doesn't change the inputs.
It is only used in cuda graph capture, because other cases will prepare
new inputs before the model forward.
"""
if self.enable_spec_decode and not self._disable_overlap_scheduler:
if inputs['attn_metadata'].kv_cache_manager is not None:
num_seqs = inputs['attn_metadata'].num_seqs
num_ctx_requests = inputs['attn_metadata'].num_contexts
num_gen_requests = inputs['attn_metadata'].num_generations
num_ctx_tokens = inputs['attn_metadata'].num_ctx_tokens
previous_batch_tokens = inputs['input_ids'].shape[
0] - num_ctx_tokens
inputs['position_ids'][0, num_ctx_tokens:] -= (
self.previous_pos_id_offsets_cuda[:previous_batch_tokens])
inputs['attn_metadata'].kv_lens_cuda[
num_ctx_requests:num_seqs] -= (
self.previous_kv_lens_offsets_cuda[:num_gen_requests])

def _get_all_rank_num_tokens(self, attn_metadata: AttentionMetadata):
if self.enable_attention_dp:
return list(self.dist.tp_allgather(attn_metadata.num_tokens))
Expand Down Expand Up @@ -2206,8 +2226,12 @@ def capture_forward_fn(inputs: Dict[str, Any]):
gather_ids=gather_ids,
gather_context_logits=gather_context_logits)

def capture_postprocess_fn(inputs: Dict[str, Any]):
self._postprocess_inputs(inputs)

self.cuda_graph_runner.capture(batch_size,
capture_forward_fn, inputs)
capture_forward_fn, inputs,
capture_postprocess_fn)

# here we don't need to use context since cuda graph capture didn't run kernel.
# maybe we need a cleaner way to do this.
Expand Down
36 changes: 36 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,6 +1871,42 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
# task.evaluate(llm,
# extra_evaluator_kwargs=dict(apply_chat_template=True))

def test_nvfp4_multi_gpus_corner_case(self):
"""
This test is used to test the corner case of the NVFP4 model.
When using the same value for max_seq_len and max_num_tokens, there will be no
enough kv block for the dummy requests in CUDA graph warmup when creating
the py_executor before estimating kv cache. Then CUDA graph capture will be
triggered when estimating kv cache. This may cause some errors.
More info in https://nvbugs/5485325.
"""
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.80,
dtype="fp8",
enable_block_reuse=False)
pytorch_config = dict(disable_overlap_scheduler=False,
cuda_graph_config=CudaGraphConfig(
enable_padding=True, max_batch_size=1024),
moe_config=MoeConfig(backend="TRTLLM"))

mtp_config = MTPDecodingConfig(num_nextn_predict_layers=1)
with LLM(f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-FP4",
tensor_parallel_size=8,
pipeline_parallel_size=1,
moe_expert_parallel_size=8,
kv_cache_config=kv_cache_config,
**pytorch_config,
enable_attention_dp=False,
speculative_config=mtp_config,
max_seq_len=5120,
max_num_tokens=5120) as llm:

assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4

task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_mpi_world_size(8)
@skip_pre_hopper
@pytest.mark.parametrize(
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/qa/llm_function_full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput]
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8]
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4]
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput]
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/qa/llm_function_sanity.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency]
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4]
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8]
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput]
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False]
Expand Down
15 changes: 15 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ l0_dgx_b200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (180)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] TIMEOUT (180)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp] TIMEOUT (180)
- condition:
ranges:
system_gpu_count:
gte: 8
lte: 8
wildcards:
gpu:
- '*b200*'
linux_distribution_name: ubuntu*
cpu: x86_64
terms:
stage: post_merge
backend: pytorch
tests:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (180)
- condition:
ranges:
system_gpu_count:
Expand Down