1
+ # Copyright (c) Microsoft Corporation.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # DeepSpeed Team
5
+
1
6
import torch
2
7
import os
3
8
import shutil
4
9
import logging
5
- from concurrent .futures import ProcessPoolExecutor , as_completed
10
+ from concurrent .futures import ProcessPoolExecutor
11
+ from deepspeed .accelerator import get_accelerator
6
12
from tqdm import tqdm
7
13
from typing import List
8
14
14
20
'word_embeddings' ,
15
21
'embed_tokens' ,
16
22
'embedding' ,
17
- 'wte' , # GPT style embeddings
18
- 'lm_head' # Language model head, often tied with embeddings
23
+ 'wte' , # GPT style embeddings
24
+ 'lm_head' # Language model head, often tied with embeddings
19
25
]
20
26
21
27
@@ -24,20 +30,27 @@ def get_parameter_type(name: str) -> dict:
24
30
param_info = {
25
31
'cat_dim' : 0 # Default concatenation dimension
26
32
}
27
-
33
+
28
34
# Check for vocabulary tensors (embeddings, etc.)
29
35
if any (pattern in name .lower () for pattern in VOCAB_PARAMETER_PATTERNS ):
30
36
param_info ['vocab_tensor' ] = True
31
-
37
+
32
38
# TODO: figure out if we need to check for row-parallel parameters
33
39
return param_info
34
40
41
+
35
42
if __name__ == '__main__' :
36
43
import argparse
37
-
44
+
38
45
parser = argparse .ArgumentParser (description = 'Convert HuggingFace checkpoint to Universal Checkpoint format' )
39
- parser .add_argument ('--hf_checkpoint_dir' , type = str , required = True , help = 'Path to the HuggingFace checkpoint directory' )
40
- parser .add_argument ('--safe_serialization' , action = 'store_true' , default = False , help = 'Use safetensors for serialization' )
46
+ parser .add_argument ('--hf_checkpoint_dir' ,
47
+ type = str ,
48
+ required = True ,
49
+ help = 'Path to the HuggingFace checkpoint directory' )
50
+ parser .add_argument ('--safe_serialization' ,
51
+ action = 'store_true' ,
52
+ default = False ,
53
+ help = 'Use safetensors for serialization' )
41
54
parser .add_argument ('--num_workers' , type = int , default = 4 , help = 'Number of workers to use for saving checkpoints' )
42
55
parser .add_argument ('--save_dir' , type = str , required = True , help = 'Directory to save checkpoints' )
43
56
args = parser .parse_args ()
@@ -50,19 +63,19 @@ def save_parameter(name: str, param: torch.Tensor, save_dir: str):
50
63
# Create parameter directory under zero/
51
64
param_dir = os .path .join (save_dir , name )
52
65
os .makedirs (param_dir , exist_ok = True )
53
-
66
+
54
67
# Get parameter type and required fields
55
68
param_info = get_parameter_type (name )
56
-
69
+
57
70
# Save parameter in fp32 with proper dictionary structure
58
71
param_path = os .path .join (param_dir , "fp32.pt" )
59
72
param_dict = {
60
73
'param' : param .to (torch .float32 ), # Main tensor goes in 'param' field
61
74
** param_info # Include all determined parameter info
62
75
}
63
76
torch .save (param_dict , param_path )
64
-
65
- # Since HuggingFace checkpoints do not have optimizer states,
77
+
78
+ # Since HuggingFace checkpoints do not have optimizer states,
66
79
# we initialize them with zeros
67
80
for state in ("exp_avg" , "exp_avg_sq" ):
68
81
state_path = os .path .join (param_dir , f"{ state } .pt" )
@@ -77,30 +90,30 @@ def process_shard(shard_file, checkpoint_dir, save_dir, safe_serialization):
77
90
try :
78
91
shard_path = os .path .join (checkpoint_dir , shard_file )
79
92
logger .info (f"Loading shard from: { shard_path } " )
80
-
93
+
81
94
if safe_serialization :
82
95
from safetensors .torch import load_file
83
96
shard_dict = load_file (shard_path )
84
97
else :
85
98
shard_dict = torch .load (shard_path , map_location = 'cpu' )
86
-
99
+
87
100
# Create progress bar for parameters within this shard
88
- pbar = tqdm (total = len (shard_dict ),
89
- desc = f"Processing { os .path .basename (shard_file )} " ,
90
- position = 1 ,
91
- leave = False )
92
-
101
+ pbar = tqdm (total = len (shard_dict ),
102
+ desc = f"Processing { os .path .basename (shard_file )} " ,
103
+ position = 1 ,
104
+ leave = False )
105
+
93
106
for key , param in shard_dict .items ():
94
107
save_parameter (key , param , save_dir )
95
108
del param
96
109
pbar .update (1 )
97
110
pbar .set_postfix ({'key' : key [:20 ] + '...' if len (key ) > 20 else key })
98
-
111
+
99
112
pbar .close ()
100
113
del shard_dict
101
- torch . cuda .empty_cache ()
114
+ get_accelerator () .empty_cache ()
102
115
logger .info (f"Completed processing shard: { shard_file } " )
103
-
116
+
104
117
except Exception as e :
105
118
logger .error (f"Error processing shard { shard_file } : { str (e )} " )
106
119
raise
@@ -111,7 +124,7 @@ def get_shard_list(checkpoint_dir):
111
124
index_file = os .path .join (checkpoint_dir , "model.safetensors.index.json" )
112
125
else :
113
126
index_file = os .path .join (checkpoint_dir , "pytorch_model.bin.index.json" )
114
-
127
+
115
128
if os .path .exists (index_file ):
116
129
import json
117
130
with open (index_file , 'r' ) as f :
@@ -131,18 +144,11 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
131
144
with ProcessPoolExecutor (max_workers = args .num_workers ) as executor :
132
145
futures = []
133
146
for shard_file in shard_files :
134
- future = executor .submit (process_shard ,
135
- shard_file ,
136
- checkpoint_dir ,
137
- save_dir ,
138
- safe_serialization )
147
+ future = executor .submit (process_shard , shard_file , checkpoint_dir , save_dir , safe_serialization )
139
148
futures .append ((shard_file , future ))
140
-
149
+
141
150
# Create progress bar for this batch
142
- batch_pbar = tqdm (total = len (futures ),
143
- desc = f"Processing shard batch" ,
144
- position = 0 ,
145
- leave = True )
151
+ batch_pbar = tqdm (total = len (futures ), desc = f"Processing shard batch" , position = 0 , leave = True )
146
152
147
153
# Wait for all futures to complete
148
154
for shard_file , future in futures :
@@ -153,7 +159,7 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
153
159
except Exception as e :
154
160
logger .error (f"Failed processing shard { shard_file } : { str (e )} " )
155
161
raise
156
-
162
+
157
163
batch_pbar .close ()
158
164
159
165
try :
@@ -162,42 +168,45 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
162
168
if os .path .exists (temp_zero_dir ):
163
169
logger .info (f"Removing existing temp directory: { temp_zero_dir } " )
164
170
shutil .rmtree (temp_zero_dir )
165
-
171
+
166
172
shard_files = get_shard_list (args .hf_checkpoint_dir )
167
173
total_shards = len (shard_files )
168
174
logger .info (f"Found { total_shards } shards to process" )
169
175
# Process shards in batches equal to the number of workers
170
176
batch_size = args .num_workers
171
177
for i in range (0 , total_shards , batch_size ):
172
178
batch_shards = shard_files [i :i + batch_size ]
173
- logger .info (f"Processing batch of { len (batch_shards )} shards ({ i + 1 } -{ min (i + batch_size , total_shards )} of { total_shards } )" )
174
- process_shard_batch (batch_shards ,
175
- args .hf_checkpoint_dir ,
176
- temp_zero_dir , # Changed from temp_save_dir to temp_zero_dir
177
- args .safe_serialization )
178
-
179
+ logger .info (
180
+ f"Processing batch of { len (batch_shards )} shards ({ i + 1 } -{ min (i + batch_size , total_shards )} of { total_shards } )"
181
+ )
182
+ process_shard_batch (
183
+ batch_shards ,
184
+ args .hf_checkpoint_dir ,
185
+ temp_zero_dir , # Changed from temp_save_dir to temp_zero_dir
186
+ args .safe_serialization )
187
+
179
188
# Clear CUDA cache after each batch to free up memory
180
- torch . cuda .empty_cache ()
181
-
189
+ get_accelerator () .empty_cache ()
190
+
182
191
logger .info ("All shard batches processed successfully" )
183
-
192
+
184
193
final_save_dir = os .path .join (args .save_dir , 'zero' )
185
194
if os .path .exists (final_save_dir ):
186
195
shutil .rmtree (final_save_dir )
187
-
196
+
188
197
# Create the parent directory if it doesn't exist
189
198
os .makedirs (os .path .dirname (final_save_dir ), exist_ok = True )
190
199
# Move the zero directory to its final location
191
200
os .rename (temp_zero_dir , final_save_dir )
192
-
201
+
193
202
# Clean up the temporary directory
194
203
if os .path .exists (temp_save_dir ):
195
204
shutil .rmtree (temp_save_dir )
196
-
205
+
197
206
# Write identifier file
198
207
with open (os .path .join (args .save_dir , 'source.txt' ), 'w' ) as f :
199
208
f .write ("Huggingface checkpoint" )
200
-
209
+
201
210
logger .info (f"Successfully saved checkpoint to { final_save_dir } " )
202
211
203
212
# Update latest file
@@ -206,7 +215,7 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
206
215
latest_file = os .path .join (checkpoint_root_folder , 'latest_universal' )
207
216
with open (latest_file , 'w' ) as f :
208
217
f .write (step_folder )
209
-
218
+
210
219
logger .info (f"Checkpoint conversion completed successfully. Latest file updated at { latest_file } " )
211
220
212
221
except Exception as e :
0 commit comments