Skip to content

Commit 2930f2a

Browse files
xylian86Schwidola0607
authored andcommitted
nits
Signed-off-by: Schwidola0607 <[email protected]>
1 parent 7bef517 commit 2930f2a

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

deepspeed/checkpoint/hf_to_universal.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
'word_embeddings',
1515
'embed_tokens',
1616
'embedding',
17-
'wte', # GPT style embeddings
18-
'lm_head' # Often tied with embeddings
17+
'wte', # GPT style embeddings
18+
'lm_head' # Language model head, often tied with embeddings
1919
]
2020

2121

@@ -35,8 +35,8 @@ def get_parameter_type(name: str) -> dict:
3535
if __name__ == '__main__':
3636
import argparse
3737

38-
parser = argparse.ArgumentParser(description='Load a HuggingFace model')
39-
parser.add_argument('--hf_checkpoint_dir', type=str, help='Path to the HuggingFace checkpoint directory')
38+
parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint to Universal Checkpoint format')
39+
parser.add_argument('--hf_checkpoint_dir', type=str, required=True, help='Path to the HuggingFace checkpoint directory')
4040
parser.add_argument('--safe_serialization', action='store_true', default=False, help='Use safetensors for serialization')
4141
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for saving checkpoints')
4242
parser.add_argument('--save_dir', type=str, required=True, help='Directory to save checkpoints')
@@ -119,10 +119,12 @@ def get_shard_list(checkpoint_dir):
119119
return list(set(index['weight_map'].values()))
120120
else:
121121
# Handle single file case
122-
if args.safe_serialization:
122+
if args.safe_serialization and os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")):
123123
return ["model.safetensors"]
124-
else:
124+
elif os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")):
125125
return ["pytorch_model.bin"]
126+
else:
127+
raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}")
126128

127129
def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: str, safe_serialization: bool):
128130
"""Process a batch of shards in parallel."""

0 commit comments

Comments
 (0)