Skip to content

Commit 06a4a9f

Browse files
committed
formatting and license
1 parent f34c6df commit 06a4a9f

File tree

3 files changed

+60
-51
lines changed

3 files changed

+60
-51
lines changed

deepspeed/checkpoint/hf_to_universal.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
16
import torch
27
import os
38
import shutil
49
import logging
5-
from concurrent.futures import ProcessPoolExecutor, as_completed
10+
from concurrent.futures import ProcessPoolExecutor
11+
from deepspeed.accelerator import get_accelerator
612
from tqdm import tqdm
713
from typing import List
814

@@ -14,8 +20,8 @@
1420
'word_embeddings',
1521
'embed_tokens',
1622
'embedding',
17-
'wte', # GPT style embeddings
18-
'lm_head' # Language model head, often tied with embeddings
23+
'wte', # GPT style embeddings
24+
'lm_head' # Language model head, often tied with embeddings
1925
]
2026

2127

@@ -24,20 +30,27 @@ def get_parameter_type(name: str) -> dict:
2430
param_info = {
2531
'cat_dim': 0 # Default concatenation dimension
2632
}
27-
33+
2834
# Check for vocabulary tensors (embeddings, etc.)
2935
if any(pattern in name.lower() for pattern in VOCAB_PARAMETER_PATTERNS):
3036
param_info['vocab_tensor'] = True
31-
37+
3238
# TODO: figure out if we need to check for row-parallel parameters
3339
return param_info
3440

41+
3542
if __name__ == '__main__':
3643
import argparse
37-
44+
3845
parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint to Universal Checkpoint format')
39-
parser.add_argument('--hf_checkpoint_dir', type=str, required=True, help='Path to the HuggingFace checkpoint directory')
40-
parser.add_argument('--safe_serialization', action='store_true', default=False, help='Use safetensors for serialization')
46+
parser.add_argument('--hf_checkpoint_dir',
47+
type=str,
48+
required=True,
49+
help='Path to the HuggingFace checkpoint directory')
50+
parser.add_argument('--safe_serialization',
51+
action='store_true',
52+
default=False,
53+
help='Use safetensors for serialization')
4154
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for saving checkpoints')
4255
parser.add_argument('--save_dir', type=str, required=True, help='Directory to save checkpoints')
4356
args = parser.parse_args()
@@ -50,19 +63,19 @@ def save_parameter(name: str, param: torch.Tensor, save_dir: str):
5063
# Create parameter directory under zero/
5164
param_dir = os.path.join(save_dir, name)
5265
os.makedirs(param_dir, exist_ok=True)
53-
66+
5467
# Get parameter type and required fields
5568
param_info = get_parameter_type(name)
56-
69+
5770
# Save parameter in fp32 with proper dictionary structure
5871
param_path = os.path.join(param_dir, "fp32.pt")
5972
param_dict = {
6073
'param': param.to(torch.float32), # Main tensor goes in 'param' field
6174
**param_info # Include all determined parameter info
6275
}
6376
torch.save(param_dict, param_path)
64-
65-
# Since HuggingFace checkpoints do not have optimizer states,
77+
78+
# Since HuggingFace checkpoints do not have optimizer states,
6679
# we initialize them with zeros
6780
for state in ("exp_avg", "exp_avg_sq"):
6881
state_path = os.path.join(param_dir, f"{state}.pt")
@@ -77,30 +90,30 @@ def process_shard(shard_file, checkpoint_dir, save_dir, safe_serialization):
7790
try:
7891
shard_path = os.path.join(checkpoint_dir, shard_file)
7992
logger.info(f"Loading shard from: {shard_path}")
80-
93+
8194
if safe_serialization:
8295
from safetensors.torch import load_file
8396
shard_dict = load_file(shard_path)
8497
else:
8598
shard_dict = torch.load(shard_path, map_location='cpu')
86-
99+
87100
# Create progress bar for parameters within this shard
88-
pbar = tqdm(total=len(shard_dict),
89-
desc=f"Processing {os.path.basename(shard_file)}",
90-
position=1,
91-
leave=False)
92-
101+
pbar = tqdm(total=len(shard_dict),
102+
desc=f"Processing {os.path.basename(shard_file)}",
103+
position=1,
104+
leave=False)
105+
93106
for key, param in shard_dict.items():
94107
save_parameter(key, param, save_dir)
95108
del param
96109
pbar.update(1)
97110
pbar.set_postfix({'key': key[:20] + '...' if len(key) > 20 else key})
98-
111+
99112
pbar.close()
100113
del shard_dict
101-
torch.cuda.empty_cache()
114+
get_accelerator().empty_cache()
102115
logger.info(f"Completed processing shard: {shard_file}")
103-
116+
104117
except Exception as e:
105118
logger.error(f"Error processing shard {shard_file}: {str(e)}")
106119
raise
@@ -111,7 +124,7 @@ def get_shard_list(checkpoint_dir):
111124
index_file = os.path.join(checkpoint_dir, "model.safetensors.index.json")
112125
else:
113126
index_file = os.path.join(checkpoint_dir, "pytorch_model.bin.index.json")
114-
127+
115128
if os.path.exists(index_file):
116129
import json
117130
with open(index_file, 'r') as f:
@@ -131,18 +144,11 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
131144
with ProcessPoolExecutor(max_workers=args.num_workers) as executor:
132145
futures = []
133146
for shard_file in shard_files:
134-
future = executor.submit(process_shard,
135-
shard_file,
136-
checkpoint_dir,
137-
save_dir,
138-
safe_serialization)
147+
future = executor.submit(process_shard, shard_file, checkpoint_dir, save_dir, safe_serialization)
139148
futures.append((shard_file, future))
140-
149+
141150
# Create progress bar for this batch
142-
batch_pbar = tqdm(total=len(futures),
143-
desc=f"Processing shard batch",
144-
position=0,
145-
leave=True)
151+
batch_pbar = tqdm(total=len(futures), desc=f"Processing shard batch", position=0, leave=True)
146152

147153
# Wait for all futures to complete
148154
for shard_file, future in futures:
@@ -153,7 +159,7 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
153159
except Exception as e:
154160
logger.error(f"Failed processing shard {shard_file}: {str(e)}")
155161
raise
156-
162+
157163
batch_pbar.close()
158164

159165
try:
@@ -162,42 +168,45 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
162168
if os.path.exists(temp_zero_dir):
163169
logger.info(f"Removing existing temp directory: {temp_zero_dir}")
164170
shutil.rmtree(temp_zero_dir)
165-
171+
166172
shard_files = get_shard_list(args.hf_checkpoint_dir)
167173
total_shards = len(shard_files)
168174
logger.info(f"Found {total_shards} shards to process")
169175
# Process shards in batches equal to the number of workers
170176
batch_size = args.num_workers
171177
for i in range(0, total_shards, batch_size):
172178
batch_shards = shard_files[i:i + batch_size]
173-
logger.info(f"Processing batch of {len(batch_shards)} shards ({i+1}-{min(i+batch_size, total_shards)} of {total_shards})")
174-
process_shard_batch(batch_shards,
175-
args.hf_checkpoint_dir,
176-
temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir
177-
args.safe_serialization)
178-
179+
logger.info(
180+
f"Processing batch of {len(batch_shards)} shards ({i+1}-{min(i+batch_size, total_shards)} of {total_shards})"
181+
)
182+
process_shard_batch(
183+
batch_shards,
184+
args.hf_checkpoint_dir,
185+
temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir
186+
args.safe_serialization)
187+
179188
# Clear CUDA cache after each batch to free up memory
180-
torch.cuda.empty_cache()
181-
189+
get_accelerator().empty_cache()
190+
182191
logger.info("All shard batches processed successfully")
183-
192+
184193
final_save_dir = os.path.join(args.save_dir, 'zero')
185194
if os.path.exists(final_save_dir):
186195
shutil.rmtree(final_save_dir)
187-
196+
188197
# Create the parent directory if it doesn't exist
189198
os.makedirs(os.path.dirname(final_save_dir), exist_ok=True)
190199
# Move the zero directory to its final location
191200
os.rename(temp_zero_dir, final_save_dir)
192-
201+
193202
# Clean up the temporary directory
194203
if os.path.exists(temp_save_dir):
195204
shutil.rmtree(temp_save_dir)
196-
205+
197206
# Write identifier file
198207
with open(os.path.join(args.save_dir, 'source.txt'), 'w') as f:
199208
f.write("Huggingface checkpoint")
200-
209+
201210
logger.info(f"Successfully saved checkpoint to {final_save_dir}")
202211

203212
# Update latest file
@@ -206,7 +215,7 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
206215
latest_file = os.path.join(checkpoint_root_folder, 'latest_universal')
207216
with open(latest_file, 'w') as f:
208217
f.write(step_folder)
209-
218+
210219
logger.info(f"Checkpoint conversion completed successfully. Latest file updated at {latest_file}")
211220

212221
except Exception as e:

deepspeed/runtime/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3035,7 +3035,7 @@ def _load_checkpoint(self,
30353035
if self.load_universal_checkpoint() and len(ckpt_list) == 0:
30363036
logger.warning(f"Unable to find {ckpt_file_pattern} files in UCP folder {load_dir}")
30373037
return None, {}
3038-
3038+
30393039
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine)
30403040

30413041
is_pipe_parallel = isinstance(self.module, PipelineModule)

deepspeed/runtime/zero/stage3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2812,7 +2812,7 @@ def _load_global_state_stage3(self, sd):
28122812

28132813
def load_hp_checkpoint_state(self, folder, key):
28142814
local_rank = dist.get_local_rank()
2815-
2815+
28162816
# Load tensors from files and reshape them to flat vectors
28172817
loaded_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False)
28182818
if isinstance(loaded_state, dict):

0 commit comments

Comments
 (0)