Skip to content

Commit e07f056

Browse files
truongp-awsawsjoshir
authored andcommitted
Neuron SDK Release 2.26.0 changes
1 parent fc415aa commit e07f056

File tree

87 files changed

+16106
-778
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+16106
-778
lines changed

examples/generate_flux.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import os
2+
import argparse
3+
import time
4+
import torch
5+
from neuronx_distributed_inference.models.diffusers.flux.application import NeuronFluxApplication
6+
from neuronx_distributed_inference.models.config import NeuronConfig
7+
from neuronx_distributed_inference.models.diffusers.flux.clip.modeling_clip import CLIPInferenceConfig
8+
from neuronx_distributed_inference.models.diffusers.flux.t5.modeling_t5 import T5InferenceConfig
9+
from neuronx_distributed_inference.models.diffusers.flux.modeling_flux import FluxBackboneInferenceConfig
10+
from neuronx_distributed_inference.models.diffusers.flux.vae.modeling_vae import VAEDecoderInferenceConfig
11+
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config
12+
from neuronx_distributed_inference.utils.diffusers_adapter import load_diffusers_config
13+
from neuronx_distributed_inference.utils.random import set_random_seed
14+
15+
set_random_seed(0)
16+
17+
# Existing Compiled working directory for the compiler
18+
BASE_COMPILE_WORK_DIR = "/tmp/flux/compiler_workdir/"
19+
20+
21+
def create_flux_config(model_path, world_size, backbone_tp_degree, dtype, height, width):
22+
text_encoder_path = os.path.join(model_path, "text_encoder")
23+
text_encoder_2_path = os.path.join(model_path, "text_encoder_2")
24+
backbone_path = os.path.join(model_path, "transformer")
25+
vae_decoder_path = os.path.join(model_path, "vae")
26+
27+
clip_neuron_config = NeuronConfig(
28+
tp_degree=1,
29+
world_size=world_size,
30+
torch_dtype=dtype,
31+
)
32+
clip_config = CLIPInferenceConfig(
33+
neuron_config=clip_neuron_config,
34+
load_config=load_pretrained_config(text_encoder_path),
35+
)
36+
37+
t5_neuron_config = NeuronConfig(
38+
tp_degree = world_size, # T5: TP degree = world_size
39+
world_size = world_size,
40+
torch_dtype=dtype
41+
)
42+
t5_config = T5InferenceConfig(
43+
neuron_config=t5_neuron_config,
44+
load_config=load_pretrained_config(text_encoder_2_path),
45+
)
46+
47+
backbone_neuron_config = NeuronConfig(
48+
tp_degree = backbone_tp_degree,
49+
world_size = world_size,
50+
torch_type = dtype
51+
)
52+
backbone_config = FluxBackboneInferenceConfig(
53+
neuron_config = backbone_neuron_config,
54+
load_config = load_diffusers_config(backbone_path),
55+
height = height,
56+
width = width,
57+
)
58+
59+
decoder_neuron_config = NeuronConfig(
60+
tp_degree = 1,
61+
world_size = world_size,
62+
torch_type = dtype
63+
)
64+
decoder_config = VAEDecoderInferenceConfig(
65+
neuron_config = decoder_neuron_config,
66+
load_config = load_diffusers_config(vae_decoder_path),
67+
height = height,
68+
width = width,
69+
transformer_in_channels = backbone_config.in_channels,
70+
)
71+
72+
setattr(backbone_config, "vae_scale_factor", decoder_config.vae_scale_factor)
73+
74+
return (clip_config, t5_config, backbone_config, decoder_config)
75+
76+
def run_flux_generate(args):
77+
print(f"run_flux_generate with args: {args}")
78+
world_size = 8
79+
backbone_tp_degree = 8
80+
if args.instance_type == "trn1":
81+
if args.context_parallel_enabled:
82+
world_size = 16
83+
backbone_tp_degree = 8
84+
else:
85+
world_size = 8
86+
backbone_tp_degree = 8
87+
elif args.instance_type == "trn2":
88+
if args.context_parallel_enabled:
89+
world_size = 8
90+
backbone_tp_degree = 4
91+
else:
92+
world_size = 4
93+
backbone_tp_degree = 4
94+
95+
dtype = torch.bfloat16
96+
97+
clip_config, t5_config, backbone_config, decoder_config = create_flux_config(args.checkpoint_dir, world_size, backbone_tp_degree, dtype, args.height, args.width)
98+
99+
flux_app = NeuronFluxApplication(
100+
model_path=args.checkpoint_dir,
101+
text_encoder_config = clip_config,
102+
text_encoder2_config = t5_config,
103+
backbone_config = backbone_config,
104+
decoder_config = decoder_config,
105+
instance_type = args.instance_type,
106+
height = args.height,
107+
width = args.width,
108+
)
109+
flux_app.compile(BASE_COMPILE_WORK_DIR)
110+
flux_app.load(BASE_COMPILE_WORK_DIR)
111+
112+
warmup_rounds = 5
113+
print("Warming up the model for better latency testing")
114+
for i in range(warmup_rounds):
115+
flux_app(
116+
args.prompt,
117+
height=args.height,
118+
width=args.width,
119+
guidance_scale=args.guidance_scale,
120+
num_inference_steps=args.num_inference_steps
121+
).images[0]
122+
123+
124+
if args.profile:
125+
from torch.profiler import profile, ProfilerActivity
126+
with profile(activities=[ProfilerActivity.CPU], record_shapes=True, profile_memory=True, with_stack=True) as prof:
127+
_run_flux_helper(flux_app, args)
128+
129+
print(prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10))
130+
prof.export_chrome_trace(f"{args.profile_name}")
131+
else:
132+
_run_flux_helper(flux_app, args)
133+
134+
135+
def _run_flux_helper(flux_app, args):
136+
total_time = 0
137+
for i in range(args.num_images):
138+
start_time = time.time()
139+
140+
image = flux_app(
141+
args.prompt,
142+
height=args.height,
143+
width=args.width,
144+
guidance_scale=args.guidance_scale,
145+
num_inference_steps=args.num_inference_steps
146+
).images[0]
147+
148+
end_time = time.time()
149+
generation_time = end_time - start_time
150+
total_time += generation_time
151+
152+
if args.save_image:
153+
filename = f"output_{i+1}.png"
154+
image.save(filename)
155+
156+
print(f"Image {i+1} generated in {generation_time:.2f} seconds")
157+
158+
average_time = total_time / args.num_images
159+
print(f"\nAverage generation time: {average_time:.2f} seconds")
160+
161+
162+
if __name__ == "__main__":
163+
# The Ckpt directory root under huggingface
164+
CKPT_DIR = "/shared/flux/FLUX.1-dev/"
165+
166+
parser = argparse.ArgumentParser()
167+
parser.add_argument("-p", "--prompt", type=str, default="A cat holding a sign that says hello world")
168+
parser.add_argument("-hh", "--height", type=int, default=1024)
169+
parser.add_argument("-w", "--width", type=int, default=1024)
170+
parser.add_argument("-n", "--num_inference_steps", type=int, default=25)
171+
parser.add_argument("-i", "--instance_type", type=str, default="trn2")
172+
parser.add_argument("-g", "--guidance_scale", type=float, default=3.5)
173+
parser.add_argument("-c", "--checkpoint_dir", type=str, default=CKPT_DIR)
174+
parser.add_argument("--profile", action="store_true")
175+
parser.add_argument("--profile_name", type=str, default="flux_torch_profile.json")
176+
parser.add_argument("--num_images", type=int, default=1)
177+
parser.add_argument("--save_image", action="store_true")
178+
parser.add_argument("--context_parallel_enabled", action="store_true", default=True)
179+
180+
args = parser.parse_args()
181+
run_flux_generate(args)

examples/generation_llama4.py

Lines changed: 76 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import os
33
import logging
4-
import base64
54

65
from transformers import AutoTokenizer, AutoProcessor, GenerationConfig
76
from neuronx_distributed_inference.models.config import OnDeviceSamplingConfig as SmplConfig
@@ -21,8 +20,7 @@
2120
VISION_TP_DEGERE = 16
2221
WORLD_SIZE = 64
2322
BATCH_SIZE = 1
24-
SEQ_LENGTH = 8192
25-
# SEQ_LENGTH = 10240 for chunked attention
23+
SEQ_LENGTH = 16384
2624
TEXT_TO_TEXT = False
2725
# TEXT_TO_TEXT = True for text only generation
2826
DTYPE = torch.bfloat16
@@ -33,46 +31,65 @@
3331
os.environ['NEURON_RT_NUM_CORES']=f'{TEXT_TP_DEGREE}'
3432
os.environ['BASE_COMPILE_WORK_DIR'] = "./compiler_path/"
3533

36-
model_path = "/home/ubuntu/models/Llama-4-Scout-17B-16E-Instruct/"
37-
traced_model_path = "/home/ubuntu/traced_model_Llama-4-Scout-17B-16E-Instruct"
34+
# Llama4 checkpoints can be downloaded from HuggingFace
35+
model_path = "/shared/models/Llama-4-Scout-17B-16E-Instruct/"
36+
# Path to the compiled model artifacts. If this directory exists, the next run will skip
37+
# the trace and compile steps, reducing test time.
38+
traced_model_path = "/shared/traced_models/Llama-4/scout_text_vision_baseline_bs1/"
3839

3940
torch.manual_seed(0)
4041

4142
def run_llama_generate_image_to_text():
42-
# Initialize configs and tokenizer.
43-
batch_size = 1
44-
text_neuron_config = Llama4NeuronConfig(batch_size=1,
45-
seq_len=SEQ_LENGTH,
46-
torch_dtype=torch.bfloat16,
47-
skip_sharding=False,
48-
save_sharded_checkpoint=False,
49-
tp_degree=TEXT_TP_DEGREE,
50-
cp_degree=1,
51-
on_device_sampling_config=SmplConfig(dynamic=False, top_k=1),
52-
world_size=WORLD_SIZE,
53-
capacity_factor=None,
54-
fused_qkv=False,
55-
attention_dtype=torch.float16,
56-
rpl_reduce_dtype=torch.float32,
57-
cast_type="as-declared",
58-
logical_neuron_cores=2)
59-
60-
vision_neuron_config = Llama4NeuronConfig(batch_size=1,
61-
seq_len=SEQ_LENGTH,
62-
torch_dtype=torch.float16,
63-
skip_sharding=False,
64-
save_sharded_checkpoint=False,
65-
tp_degree=VISION_TP_DEGERE,
66-
cp_degree=1,
67-
on_device_sampling_config=SmplConfig(dynamic=False, top_k=1),
68-
dp_degree=4,
69-
world_size=WORLD_SIZE,
70-
fused_qkv=True,
71-
qkv_kernel_enabled=True,
72-
attn_kernel_enabled=True,
73-
mlp_kernel_enabled=True,
74-
enable_bucketing=False,
75-
logical_neuron_cores=2)
43+
text_neuron_config = Llama4NeuronConfig(
44+
batch_size=1,
45+
is_continuous_batching=True,
46+
seq_len=SEQ_LENGTH,
47+
enable_bucketing=True,
48+
context_encoding_buckets=[256, 512, 1024, 2048, 4096, 8192, 10240, 16384],
49+
token_generation_buckets=[256, 512, 1024, 2048, 4096, 8192, 10240, 16384],
50+
torch_dtype=torch.float16,
51+
async_mode=True,
52+
rpl_reduce_dtype=torch.float32,
53+
tp_degree=TEXT_TP_DEGREE,
54+
cp_degree=16,
55+
on_device_sampling_config=SmplConfig(dynamic=True, top_k=1, top_k_kernel_enabled=True),
56+
world_size=WORLD_SIZE,
57+
fused_qkv=True,
58+
cast_type="as-declared",
59+
save_sharded_checkpoint=True,
60+
cc_pipeline_tiling_factor=1,
61+
sequence_parallel_enabled=True,
62+
qkv_kernel_enabled=True,
63+
attn_kernel_enabled=True,
64+
attn_block_tkg_nki_kernel_enabled=True,
65+
attn_block_tkg_nki_kernel_cache_update=True,
66+
k_cache_transposed=False,
67+
blockwise_matmul_config={
68+
"block_size": 256,
69+
"use_block_parallel": True,
70+
"block_sharding_strategy": "HI_LO",
71+
"skip_dma_token": True,
72+
"skip_dma_weight": True,
73+
"parallelize_token_to_block_mapping": True
74+
},
75+
logical_neuron_cores=2)
76+
77+
vision_neuron_config = Llama4NeuronConfig(
78+
batch_size=1,
79+
seq_len=SEQ_LENGTH,
80+
torch_dtype=torch.float16,
81+
tp_degree=VISION_TP_DEGERE,
82+
cp_degree=1,
83+
dp_degree=4,
84+
world_size=WORLD_SIZE,
85+
fused_qkv=True,
86+
qkv_kernel_enabled=True,
87+
attn_kernel_enabled=True,
88+
mlp_kernel_enabled=True,
89+
enable_bucketing=True,
90+
buckets=[8, 28, 88],
91+
save_sharded_checkpoint=True,
92+
logical_neuron_cores=2)
7693

7794
config = Llama4InferenceConfig(
7895
text_neuron_config=text_neuron_config,
@@ -85,10 +102,10 @@ def run_llama_generate_image_to_text():
85102

86103
hf_llama4_processor = AutoProcessor.from_pretrained(model_path)
87104
# Prepare generate outputs.
88-
text_prompt="If I had to write a haiku for this one"
105+
text_prompt="Describe this image"
89106
image_path="./dog.jpg"
90107
role='user'
91-
108+
92109
with torch.profiler.record_function("prepare_generation_inputs"):
93110
input_ids, attention_mask, pixel_values, vision_mask = prepare_generation_inputs_hf(text_prompt, image_path, hf_llama4_processor, role, config)
94111

@@ -97,7 +114,7 @@ def run_llama_generate_image_to_text():
97114
print("\nCompiling and saving model...")
98115
model = NeuronLlama4ForCausalLM(model_path, config)
99116
model.compile(traced_model_path)
100-
tokenizer.save_pretrained(traced_model_path)
117+
tokenizer.save_pretrained(traced_model_path)
101118

102119
# Load from compiled checkpoint.
103120

@@ -111,7 +128,7 @@ def run_llama_generate_image_to_text():
111128
generation_config = GenerationConfig.from_pretrained(model_path)
112129

113130
# Test Sampling Parameters
114-
sampling_params = prepare_sampling_params(batch_size=batch_size, top_k=[1], top_p=[1.0], temperature=[1.0])
131+
sampling_params = prepare_sampling_params(batch_size=1, top_k=[1], top_p=[1.0], temperature=[1.0])
115132
outputs = generation_model.generate(
116133
input_ids,
117134
generation_config=generation_config,
@@ -134,7 +151,7 @@ def run_llama_generate_image_to_text():
134151
role='user'
135152

136153
input_ids, attention_mask, _, _ = prepare_generation_inputs_hf(text_prompt, image_path, hf_llama4_processor, role)
137-
sampling_params = prepare_sampling_params(batch_size=batch_size, top_k=[1], top_p=[1.0], temperature=[1.0])
154+
sampling_params = prepare_sampling_params(batch_size=1, top_k=[1], top_p=[1.0], temperature=[1.0])
138155
outputs = generation_model.generate(
139156
input_ids,
140157
generation_config=generation_config,
@@ -159,22 +176,20 @@ def run_llama_generate_image_to_text():
159176
def run_llama_generate_text_to_text():
160177
# Initialize configs and tokenizer.
161178
batch_size = 1
162-
neuron_config = Llama4NeuronConfig(batch_size=1,
163-
seq_len=SEQ_LENGTH,
164-
torch_dtype=torch.bfloat16,
165-
skip_sharding=False,
166-
save_sharded_checkpoint=True,
167-
tp_degree=TEXT_TP_DEGREE,
168-
cp_degree=16,
169-
on_device_sampling_config=SmplConfig(dynamic=False, top_k=1),
170-
world_size=WORLD_SIZE,
171-
capacity_factor=None,
172-
fused_qkv=False,
173-
attention_dtype=torch.float16,
174-
rpl_reduce_dtype=torch.float32,
175-
cast_type="as-declared",
176-
logical_neuron_cores=2)
177-
179+
neuron_config = Llama4NeuronConfig(
180+
batch_size=1,
181+
is_continuous_batching=True,
182+
seq_len=SEQ_LENGTH,
183+
torch_dtype=torch.float16,
184+
rpl_reduce_dtype=torch.float32,
185+
tp_degree=TEXT_TP_DEGREE,
186+
cp_degree=1,
187+
on_device_sampling_config=SmplConfig(dynamic=True, top_k=1),
188+
world_size=WORLD_SIZE,
189+
fused_qkv=False,
190+
cast_type="as-declared",
191+
save_sharded_checkpoint=True,
192+
logical_neuron_cores=2)
178193

179194
config = LlamaInferenceConfig(
180195
neuron_config=neuron_config,
@@ -191,7 +206,7 @@ def run_llama_generate_text_to_text():
191206
print("\nCompiling and saving model...")
192207
model = NeuronLlama4TextForCausalLM(model_path, config.get_text_config())
193208
model.compile(traced_model_path)
194-
tokenizer.save_pretrained(traced_model_path)
209+
tokenizer.save_pretrained(traced_model_path)
195210
# Load from compiled checkpoint.
196211

197212
print("\nLoading model from compiled checkpoint...")

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def get_version(version_str):
3939
package_data={"": []},
4040
install_requires=[
4141
"neuronx_distributed",
42+
"torch_neuronx>=2.5",
4243
"transformers==4.51.*",
4344
"huggingface-hub",
4445
"sentencepiece",

0 commit comments

Comments
 (0)