@@ -119,7 +119,7 @@ def export_hf_checkpoint(tag: str, export_path: Path):
119119 torch .cuda .empty_cache ()
120120
121121
122- def export_te_checkpoint (te_checkpoint_path : str , output_path : str ):
122+ def export_te_checkpoint (te_checkpoint_path : str , output_path : Path ):
123123 """Export a Transformer Engine checkpoint back to the original HuggingFace Facebook ESM-2 format.
124124
125125 This function converts from the NVIDIA Transformer Engine (TE) format back to the
@@ -129,7 +129,7 @@ def export_te_checkpoint(te_checkpoint_path: str, output_path: str):
129129
130130 Args:
131131 te_checkpoint_path (str): Path to the TE checkpoint
132- output_path (str ): Output path for the converted Facebook ESM-2 format model
132+ output_path (Path ): Output path for the converted Facebook ESM-2 format model
133133 """
134134 if not Path (te_checkpoint_path ).exists ():
135135 raise FileNotFoundError (f"TE checkpoint { te_checkpoint_path } not found" )
@@ -143,15 +143,15 @@ def export_te_checkpoint(te_checkpoint_path: str, output_path: str):
143143
144144 tokenizer_config_path = Path (te_checkpoint_path ) / "tokenizer_config.json"
145145 if tokenizer_config_path .exists ():
146- shutil .copy (tokenizer_config_path , Path ( output_path ) / "tokenizer_config.json" )
146+ shutil .copy (tokenizer_config_path , output_path / "tokenizer_config.json" )
147147
148148 vocab_path = Path (te_checkpoint_path ) / "vocab.txt"
149149 if vocab_path .exists ():
150- shutil .copy (vocab_path , Path ( output_path ) / "vocab.txt" )
150+ shutil .copy (vocab_path , output_path / "vocab.txt" )
151151
152152 special_tokens_path = Path (te_checkpoint_path ) / "special_tokens_map.json"
153153 if special_tokens_path .exists ():
154- shutil .copy (special_tokens_path , Path ( output_path ) / "special_tokens_map.json" )
154+ shutil .copy (special_tokens_path , output_path / "special_tokens_map.json" )
155155
156156 model_hf = AutoModelForMaskedLM .from_pretrained (
157157 output_path ,
0 commit comments