Skip to content

Commit 61b0070

Browse files
authored
Support pytorch checkpoints for Qwen3/Phi4 mini (#13984)
1 parent 2845fd3 commit 61b0070

File tree

3 files changed

+70
-22
lines changed

3 files changed

+70
-22
lines changed

examples/models/checkpoint.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
# pyre-unsafe
99

10+
import json
11+
import os
1012
from pathlib import Path
1113
from typing import Any, Dict, Optional
1214

@@ -74,3 +76,39 @@ def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[torch.dtype]:
7476
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
7577
)
7678
return dtype
79+
80+
81+
def load_checkpoint_from_pytorch_model(input_dir: str) -> Dict:
82+
index_path = os.path.join(input_dir, "pytorch_model.bin.index.json")
83+
if os.path.exists(index_path):
84+
# Sharded checkpoint.
85+
with open(index_path, "r") as f:
86+
index = json.load(f)
87+
weight_map = index["weight_map"]
88+
checkpoint_shards = sorted(set(weight_map.values()))
89+
90+
# Load all the shards into memory
91+
shard_to_weights = {}
92+
for shard in checkpoint_shards:
93+
shard_to_weights[shard] = torch.load(
94+
os.path.join(input_dir, shard),
95+
weights_only=True,
96+
map_location=torch.device("cpu"),
97+
)
98+
99+
# Merge tensors into consolidated state dict.
100+
merged_state_dict = {}
101+
for weight_name, shard in weight_map.items():
102+
tensor = shard_to_weights[shard][weight_name]
103+
merged_state_dict[weight_name] = tensor
104+
return merged_state_dict
105+
106+
# Single checkpoint
107+
model_path = os.path.join(input_dir, "pytorch_model.bin")
108+
if os.path.exists(model_path):
109+
state_dict = torch.load(
110+
model_path, weights_only=True, map_location=torch.device("cpu")
111+
)
112+
return state_dict
113+
114+
raise FileNotFoundError(f"Could not find pytorch_model checkpoint in {input_dir}")

examples/models/phi_4_mini/convert_weights.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import argparse
2-
import os
32
from typing import Dict
43

54
import torch
5+
from executorch.examples.models.checkpoint import load_checkpoint_from_pytorch_model
66

77
from torchtune.models.convert_weights import get_mapped_key
88

@@ -87,10 +87,8 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
8787
Convert a state dict from torchtune's format to Meta's format. This function
8888
doesn't handle any sharding or splitting of state dicts. It follows the
8989
state_dict IN -> state_dict OUT pattern.
90-
9190
Args:
9291
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
93-
9492
Returns:
9593
Dict[str, torch.Tensor]: State dict in Meta's format.
9694
"""
@@ -105,14 +103,15 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
105103
converted_state_dict["output.weight"] = converted_state_dict[
106104
"tok_embeddings.weight"
107105
]
108-
109106
return converted_state_dict
110107

111108

112109
def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None:
113-
# If input_dir_or_checkpoint is a directory downloaded from HF, FullModelHFCheckpointer is used to extract the state dict
114-
# If input_dir_or_checkpoint is a checkpoint (from eager model model), it is loaded directly
115-
if os.path.isdir(input_dir_or_checkpoint):
110+
try:
111+
sd = load_checkpoint_from_pytorch_model(input_dir_or_checkpoint)
112+
print("Converting checkpoint...")
113+
sd = phi_4_hf_to_meta(sd)
114+
except FileNotFoundError:
116115
checkpointer = FullModelHFCheckpointer(
117116
checkpoint_dir=input_dir_or_checkpoint,
118117
checkpoint_files=[
@@ -127,11 +126,6 @@ def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None:
127126
sd = sd["model"]
128127
print("Converting checkpoint...")
129128
sd = phi_4_tune_to_meta(sd)
130-
else:
131-
print("Loading checkpoint from file...")
132-
sd = torch.load(input_dir_or_checkpoint, map_location="cpu", weights_only=True)
133-
print("Converting checkpoint...")
134-
sd = phi_4_hf_to_meta(sd)
135129

136130
print("Saving checkpoint...")
137131
torch.save(sd, output_file)

examples/models/qwen3/convert_weights.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict
66

77
import torch
8+
from executorch.examples.models.checkpoint import load_checkpoint_from_pytorch_model
89
from safetensors.torch import load_file
910

1011
from torchtune.models.convert_weights import get_mapped_key
@@ -80,19 +81,34 @@ def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
8081
tensor = shard_to_weights[shard][weight_name]
8182
merged_state_dict[weight_name] = tensor
8283
return merged_state_dict
83-
else:
84-
# Single checkpoint.
85-
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
86-
return state_dict
84+
85+
# Single checkpoint.
86+
model_path = os.path.join(input_dir, "model.safetensors")
87+
if os.path.exists(model_path):
88+
return load_file(os.path.join(input_dir, "model.safetensors"))
89+
90+
raise FileNotFoundError(f"Could not find safetensors checkpoint in {input_dir}")
8791

8892

8993
def load_checkpoint(input_dir: str) -> Dict:
90-
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
91-
if os.path.exists(pytorch_path):
92-
print("Loading checkpoint from PyTorch .bin file")
93-
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
94-
print("Loading checkpoint from safetensors directory")
95-
return load_checkpoint_from_safetensors(input_dir)
94+
try:
95+
print("Loading checkpoint from pytorch_model directory")
96+
state_dict = load_checkpoint_from_pytorch_model(input_dir)
97+
return state_dict
98+
except FileNotFoundError:
99+
print(
100+
"Could not find pytorch_model checkpoints in directory, trying safetensors"
101+
)
102+
pass
103+
104+
try:
105+
print("Loading checkpoint from safetensors directory")
106+
state_dict = load_checkpoint_from_safetensors(input_dir)
107+
return state_dict
108+
except FileNotFoundError:
109+
pass
110+
111+
raise FileNotFoundError(f"Could not find checkpoint in {input_dir}")
96112

97113

98114
def convert_weights(input_dir: str, output_file: str) -> None:

0 commit comments

Comments
 (0)