1+ from transformers import ViTModel , AutoImageProcessor
2+ from PIL import Image
3+ import time
4+ import torch
5+ import os
6+ import numpy as np
7+ import logging
8+
9+ import torch_xla
10+
11+ from neuronx_distributed_inference .utils .hf_adapter import load_pretrained_config
12+ from neuronx_distributed_inference .models .config import NeuronConfig
13+ from neuronx_distributed_inference .utils .accuracy import check_accuracy_embeddings
14+ from neuronx_distributed_inference .utils .benchmark import LatencyCollector
15+ from neuronx_distributed_inference .models .vit .modeling_vit import NeuronViTForImageEncoding , ViTInferenceConfig
16+
17+
18+ NUM_BENCHMARK_ITER = 10
19+ MODEL_PATH = "/home/ubuntu/model_hf/google--vit-huge-patch14-224-in21k/"
20+ TRACED_MODEL_PATH = "/home/ubuntu/model_hf/google--vit-huge-patch14-224-in21k/traced_model/"
21+
22+ logger = logging .getLogger (__name__ )
23+ logger .setLevel (logging .INFO )
24+
25+ def setup_debug_env ():
26+ os .environ ["XLA_FALLBACK_CPU" ] = "0"
27+ os .environ ["XLA_IR_DEBUG" ] = "1"
28+ os .environ ["XLA_HLO_DEBUG" ] = "1"
29+ os .environ ["NEURON_FUSE_SOFTMAX" ] = "1"
30+ torch_xla ._XLAC ._set_ir_debug (True )
31+ torch .manual_seed (0 )
32+
33+
34+ def run_vit_encoding (validate_accuracy = True ):
35+ # Define configs
36+ neuron_config = NeuronConfig (
37+ tp_degree = 32 ,
38+ torch_dtype = torch .float32 ,
39+ )
40+ inference_config = ViTInferenceConfig (
41+ neuron_config = neuron_config ,
42+ load_config = load_pretrained_config (MODEL_PATH ),
43+ use_mask_token = False ,
44+ add_pooling_layer = False ,
45+ interpolate_pos_encoding = False
46+ )
47+
48+ # input image
49+ image_file = "dog.jpg" # [512, 512]
50+ with open (image_file , "rb" ) as f :
51+ image = Image .open (f ).convert ("RGB" )
52+ print (f"Input image size { image .size } " )
53+ # preprocess input image
54+ image_processor = AutoImageProcessor .from_pretrained (MODEL_PATH )
55+ pixel_values = image_processor (image , return_tensors = "pt" )["pixel_values" ]
56+
57+ # Get neuron model
58+ neuron_model = NeuronViTForImageEncoding (model_path = MODEL_PATH , config = inference_config )
59+
60+ # Compile model on Neuron
61+ compile_start_time = time .time ()
62+ neuron_model .compile (TRACED_MODEL_PATH )
63+ compile_elapsed_time = time .time () - compile_start_time
64+ print (f"Compilation time taken { compile_elapsed_time } s" )
65+
66+ # Load model on Neuron
67+ neuron_model .load (TRACED_MODEL_PATH )
68+ print ("Done loading neuron model" )
69+
70+ # Run NxDI implementation on Neuron
71+ neuron_latency_collector = LatencyCollector ()
72+ for i in range (NUM_BENCHMARK_ITER ):
73+ neuron_latency_collector .pre_hook ()
74+ neuron_output = neuron_model (pixel_values )[0 ] # NeuronViTModel output (sequence_output,) or (sequence_output, pooled_output)
75+ neuron_latency_collector .hook ()
76+ print (f"Got neuron output { neuron_output .shape } { neuron_output } " )
77+ # Benchmark report
78+ for p in [25 , 50 , 90 , 99 ]:
79+ latency = np .percentile (neuron_latency_collector .latency_list , p ) * 1000
80+ print (f"Neuron inference latency_ms_p{ p } : { latency } " )
81+
82+ # The below section is optional, use if you want to validate e2e accuracy against golden
83+ if validate_accuracy :
84+ # Get CPU model
85+ cpu_model = ViTModel .from_pretrained (MODEL_PATH )
86+ print (f"cpu model { cpu_model } " )
87+
88+ # Get golden output by running original implementation on CPU
89+ cpu_latency_collector = LatencyCollector ()
90+ for i in range (NUM_BENCHMARK_ITER ):
91+ cpu_latency_collector .pre_hook ()
92+ golden_output = cpu_model (pixel_values ).last_hidden_state
93+ cpu_latency_collector .hook ()
94+ print (f"expected_output { golden_output .shape } { golden_output } " )
95+ # Benchmark report
96+ for p in [25 , 50 , 90 , 99 ]:
97+ latency = np .percentile (cpu_latency_collector .latency_list , p ) * 1000
98+ print (f"CPU inference latency_ms_p{ p } : { latency } " )
99+
100+ # Compare output logits
101+ passed , max_err = check_accuracy_embeddings (neuron_output , golden_output , plot_outputs = True , atol = 1e-5 , rtol = 1e-5 )
102+ print (f"Golden and Neuron outputs match: { passed } , max relative error: { max_err } " )
103+
104+
105+
106+ if __name__ == "__main__" :
107+ # Set flags for debugging
108+ setup_debug_env ()
109+
110+ run_vit_encoding (validate_accuracy = True )
0 commit comments