1+ import torch
2+ import os
3+ import logging
4+ import base64
5+
6+ from transformers import AutoTokenizer , AutoProcessor , GenerationConfig
7+ from neuronx_distributed_inference .models .config import OnDeviceSamplingConfig as SmplConfig
8+ from neuronx_distributed_inference .utils .hf_adapter import load_pretrained_config , HuggingFaceGenerationAdapter
9+ from neuronx_distributed_inference .modules .generation .sampling import prepare_sampling_params
10+
11+ from neuronx_distributed_inference .models .llama4 .modeling_llama4 import NeuronLlama4ForCausalLM , Llama4InferenceConfig , Llama4NeuronConfig
12+ from neuronx_distributed_inference .models .llama4 .modeling_llama4_text import NeuronLlama4TextForCausalLM , LlamaInferenceConfig
13+ from neuronx_distributed_inference .utils .benchmark import benchmark_sampling
14+ from neuronx_distributed_inference .models .llama4 .utils .input_processor import prepare_generation_inputs_hf
15+
16+ # TODO : Either read from os_environment var or from arg_parser.
17+ logger = logging .getLogger (__name__ )
18+ logger .setLevel (logging .DEBUG )
19+
20+ TEXT_TP_DEGREE = 64
21+ VISION_TP_DEGERE = 16
22+ WORLD_SIZE = 64
23+ BATCH_SIZE = 1
24+ SEQ_LENGTH = 8192
25+ # SEQ_LENGTH = 10240 for chunked attention
26+ TEXT_TO_TEXT = False
27+ # TEXT_TO_TEXT = True for text only generation
28+ DTYPE = torch .bfloat16
29+
30+ os .environ ['NEURON_PLATFORM_TARGET_OVERRIDE' ] = 'trn2'
31+ os .environ ["NEURON_RT_VIRTUAL_CORE_SIZE" ] = "2"
32+ os .environ ["NEURON_LOGICAL_NC_CONFIG" ] = "2"
33+ os .environ ['NEURON_RT_NUM_CORES' ]= f'{ TEXT_TP_DEGREE } '
34+ os .environ ['BASE_COMPILE_WORK_DIR' ] = "./compiler_path/"
35+
36+ model_path = "/home/ubuntu/models/Llama-4-Scout-17B-16E-Instruct/"
37+ traced_model_path = "/home/ubuntu/traced_model_Llama-4-Scout-17B-16E-Instruct"
38+
39+ torch .manual_seed (0 )
40+
41+ def run_llama_generate_image_to_text ():
42+ # Initialize configs and tokenizer.
43+ batch_size = 1
44+ text_neuron_config = Llama4NeuronConfig (batch_size = 1 ,
45+ seq_len = SEQ_LENGTH ,
46+ torch_dtype = torch .bfloat16 ,
47+ skip_sharding = False ,
48+ save_sharded_checkpoint = False ,
49+ tp_degree = TEXT_TP_DEGREE ,
50+ cp_degree = 1 ,
51+ on_device_sampling_config = SmplConfig (dynamic = False , top_k = 1 ),
52+ world_size = WORLD_SIZE ,
53+ capacity_factor = None ,
54+ fused_qkv = False ,
55+ attention_dtype = torch .float16 ,
56+ rpl_reduce_dtype = torch .float32 ,
57+ cast_type = "as-declared" ,
58+ logical_neuron_cores = 2 )
59+
60+ vision_neuron_config = Llama4NeuronConfig (batch_size = 1 ,
61+ seq_len = SEQ_LENGTH ,
62+ torch_dtype = torch .float16 ,
63+ skip_sharding = False ,
64+ save_sharded_checkpoint = False ,
65+ tp_degree = VISION_TP_DEGERE ,
66+ cp_degree = 1 ,
67+ on_device_sampling_config = SmplConfig (dynamic = False , top_k = 1 ),
68+ dp_degree = 4 ,
69+ world_size = WORLD_SIZE ,
70+ fused_qkv = True ,
71+ qkv_kernel_enabled = True ,
72+ attn_kernel_enabled = True ,
73+ mlp_kernel_enabled = True ,
74+ enable_bucketing = False ,
75+ logical_neuron_cores = 2 )
76+
77+ config = Llama4InferenceConfig (
78+ text_neuron_config = text_neuron_config ,
79+ vision_neuron_config = vision_neuron_config ,
80+ load_config = load_pretrained_config (model_path ),
81+ )
82+
83+ tokenizer = AutoTokenizer .from_pretrained (model_path , padding_side = "right" )
84+ tokenizer .pad_token = tokenizer .eos_token
85+
86+ hf_llama4_processor = AutoProcessor .from_pretrained (model_path )
87+ # Prepare generate outputs.
88+ text_prompt = "If I had to write a haiku for this one"
89+ image_path = "./dog.jpg"
90+ role = 'user'
91+
92+ with torch .profiler .record_function ("prepare_generation_inputs" ):
93+ input_ids , attention_mask , pixel_values , vision_mask = prepare_generation_inputs_hf (text_prompt , image_path , hf_llama4_processor , role , config )
94+
95+ if not os .path .exists (traced_model_path ):
96+ # Compile and save model.
97+ print ("\n Compiling and saving model..." )
98+ model = NeuronLlama4ForCausalLM (model_path , config )
99+ model .compile (traced_model_path )
100+ tokenizer .save_pretrained (traced_model_path )
101+
102+ # Load from compiled checkpoint.
103+
104+ print ("\n Loading model from compiled checkpoint..." )
105+ model = NeuronLlama4ForCausalLM (traced_model_path )
106+ model .load (traced_model_path , skip_warmup = True )
107+ tokenizer = AutoTokenizer .from_pretrained (traced_model_path )
108+
109+ generation_model = HuggingFaceGenerationAdapter (model )
110+
111+ generation_config = GenerationConfig .from_pretrained (model_path )
112+
113+ # Test Sampling Parameters
114+ sampling_params = prepare_sampling_params (batch_size = batch_size , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ])
115+ outputs = generation_model .generate (
116+ input_ids ,
117+ generation_config = generation_config ,
118+ attention_mask = attention_mask ,
119+ max_length = model .config .neuron_config .max_length ,
120+ sampling_params = sampling_params ,
121+ pixel_values = pixel_values ,
122+ vision_mask = vision_mask .to (torch .bool ),
123+ max_new_tokens = 512 ,
124+ )
125+ output_tokens = tokenizer .batch_decode (outputs , skip_special_tokens = True , clean_up_tokenization_spaces = False )
126+ print (f"Generated outputs shape: { outputs .shape } " )
127+ for i , output_token in enumerate (output_tokens ):
128+ print (f"Output { i } : { output_token } " )
129+
130+
131+ # Test Text-Only inputs
132+ text_prompt = "what is the recipe of mayonnaise in two sentences?"
133+ image_path = None
134+ role = 'user'
135+
136+ input_ids , attention_mask , _ , _ = prepare_generation_inputs_hf (text_prompt , image_path , hf_llama4_processor , role )
137+ sampling_params = prepare_sampling_params (batch_size = batch_size , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ])
138+ outputs = generation_model .generate (
139+ input_ids ,
140+ generation_config = generation_config ,
141+ attention_mask = attention_mask ,
142+ max_length = model .config .neuron_config .max_length ,
143+ sampling_params = sampling_params ,
144+ pixel_values = None ,
145+ vision_mask = None ,
146+ max_new_tokens = 100 ,
147+ )
148+ output_tokens = tokenizer .batch_decode (outputs , skip_special_tokens = True , clean_up_tokenization_spaces = False )
149+ print (f"Generated outputs shape: { outputs .shape } " )
150+ for i , output_token in enumerate (output_tokens ):
151+ print (f"Output { i } : { output_token } " )
152+
153+ print ("\n Performance Benchmarking text-only!" )
154+ benchmark_sampling (model = model , draft_model = None , generation_config = generation_config , target = "all" , image = None ,benchmark_report_path = "benchmark_report_text_only.json" , num_runs = 5 )
155+
156+ print ("\n Performance Benchmarking text+image!" )
157+ benchmark_sampling (model = model , draft_model = None , generation_config = generation_config , target = "all" , image = True ,benchmark_report_path = "benchmark_report_text_and_image.json" , num_runs = 5 )
158+
159+ def run_llama_generate_text_to_text ():
160+ # Initialize configs and tokenizer.
161+ batch_size = 1
162+ neuron_config = Llama4NeuronConfig (batch_size = 1 ,
163+ seq_len = SEQ_LENGTH ,
164+ torch_dtype = torch .bfloat16 ,
165+ skip_sharding = False ,
166+ save_sharded_checkpoint = True ,
167+ tp_degree = TEXT_TP_DEGREE ,
168+ cp_degree = 16 ,
169+ on_device_sampling_config = SmplConfig (dynamic = False , top_k = 1 ),
170+ world_size = WORLD_SIZE ,
171+ capacity_factor = None ,
172+ fused_qkv = False ,
173+ attention_dtype = torch .float16 ,
174+ rpl_reduce_dtype = torch .float32 ,
175+ cast_type = "as-declared" ,
176+ logical_neuron_cores = 2 )
177+
178+
179+ config = LlamaInferenceConfig (
180+ neuron_config = neuron_config ,
181+ load_config = load_pretrained_config (model_path ),
182+ )
183+
184+ tokenizer = AutoTokenizer .from_pretrained (model_path , padding_side = "right" )
185+ tokenizer .pad_token = tokenizer .eos_token
186+
187+ hf_llama4_processor = AutoProcessor .from_pretrained (model_path )
188+
189+ if not os .path .exists (traced_model_path ):
190+ # Compile and save model.
191+ print ("\n Compiling and saving model..." )
192+ model = NeuronLlama4TextForCausalLM (model_path , config .get_text_config ())
193+ model .compile (traced_model_path )
194+ tokenizer .save_pretrained (traced_model_path )
195+ # Load from compiled checkpoint.
196+
197+ print ("\n Loading model from compiled checkpoint..." )
198+ model = NeuronLlama4TextForCausalLM (traced_model_path , config .get_text_config ())
199+ model .load (traced_model_path , skip_warmup = True )
200+ tokenizer = AutoTokenizer .from_pretrained (traced_model_path )
201+
202+ # Test Text-Only inputs
203+ text_prompt = "what is the recipe of mayonnaise in two sentences?"
204+
205+ # Uncomment for a longer prompt
206+ # int_list = list(str(i) for i in range(2500))
207+ # int_str = ', '.join(int_list)
208+ # text_prompt = f"Keep counting until 3000. I will start {int_str}..."
209+ image_path = None
210+ role = 'user'
211+
212+ generation_model = HuggingFaceGenerationAdapter (model )
213+ generation_config = GenerationConfig .from_pretrained (model_path )
214+
215+ input_ids , attention_mask , _ , _ = prepare_generation_inputs_hf (text_prompt , image_path , hf_llama4_processor , role )
216+ print (f"input shape { input_ids .shape } " )
217+ sampling_params = prepare_sampling_params (batch_size = batch_size , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ])
218+ outputs = generation_model .generate (
219+ input_ids ,
220+ generation_config = generation_config ,
221+ attention_mask = attention_mask ,
222+ max_length = model .config .neuron_config .max_length ,
223+ sampling_params = sampling_params ,
224+ pixel_values = None ,
225+ vision_mask = None ,
226+ max_new_tokens = 100 ,
227+ )
228+ output_tokens = tokenizer .batch_decode (outputs , skip_special_tokens = True , clean_up_tokenization_spaces = False )
229+ print (f"Generated outputs shape: { outputs .shape } " )
230+ for i , output_token in enumerate (output_tokens ):
231+ print (f"Output { i } : { output_token } " )
232+
233+ print ("\n Performance Benchmarking text-only!" )
234+ benchmark_sampling (model = model , draft_model = None , generation_config = generation_config , target = "all" , image = False ,benchmark_report_path = "benchmark_report_text_only.json" , num_runs = 5 )
235+
236+ if __name__ == "__main__" :
237+ if TEXT_TO_TEXT :
238+ run_llama_generate_text_to_text ()
239+ else :
240+ run_llama_generate_image_to_text ()
241+
0 commit comments