Skip to content

Revised the lowering pass according to Bo's suggestion #3756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 169 additions & 23 deletions examples/dynamo/llama2_flashinfer_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
This example illustrates advanced extensibility in Torch-TensorRT through automatic plugin generation and operator lowering customization.
"""

from typing import Callable, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union

import flashinfer
import torch
import torch_tensorrt
from torch._subclasses import FakeTensor
from torch.fx.passes.shape_prop import TensorMetadata
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
_aten_lowering_pass,
Expand Down Expand Up @@ -51,6 +52,8 @@ def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tenso
def replace_rmsnorm(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
print("before2\n")
print(gm.graph)
for node in gm.graph.nodes:
if (
node.target == torch.ops.aten._to_copy.default
Expand Down Expand Up @@ -90,13 +93,60 @@ def replace_rmsnorm(
weight_mul_node = list(copy_node.users)[0]

weight = weight_mul_node.args[0]
hidden_states_node = node.args[0]

original_meta = weight_mul_node.meta.get(
original_meta = hidden_states_node.meta.get(
"tensor_meta", {}
)
memory_format = original_meta.memory_format
from torch.fx.experimental.symbolic_shapes import (
ShapeEnv,
)

shape_env = ShapeEnv()

with gm.graph.inserting_after(weight_mul_node):
input_meta = node.args[0].meta["val"]
batch_size = input_meta.shape[0]
seq_len = input_meta.shape[1]
head_dim = input_meta.shape[2]

# Create symbolic ints for batch_size
if isinstance(batch_size, int):
batch_size_unbacked_symint = (
shape_env.create_unbacked_symint()
)
torch._check(
batch_size_unbacked_symint >= batch_size
)
torch._check(
batch_size_unbacked_symint <= batch_size
)
elif isinstance(batch_size, torch.SymInt):
pass
else:
raise ValueError(
"Batch size must be a sym int"
)

# Create symbolic ints for head_dim
if isinstance(head_dim, int):
head_dim_unbacked_symint = (
shape_env.create_unbacked_symint()
)
torch._check(
head_dim_unbacked_symint >= head_dim
)
torch._check(
head_dim_unbacked_symint <= head_dim
)
elif isinstance(head_dim, torch.SymInt):
pass
else:
raise ValueError(
"head_dim must be a sym int"
)

b = gm.graph.create_node(
op="call_function",
target=torch.ops.aten.sym_size.int,
Expand All @@ -111,19 +161,24 @@ def replace_rmsnorm(
is_quantized=False,
qparams={},
)

batch_size = node.args[0].meta["val"].shape[0]
b.meta["val"] = batch_size_unbacked_symint

s = gm.graph.create_node(
op="call_function",
target=torch.ops.aten.sym_size.int,
args=(node.args[0], 1),
)
s.meta.update(b.meta)

s.meta["val"] = seq_len
d = gm.graph.create_node(
op="call_function",
target=torch.ops.aten.sym_size.int,
args=(node.args[0], 2),
)
d.meta.update(b.meta)
d.meta["val"] = head_dim_unbacked_symint

with gm.graph.inserting_after(b):
new_first_dim = gm.graph.create_node(
Expand All @@ -150,11 +205,11 @@ def replace_rmsnorm(
[b_val * s_val, d_val]
),
dtype=original_meta.dtype,
requires_grad=True,
stride=None,
memory_format=memory_format,
is_quantized=False,
qparams={},
requires_grad=False,
)
)

Expand Down Expand Up @@ -183,11 +238,22 @@ def replace_rmsnorm(
[b, s, d],
),
)
reshapback_node.meta["tensor_meta"] = (
TensorMetadata(
shape=torch.Size([b_val, s_val, d_val]),
dtype=original_meta.dtype,
stride=None,
memory_format=memory_format,
is_quantized=False,
qparams={},
requires_grad=False,
)
)

# reshapback_node.meta.update(weight_mul_node.meta)
weight_mul_node.replace_all_uses_with(
reshapback_node
)
reshapback_node.meta.update(weight_mul_node.meta)

modified_graph = True

Expand All @@ -207,6 +273,43 @@ def replace_rmsnorm(
return gm


@_aten_lowering_pass
def set_copy_node_meta_data(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.target == torch.ops.aten._to_copy.default and (
"tensor_meta" not in node.meta
):
input_node = node.args[0]

# Check if input has metadata
if "tensor_meta" in input_node.meta:
# Copy input metadata and update dtype to float32
output_meta = input_node.meta["tensor_meta"]
# output_meta.dtype = node.kwargs.get("dtype")

# # Assign to the _to_copy node
# node.meta["tensor_meta"] = output_meta
node.meta["tensor_meta"] = TensorMetadata(
shape=output_meta.shape,
dtype=node.kwargs.get("dtype"),
requires_grad=True,
stride=None,
memory_format=input_node.meta["tensor_meta"].memory_format,
is_quantized=False,
qparams={},
)

else:
# Handle missing metadata (optional warning/logging)
print(f"Warning: Input node {input_node} has no tensor_meta")

gm = clean_up_graph_after_modifications(gm)

return gm


# 1. Create a custom config with 1 layer
config = LlamaConfig(
vocab_size=32000,
Expand All @@ -222,12 +325,14 @@ def replace_rmsnorm(
with torch.no_grad():
model = LlamaForCausalLM(config).eval().half()

MAX_TOKENS = 64
seq_len = torch.export.Dim("seq_len", min=2, max=MAX_TOKENS)
# 3. Export with static shapes
input_ids = torch.randint(0, 32000, (1, 64)) # Static [batch=1, seq=64]
exported = torch.export.export(
model,
(input_ids,),
dynamic_shapes=None, # Fully static
dynamic_shapes=({1: seq_len},),
)

# Test forward pass
Expand All @@ -238,20 +343,61 @@ def replace_rmsnorm(
# Export validation

DEVICE = torch.device("cuda:0")

with torch_tensorrt.logging.errors():
trt_model = torch_tensorrt.dynamo.compile(
exported,
inputs=[input_ids],
enabled_precisions={torch.float32, torch.float16},
truncate_double=True,
device=DEVICE,
disable_tf32=True,
use_explicit_typing=False,
use_fp32_acc=True,
)

input_ids = input_ids.to(DEVICE)

res = trt_model.forward(input_ids)
print(res)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
with torch_tensorrt.dynamo.Debugger(
log_level="info",
# profile_format="trex",
# save_engine_profile=True,
capture_fx_graph_before=["remove_detach"],
capture_fx_graph_after=["replace_rmsnorm"],
logging_dir="/home/profile/logging/torchtrt",
engine_builder_monitor=False,
):
trt_model = torch_tensorrt.dynamo.compile(
exported,
inputs=[input_ids],
enabled_precisions={torch.float32, torch.float16},
truncate_double=True,
device=DEVICE,
disable_tf32=True,
use_explicit_typing=False,
use_fp32_acc=True,
use_python_runtime=True,
)

input_ids = input_ids.to(DEVICE)

res = trt_model.forward(input_ids)

# Benchmark TensorRT models

import time

def benchmark_model(model, input_ids, label, n_runs=100):
torch.cuda.synchronize()
start = time.time()
for _ in range(n_runs):
with torch.no_grad():
out = model(input_ids)
torch.cuda.synchronize()
end = time.time()
print(f"{label}: {n_runs} runs, total {(end - start):.4f} s")
return out

# Warmup
with torch.no_grad():
_ = trt_model(input_ids)

# Benchmark
trt_out = benchmark_model(trt_model, input_ids, "TensorRT model")

# Compare outputs

pytorch_logits = output.logits
trt_logits = trt_out.logits

pytorch_logits = pytorch_logits.to(DEVICE)
trt_logits = trt_logits.to(DEVICE)
print("Max abs diff:", (pytorch_logits - trt_logits).abs().max().item())
print("Mean abs diff:", (pytorch_logits - trt_logits).abs().mean().item())
Loading