From a09953bd6054c2ccf5d346ba3b010583619e802a Mon Sep 17 00:00:00 2001 From: Hossein Kavianihamedani Date: Tue, 19 Aug 2025 11:05:14 -0700 Subject: [PATCH 1/6] Fix config file path in run_train.sh --- torchtitan/experiments/qwen3/README.md | 9 +- .../qwen3/model/state_dict_adapter.py | 82 +++++++++++++++++++ 2 files changed, 87 insertions(+), 4 deletions(-) create mode 100644 torchtitan/experiments/qwen3/model/state_dict_adapter.py diff --git a/torchtitan/experiments/qwen3/README.md b/torchtitan/experiments/qwen3/README.md index dce71ed11..c145e6245 100644 --- a/torchtitan/experiments/qwen3/README.md +++ b/torchtitan/experiments/qwen3/README.md @@ -6,22 +6,23 @@ QWEN3 0.6B Dense model is available for: - FSDP/HSDP, TP, DDP, AC, compile support -Other model sizes are added to the args, but toml file configs need to be added and tested. Further testing is needed to check the coistency of the parallelism implementations. +Other model sizes are added to the args, but toml file configs need to be added and tested. #### Download Qwen3 tokenizer -```python scripts/download_tokenizer.py --repo_id Qwen/Qwen3-0.6B``` +```python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --assets tokenizer``` #### Parity with HF -Model parity test has been done and results suggest parity with HF implementation. Further investigation is needed to check the sanity of the Rope function. +Model parity test has been done and results suggest parity with HF implementation. + #### To be added - Modeling - Variants of Dense models up to 32B - MoE alternatives - - Weight tying + - Testing - The model should be tested against established performance benchmarks - CI integration diff --git a/torchtitan/experiments/qwen3/model/state_dict_adapter.py b/torchtitan/experiments/qwen3/model/state_dict_adapter.py new file mode 100644 index 000000000..30fef67a1 --- /dev/null +++ b/torchtitan/experiments/qwen3/model/state_dict_adapter.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script is adapted from torchtitan/models/llama3/model/state_dict_adapter.py. + +We can use this script to adapt the checkpoint from HF to the format that we can load into the model and vice versa. +This can enable us to do a parity test with the HF implementation and make sure that our results are +aligned with the HF implementation. + +""" +import re +from typing import Any + +from torchtitan.protocols.state_dict_adapter import StateDictAdapter + +from .args import Qwen3ModelArgs + + +class Qwen3StateDictAdapter(StateDictAdapter): + def __init__(self, model_args: Qwen3ModelArgs): + self.model_args = model_args + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + hf_state_dict = {} + + for key, value in state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = to_hf_map[abstract_key] + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = to_hf_map[key] + + hf_state_dict[new_key] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + + state_dict = {} + + for key, value in hf_state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = self.from_hf_map[abstract_key] + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = self.from_hf_map[key] + + state_dict[new_key] = value + return state_dict From 8dbceebfd1a73b812946458aa9f42cc238450e8b Mon Sep 17 00:00:00 2001 From: Hossein Kavianihamedani Date: Tue, 19 Aug 2025 11:17:37 -0700 Subject: [PATCH 2/6] Adding state_dict_adapter --- torchtitan/experiments/qwen3/model/state_dict_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/experiments/qwen3/model/state_dict_adapter.py b/torchtitan/experiments/qwen3/model/state_dict_adapter.py index 30fef67a1..ae9d0327b 100644 --- a/torchtitan/experiments/qwen3/model/state_dict_adapter.py +++ b/torchtitan/experiments/qwen3/model/state_dict_adapter.py @@ -7,7 +7,7 @@ """ This script is adapted from torchtitan/models/llama3/model/state_dict_adapter.py. -We can use this script to adapt the checkpoint from HF to the format that we can load into the model and vice versa. +We can use this script to adapt the checkpoint from HF to the format that we can load into the torchtitan model and vice versa. This can enable us to do a parity test with the HF implementation and make sure that our results are aligned with the HF implementation. From a755f537ee68c4c05be40987bbb67defa2435423 Mon Sep 17 00:00:00 2001 From: Hossein Kavianihamedani Date: Tue, 19 Aug 2025 11:54:08 -0700 Subject: [PATCH 3/6] Adding state_dict_adapter --- torchtitan/experiments/qwen3/model/state_dict_adapter.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/qwen3/model/state_dict_adapter.py b/torchtitan/experiments/qwen3/model/state_dict_adapter.py index ae9d0327b..760cc662b 100644 --- a/torchtitan/experiments/qwen3/model/state_dict_adapter.py +++ b/torchtitan/experiments/qwen3/model/state_dict_adapter.py @@ -21,8 +21,12 @@ class Qwen3StateDictAdapter(StateDictAdapter): - def __init__(self, model_args: Qwen3ModelArgs): + def __init__(self, model_args: Qwen3ModelArgs, hf_assets_path: str | None): + super().__init__(model_args, hf_assets_path) + self.model_args = model_args + self.hf_assets_path = hf_assets_path + self.from_hf_map = { "model.embed_tokens.weight": "tok_embeddings.weight", "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", From 0a04bde272e46c74a3f66b2629965e9a1315e755 Mon Sep 17 00:00:00 2001 From: Hossein Kavianihamedani Date: Tue, 19 Aug 2025 16:19:13 -0700 Subject: [PATCH 4/6] Resolve README conflict --- torchtitan/experiments/qwen3/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/experiments/qwen3/README.md b/torchtitan/experiments/qwen3/README.md index c145e6245..0ce2dc958 100644 --- a/torchtitan/experiments/qwen3/README.md +++ b/torchtitan/experiments/qwen3/README.md @@ -17,7 +17,6 @@ Other model sizes are added to the args, but toml file configs need to be added Model parity test has been done and results suggest parity with HF implementation. - #### To be added - Modeling - Variants of Dense models up to 32B From 41f658968cf12b1246fe9dd362c155f7c6704d61 Mon Sep 17 00:00:00 2001 From: Hossein Kavianihamedani Date: Wed, 20 Aug 2025 14:25:19 -0700 Subject: [PATCH 5/6] Resolve README conflict and add StateDictAdapter changes --- torchtitan/experiments/qwen3/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/experiments/qwen3/__init__.py b/torchtitan/experiments/qwen3/__init__.py index d22053ff6..31b3ff699 100644 --- a/torchtitan/experiments/qwen3/__init__.py +++ b/torchtitan/experiments/qwen3/__init__.py @@ -25,7 +25,6 @@ "qwen3_configs", ] - # Adding different variants of the model qwen3_configs = { From 4c790fb6dff976e56740aab0abb3ad1e5245c8a3 Mon Sep 17 00:00:00 2001 From: Hossein Kavianihamedani Date: Wed, 20 Aug 2025 14:30:37 -0700 Subject: [PATCH 6/6] Update __init__.py file --- torchtitan/experiments/qwen3/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchtitan/experiments/qwen3/__init__.py b/torchtitan/experiments/qwen3/__init__.py index 31b3ff699..c6d2549ed 100644 --- a/torchtitan/experiments/qwen3/__init__.py +++ b/torchtitan/experiments/qwen3/__init__.py @@ -17,6 +17,7 @@ from .infra.parallelize import parallelize_qwen3 from .model.args import Qwen3ModelArgs from .model.model import Qwen3Model +from .model.state_dict_adapter import Qwen3StateDictAdapter __all__ = [ "parallelize_qwen3", @@ -116,5 +117,6 @@ build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, build_validator_fn=build_validator, + state_dict_adapter=Qwen3StateDictAdapter, ) )