Skip to content

Commit b2fc9a2

Browse files
committed
[None][feat] AutoDeploy ONNX export
[none][feat] Add AutoDeploy export-onnx mode Add a new mode "export-onnx" to AutoDeploy. The mode is almost identical to the default one with 2 difference: 1. Fuse torch_rope_with_explicit_cos_sin & torch_cached_attention_with_cache into onnx_rope_attnetion 2. The result is not TRT Engine but .onnx Files added: - export_onnx.py: The transformation to fuse the ops - graph_module_visualizer.py: Convert GraphModule to .dot - examples/onnx_export_llm.py: Example usage - onnx_driveos_llm.yaml: The new mode config file - onnx_attnetion.py: The definition of the fused op [none][feat] fix small graphviz bug, remove useless code [none][feat] Rename mode from onnx_driveos_llm to export_driveos_llm_onnx [none][feat] Rename export_onnx.py to fuse_rope_attention.py [none][feat] Annotate .meta['val'] with add_graph_input() [none][feat] Successfully export .onnx [none][feat] Add set_kvcache_placeholder_metadata transform [none][feat] Skip torch_cached_attention_prepare_metadata [none][feat] Fix SetKVCachePlaceholderMetadata transform [none][feat] Remove unused placeholder of prepare_metadata [none][feat] Fix to run DeepSeek-R1 [none][feat] Add remove_graph_input, refactor remove_unused_placeholder() [none][feat] Merge K&V cache placeholder [none][feat] Replace sin_cos with input [none][feat] Manually fuse rope & attn [none][feat] Export torch_attention_bsnd_grouped_sdpa with dynamic shape [none][feat] Manually match rope & attn, not replace yet [none][feat] Successfully export ONNX with dynamic input [none][feat] Hack out_spec to add graph output [none][feat] Fix present_key_values shape [none][feat] Fix input & output names [none][feat] Change out_spec in add_graph_output [none][feat] Fix export of torch_linear_simple The original translation misses a transpose on the weight. [none][feat] Fix present_key_values shape [none][feat] Rewire reshape's new shape as TRT-LLM edge [none][feat] Fix non-text rebase conflicts [none][feat] Fix AttentionPlugin domain. should be "" not "ai.onnx" [none][feat] Enhance visualize, use .meta["val"] instead of .meta["tensor_meta"] [none][feat] Fix visualize tensor width calculation When calculate the width of the tensor, check it the dimension is a int or SymInt. The original implementation accidentally introduce constraints to the symbol int. I don't execlty know how it happen. actually I don't think it should introduce new constraints, but it dose. [none][feat] Fix output dynamic batch_size Originally max batch size is 2, however, don't know why, when set to 2,the batch_size will collapse to literal static int 2 even we explicitly it is dynamic axis. And more weird, when set to 13, the batch_size will be dynamic. default=13, # to enable dynamic batch_size, the match size must > 1 [none][feat] Rename fuse_rope_attention_manually to fuse_rope_attention [none][feat] Remove fuse_rope_attention.py [none][feat] Rewire reshape to make the graph like Luxiao's [none][feat] Fix last_token_ids dtype from i32 to i64 [none][feat] Catch up update to date DriveOS LLM - Add placeholder kvcache_start_index - AttentionPlugin add input kvcache_start_index - Insert Unsqueeze -1 before GatherND - rope_rotary_cos_sin dynamic axis name changed from rope_max_position_length to max_position_embeddings - logits' dtype should be float32, insert a cast - Insert cast to f16 before AttentionPlugin - All cast to bf16 should be f16 [none][feat] Catch up update to date DriveOS LLM - model.half() convert whole model to f16, including weight - Remove AttentionPlugin attribute kv_cache_capacity & max_batch_size - AttentionPlugin output[1] shape infer by seq_len + past_len - AttentionPlugin domain changed from `onnx.ai` to `trt` - Placeholder `kvcache_start_index` dynamic axes changed from `batch_size` to `kv_cache_start_batch_size` [none][feat] Catch up-to-date main [none][feat] Add test for fuse_rope_attention transform - Add test for fuse_rope_attention - Enhance run_test_transformed_gm support Module with multiple input - Fix add_graph_output for graph with only one _LEAF_SPEC [none][feat] Add unit test for fuse_rope_attn - Add a unit test - Fix add_graph_output when out_spec is _LEAF_SPEC [none][feat] Export .json files [none][feat] add AutoDeploy export onnx end-to-end test [none][feat] Export ONNX with cpu to reduce GPU memory footprint [none][feat] Use model.config to get head_dim, instead of using literal Signed-off-by: Po-Han Huang <[email protected]> Signed-off-by: yoco xiao <[email protected]> [none][feat] Visualize graph only when env var AD_DEBUG_VISUALIZE_DIR is set - Now we don't visualize by default, only when AD_DEBUG_VISUALIZE_DIR is set. - Also, AD_DEBUG_VISUALIZE_DIR is the output dir, so you can specify the output dir - Simplify the logging message, move lots of msg to debug instead of info - Add .cursor to .gitignore Signed-off-by: yoco xiao <[email protected]>
1 parent 355e06d commit b2fc9a2

File tree

22 files changed

+3931
-16
lines changed

22 files changed

+3931
-16
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
__pycache__/
22
.vscode
3+
.cursor
34
*.engine
45
*.engine.config
56
*.cache

docker/common/install_base.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ init_ubuntu() {
5353
gdb \
5454
git-lfs \
5555
clang \
56+
graphviz \
5657
lld \
5758
llvm \
5859
libclang-rt-dev \
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import argparse
17+
18+
from tensorrt_llm._torch.auto_deploy import LLM, AutoDeployConfig
19+
20+
21+
def main():
22+
parser = argparse.ArgumentParser()
23+
parser.add_argument(
24+
"--model",
25+
type=str,
26+
default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
27+
help="The HF model to use for onnx export.",
28+
)
29+
parser.add_argument(
30+
"--max_seq_len",
31+
type=int,
32+
default=4,
33+
help="The max sequence length to use for the model.",
34+
)
35+
parser.add_argument(
36+
"--max_batch_size",
37+
type=int,
38+
# NOTE(yoco): Originally this is 2, however, don't know why, when set to 2,
39+
# the batch_size will collapse static int 2 even we explicitly it is dynamic axis.
40+
# And more weird, when set to 13, the batch_size will be dynamic.
41+
default=13, # to enable dynamic batch_size, the match size must > 1
42+
help="The max batch size to use for the model.",
43+
)
44+
parser.add_argument(
45+
"--device",
46+
type=str,
47+
default="cpu",
48+
help="The device to use for the model.",
49+
)
50+
parser.add_argument(
51+
"--output_dir",
52+
type=str,
53+
default=None,
54+
help="The directory to save the exported ONNX model.",
55+
)
56+
parser.add_argument(
57+
"--output_name",
58+
type=str,
59+
default=None,
60+
help="The name of the exported ONNX model.",
61+
)
62+
args = parser.parse_args()
63+
64+
print(f"Constructing model from {args.model}")
65+
66+
# Prepare the AutoDeploy config, mode is export_driveos_llm_onnx
67+
ad_config = AutoDeployConfig(
68+
model=args.model,
69+
mode="export_driveos_llm_onnx",
70+
max_batch_size=args.max_batch_size,
71+
max_seq_len=args.max_seq_len,
72+
device=args.device,
73+
)
74+
ad_config.attn_backend = "torch"
75+
if args.output_dir is not None:
76+
ad_config.transforms["export_to_onnx"]["output_dir"] = args.output_dir
77+
if args.output_name is not None:
78+
ad_config.transforms["export_to_onnx"]["output_name"] = args.output_name
79+
_ = LLM(**ad_config.to_llm_kwargs())
80+
81+
82+
if __name__ == "__main__":
83+
main()

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ mpi4py
1010
numpy<2
1111
onnx>=1.18.0,<1.20.0
1212
onnx_graphsurgeon>=0.5.2
13+
onnxscript==0.5.4
14+
graphviz
1315
openai
1416
polygraphy
1517
psutil
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# This is the set of transforms running in "graph" mode. In this mode, we capture the full graph
2+
# of the model and optimize it for inference.
3+
transforms:
4+
############################################################################################
5+
# BUILD MODEL, EXPORT TO GRAPH MODULE, AND CLEAN UP
6+
############################################################################################
7+
build_model:
8+
stage: factory
9+
run_per_gm: false
10+
device: meta
11+
requires_clean_graph: false
12+
export_to_gm:
13+
stage: export
14+
clone_state_dict: false
15+
strict: false
16+
run_per_gm: false
17+
requires_clean_graph: false
18+
cleanup_noop_slice:
19+
stage: post_export
20+
cleanup_noop_add:
21+
stage: post_export
22+
cleanup_input_constraints:
23+
stage: post_export
24+
############################################################################################
25+
# RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION
26+
############################################################################################
27+
match_moe_pattern:
28+
stage: pattern_matcher
29+
match_dense_moe_pattern:
30+
stage: pattern_matcher
31+
match_repeat_kv:
32+
stage: pattern_matcher
33+
run_shape_prop: true
34+
match_eager_attention:
35+
stage: pattern_matcher
36+
requires_shape_prop: true
37+
match_sdpa_to_torch_attention:
38+
stage: pattern_matcher
39+
match_grouped_attention:
40+
stage: pattern_matcher
41+
match_attention_layout:
42+
stage: pattern_matcher
43+
attn_layout: bsnd
44+
match_rope_pattern:
45+
stage: pattern_matcher
46+
match_rope_layout:
47+
stage: pattern_matcher
48+
expected_layout: bsnd
49+
############################################################################################
50+
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
51+
############################################################################################
52+
eliminate_redundant_transposes:
53+
stage: pattern_matcher
54+
# TODO (lucaslie): let's move this to perf optimization once TP sharding is improved
55+
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
56+
# NOTE (yoco): To export ONNX for DriveOS LLM, we don't need this optimization,
57+
# because the rope will be fused into the AttentionPlugin operation
58+
# in the fuse_rope_attention transform.
59+
# optimize_rope:
60+
# stage: pattern_matcher
61+
quantize_int4_linear_from_config:
62+
stage: pattern_matcher
63+
quantize_fp8_linear_from_config:
64+
stage: pattern_matcher
65+
quantize_nvfp4_linear_from_config:
66+
stage: pattern_matcher
67+
quantize_fp8_bmm_from_config:
68+
stage: pattern_matcher
69+
quantize_fp8_from_graph:
70+
stage: pattern_matcher
71+
quantize_nvfp4_from_graph:
72+
stage: pattern_matcher
73+
quantize_fp8_moe:
74+
stage: pattern_matcher
75+
quantize_nvfp4_moe:
76+
stage: pattern_matcher
77+
quantize_mxfp4_moe:
78+
stage: pattern_matcher
79+
detect_sharding:
80+
stage: sharding
81+
simple_shard_only: false
82+
sharding_source: ["manual", "factory", "heuristic"]
83+
support_partial_config: true
84+
sharding_dims: ["tp", "ep", "bmm"]
85+
allreduce_strategy: "AUTO"
86+
requires_shape_prop: true
87+
sharding_transform_executor:
88+
stage: sharding
89+
run_shape_prop: true
90+
############################################################################################
91+
# MOVE MODEL AND LOAD WEIGHTS
92+
############################################################################################
93+
load_weights:
94+
stage: weight_load
95+
run_per_gm: false
96+
checkpoint_device: cpu
97+
move_inputs_to_device:
98+
stage: weight_load
99+
checkpoint_device: cpu
100+
run_per_gm: false
101+
############################################################################################
102+
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
103+
############################################################################################
104+
fuse_gemms:
105+
stage: post_load_fusion
106+
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
107+
fuse_fp4_gemms:
108+
stage: post_load_fusion
109+
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
110+
fuse_fp8_gemms:
111+
stage: post_load_fusion
112+
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
113+
fuse_fp8_linear:
114+
stage: post_load_fusion
115+
backend: trtllm
116+
fuse_nvfp4_linear:
117+
stage: post_load_fusion
118+
backend: trtllm
119+
fuse_moe:
120+
stage: post_load_fusion
121+
enabled: true
122+
backend: trtllm
123+
fuse_fp8_moe:
124+
stage: post_load_fusion
125+
enabled: true
126+
backend: trtllm
127+
128+
############################################################################################
129+
# VISUALIZE GRAPH
130+
############################################################################################
131+
visualize_namespace:
132+
stage: visualize
133+
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/8460
134+
############################################################################################
135+
# FUSE Rope Attention & export to ONNX
136+
############################################################################################
137+
fuse_rope_attention:
138+
stage: export_onnx
139+
short_reshape_attention_output:
140+
stage: export_onnx
141+
gather_last_token_ids:
142+
stage: export_onnx
143+
adapt_to_driveos_llm:
144+
stage: export_onnx
145+
export_to_onnx:
146+
stage: export_onnx
147+
output_dir: "."
148+
output_name: "model.onnx"

0 commit comments

Comments
 (0)