1
+ import torch
2
+ import os
3
+ import logging
4
+ import base64
5
+
6
+ from transformers import AutoTokenizer , AutoProcessor , GenerationConfig
7
+ from neuronx_distributed_inference .models .config import OnDeviceSamplingConfig as SmplConfig
8
+ from neuronx_distributed_inference .utils .hf_adapter import load_pretrained_config , HuggingFaceGenerationAdapter
9
+ from neuronx_distributed_inference .modules .generation .sampling import prepare_sampling_params
10
+
11
+ from neuronx_distributed_inference .models .llama4 .modeling_llama4 import NeuronLlama4ForCausalLM , Llama4InferenceConfig , Llama4NeuronConfig
12
+ from neuronx_distributed_inference .models .llama4 .modeling_llama4_text import NeuronLlama4TextForCausalLM , LlamaInferenceConfig
13
+ from neuronx_distributed_inference .utils .benchmark import benchmark_sampling
14
+ from neuronx_distributed_inference .models .llama4 .utils .input_processor import prepare_generation_inputs_hf
15
+
16
+ # TODO : Either read from os_environment var or from arg_parser.
17
+ logger = logging .getLogger (__name__ )
18
+ logger .setLevel (logging .DEBUG )
19
+
20
+ TEXT_TP_DEGREE = 64
21
+ VISION_TP_DEGERE = 16
22
+ WORLD_SIZE = 64
23
+ BATCH_SIZE = 1
24
+ SEQ_LENGTH = 8192
25
+ # SEQ_LENGTH = 10240 for chunked attention
26
+ TEXT_TO_TEXT = False
27
+ # TEXT_TO_TEXT = True for text only generation
28
+ DTYPE = torch .bfloat16
29
+
30
+ os .environ ['NEURON_PLATFORM_TARGET_OVERRIDE' ] = 'trn2'
31
+ os .environ ["NEURON_RT_VIRTUAL_CORE_SIZE" ] = "2"
32
+ os .environ ["NEURON_LOGICAL_NC_CONFIG" ] = "2"
33
+ os .environ ['NEURON_RT_NUM_CORES' ]= f'{ TEXT_TP_DEGREE } '
34
+ os .environ ['BASE_COMPILE_WORK_DIR' ] = "./compiler_path/"
35
+
36
+ model_path = "/home/ubuntu/models/Llama-4-Scout-17B-16E-Instruct/"
37
+ traced_model_path = "/home/ubuntu/traced_model_Llama-4-Scout-17B-16E-Instruct"
38
+
39
+ torch .manual_seed (0 )
40
+
41
+ def run_llama_generate_image_to_text ():
42
+ # Initialize configs and tokenizer.
43
+ batch_size = 1
44
+ text_neuron_config = Llama4NeuronConfig (batch_size = 1 ,
45
+ seq_len = SEQ_LENGTH ,
46
+ torch_dtype = torch .bfloat16 ,
47
+ skip_sharding = False ,
48
+ save_sharded_checkpoint = False ,
49
+ tp_degree = TEXT_TP_DEGREE ,
50
+ cp_degree = 1 ,
51
+ on_device_sampling_config = SmplConfig (dynamic = False , top_k = 1 ),
52
+ world_size = WORLD_SIZE ,
53
+ capacity_factor = None ,
54
+ fused_qkv = False ,
55
+ attention_dtype = torch .float16 ,
56
+ rpl_reduce_dtype = torch .float32 ,
57
+ cast_type = "as-declared" ,
58
+ logical_neuron_cores = 2 )
59
+
60
+ vision_neuron_config = Llama4NeuronConfig (batch_size = 1 ,
61
+ seq_len = SEQ_LENGTH ,
62
+ torch_dtype = torch .float16 ,
63
+ skip_sharding = False ,
64
+ save_sharded_checkpoint = False ,
65
+ tp_degree = VISION_TP_DEGERE ,
66
+ cp_degree = 1 ,
67
+ on_device_sampling_config = SmplConfig (dynamic = False , top_k = 1 ),
68
+ dp_degree = 4 ,
69
+ world_size = WORLD_SIZE ,
70
+ fused_qkv = True ,
71
+ qkv_kernel_enabled = True ,
72
+ attn_kernel_enabled = True ,
73
+ mlp_kernel_enabled = True ,
74
+ enable_bucketing = False ,
75
+ logical_neuron_cores = 2 )
76
+
77
+ config = Llama4InferenceConfig (
78
+ text_neuron_config = text_neuron_config ,
79
+ vision_neuron_config = vision_neuron_config ,
80
+ load_config = load_pretrained_config (model_path ),
81
+ )
82
+
83
+ tokenizer = AutoTokenizer .from_pretrained (model_path , padding_side = "right" )
84
+ tokenizer .pad_token = tokenizer .eos_token
85
+
86
+ hf_llama4_processor = AutoProcessor .from_pretrained (model_path )
87
+ # Prepare generate outputs.
88
+ text_prompt = "If I had to write a haiku for this one"
89
+ image_path = "./dog.jpg"
90
+ role = 'user'
91
+
92
+ with torch .profiler .record_function ("prepare_generation_inputs" ):
93
+ input_ids , attention_mask , pixel_values , vision_mask = prepare_generation_inputs_hf (text_prompt , image_path , hf_llama4_processor , role , config )
94
+
95
+ if not os .path .exists (traced_model_path ):
96
+ # Compile and save model.
97
+ print ("\n Compiling and saving model..." )
98
+ model = NeuronLlama4ForCausalLM (model_path , config )
99
+ model .compile (traced_model_path )
100
+ tokenizer .save_pretrained (traced_model_path )
101
+
102
+ # Load from compiled checkpoint.
103
+
104
+ print ("\n Loading model from compiled checkpoint..." )
105
+ model = NeuronLlama4ForCausalLM (traced_model_path )
106
+ model .load (traced_model_path , skip_warmup = True )
107
+ tokenizer = AutoTokenizer .from_pretrained (traced_model_path )
108
+
109
+ generation_model = HuggingFaceGenerationAdapter (model )
110
+
111
+ generation_config = GenerationConfig .from_pretrained (model_path )
112
+
113
+ # Test Sampling Parameters
114
+ sampling_params = prepare_sampling_params (batch_size = batch_size , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ])
115
+ outputs = generation_model .generate (
116
+ input_ids ,
117
+ generation_config = generation_config ,
118
+ attention_mask = attention_mask ,
119
+ max_length = model .config .neuron_config .max_length ,
120
+ sampling_params = sampling_params ,
121
+ pixel_values = pixel_values ,
122
+ vision_mask = vision_mask .to (torch .bool ),
123
+ max_new_tokens = 512 ,
124
+ )
125
+ output_tokens = tokenizer .batch_decode (outputs , skip_special_tokens = True , clean_up_tokenization_spaces = False )
126
+ print (f"Generated outputs shape: { outputs .shape } " )
127
+ for i , output_token in enumerate (output_tokens ):
128
+ print (f"Output { i } : { output_token } " )
129
+
130
+
131
+ # Test Text-Only inputs
132
+ text_prompt = "what is the recipe of mayonnaise in two sentences?"
133
+ image_path = None
134
+ role = 'user'
135
+
136
+ input_ids , attention_mask , _ , _ = prepare_generation_inputs_hf (text_prompt , image_path , hf_llama4_processor , role )
137
+ sampling_params = prepare_sampling_params (batch_size = batch_size , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ])
138
+ outputs = generation_model .generate (
139
+ input_ids ,
140
+ generation_config = generation_config ,
141
+ attention_mask = attention_mask ,
142
+ max_length = model .config .neuron_config .max_length ,
143
+ sampling_params = sampling_params ,
144
+ pixel_values = None ,
145
+ vision_mask = None ,
146
+ max_new_tokens = 100 ,
147
+ )
148
+ output_tokens = tokenizer .batch_decode (outputs , skip_special_tokens = True , clean_up_tokenization_spaces = False )
149
+ print (f"Generated outputs shape: { outputs .shape } " )
150
+ for i , output_token in enumerate (output_tokens ):
151
+ print (f"Output { i } : { output_token } " )
152
+
153
+ print ("\n Performance Benchmarking text-only!" )
154
+ benchmark_sampling (model = model , draft_model = None , generation_config = generation_config , target = "all" , image = None ,benchmark_report_path = "benchmark_report_text_only.json" , num_runs = 5 )
155
+
156
+ print ("\n Performance Benchmarking text+image!" )
157
+ benchmark_sampling (model = model , draft_model = None , generation_config = generation_config , target = "all" , image = True ,benchmark_report_path = "benchmark_report_text_and_image.json" , num_runs = 5 )
158
+
159
+ def run_llama_generate_text_to_text ():
160
+ # Initialize configs and tokenizer.
161
+ batch_size = 1
162
+ neuron_config = Llama4NeuronConfig (batch_size = 1 ,
163
+ seq_len = SEQ_LENGTH ,
164
+ torch_dtype = torch .bfloat16 ,
165
+ skip_sharding = False ,
166
+ save_sharded_checkpoint = True ,
167
+ tp_degree = TEXT_TP_DEGREE ,
168
+ cp_degree = 16 ,
169
+ on_device_sampling_config = SmplConfig (dynamic = False , top_k = 1 ),
170
+ world_size = WORLD_SIZE ,
171
+ capacity_factor = None ,
172
+ fused_qkv = False ,
173
+ attention_dtype = torch .float16 ,
174
+ rpl_reduce_dtype = torch .float32 ,
175
+ cast_type = "as-declared" ,
176
+ logical_neuron_cores = 2 )
177
+
178
+
179
+ config = LlamaInferenceConfig (
180
+ neuron_config = neuron_config ,
181
+ load_config = load_pretrained_config (model_path ),
182
+ )
183
+
184
+ tokenizer = AutoTokenizer .from_pretrained (model_path , padding_side = "right" )
185
+ tokenizer .pad_token = tokenizer .eos_token
186
+
187
+ hf_llama4_processor = AutoProcessor .from_pretrained (model_path )
188
+
189
+ if not os .path .exists (traced_model_path ):
190
+ # Compile and save model.
191
+ print ("\n Compiling and saving model..." )
192
+ model = NeuronLlama4TextForCausalLM (model_path , config .get_text_config ())
193
+ model .compile (traced_model_path )
194
+ tokenizer .save_pretrained (traced_model_path )
195
+ # Load from compiled checkpoint.
196
+
197
+ print ("\n Loading model from compiled checkpoint..." )
198
+ model = NeuronLlama4TextForCausalLM (traced_model_path , config .get_text_config ())
199
+ model .load (traced_model_path , skip_warmup = True )
200
+ tokenizer = AutoTokenizer .from_pretrained (traced_model_path )
201
+
202
+ # Test Text-Only inputs
203
+ text_prompt = "what is the recipe of mayonnaise in two sentences?"
204
+
205
+ # Uncomment for a longer prompt
206
+ # int_list = list(str(i) for i in range(2500))
207
+ # int_str = ', '.join(int_list)
208
+ # text_prompt = f"Keep counting until 3000. I will start {int_str}..."
209
+ image_path = None
210
+ role = 'user'
211
+
212
+ generation_model = HuggingFaceGenerationAdapter (model )
213
+ generation_config = GenerationConfig .from_pretrained (model_path )
214
+
215
+ input_ids , attention_mask , _ , _ = prepare_generation_inputs_hf (text_prompt , image_path , hf_llama4_processor , role )
216
+ print (f"input shape { input_ids .shape } " )
217
+ sampling_params = prepare_sampling_params (batch_size = batch_size , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ])
218
+ outputs = generation_model .generate (
219
+ input_ids ,
220
+ generation_config = generation_config ,
221
+ attention_mask = attention_mask ,
222
+ max_length = model .config .neuron_config .max_length ,
223
+ sampling_params = sampling_params ,
224
+ pixel_values = None ,
225
+ vision_mask = None ,
226
+ max_new_tokens = 100 ,
227
+ )
228
+ output_tokens = tokenizer .batch_decode (outputs , skip_special_tokens = True , clean_up_tokenization_spaces = False )
229
+ print (f"Generated outputs shape: { outputs .shape } " )
230
+ for i , output_token in enumerate (output_tokens ):
231
+ print (f"Output { i } : { output_token } " )
232
+
233
+ print ("\n Performance Benchmarking text-only!" )
234
+ benchmark_sampling (model = model , draft_model = None , generation_config = generation_config , target = "all" , image = False ,benchmark_report_path = "benchmark_report_text_only.json" , num_runs = 5 )
235
+
236
+ if __name__ == "__main__" :
237
+ if TEXT_TO_TEXT :
238
+ run_llama_generate_text_to_text ()
239
+ else :
240
+ run_llama_generate_image_to_text ()
241
+
0 commit comments