Skip to content

Commit 4b78d85

Browse files
Merge pull request #2084 from AI-Hypercomputer:yixuannwang-doc
PiperOrigin-RevId: 793836391
2 parents 59b8792 + 5a9a063 commit 4b78d85

File tree

11 files changed

+655
-23
lines changed

11 files changed

+655
-23
lines changed
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Checkpoint conversion utilities
2+
3+
This guide provides instructions for using the scripts that convert model checkpoints bidirectionally between Hugging Face and MaxText formats.
4+
5+
## Supported models
6+
7+
The following models are supported:
8+
9+
- Gemma2 (2B, 9B, 27B).
10+
- Gemma3 multimodal (4B, 12B, 27B).
11+
- Qwen3 (0.6B, 4B, 8B, 14B, 32B).
12+
13+
## Prerequisites
14+
- Hugging Face requires Pytorch.
15+
- Hugging Face model checkpoints require local disk space.
16+
- The model files are always downloaded to a disk cache first before being loaded into memory (for more info, please consult Hugging Face [docs](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference)). The default local storage path for Hugging Face models is $HOME/.cache/huggingface/hub
17+
18+
## Hugging Face to MaxText
19+
20+
Use the `to_maxtext.py` script to convert a Hugging Face model into a MaxText checkpoint. The script will automatically download the specified model from the Hugging Face Hub, perform conversion, and save converted checkpoints to given output directory.
21+
22+
\*\**For a complete example, see the test script at [`end_to_end/tpu/qwen3/4b/test_qwen3.sh`](../../../end_to_end/tpu/qwen3/4b/test_qwen3.sh) and [`end_to_end/tpu/gemma3/4b/test_gemma3_unified.sh`](../../../end_to_end/tpu/gemma3/4b/test_gemma3_unified.sh).*
23+
24+
### Usage
25+
26+
The following command demonstrates how to run the conversion. You must provide your Hugging Face token in the `MaxText/configs/base.yml` file (hf_access_token).
27+
28+
```bash
29+
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
30+
model_name=<model-name> \
31+
base_output_directory=<gcs-path-to-save-checkpoint> \
32+
hf_access_token=<your-hf-token> \
33+
use_multimodal=false \
34+
scan_layers=false
35+
```
36+
37+
**Key arguments:**
38+
39+
* `model_name`: The model identifier, which should be defined in `MaxText/utils/utils.py`.
40+
* `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false).
41+
* `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
42+
* `hf_access_token`: Your Hugging Face token.
43+
* `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`.
44+
45+
\*\**It only converts the official version of Hugging Face model. You can refer the supported official version in HF_IDS in `MaxText/utils/ckpt_conversion/utils/utils.py`*
46+
47+
## MaxText to Hugging Face
48+
49+
Use the `to_huggingface.py` script to convert a MaxText checkpoint into the Hugging Face format. This is useful for sharing your models or integrating them with the Hugging Face ecosystem.
50+
\*\**For a complete example, see the test script at [`end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh`](../../../end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh).*
51+
52+
### Usage
53+
54+
The following command converts a MaxText checkpoint and saves it locally, to GCS, or uploads it directly to the Hugging Face Hub.
55+
56+
```bash
57+
python -m MaxText.utils.ckpt_conversion.to_huggingface MaxText/configs/base.yml \
58+
model_name=<MODEL_NAME> \
59+
load_parameters_path=<path-to-maxtext-checkpoint> \
60+
base_output_directory=<path-to-save-converted-checkpoint> \
61+
scan_layers=false \
62+
use_multimodal=false \
63+
hf_access_token=<your-hf-token> \
64+
```
65+
66+
**Key arguments:**
67+
68+
* `load_parameters_path`: The path to the source MaxText Orbax checkpoint (e.g., `gs://your-bucket/maxtext-checkpoint/0/items`).
69+
* `model_name`: The corresponding model name in the MaxText configuration (e.g., `qwen3-4b`).
70+
* `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false).
71+
* `hf_access_token`: Your Hugging Face token.
72+
* `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
73+
* `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS), Hugging Face Hub or local. If not set, the default output directory is `Maxtext/tmp`.
74+
75+
76+
## Verifying conversion correctness
77+
78+
To ensure the conversion was successful, you can use the `MaxText/tests/forward_pass_logit_checker.py` script. It runs a forward pass on both the original and converted models and compares the output logits to verify conversion. It is used to verify the bidirectional conversion.
79+
80+
### Usage
81+
82+
```bash
83+
python3 -m MaxText.tests.forward_pass_logit_checker MaxText/configs/base.yml \
84+
tokenizer_path=assets/<tokenizer> \
85+
load_parameters_path=<path-to-maxtext-checkpoint> \
86+
model_name=<MODEL_NAME> \
87+
scan_layers=false \
88+
use_multimodal=false \
89+
--run_hf_model=True \
90+
--hf_model_path=<path-to-HF-checkpoint> \
91+
--max_kl_div=0.015 \
92+
```
93+
94+
**Key arguments:**
95+
96+
* `load_parameters_path`: The path to the source MaxText Orbax checkpoint (e.g., `gs://your-bucket/maxtext-checkpoint/0/items`).
97+
* `model_name`: The corresponding model name in the MaxText configuration (e.g., `qwen3-4b`).
98+
* `scan_layers`: Indicates if the output checkpoint is scanned (scan_layers=true) or unscanned (scan_layers=false).
99+
* `use_multimodal`: Indicates if multimodality is used.
100+
* `--run_hf_model`: Indicates if loading Hugging Face model from the hf_model_path. If not set, it will compare the maxtext logits with pre-saved golden logits.
101+
* `--hf_model_path`: The path to the Hugging Face checkpoint.
102+
* `--max_kl_div`: Max KL divergence tolerance during comparisons.
103+
104+
**Example successful conversion verification:**
105+
106+
Here is part of the output of forward_pass_logit_checker for the gemma2-2b.
107+
108+
```
109+
--- Prompt: What is the ---
110+
111+
--- MaxText model top 10 tokens ---
112+
| Token ID | Token | Score |
113+
|------------|----------------------|------------|
114+
| 5830 | difference | 27.2500 |
115+
| 1963 | best | 26.6250 |
116+
| 5316 | average | 26.3750 |
117+
| 2669 | change | 26.1250 |
118+
| 12070 | percentage | 26.1250 |
119+
| 1618 | value | 25.8750 |
120+
| 1546 | most | 25.7500 |
121+
| 66202 | molar | 25.5000 |
122+
| 3051 | total | 25.5000 |
123+
| 1503 | name | 25.3750 |
124+
125+
126+
--- HF model top 10 tokens ---
127+
| Token ID | Token | Score |
128+
|------------|----------------------|------------|
129+
| 5830 | difference | 27.2500 |
130+
| 1963 | best | 26.6250 |
131+
| 5316 | average | 26.3750 |
132+
| 12070 | percentage | 26.1250 |
133+
| 2669 | change | 26.1250 |
134+
| 1618 | value | 25.8750 |
135+
| 1546 | most | 25.7500 |
136+
| 66202 | molar | 25.5000 |
137+
| 3051 | total | 25.5000 |
138+
| 6187 | purpose | 25.3750 |
139+
140+
141+
--- Similarity Metrics of Top Tokens ---
142+
| Metric | Value |
143+
|--------------------------------|----------------------|
144+
| overlap_count | 9/10 |
145+
| jaccard_similarity | 0.8181818181818182 |
146+
| rank_agreement_percentage | 70.0 |
147+
148+
149+
Average KL divergence per token (D_KL(P_golden || Q_model)): 0.000409
150+
151+
Max KL divergence for a single token in the set: 0.003497
152+
```
153+
-----
154+
155+
## Adding support for new models
156+
To extend conversion support to a new model architecture, you must define its specific parameter and configuration mappings. The conversion logic is decoupled, so you only need to modify the mapping files.
157+
158+
1. **Add parameter mappings**:
159+
- In [`utils/param_mapping.py`](./utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer.
160+
- In [`utils/param_mapping.py`](./utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.
161+
2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](./utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
162+
3. **Register model key**: In [`utils/utils.py`](./utils/utils.py), add the new model key in `HF_IDS`.
163+
4. **Add transformer config**: In [`utils/hf_model_configs.py`](./utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in ['MaxText/configs/models'](../configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture.
164+
165+
Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983)
166+
167+
## Debugging tips
168+
169+
If a converted checkpoint loads without errors but produces incorrect output, consider these common issues:
170+
171+
* **Symptom**: The model generates garbage or nonsensical tokens.
172+
173+
* **Potential Cause**: The query/key/value (Q/K/V) or Out vectors weights were likely reshaped incorrectly during conversion.
174+
175+
* **Symptom**: The model generates repetitive text sequences.
176+
177+
* **Potential Cause**: The layer normalization parameters may have been converted incorrectly.

MaxText/utils/ckpt_conversion/to_huggingface.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
HOOK_FNS,
6868
PARAM_MAPPING,
6969
)
70-
from MaxText.utils.ckpt_conversion.utils.shape_mapping import SHAPE_MAPPING
70+
from MaxText.utils.ckpt_conversion.utils.hf_shape import HF_SHAPE
7171
from MaxText.utils.ckpt_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
7272
from MaxText.utils.ckpt_conversion.utils.utils import (process_leaf_param, save_model_files, HF_IDS)
7373

@@ -90,12 +90,12 @@ def _get_model_mappings(model_name: str, scan_layers: bool, config_dict: dict):
9090
Raises:
9191
ValueError: If mappings for the specified `model_name` are not found.
9292
"""
93-
if model_name not in PARAM_MAPPING or model_name not in SHAPE_MAPPING or model_name not in HOOK_FNS:
93+
if model_name not in PARAM_MAPPING or model_name not in HF_SHAPE or model_name not in HOOK_FNS:
9494
raise ValueError(f"Mappings not found for model: {model_name}. Available PARAM_MAPPING keys: {PARAM_MAPPING.keys()}")
9595

9696
return {
9797
"param_mapping": PARAM_MAPPING[model_name](config_dict, scan_layers),
98-
"shape_mapping": SHAPE_MAPPING[model_name](config_dict),
98+
"shape_mapping": HF_SHAPE[model_name](config_dict),
9999
"hook_fn_mapping": HOOK_FNS[model_name](config_dict, scan_layers, saving_to_hf=True),
100100
}
101101

@@ -140,11 +140,12 @@ def main(argv: Sequence[str]) -> None:
140140
# 2. Load Tokenizer
141141
if model_key not in HF_IDS:
142142
raise ValueError(f"HF Tokenizer ID not found for model key: {model_key}")
143+
hf_token = config.hf_access_token
143144
hf_tokenizer_id = HF_IDS[model_key]
144-
tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_id)
145+
tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_id, token=hf_token)
145146

146147
# For multi-modal case:
147-
processor = AutoProcessor.from_pretrained(hf_tokenizer_id) if config.use_multimodal else None
148+
processor = AutoProcessor.from_pretrained(hf_tokenizer_id, token=hf_token) if config.use_multimodal else None
148149

149150
# 3. Get parameter mappings
150151
mappings = _get_model_mappings(model_key, config.scan_layers, hf_config_obj.to_dict())

MaxText/utils/ckpt_conversion/utils/shape_mapping.py renamed to MaxText/utils/ckpt_conversion/utils/hf_shape.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717

18-
def GEMMA3_HF_WEIGHTS_TO_SHAPE_MAPPING(config):
18+
def GEMMA3_HF_WEIGHTS_TO_SHAPE(config):
1919
"""Generates a shape mapping for Hugging Face Gemma3 parameters.
2020
2121
This function computes the expected shapes for all parameters in a Hugging
@@ -153,7 +153,7 @@ def GEMMA3_HF_WEIGHTS_TO_SHAPE_MAPPING(config):
153153
return shapes
154154

155155

156-
def GEMMA2_HF_WEIGHTS_TO_SHAPE_MAPPING(config):
156+
def GEMMA2_HF_WEIGHTS_TO_SHAPE(config):
157157
"""Returns mapping between HuggingFace weights path and weights shape.
158158
159159
Args:
@@ -208,7 +208,7 @@ def GEMMA2_HF_WEIGHTS_TO_SHAPE_MAPPING(config):
208208
return mapping
209209

210210

211-
def QWEN3_HF_WEIGHTS_TO_SHAPE_MAPPING(config):
211+
def QWEN3_HF_WEIGHTS_TO_SHAPE(config):
212212
"""Returns mapping between HuggingFace Qwen3 weights path and the HuggingFace weights shape.
213213
214214
To check this mapping, dump the huggingface model shapes:
@@ -308,16 +308,16 @@ def QWEN3_HF_WEIGHTS_TO_SHAPE_MAPPING(config):
308308
return mapping
309309

310310

311-
SHAPE_MAPPING = {
312-
"gemma2-2b": GEMMA2_HF_WEIGHTS_TO_SHAPE_MAPPING,
313-
"gemma2-9b": GEMMA2_HF_WEIGHTS_TO_SHAPE_MAPPING,
314-
"gemma2-27b": GEMMA2_HF_WEIGHTS_TO_SHAPE_MAPPING,
315-
"gemma3-4b": GEMMA3_HF_WEIGHTS_TO_SHAPE_MAPPING,
316-
"gemma3-12b": GEMMA3_HF_WEIGHTS_TO_SHAPE_MAPPING,
317-
"gemma3-27b": GEMMA3_HF_WEIGHTS_TO_SHAPE_MAPPING,
318-
"qwen3-0.6b": QWEN3_HF_WEIGHTS_TO_SHAPE_MAPPING,
319-
"qwen3-4b": QWEN3_HF_WEIGHTS_TO_SHAPE_MAPPING,
320-
"qwen3-8b": QWEN3_HF_WEIGHTS_TO_SHAPE_MAPPING,
321-
"qwen3-14b": QWEN3_HF_WEIGHTS_TO_SHAPE_MAPPING,
322-
"qwen3-32b": QWEN3_HF_WEIGHTS_TO_SHAPE_MAPPING,
311+
HF_SHAPE = {
312+
"gemma2-2b": GEMMA2_HF_WEIGHTS_TO_SHAPE,
313+
"gemma2-9b": GEMMA2_HF_WEIGHTS_TO_SHAPE,
314+
"gemma2-27b": GEMMA2_HF_WEIGHTS_TO_SHAPE,
315+
"gemma3-4b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
316+
"gemma3-12b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
317+
"gemma3-27b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
318+
"qwen3-0.6b": QWEN3_HF_WEIGHTS_TO_SHAPE,
319+
"qwen3-4b": QWEN3_HF_WEIGHTS_TO_SHAPE,
320+
"qwen3-8b": QWEN3_HF_WEIGHTS_TO_SHAPE,
321+
"qwen3-14b": QWEN3_HF_WEIGHTS_TO_SHAPE,
322+
"qwen3-32b": QWEN3_HF_WEIGHTS_TO_SHAPE,
323323
}

MaxText/utils/ckpt_conversion/utils/param_mapping.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,15 @@ def pos_embed(x, target_shape):
280280
# Vision layers
281281
vc = config.get("vision_config", {})
282282
nvis = vc.get("num_hidden_layers", 0)
283-
for i in list(range(nvis)):
284-
base = f"params-vision_encoder-Gemma3VisionEncoderLayer_0-" f"Transformer-encoderblock_{i}-"
283+
vision_layer_ids = [None] if scan_layers else list(range(nvis))
284+
for i in vision_layer_ids:
285+
base = (
286+
f"params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock_{i}-"
287+
if i is not None
288+
else "params-vision_encoder-Gemma3VisionEncoderLayer_0-Transformer-encoderblock-"
289+
)
285290
# Attention kernels & biases
286291
for qkv in ["query", "key", "value"]:
287-
# key is [1152, 1152]-> [1152, 16, 72]
288292
hooks[base + f"MultiHeadDotProductAttention_0-{qkv}-kernel"] = reshape_kernel
289293
hooks[base + f"MultiHeadDotProductAttention_0-{qkv}-bias"] = vis_bias
290294
# [1152, 1152] -> [16, 72, 1152]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/bin/bash
2+
3+
# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Qwen3-4B.
4+
5+
# The flow of this file is as follows:
6+
# 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText
7+
# 2. Run a forward pass check to compare the logits and KL divergence between the converted ckpt and orginal golden HF model
8+
9+
10+
set -ex
11+
idx=$(date +%Y-%m-%d-%H-%M)
12+
MODEL_NAME='gemma2-2b'
13+
export MODEL_VARIATION='2b'
14+
HF_TOKEN='' # Important!!! Save your hf access token here
15+
TOKENIZER_PATH='assets/tokenizer.gemma'
16+
17+
# Installing torch for deps in forward_pass_logit_checker.py
18+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
19+
20+
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
21+
# Non-Googlers please remember to use separate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
22+
# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing.
23+
export MODEL_BUCKET=gs://maxtext-gemma/unified/gemma2
24+
# Here is an example of gemma2-2b maxtext checkpoint, converted from google/gemma-2-2b
25+
export CKPT_PATH=gs://maxtext-gemma/unified/gemma2/2b/unscanned/2025-08-05-18-06/0/items
26+
27+
# You can upload to huggingface hub or GCS using the HF_CKPT_PATH as base_output_directory
28+
# export HF_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/hf/${idx}
29+
export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx}
30+
31+
python3 -m MaxText.utils.ckpt_conversion.to_huggingface MaxText/configs/base.yml \
32+
model_name=${MODEL_NAME} \
33+
hf_access_token=${HF_TOKEN} \
34+
load_parameters_path=${CKPT_PATH} \
35+
base_output_directory=${LOCAL_PATH} \
36+
scan_layers=false
37+
38+
# Alternatively, if uploaded the converted ckpt, HF requires local storage of model
39+
# mkdir -p "${LOCAL_PATH}"
40+
# gcloud storage cp -r ${HF_CKPT_PATH} ${LOCAL_PATH}
41+
42+
# We also test whether the forward pass logits match the original HF model
43+
# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
44+
python3 -m MaxText.tests.forward_pass_logit_checker MaxText/configs/base.yml \
45+
tokenizer_path=${TOKENIZER_PATH} \
46+
load_parameters_path=${CKPT_PATH} \
47+
model_name=${MODEL_NAME} \
48+
scan_layers=false \
49+
--hf_model_path=${LOCAL_PATH} \
50+
--max_kl_div=0.015 \
51+
--run_hf_model=true

0 commit comments

Comments
 (0)