Skip to content

Commit bd005bf

Browse files
lfr-0531litaotju
authored andcommitted
[https://nvbugs/5485325][fix] Add a postprocess to the model engine to fix the CUDA graph warmup issue when using speculative decoding (NVIDIA#7373)
Signed-off-by: Fanrong Li <[email protected]> Co-authored-by: Tao Li @ NVIDIA <[email protected]>
1 parent 24fc1f9 commit bd005bf

File tree

6 files changed

+90
-3
lines changed

6 files changed

+90
-3
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,11 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests):
164164
def needs_capture(self, batch_size: int):
165165
return (batch_size, self.draft_len) not in self.graph_outputs
166166

167-
def capture(self, batch_size: int, forward_fn: Callable,
168-
initial_inputs: Dict[str, Any]):
167+
def capture(self,
168+
batch_size: int,
169+
forward_fn: Callable,
170+
initial_inputs: Dict[str, Any],
171+
postprocess_fn: Optional[Callable] = None):
169172
"""Captures the forward pass for a given batch size."""
170173
key = (batch_size, self.draft_len)
171174
# [CUDA graph spec decode padding]
@@ -203,8 +206,12 @@ def capture(self, batch_size: int, forward_fn: Callable,
203206
with with_multi_stream(True), piecewise_cuda_graph(False):
204207
for _ in range(self.WARMUP_STEPS):
205208
forward_fn(capture_inputs)
209+
if postprocess_fn is not None:
210+
postprocess_fn(capture_inputs)
206211
with torch.cuda.graph(graph, pool=self.memory_pool):
207212
output = forward_fn(capture_inputs)
213+
if postprocess_fn is not None:
214+
postprocess_fn(capture_inputs)
208215

209216
self.graphs[key] = graph
210217
self.graph_outputs[key] = make_weak_ref(output)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,26 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
11961196

11971197
return inputs
11981198

1199+
def _postprocess_inputs(self, inputs: Dict[str, Any]):
1200+
"""
1201+
Postprocess to make sure model forward doesn't change the inputs.
1202+
It is only used in cuda graph capture, because other cases will prepare
1203+
new inputs before the model forward.
1204+
"""
1205+
if self.enable_spec_decode and not self._disable_overlap_scheduler:
1206+
if inputs['attn_metadata'].kv_cache_manager is not None:
1207+
num_seqs = inputs['attn_metadata'].num_seqs
1208+
num_ctx_requests = inputs['attn_metadata'].num_contexts
1209+
num_gen_requests = inputs['attn_metadata'].num_generations
1210+
num_ctx_tokens = inputs['attn_metadata'].num_ctx_tokens
1211+
previous_batch_tokens = inputs['input_ids'].shape[
1212+
0] - num_ctx_tokens
1213+
inputs['position_ids'][0, num_ctx_tokens:] -= (
1214+
self.previous_pos_id_offsets_cuda[:previous_batch_tokens])
1215+
inputs['attn_metadata'].kv_lens_cuda[
1216+
num_ctx_requests:num_seqs] -= (
1217+
self.previous_kv_lens_offsets_cuda[:num_gen_requests])
1218+
11991219
def _get_all_rank_num_tokens(self, attn_metadata: AttentionMetadata):
12001220
if self.enable_attention_dp:
12011221
return list(self.dist.tp_allgather(attn_metadata.num_tokens))
@@ -2298,8 +2318,12 @@ def capture_forward_fn(inputs: Dict[str, Any]):
22982318
gather_ids=gather_ids,
22992319
gather_context_logits=gather_context_logits)
23002320

2321+
def capture_postprocess_fn(inputs: Dict[str, Any]):
2322+
self._postprocess_inputs(inputs)
2323+
23012324
self.cuda_graph_runner.capture(batch_size,
2302-
capture_forward_fn, inputs)
2325+
capture_forward_fn, inputs,
2326+
capture_postprocess_fn)
23032327

23042328
# here we don't need to use context since cuda graph capture didn't run kernel.
23052329
# maybe we need a cleaner way to do this.

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,6 +1947,7 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
19471947
# task.evaluate(llm,
19481948
# extra_evaluator_kwargs=dict(apply_chat_template=True))
19491949

1950+
<<<<<<< HEAD
19501951
@skip_pre_blackwell
19511952
@pytest.mark.parametrize(
19521953
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend",
@@ -2006,6 +2007,41 @@ def test_nvfp4_multi_gpus_chunked_prefill(self, tp_size, pp_size, ep_size,
20062007
assert llm.args.moe_config.backend == moe_backend
20072008
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
20082009

2010+
=======
2011+
def test_nvfp4_multi_gpus_corner_case(self):
2012+
"""
2013+
This test is used to test the corner case of the NVFP4 model.
2014+
When using the same value for max_seq_len and max_num_tokens, there will be no
2015+
enough kv block for the dummy requests in CUDA graph warmup when creating
2016+
the py_executor before estimating kv cache. Then CUDA graph capture will be
2017+
triggered when estimating kv cache. This may cause some errors.
2018+
More info in https://nvbugs/5485325.
2019+
"""
2020+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.80,
2021+
dtype="fp8",
2022+
enable_block_reuse=False)
2023+
pytorch_config = dict(disable_overlap_scheduler=False,
2024+
cuda_graph_config=CudaGraphConfig(
2025+
enable_padding=True, max_batch_size=1024),
2026+
moe_config=MoeConfig(backend="TRTLLM"))
2027+
2028+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=1)
2029+
with LLM(f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-FP4",
2030+
tensor_parallel_size=8,
2031+
pipeline_parallel_size=1,
2032+
moe_expert_parallel_size=8,
2033+
kv_cache_config=kv_cache_config,
2034+
**pytorch_config,
2035+
enable_attention_dp=False,
2036+
speculative_config=mtp_config,
2037+
max_seq_len=5120,
2038+
max_num_tokens=5120) as llm:
2039+
2040+
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
2041+
2042+
task = MMLU(self.MODEL_NAME)
2043+
task.evaluate(llm)
2044+
>>>>>>> 777679303 ([https://nvbugs/5485325][fix] Add a postprocess to the model engine to fix the CUDA graph warmup issue when using speculative decoding (#7373))
20092045
task = GSM8K(self.MODEL_NAME)
20102046
task.evaluate(llm)
20112047

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughp
488488
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4]
489489
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_chunked_prefill[latency]
490490
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_chunked_prefill[throughput_tp4]
491+
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case
491492
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput]
492493
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[latency]
493494
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[throughput]

tests/integration/test_lists/qa/llm_function_core_sanity.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,15 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency]
3737
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4]
3838
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8]
3939
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput]
40+
<<<<<<< HEAD:tests/integration/test_lists/qa/llm_function_core_sanity.txt
4041
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_chunked_prefill[latency]
4142
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_chunked_prefill[throughput_tp4]
4243
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput]
4344
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[latency]
4445
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[throughput]
46+
=======
47+
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case
48+
>>>>>>> 777679303 ([https://nvbugs/5485325][fix] Add a postprocess to the model engine to fix the CUDA graph warmup issue when using speculative decoding (#7373)):tests/integration/test_lists/qa/llm_function_sanity.txt
4549
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0]
4650
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]
4751
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False]

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,21 @@ l0_dgx_b200:
7373
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (180)
7474
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] TIMEOUT (180)
7575
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp] TIMEOUT (180)
76+
- condition:
77+
ranges:
78+
system_gpu_count:
79+
gte: 8
80+
lte: 8
81+
wildcards:
82+
gpu:
83+
- '*b200*'
84+
linux_distribution_name: ubuntu*
85+
cpu: x86_64
86+
terms:
87+
stage: post_merge
88+
backend: pytorch
89+
tests:
90+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (180)
7691
- condition:
7792
ranges:
7893
system_gpu_count:

0 commit comments

Comments
 (0)