Skip to content
9 changes: 4 additions & 5 deletions torchtitan/experiments/qwen3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,21 @@ QWEN3 0.6B Dense model is available for:

- FSDP/HSDP, TP, DDP, AC, compile support

Other model sizes are added to the args, but toml file configs need to be added and tested. Further testing is needed to check the coistency of the parallelism implementations.
Other model sizes are added to the args, but toml file configs need to be added and tested.

#### Download Qwen3 tokenizer

```python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --asset tokenizer```

```python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --assets tokenizer```

#### Parity with HF

Model parity test has been done and results suggest parity with HF implementation. Further investigation is needed to check the sanity of the Rope function.
Model parity test has been done and results suggest parity with HF implementation.

#### To be added
- Modeling
- Variants of Dense models up to 32B
- MoE alternatives
- Weight tying

- Testing
- The model should be tested against established performance benchmarks
- CI integration
3 changes: 2 additions & 1 deletion torchtitan/experiments/qwen3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .infra.parallelize import parallelize_qwen3
from .model.args import Qwen3ModelArgs
from .model.model import Qwen3Model
from .model.state_dict_adapter import Qwen3StateDictAdapter

__all__ = [
"parallelize_qwen3",
Expand All @@ -25,7 +26,6 @@
"qwen3_configs",
]


# Adding different variants of the model

qwen3_configs = {
Expand Down Expand Up @@ -120,5 +120,6 @@
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
build_validator_fn=build_validator,
state_dict_adapter=Qwen3StateDictAdapter,
)
)
86 changes: 86 additions & 0 deletions torchtitan/experiments/qwen3/model/state_dict_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
This script is adapted from torchtitan/models/llama3/model/state_dict_adapter.py.

We can use this script to adapt the checkpoint from HF to the format that we can load into the torchtitan model and vice versa.
This can enable us to do a parity test with the HF implementation and make sure that our results are
aligned with the HF implementation.

"""
import re
from typing import Any

from torchtitan.protocols.state_dict_adapter import StateDictAdapter

from .args import Qwen3ModelArgs


class Qwen3StateDictAdapter(StateDictAdapter):
def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None):
super().__init__(model_args, hf_assets_path)

self.model_args = model_args
self.hf_assets_path = hf_assets_path

self.from_hf_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight",
"model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}

def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:

to_hf_map = {v: k for k, v in self.from_hf_map.items()}
hf_state_dict = {}

for key, value in state_dict.items():
if "layers" in key:
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
layer_num = re.search(r"\d+", key).group(0)
new_key = to_hf_map[abstract_key]

if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = to_hf_map[key]

hf_state_dict[new_key] = value

return hf_state_dict

def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:

state_dict = {}

for key, value in hf_state_dict.items():
if "layers" in key:
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
layer_num = re.search(r"\d+", key).group(0)
new_key = self.from_hf_map[abstract_key]

if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = self.from_hf_map[key]

state_dict[new_key] = value
return state_dict