11import torch
22import os
33import logging
4- import base64
54
65from transformers import AutoTokenizer , AutoProcessor , GenerationConfig
76from neuronx_distributed_inference .models .config import OnDeviceSamplingConfig as SmplConfig
2120VISION_TP_DEGERE = 16
2221WORLD_SIZE = 64
2322BATCH_SIZE = 1
24- SEQ_LENGTH = 8192
25- # SEQ_LENGTH = 10240 for chunked attention
23+ SEQ_LENGTH = 16384
2624TEXT_TO_TEXT = False
2725# TEXT_TO_TEXT = True for text only generation
2826DTYPE = torch .bfloat16
3331os .environ ['NEURON_RT_NUM_CORES' ]= f'{ TEXT_TP_DEGREE } '
3432os .environ ['BASE_COMPILE_WORK_DIR' ] = "./compiler_path/"
3533
36- model_path = "/home/ubuntu/models/Llama-4-Scout-17B-16E-Instruct/"
37- traced_model_path = "/home/ubuntu/traced_model_Llama-4-Scout-17B-16E-Instruct"
34+ # Llama4 checkpoints can be downloaded from HuggingFace
35+ model_path = "/shared/models/Llama-4-Scout-17B-16E-Instruct/"
36+ # Path to the compiled model artifacts. If this directory exists, the next run will skip
37+ # the trace and compile steps, reducing test time.
38+ traced_model_path = "/shared/traced_models/Llama-4/scout_text_vision_baseline_bs1/"
3839
3940torch .manual_seed (0 )
4041
4142def run_llama_generate_image_to_text ():
42- # Initialize configs and tokenizer.
43- batch_size = 1
44- text_neuron_config = Llama4NeuronConfig (batch_size = 1 ,
45- seq_len = SEQ_LENGTH ,
46- torch_dtype = torch .bfloat16 ,
47- skip_sharding = False ,
48- save_sharded_checkpoint = False ,
49- tp_degree = TEXT_TP_DEGREE ,
50- cp_degree = 1 ,
51- on_device_sampling_config = SmplConfig (dynamic = False , top_k = 1 ),
52- world_size = WORLD_SIZE ,
53- capacity_factor = None ,
54- fused_qkv = False ,
55- attention_dtype = torch .float16 ,
56- rpl_reduce_dtype = torch .float32 ,
57- cast_type = "as-declared" ,
58- logical_neuron_cores = 2 )
59-
60- vision_neuron_config = Llama4NeuronConfig (batch_size = 1 ,
61- seq_len = SEQ_LENGTH ,
62- torch_dtype = torch .float16 ,
63- skip_sharding = False ,
64- save_sharded_checkpoint = False ,
65- tp_degree = VISION_TP_DEGERE ,
66- cp_degree = 1 ,
67- on_device_sampling_config = SmplConfig (dynamic = False , top_k = 1 ),
68- dp_degree = 4 ,
69- world_size = WORLD_SIZE ,
70- fused_qkv = True ,
71- qkv_kernel_enabled = True ,
72- attn_kernel_enabled = True ,
73- mlp_kernel_enabled = True ,
74- enable_bucketing = False ,
75- logical_neuron_cores = 2 )
43+ text_neuron_config = Llama4NeuronConfig (
44+ batch_size = 1 ,
45+ is_continuous_batching = True ,
46+ seq_len = SEQ_LENGTH ,
47+ enable_bucketing = True ,
48+ context_encoding_buckets = [256 , 512 , 1024 , 2048 , 4096 , 8192 , 10240 , 16384 ],
49+ token_generation_buckets = [256 , 512 , 1024 , 2048 , 4096 , 8192 , 10240 , 16384 ],
50+ torch_dtype = torch .float16 ,
51+ async_mode = True ,
52+ rpl_reduce_dtype = torch .float32 ,
53+ tp_degree = TEXT_TP_DEGREE ,
54+ cp_degree = 16 ,
55+ on_device_sampling_config = SmplConfig (dynamic = True , top_k = 1 , top_k_kernel_enabled = True ),
56+ world_size = WORLD_SIZE ,
57+ fused_qkv = True ,
58+ cast_type = "as-declared" ,
59+ save_sharded_checkpoint = True ,
60+ cc_pipeline_tiling_factor = 1 ,
61+ sequence_parallel_enabled = True ,
62+ qkv_kernel_enabled = True ,
63+ attn_kernel_enabled = True ,
64+ attn_block_tkg_nki_kernel_enabled = True ,
65+ attn_block_tkg_nki_kernel_cache_update = True ,
66+ k_cache_transposed = False ,
67+ blockwise_matmul_config = {
68+ "block_size" : 256 ,
69+ "use_block_parallel" : True ,
70+ "block_sharding_strategy" : "HI_LO" ,
71+ "skip_dma_token" : True ,
72+ "skip_dma_weight" : True ,
73+ "parallelize_token_to_block_mapping" : True
74+ },
75+ logical_neuron_cores = 2 )
76+
77+ vision_neuron_config = Llama4NeuronConfig (
78+ batch_size = 1 ,
79+ seq_len = SEQ_LENGTH ,
80+ torch_dtype = torch .float16 ,
81+ tp_degree = VISION_TP_DEGERE ,
82+ cp_degree = 1 ,
83+ dp_degree = 4 ,
84+ world_size = WORLD_SIZE ,
85+ fused_qkv = True ,
86+ qkv_kernel_enabled = True ,
87+ attn_kernel_enabled = True ,
88+ mlp_kernel_enabled = True ,
89+ enable_bucketing = True ,
90+ buckets = [8 , 28 , 88 ],
91+ save_sharded_checkpoint = True ,
92+ logical_neuron_cores = 2 )
7693
7794 config = Llama4InferenceConfig (
7895 text_neuron_config = text_neuron_config ,
@@ -85,10 +102,10 @@ def run_llama_generate_image_to_text():
85102
86103 hf_llama4_processor = AutoProcessor .from_pretrained (model_path )
87104 # Prepare generate outputs.
88- text_prompt = "If I had to write a haiku for this one "
105+ text_prompt = "Describe this image "
89106 image_path = "./dog.jpg"
90107 role = 'user'
91-
108+
92109 with torch .profiler .record_function ("prepare_generation_inputs" ):
93110 input_ids , attention_mask , pixel_values , vision_mask = prepare_generation_inputs_hf (text_prompt , image_path , hf_llama4_processor , role , config )
94111
@@ -97,7 +114,7 @@ def run_llama_generate_image_to_text():
97114 print ("\n Compiling and saving model..." )
98115 model = NeuronLlama4ForCausalLM (model_path , config )
99116 model .compile (traced_model_path )
100- tokenizer .save_pretrained (traced_model_path )
117+ tokenizer .save_pretrained (traced_model_path )
101118
102119 # Load from compiled checkpoint.
103120
@@ -111,7 +128,7 @@ def run_llama_generate_image_to_text():
111128 generation_config = GenerationConfig .from_pretrained (model_path )
112129
113130 # Test Sampling Parameters
114- sampling_params = prepare_sampling_params (batch_size = batch_size , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ])
131+ sampling_params = prepare_sampling_params (batch_size = 1 , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ])
115132 outputs = generation_model .generate (
116133 input_ids ,
117134 generation_config = generation_config ,
@@ -134,7 +151,7 @@ def run_llama_generate_image_to_text():
134151 role = 'user'
135152
136153 input_ids , attention_mask , _ , _ = prepare_generation_inputs_hf (text_prompt , image_path , hf_llama4_processor , role )
137- sampling_params = prepare_sampling_params (batch_size = batch_size , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ])
154+ sampling_params = prepare_sampling_params (batch_size = 1 , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ])
138155 outputs = generation_model .generate (
139156 input_ids ,
140157 generation_config = generation_config ,
@@ -159,22 +176,20 @@ def run_llama_generate_image_to_text():
159176def run_llama_generate_text_to_text ():
160177 # Initialize configs and tokenizer.
161178 batch_size = 1
162- neuron_config = Llama4NeuronConfig (batch_size = 1 ,
163- seq_len = SEQ_LENGTH ,
164- torch_dtype = torch .bfloat16 ,
165- skip_sharding = False ,
166- save_sharded_checkpoint = True ,
167- tp_degree = TEXT_TP_DEGREE ,
168- cp_degree = 16 ,
169- on_device_sampling_config = SmplConfig (dynamic = False , top_k = 1 ),
170- world_size = WORLD_SIZE ,
171- capacity_factor = None ,
172- fused_qkv = False ,
173- attention_dtype = torch .float16 ,
174- rpl_reduce_dtype = torch .float32 ,
175- cast_type = "as-declared" ,
176- logical_neuron_cores = 2 )
177-
179+ neuron_config = Llama4NeuronConfig (
180+ batch_size = 1 ,
181+ is_continuous_batching = True ,
182+ seq_len = SEQ_LENGTH ,
183+ torch_dtype = torch .float16 ,
184+ rpl_reduce_dtype = torch .float32 ,
185+ tp_degree = TEXT_TP_DEGREE ,
186+ cp_degree = 1 ,
187+ on_device_sampling_config = SmplConfig (dynamic = True , top_k = 1 ),
188+ world_size = WORLD_SIZE ,
189+ fused_qkv = False ,
190+ cast_type = "as-declared" ,
191+ save_sharded_checkpoint = True ,
192+ logical_neuron_cores = 2 )
178193
179194 config = LlamaInferenceConfig (
180195 neuron_config = neuron_config ,
@@ -191,7 +206,7 @@ def run_llama_generate_text_to_text():
191206 print ("\n Compiling and saving model..." )
192207 model = NeuronLlama4TextForCausalLM (model_path , config .get_text_config ())
193208 model .compile (traced_model_path )
194- tokenizer .save_pretrained (traced_model_path )
209+ tokenizer .save_pretrained (traced_model_path )
195210 # Load from compiled checkpoint.
196211
197212 print ("\n Loading model from compiled checkpoint..." )
0 commit comments