diff --git a/torchtitan/experiments/qwen3/README.md b/torchtitan/experiments/qwen3/README.md index 77b23d55c..d6e759181 100644 --- a/torchtitan/experiments/qwen3/README.md +++ b/torchtitan/experiments/qwen3/README.md @@ -6,22 +6,21 @@ QWEN3 0.6B Dense model is available for: - FSDP/HSDP, TP, DDP, AC, compile support -Other model sizes are added to the args, but toml file configs need to be added and tested. Further testing is needed to check the coistency of the parallelism implementations. +Other model sizes are added to the args, but toml file configs need to be added and tested. #### Download Qwen3 tokenizer -```python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --asset tokenizer``` - +```python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --assets tokenizer``` #### Parity with HF -Model parity test has been done and results suggest parity with HF implementation. Further investigation is needed to check the sanity of the Rope function. +Model parity test has been done and results suggest parity with HF implementation. #### To be added - Modeling - Variants of Dense models up to 32B - MoE alternatives - - Weight tying + - Testing - The model should be tested against established performance benchmarks - CI integration diff --git a/torchtitan/experiments/qwen3/__init__.py b/torchtitan/experiments/qwen3/__init__.py index 9ea4582aa..b5aa870d4 100644 --- a/torchtitan/experiments/qwen3/__init__.py +++ b/torchtitan/experiments/qwen3/__init__.py @@ -17,6 +17,7 @@ from .infra.parallelize import parallelize_qwen3 from .model.args import Qwen3ModelArgs from .model.model import Qwen3Model +from .model.state_dict_adapter import Qwen3StateDictAdapter __all__ = [ "parallelize_qwen3", @@ -25,7 +26,6 @@ "qwen3_configs", ] - # Adding different variants of the model qwen3_configs = { @@ -120,5 +120,6 @@ build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, build_validator_fn=build_validator, + state_dict_adapter=Qwen3StateDictAdapter, ) ) diff --git a/torchtitan/experiments/qwen3/model/state_dict_adapter.py b/torchtitan/experiments/qwen3/model/state_dict_adapter.py new file mode 100644 index 000000000..760cc662b --- /dev/null +++ b/torchtitan/experiments/qwen3/model/state_dict_adapter.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script is adapted from torchtitan/models/llama3/model/state_dict_adapter.py. + +We can use this script to adapt the checkpoint from HF to the format that we can load into the torchtitan model and vice versa. +This can enable us to do a parity test with the HF implementation and make sure that our results are +aligned with the HF implementation. + +""" +import re +from typing import Any + +from torchtitan.protocols.state_dict_adapter import StateDictAdapter + +from .args import Qwen3ModelArgs + + +class Qwen3StateDictAdapter(StateDictAdapter): + def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None): + super().__init__(model_args, hf_assets_path) + + self.model_args = model_args + self.hf_assets_path = hf_assets_path + + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + hf_state_dict = {} + + for key, value in state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = to_hf_map[abstract_key] + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = to_hf_map[key] + + hf_state_dict[new_key] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + + state_dict = {} + + for key, value in hf_state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = self.from_hf_map[abstract_key] + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = self.from_hf_map[key] + + state_dict[new_key] = value + return state_dict