14
14
'word_embeddings' ,
15
15
'embed_tokens' ,
16
16
'embedding' ,
17
- 'wte' , # GPT style embeddings
18
- 'lm_head' # Often tied with embeddings
17
+ 'wte' , # GPT style embeddings
18
+ 'lm_head' # Language model head, often tied with embeddings
19
19
]
20
20
21
21
@@ -35,8 +35,8 @@ def get_parameter_type(name: str) -> dict:
35
35
if __name__ == '__main__' :
36
36
import argparse
37
37
38
- parser = argparse .ArgumentParser (description = 'Load a HuggingFace model ' )
39
- parser .add_argument ('--hf_checkpoint_dir' , type = str , help = 'Path to the HuggingFace checkpoint directory' )
38
+ parser = argparse .ArgumentParser (description = 'Convert HuggingFace checkpoint to Universal Checkpoint format ' )
39
+ parser .add_argument ('--hf_checkpoint_dir' , type = str , required = True , help = 'Path to the HuggingFace checkpoint directory' )
40
40
parser .add_argument ('--safe_serialization' , action = 'store_true' , default = False , help = 'Use safetensors for serialization' )
41
41
parser .add_argument ('--num_workers' , type = int , default = 4 , help = 'Number of workers to use for saving checkpoints' )
42
42
parser .add_argument ('--save_dir' , type = str , required = True , help = 'Directory to save checkpoints' )
@@ -119,10 +119,12 @@ def get_shard_list(checkpoint_dir):
119
119
return list (set (index ['weight_map' ].values ()))
120
120
else :
121
121
# Handle single file case
122
- if args .safe_serialization :
122
+ if args .safe_serialization and os . path . exists ( os . path . join ( checkpoint_dir , "model.safetensors" )) :
123
123
return ["model.safetensors" ]
124
- else :
124
+ elif os . path . exists ( os . path . join ( checkpoint_dir , "pytorch_model.bin" )) :
125
125
return ["pytorch_model.bin" ]
126
+ else :
127
+ raise FileNotFoundError (f"No checkpoint files found in { checkpoint_dir } " )
126
128
127
129
def process_shard_batch (shard_files : List [str ], checkpoint_dir : str , save_dir : str , safe_serialization : bool ):
128
130
"""Process a batch of shards in parallel."""
0 commit comments