From d2c677f1e8168c4f482b87a7863f0dcc9814f7a6 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Fri, 5 Sep 2025 12:04:42 +0800 Subject: [PATCH] [https://nvbugs/5485325][fix] Add a postprocess to the model engine to fix the CUDA graph warmup issue when using speculative decoding (#7373) Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Co-authored-by: Tao Li @ NVIDIA --- .../_torch/pyexecutor/cuda_graph_runner.py | 11 ++++-- .../_torch/pyexecutor/model_engine.py | 26 +++++++++++++- .../defs/accuracy/test_llm_api_pytorch.py | 36 +++++++++++++++++++ .../test_lists/qa/llm_function_core.txt | 1 + .../qa/llm_function_core_sanity.txt | 1 + .../test_lists/test-db/l0_dgx_b200.yml | 15 ++++++++ 6 files changed, 87 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 15afa50e4df..8affdf55542 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -164,8 +164,11 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests): def needs_capture(self, batch_size: int): return (batch_size, self.draft_len) not in self.graph_outputs - def capture(self, batch_size: int, forward_fn: Callable, - initial_inputs: Dict[str, Any]): + def capture(self, + batch_size: int, + forward_fn: Callable, + initial_inputs: Dict[str, Any], + postprocess_fn: Optional[Callable] = None): """Captures the forward pass for a given batch size.""" key = (batch_size, self.draft_len) # [CUDA graph spec decode padding] @@ -203,8 +206,12 @@ def capture(self, batch_size: int, forward_fn: Callable, with with_multi_stream(True), piecewise_cuda_graph(False): for _ in range(self.WARMUP_STEPS): forward_fn(capture_inputs) + if postprocess_fn is not None: + postprocess_fn(capture_inputs) with torch.cuda.graph(graph, pool=self.memory_pool): output = forward_fn(capture_inputs) + if postprocess_fn is not None: + postprocess_fn(capture_inputs) self.graphs[key] = graph self.graph_outputs[key] = make_weak_ref(output) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 426dcb29ec4..dab97d26b6b 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1196,6 +1196,26 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]): return inputs + def _postprocess_inputs(self, inputs: Dict[str, Any]): + """ + Postprocess to make sure model forward doesn't change the inputs. + It is only used in cuda graph capture, because other cases will prepare + new inputs before the model forward. + """ + if self.enable_spec_decode and not self._disable_overlap_scheduler: + if inputs['attn_metadata'].kv_cache_manager is not None: + num_seqs = inputs['attn_metadata'].num_seqs + num_ctx_requests = inputs['attn_metadata'].num_contexts + num_gen_requests = inputs['attn_metadata'].num_generations + num_ctx_tokens = inputs['attn_metadata'].num_ctx_tokens + previous_batch_tokens = inputs['input_ids'].shape[ + 0] - num_ctx_tokens + inputs['position_ids'][0, num_ctx_tokens:] -= ( + self.previous_pos_id_offsets_cuda[:previous_batch_tokens]) + inputs['attn_metadata'].kv_lens_cuda[ + num_ctx_requests:num_seqs] -= ( + self.previous_kv_lens_offsets_cuda[:num_gen_requests]) + def _get_all_rank_num_tokens(self, attn_metadata: AttentionMetadata): if self.enable_attention_dp: return list(self.dist.tp_allgather(attn_metadata.num_tokens)) @@ -2298,8 +2318,12 @@ def capture_forward_fn(inputs: Dict[str, Any]): gather_ids=gather_ids, gather_context_logits=gather_context_logits) + def capture_postprocess_fn(inputs: Dict[str, Any]): + self._postprocess_inputs(inputs) + self.cuda_graph_runner.capture(batch_size, - capture_forward_fn, inputs) + capture_forward_fn, inputs, + capture_postprocess_fn) # here we don't need to use context since cuda graph capture didn't run kernel. # maybe we need a cleaner way to do this. diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index b5ad70d7bf6..2cd02dc2f9d 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2009,6 +2009,42 @@ def test_nvfp4_multi_gpus_chunked_prefill(self, tp_size, pp_size, ep_size, task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + def test_nvfp4_multi_gpus_corner_case(self): + """ + This test is used to test the corner case of the NVFP4 model. + When using the same value for max_seq_len and max_num_tokens, there will be no + enough kv block for the dummy requests in CUDA graph warmup when creating + the py_executor before estimating kv cache. Then CUDA graph capture will be + triggered when estimating kv cache. This may cause some errors. + More info in https://nvbugs/5485325. + """ + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.80, + dtype="fp8", + enable_block_reuse=False) + pytorch_config = dict(disable_overlap_scheduler=False, + cuda_graph_config=CudaGraphConfig( + enable_padding=True, max_batch_size=1024), + moe_config=MoeConfig(backend="TRTLLM")) + + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=1) + with LLM(f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-FP4", + tensor_parallel_size=8, + pipeline_parallel_size=1, + moe_expert_parallel_size=8, + kv_cache_config=kv_cache_config, + **pytorch_config, + enable_attention_dp=False, + speculative_config=mtp_config, + max_seq_len=5120, + max_num_tokens=5120) as llm: + + assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 + + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.skip_less_mpi_world_size(8) @skip_pre_hopper @pytest.mark.parametrize( diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 82f17430a75..07c066c8c81 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -488,6 +488,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughp accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_chunked_prefill[latency] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_chunked_prefill[throughput_tp4] +accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[latency] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[throughput] diff --git a/tests/integration/test_lists/qa/llm_function_core_sanity.txt b/tests/integration/test_lists/qa/llm_function_core_sanity.txt index 87ebd731e80..3ce19c0f130 100644 --- a/tests/integration/test_lists/qa/llm_function_core_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_core_sanity.txt @@ -42,6 +42,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_chunked_ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[latency] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[throughput] +accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index af66bf794f8..31fc2ee704c 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -73,6 +73,21 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (180) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] TIMEOUT (180) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp] TIMEOUT (180) +- condition: + ranges: + system_gpu_count: + gte: 8 + lte: 8 + wildcards: + gpu: + - '*b200*' + linux_distribution_name: ubuntu* + cpu: x86_64 + terms: + stage: post_merge + backend: pytorch + tests: + - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (180) - condition: ranges: system_gpu_count: