15
15
Usage:
16
16
17
17
python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V2-Lite \
18
- --output-path=golden_DeepSeek-V2-Lite.jsonl --prompts='I love to, Today is a, What is the' \
18
+ --output-path=golden_DeepSeek-V2-Lite.jsonl --prompts='I love to; Today is a; What is the' \
19
19
--gcs-bucket=my-gcs-bucket
20
20
21
+ For large models, you can use an m1 cpu. Calling the script directly instead of calling MaxText module \
22
+ can skip importing unnecessary dependencies.
23
+ For large Hugginface checkpoints, you can use pre-downloaded checkpoints with --hf-model-path argument.
24
+ For multimodal models, use --image-paths argument to provide image path(s),\
25
+ use --apply-chat-template=true if use HF chat template to format image+prompt.\
26
+ When using chat template, the prompt should not contain image placeholders.
27
+
28
+ More examples:
29
+ python3 MaxText/scratch_code/generate_hf_golden_logits.py --model-id=meta-llama/Llama-4-Scout-17B-16E \
30
+ --output-path=golden_Llama-4-Scout-17B-16E_vision.jsonl --prompts='Describe this image.' \
31
+ --apply-chat-template=true --gcs-bucket=<bucket> --hf-model-path=<hf_checkpoint_path> \
32
+ --image-paths=MaxText/test_assets/test_image.jpg
33
+
34
+ python3 MaxText/scratch_code/generate_hf_golden_logits.py --model-id=google/gemma-3-4b-it \
35
+ --output-path=golden_gemma-3-4b-it_vision.jsonl --prompts='<start_of_image>' \
36
+ --apply-chat-template=false --gcs-bucket=<bucket> --hf-model-path=<hf_checkpoint_path> \
37
+ --image-paths=MaxText/test_assets/test_image.jpg
21
38
"""
22
39
23
40
import torch
24
41
import argparse
25
- from transformers import AutoTokenizer , AutoModelForCausalLM
42
+ from transformers import AutoTokenizer , AutoProcessor , AutoModelForCausalLM
26
43
import jsonlines
27
44
from google .cloud import storage
45
+ from PIL import Image
28
46
29
47
# Load the tokenizer and model from Hugging Face
30
48
@@ -37,32 +55,74 @@ def upload_blob(bucket_name, source_file_name, destination_blob_name):
37
55
blob .upload_from_filename (source_file_name )
38
56
39
57
40
- def save_golden_logits (model_id , output_path , prompt_texts , gcs_bucket ):
58
+ def save_golden_logits (model_id , output_path , prompt_texts , apply_chat_template , gcs_bucket , hf_model_path , image_paths ):
41
59
"""save golden logits"""
42
- tokenizer = AutoTokenizer .from_pretrained (model_id )
60
+ if hf_model_path is None :
61
+ hf_model_path = model_id
62
+ tokenizer = AutoTokenizer .from_pretrained (hf_model_path )
43
63
model = AutoModelForCausalLM .from_pretrained (
44
- model_id ,
64
+ hf_model_path ,
45
65
torch_dtype = torch .float32 ,
46
66
trust_remote_code = True ,
47
67
)
48
68
49
69
all_data_to_save = []
50
- for prompt_text in prompt_texts :
70
+ for i , prompt_text in enumerate ( prompt_texts ) :
51
71
# Encode the prompt text
52
- input_ids = tokenizer .encode (prompt_text , return_tensors = "pt" )
72
+ if image_paths :
73
+ try :
74
+ image = Image .open (image_paths [i ])
75
+ except Exception as e :
76
+ raise e
77
+ image = image .convert ("RGB" )
78
+ # TODO (aireenmei): remove this when Llama-4 supports dynamic image shapes.
79
+ if model_id .startswith ("meta-llama/Llama-4" ):
80
+ image = image .resize ((336 , 336 ))
81
+ processor = AutoProcessor .from_pretrained (model_id , token = True )
82
+ if apply_chat_template :
83
+ messages = [
84
+ {
85
+ "role" : "user" ,
86
+ "content" : [
87
+ {"type" : "image" },
88
+ {"type" : "text" , "text" : prompt_text },
89
+ ],
90
+ },
91
+ ]
92
+ formatted_prompt = processor .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
93
+ inputs = processor (text = formatted_prompt , images = image , return_tensors = "pt" )
94
+ else :
95
+ formatted_prompt = prompt_text
96
+ inputs = processor (text = formatted_prompt , images = image , return_tensors = "pt" , add_special_tokens = False )
97
+ with torch .no_grad ():
98
+ outputs = model (** inputs )
99
+ logits = outputs .logits .cpu ().numpy ().astype ("float32" )
53
100
54
- # Get the logits for the prompt + completion
55
- with torch .no_grad ():
56
- outputs = model (input_ids )
57
- logits = outputs .logits .cpu ().numpy ().astype ("float32" )
101
+ data_to_save = {
102
+ "prompt" : prompt_text ,
103
+ "formatted_prompt" : formatted_prompt ,
104
+ "tokens" : inputs ["input_ids" ].tolist ()[0 ],
105
+ "attention_mask" : inputs ["attention_mask" ].tolist ()[0 ],
106
+ "image_path" : image_paths [i ],
107
+ "pixel_values" : inputs ["pixel_values" ].tolist ()[0 ],
108
+ "logits" : logits .tolist ()[0 ],
109
+ }
110
+ else :
111
+ input_ids = tokenizer .encode (prompt_text , return_tensors = "pt" )
112
+ # Get the logits for the prompt + completion
113
+ with torch .no_grad ():
114
+ outputs = model (input_ids )
115
+ logits = outputs .logits .cpu ().numpy ().astype ("float32" )
58
116
59
117
# Prepare data to be saved
60
118
data_to_save = {
61
119
"prompt" : prompt_text ,
62
120
"tokens" : input_ids .tolist ()[0 ],
63
121
"logits" : logits .tolist ()[0 ], # Convert numpy array to list for JSON serialization
64
122
}
65
- all_data_to_save .append (data_to_save )
123
+ print (f"Token length is { len (data_to_save ['tokens' ])} for prompt: { prompt_text } " )
124
+ print (f"raw ids: { data_to_save ['tokens' ]} " )
125
+ all_data_to_save .append (data_to_save )
66
126
67
127
with jsonlines .open (output_path , "w" ) as f :
68
128
f .write_all (all_data_to_save )
@@ -77,13 +137,33 @@ def main(raw_args=None) -> None:
77
137
parser = argparse .ArgumentParser ()
78
138
parser .add_argument ("--model-id" , type = str , required = True , help = "The identifier of the model to use." )
79
139
parser .add_argument ("--output-path" , type = str , required = True , help = "The path to save the generated golden logits." )
80
- parser .add_argument ("--prompts" , type = str , required = True , help = "A comma-separated list of prompts." )
140
+ parser .add_argument ("--prompts" , type = str , required = True , help = "A semicolon-separated list of prompts." )
141
+ parser .add_argument (
142
+ "--apply-chat-template" ,
143
+ type = bool ,
144
+ required = False ,
145
+ default = False ,
146
+ help = "Whether to apply chat template from the HF processor. Used for image+text input." ,
147
+ )
81
148
parser .add_argument (
82
149
"--gcs-bucket" , type = str , required = False , default = None , help = "A GCS bucket to store logits, without gs://."
83
150
)
151
+ parser .add_argument ("--hf-model-path" , type = str , required = False , default = None , help = "local path to checkpoint if exists." )
152
+ parser .add_argument (
153
+ "--image-paths" , type = str , required = False , default = None , help = "A semicolon-separated list of image_paths."
154
+ )
84
155
args = parser .parse_args (raw_args )
85
- prompts = args .prompts .split ("," )
86
- save_golden_logits (args .model_id , args .output_path , prompts , args .gcs_bucket )
156
+ prompts = args .prompts .split (";" )
157
+ image_paths = args .image_paths .split (";" ) if args .image_paths else []
158
+ if image_paths :
159
+ assert len (image_paths ) == len (
160
+ prompts
161
+ ), "when image paths are provided, image_paths and prompts must have the same length."
162
+ if args .apply_chat_template :
163
+ assert image_paths , "apply_chat_template is only used for image+text input, so image_paths must be provided."
164
+ save_golden_logits (
165
+ args .model_id , args .output_path , prompts , args .apply_chat_template , args .gcs_bucket , args .hf_model_path , image_paths
166
+ )
87
167
88
168
89
169
if __name__ == "__main__" :
0 commit comments