From 172aa161e9733960e030095b9179b913add90dae Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Fri, 29 Aug 2025 02:34:17 -0700 Subject: [PATCH 1/5] fix cuda graph warmup when using speculative decoding. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/pyexecutor/cuda_graph_runner.py | 4 ++- .../_torch/pyexecutor/model_engine.py | 30 +++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 0007b99ebd2..3c831dc0c8e 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -137,7 +137,7 @@ def needs_capture(self, batch_size: int): return (batch_size, self.draft_len) not in self.graph_outputs def capture(self, batch_size: int, forward_fn: Callable, - initial_inputs: Dict[str, Any]): + postprocess_fn: Callable, initial_inputs: Dict[str, Any]): """Captures the forward pass for a given batch size.""" engine = self._get_engine() key = (batch_size, self.draft_len) @@ -181,8 +181,10 @@ def capture(self, batch_size: int, forward_fn: Callable, with with_multi_stream(True), piecewise_cuda_graph(False): for _ in range(self.WARMUP_STEPS): forward_fn(capture_inputs) + postprocess_fn(capture_inputs) with torch.cuda.graph(graph, pool=self.memory_pool): output = forward_fn(capture_inputs) + postprocess_fn(capture_inputs) self.graphs[key] = graph self.graph_outputs[key] = make_weak_ref(output) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 2920f5b6972..e7e5d7b4dad 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1125,6 +1125,26 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]): self.previous_kv_lens_offsets_cuda[:num_gen_requests]) return inputs + def _postprocess_inputs(self, inputs: Dict[str, Any]): + """ + Postprocess to make sure model forward doesn't change the inputs. + It is only used in cuda graph capture, because other cases will prepare + new inputs before the model forward. + """ + if self.enable_spec_decode and not self._disable_overlap_scheduler: + if inputs['attn_metadata'].kv_cache_manager is not None: + num_seqs = inputs['attn_metadata'].num_seqs + num_ctx_requests = inputs['attn_metadata'].num_contexts + num_gen_requests = inputs['attn_metadata'].num_generations + num_ctx_tokens = inputs['attn_metadata'].num_ctx_tokens + previous_batch_tokens = inputs['input_ids'].shape[ + 0] - num_ctx_tokens + inputs['position_ids'][0, num_ctx_tokens:] -= ( + self.previous_pos_id_offsets_cuda[:previous_batch_tokens]) + inputs['attn_metadata'].kv_lens_cuda[ + num_ctx_requests:num_seqs] -= ( + self.previous_kv_lens_offsets_cuda[:num_gen_requests]) + def _get_all_rank_num_tokens(self, attn_metadata: AttentionMetadata): if self.enable_attention_dp: return list(self.dist.tp_allgather(attn_metadata.num_tokens)) @@ -2193,13 +2213,19 @@ def forward( def capture_forward_fn(inputs: Dict[str, Any]): with MoeLoadBalancerIterContext(moe_load_balancer): - return self._forward_step( + outputs = self._forward_step( inputs, gather_ids=gather_ids, gather_context_logits=gather_context_logits) + return outputs + + def capture_postprocess_fn(inputs: Dict[str, Any]): + self._postprocess_inputs(inputs) self.cuda_graph_runner.capture(batch_size, - capture_forward_fn, inputs) + capture_forward_fn, + capture_postprocess_fn, + inputs) # here we don't need to use context since cuda graph capture didn't run kernel. # maybe we need a cleaner way to do this. From fc11211688d719bf47de253cfa468ffadeaed537 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Fri, 29 Aug 2025 03:36:14 -0700 Subject: [PATCH 2/5] fix. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e7e5d7b4dad..563a37d9f2b 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2213,11 +2213,10 @@ def forward( def capture_forward_fn(inputs: Dict[str, Any]): with MoeLoadBalancerIterContext(moe_load_balancer): - outputs = self._forward_step( + return self._forward_step( inputs, gather_ids=gather_ids, gather_context_logits=gather_context_logits) - return outputs def capture_postprocess_fn(inputs: Dict[str, Any]): self._postprocess_inputs(inputs) From a37dcdd49562d63ed5132f016bdc4917b9aadaa4 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 1 Sep 2025 07:03:05 -0700 Subject: [PATCH 3/5] add test. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../defs/accuracy/test_llm_api_pytorch.py | 28 +++++++++++++++++++ .../test_lists/test-db/l0_dgx_b200.yml | 15 ++++++++++ 2 files changed, 43 insertions(+) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 2667657a168..180608cae8b 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1828,6 +1828,34 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, # task.evaluate(llm, # extra_evaluator_kwargs=dict(apply_chat_template=True)) + def test_nvfp4_multi_gpus_corner_case(self): + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.80, + dtype="fp8", + enable_block_reuse=False) + pytorch_config = dict(disable_overlap_scheduler=False, + cuda_graph_config=CudaGraphConfig( + enable_padding=True, max_batch_size=1024), + moe_config=MoeConfig(backend="TRTLLM")) + + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=1) + with LLM(f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-FP4", + tensor_parallel_size=8, + pipeline_parallel_size=1, + moe_expert_parallel_size=8, + kv_cache_config=kv_cache_config, + **pytorch_config, + enable_attention_dp=False, + speculative_config=mtp_config, + max_seq_len=2200, + max_num_tokens=2200) as llm: + + assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 + + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.skip_less_mpi_world_size(8) @skip_pre_hopper @pytest.mark.parametrize( diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 179b50963bc..9f1fc1c68dc 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -70,6 +70,21 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (180) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] TIMEOUT (180) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp] TIMEOUT (180) +- condition: + ranges: + system_gpu_count: + gte: 8 + lte: 8 + wildcards: + gpu: + - '*b200*' + linux_distribution_name: ubuntu* + cpu: x86_64 + terms: + stage: post_merge + backend: pytorch + tests: + - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (180) - condition: ranges: system_gpu_count: From ee257951141a989b7dbdbbb703b18fb357299608 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 1 Sep 2025 23:56:10 -0700 Subject: [PATCH 4/5] add test. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tests/integration/test_lists/qa/llm_function_full.txt | 1 + tests/integration/test_lists/qa/llm_function_sanity.txt | 1 + tests/unittest/_torch/modeling/test_modeling_qwen.py | 5 ++++- tests/unittest/_torch/modeling/test_modeling_qwen_moe.py | 5 ++++- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_lists/qa/llm_function_full.txt b/tests/integration/test_lists/qa/llm_function_full.txt index 18d4faa2e89..f4b1d947520 100644 --- a/tests/integration/test_lists/qa/llm_function_full.txt +++ b/tests/integration/test_lists/qa/llm_function_full.txt @@ -525,6 +525,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4] +accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False] diff --git a/tests/integration/test_lists/qa/llm_function_sanity.txt b/tests/integration/test_lists/qa/llm_function_sanity.txt index d6e54316a2e..e5fa2da0423 100644 --- a/tests/integration/test_lists/qa/llm_function_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_sanity.txt @@ -35,6 +35,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] +accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen.py b/tests/unittest/_torch/modeling/test_modeling_qwen.py index d1d129de083..24a876cc6e7 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen.py @@ -270,6 +270,9 @@ def test_qwen_allclose_to_hf(self, scenario: Scenario) -> None: mock_engine = create_mock_engine(1) graph_runner = CUDAGraphRunner(mock_engine) + def postprocess_fn(inputs): + pass + def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not scenario.use_cuda_graph: @@ -283,7 +286,7 @@ def run_forward(input_ids, position_ids, attn_metadata): "attn_metadata": attn_metadata, } graph_runner.capture(1, lambda inputs: qwen.forward(**inputs), - inputs) + postprocess_fn, inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py b/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py index 8658ae0e242..4efefda99e0 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py @@ -320,6 +320,9 @@ def test_qwen_moe_allclose_to_hf(self, scenario: Scenario): mock_engine = create_mock_engine(1) graph_runner = CUDAGraphRunner(mock_engine) + def postprocess_fn(inputs): + pass + def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not scenario.use_cuda_graph: @@ -334,7 +337,7 @@ def run_forward(input_ids, position_ids, attn_metadata): } graph_runner.capture(1, lambda inputs: qwen_moe.forward(**inputs), - inputs) + postprocess_fn, inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated From fc96c6903183dd7094963d027ad86463b1444c13 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Tue, 2 Sep 2025 02:56:05 -0700 Subject: [PATCH 5/5] fix test. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py | 13 +++++++++---- tensorrt_llm/_torch/pyexecutor/model_engine.py | 5 ++--- .../defs/accuracy/test_llm_api_pytorch.py | 12 ++++++++++-- .../unittest/_torch/modeling/test_modeling_qwen.py | 5 +---- .../_torch/modeling/test_modeling_qwen_moe.py | 5 +---- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 3c831dc0c8e..7bbbfa2c038 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -136,8 +136,11 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests): def needs_capture(self, batch_size: int): return (batch_size, self.draft_len) not in self.graph_outputs - def capture(self, batch_size: int, forward_fn: Callable, - postprocess_fn: Callable, initial_inputs: Dict[str, Any]): + def capture(self, + batch_size: int, + forward_fn: Callable, + initial_inputs: Dict[str, Any], + postprocess_fn: Optional[Callable] = None): """Captures the forward pass for a given batch size.""" engine = self._get_engine() key = (batch_size, self.draft_len) @@ -181,10 +184,12 @@ def capture(self, batch_size: int, forward_fn: Callable, with with_multi_stream(True), piecewise_cuda_graph(False): for _ in range(self.WARMUP_STEPS): forward_fn(capture_inputs) - postprocess_fn(capture_inputs) + if postprocess_fn is not None: + postprocess_fn(capture_inputs) with torch.cuda.graph(graph, pool=self.memory_pool): output = forward_fn(capture_inputs) - postprocess_fn(capture_inputs) + if postprocess_fn is not None: + postprocess_fn(capture_inputs) self.graphs[key] = graph self.graph_outputs[key] = make_weak_ref(output) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 016d63ca40b..2726e1d652c 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2230,9 +2230,8 @@ def capture_postprocess_fn(inputs: Dict[str, Any]): self._postprocess_inputs(inputs) self.cuda_graph_runner.capture(batch_size, - capture_forward_fn, - capture_postprocess_fn, - inputs) + capture_forward_fn, inputs, + capture_postprocess_fn) # here we don't need to use context since cuda graph capture didn't run kernel. # maybe we need a cleaner way to do this. diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 180608cae8b..5fc1ea538ae 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1829,6 +1829,14 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, # extra_evaluator_kwargs=dict(apply_chat_template=True)) def test_nvfp4_multi_gpus_corner_case(self): + """ + This test is used to test the corner case of the NVFP4 model. + When using the same value for max_seq_len and max_num_tokens, there will be no + enough kv block for the dummy requests in CUDA graph warmup when creating + the py_executor before estimating kv cache. Then CUDA graph capture will be + triggered when estimating kv cache. This may cause some errors. + More info in https://nvbugs/5485325. + """ kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.80, dtype="fp8", enable_block_reuse=False) @@ -1846,8 +1854,8 @@ def test_nvfp4_multi_gpus_corner_case(self): **pytorch_config, enable_attention_dp=False, speculative_config=mtp_config, - max_seq_len=2200, - max_num_tokens=2200) as llm: + max_seq_len=5120, + max_num_tokens=5120) as llm: assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen.py b/tests/unittest/_torch/modeling/test_modeling_qwen.py index 24a876cc6e7..d1d129de083 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen.py @@ -270,9 +270,6 @@ def test_qwen_allclose_to_hf(self, scenario: Scenario) -> None: mock_engine = create_mock_engine(1) graph_runner = CUDAGraphRunner(mock_engine) - def postprocess_fn(inputs): - pass - def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not scenario.use_cuda_graph: @@ -286,7 +283,7 @@ def run_forward(input_ids, position_ids, attn_metadata): "attn_metadata": attn_metadata, } graph_runner.capture(1, lambda inputs: qwen.forward(**inputs), - postprocess_fn, inputs) + inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py b/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py index 4efefda99e0..8658ae0e242 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py @@ -320,9 +320,6 @@ def test_qwen_moe_allclose_to_hf(self, scenario: Scenario): mock_engine = create_mock_engine(1) graph_runner = CUDAGraphRunner(mock_engine) - def postprocess_fn(inputs): - pass - def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not scenario.use_cuda_graph: @@ -337,7 +334,7 @@ def run_forward(input_ids, position_ids, attn_metadata): } graph_runner.capture(1, lambda inputs: qwen_moe.forward(**inputs), - postprocess_fn, inputs) + inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated