diff --git a/.github/workflows/e2e_test.yml b/.github/workflows/e2e_test.yml index 0d3828a6..9aad7130 100644 --- a/.github/workflows/e2e_test.yml +++ b/.github/workflows/e2e_test.yml @@ -32,6 +32,7 @@ jobs: llama-3-8b-ddp-fsdp-name: ${{ steps.run-llama-3-8b-ddp-fsdp.outputs.name }} llama-3-8b-fsdp-cp-name: ${{ steps.run-llama-3-8b-fsdp-cp.outputs.name }} mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }} + ds-v3-shallow-name: ${{ steps.run-ds-v3-shallow.outputs.name }} artifact-dir: ${{ steps.artifacts.outputs.artifact_dir }} steps: - name: Record artifact dir @@ -286,6 +287,25 @@ jobs: ici_mesh.fsdp=4 \ profile_start_step=3 + - name: Run Deepseek v3 Shallow + id: run-ds-v3-shallow + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + XLA_IR_DEBUG: 1 + XLA_HLO_DEBUG: 1 + run: | + name=$(e2e_testing/gen_name.py ds-v3-shallow) + echo "name=$name" >> "$GITHUB_OUTPUT" + tp run ${{ steps.docker-url-option.outputs.value }} \ + --name $name \ + torchprime/torch_xla_models/train.py \ + model=deepseek-v3-shallow \ + dataset.block_size=1024 \ + task.global_batch_size=4 \ + task.max_steps=15 \ + ici_mesh.fsdp=4 \ + profile_start_step=5 + # Load reference step times load-benchmarks: name: Load reference step times @@ -336,7 +356,8 @@ jobs: matrix.config.benchmark == 'llama-3-8b-sft' && needs.tp-run.outputs.llama-3-8b-sft-name || matrix.config.benchmark == 'llama-3-8b-2-slice' && needs.tp-run.outputs.llama-3-8b-2-slice-name || matrix.config.benchmark == 'llama-3-8b-ddp-fsdp' && needs.tp-run.outputs.llama-3-8b-ddp-fsdp-name || - matrix.config.benchmark == 'llama-3-8b-fsdp-cp' && needs.tp-run.outputs.llama-3-8b-fsdp-cp-name + matrix.config.benchmark == 'llama-3-8b-fsdp-cp' && needs.tp-run.outputs.llama-3-8b-fsdp-cp-name || + matrix.config.benchmark == 'ds-v3-shallow' && needs.tp-run.outputs.ds-v3-shallow-name }} artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }} step_time_lower_bound: ${{ matrix.config.lower_bound }} diff --git a/e2e_testing/step_time_bounds.yaml b/e2e_testing/step_time_bounds.yaml index 02bfbed5..f0a0525c 100644 --- a/e2e_testing/step_time_bounds.yaml +++ b/e2e_testing/step_time_bounds.yaml @@ -64,6 +64,13 @@ benchmarks: confidence_interval: 0.02426 average: 1.6172 sample_size: 51 + ds-v3-shallow: + name: Deepseek v3 Shallow # dummy number + step_time_lower_bound: 0.1 + step_time_upper_bound: 1.1 + confidence_interval: 0.5 + average: 0.6 + sample_size: 10 metadata: query_start: '2025-07-01T00:00:00-07:00' query_end: '2025-07-23T12:20:48-07:00' diff --git a/e2e_testing/update_step_time.py b/e2e_testing/update_step_time.py index 4ead058f..46b19375 100755 --- a/e2e_testing/update_step_time.py +++ b/e2e_testing/update_step_time.py @@ -122,6 +122,15 @@ def match_llama_3_8b_fsdp_cp(row): ) +def match_ds_v3_debug(row): + config = json.loads(row.configs_framework) + return ( + row.run_id.startswith("ds-v3-shallow") + and config["ici_mesh"]["fsdp"] == 4 + and config["ici_mesh"]["tensor"] == 1 + ) + + BENCHMARKS = { "Llama 3.0 8B": match_llama3_8b, "Llama 3.0 8B (@assume_pure)": match_llama3_8b_pure_mlp, @@ -133,6 +142,7 @@ def match_llama_3_8b_fsdp_cp(row): "Llama 3.0 8B SFT": match_llama_3_8b_sft, "Llama 3.0 8B (ddp + fsdp)": match_llama_3_8b_ddp_fsdp, "Llama 3.0 8B (fsdp + cp)": match_llama_3_8b_fsdp_cp, + "Deepseek v3 Debug Model": match_ds_v3_debug, } STEP_ID_MAPPING = { @@ -146,6 +156,7 @@ def match_llama_3_8b_fsdp_cp(row): "Llama 3.0 8B SFT": "llama-3-8b-sft", "Llama 3.0 8B (ddp + fsdp)": "llama-3-8b-ddp-fsdp", "Llama 3.0 8B (fsdp + cp)": "llama-3-8b-fsdp-cp", + "Deepseek v3 Debug Model": "ds-v3-shallow", } """Mapping from the benchmark name to the ID of the E2E test step used in GitHub Actions.""" diff --git a/torchprime/metrics/mfu.py b/torchprime/metrics/mfu.py index 54538d37..a31e1180 100644 --- a/torchprime/metrics/mfu.py +++ b/torchprime/metrics/mfu.py @@ -134,6 +134,109 @@ def calculate_tflops_training_per_device(config: Config, log=True): return total_tflops +def calculate_tflops_training_per_device_deepseek( + *, + per_device_batch_size: int, + seq_len: int, + hidden_size: int, + intermediate_size: int, + moe_intermediate_size: int, + num_hidden_layers: int, + first_k_dense_replace: int, + num_attention_heads: int, + qk_head_dim: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + kv_lora_rank: int, + num_key_value_heads: int, + num_routed_experts: int, + n_shared_experts: int, + num_experts_per_tok: int, + vocab_size: int, + capacity_factor: float = 1.5, + gradient_accumulation_steps: int = 1, + include_softmax: bool = False, +) -> float: + """ + Per-device TFLOPs *per optimizer step* for DeepSeek-v3 training. + + Assumptions + ----------- + • BF16 / FP16 → 2 FLOPs per MAC + • MLA FFN (3 linears + gating multiply) + • MoE begins after `first_k_dense_replace` + • One shared-expert FFN path in every MoE layer + • Optional soft-max term (set include_softmax=True for >~5 % extra) + """ + + # -------------------------------------------------------- constants ---- + B, L, H = per_device_batch_size, seq_len, hidden_size + L_dense = first_k_dense_replace + L_moe = num_hidden_layers - L_dense + tokens = B * L + fwd_bwd = 3 # forward + backward factor + BF16 = 2 # FLOPs per MAC in bf16/fp16 + + # -------------------------------------------------------------- FFNs --- + # Dense MLA FFN (first L_dense layers) + ffn_dense_flops = 3 * H * intermediate_size * BF16 + intermediate_size + ffn_dense_flops *= tokens * L_dense + + # Gating linear in every MoE layer + moe_gate_flops = 2 * H * num_routed_experts * tokens * L_moe + + # Per-expert MLA FFN (K experts/token) + moe_ffn_tok = 3 * H * moe_intermediate_size * BF16 + moe_intermediate_size + moe_ffn_flops = moe_ffn_tok * tokens * num_experts_per_tok * L_moe * capacity_factor + + # Shared-expert MLA FFN (runs on *all* tokens in every MoE layer) + M_shared = moe_intermediate_size * n_shared_experts + shared_ffn_tok = 3 * H * M_shared * BF16 + M_shared + shared_ffn_flops = shared_ffn_tok * tokens * L_moe + + total_ffn_flops = ffn_dense_flops + moe_gate_flops + moe_ffn_flops + shared_ffn_flops + + # ------------------------------------------------------- projections --- + q_proj_flops = 2 * H * num_attention_heads * qk_head_dim * tokens + kv_a_flops = 2 * H * (kv_lora_rank + qk_rope_head_dim) * tokens + kv_b_out_dim = num_attention_heads * (qk_nope_head_dim + v_head_dim) + kv_b_flops = 2 * kv_lora_rank * kv_b_out_dim * tokens + o_proj_flops = 2 * H * num_attention_heads * v_head_dim * tokens + + proj_flops_layer = q_proj_flops + kv_a_flops + kv_b_flops + o_proj_flops + proj_flops_total = proj_flops_layer * num_hidden_layers + + # ---------------------------------------------------- attention core --- + attn_qk = 2 * num_attention_heads * qk_head_dim * L * L * B + attn_av = 2 * num_attention_heads * v_head_dim * L * L * B + attn_core_layer = attn_qk + attn_av + + softmax_flops_layer = 4 * B * L * L * num_attention_heads if include_softmax else 0 + + attn_core_total = (attn_core_layer + softmax_flops_layer) * num_hidden_layers + + # --------------------------------------------- embedding / lm-head ---- + embed_flops = 2 * H * vocab_size * tokens # embedding + lm_head + + # ------------------------------------------------ aggregate numbers --- + trainable = (total_ffn_flops + proj_flops_total + embed_flops) * fwd_bwd + attention = attn_core_total * fwd_bwd + total = (trainable + attention) * gradient_accumulation_steps + tflops = total / 1e12 + + # ----------------------------------------------------- quick report --- + print(f"[DeepSeek-v3] TFLOPs/device/step : {tflops:>.2f}") + print(f" • FFNs (dense+MoE+shared) : {total_ffn_flops * fwd_bwd / 1e12:>.2f}") + print(f" • Attn projections : {proj_flops_total * fwd_bwd / 1e12:>.2f}") + print( + f" • Attn QK/AV{' + softmax' if include_softmax else ''} : {attention / 1e12:>.2f}" + ) + print(f" • Embed + LM head : {embed_flops * fwd_bwd / 1e12:>.2f}") + + return tflops + + def compute_mfu( config: dict, batch_size: int, @@ -164,24 +267,49 @@ def compute_mfu( torch_dtype: data type used for training (e.g. `bfloat16`). """ - total_tflops = calculate_tflops_training_per_device( - Config( + if "model_id" in config and "deepseek" in config["model_id"]: + total_tflops = calculate_tflops_training_per_device_deepseek( per_device_batch_size=batch_size, - max_target_length=sequence_length, - mlp_dim=int(config["intermediate_size"]), - emb_dim=int(config["hidden_size"]), - mlp_activations=["silu", "linear"], - num_experts=int(config.get("num_local_experts", 1)), - num_experts_per_tok=int(config.get("num_experts_per_tok", 1)), - num_query_heads=int(config["num_attention_heads"]), - num_kv_heads=int(config["num_key_value_heads"]), - head_dim=int(config["hidden_size"] / config["num_attention_heads"]), - num_decoder_layers=int(config["num_hidden_layers"]), + seq_len=sequence_length, + hidden_size=int(config["hidden_size"]), + intermediate_size=int(config["intermediate_size"]), + moe_intermediate_size=int(config["moe_intermediate_size"]), + num_hidden_layers=int(config["num_hidden_layers"]), + first_k_dense_replace=int(config["first_k_dense_replace"]), + num_attention_heads=int(config["num_attention_heads"]), + qk_head_dim=int(config["qk_head_dim"]), + qk_nope_head_dim=int(config["qk_nope_head_dim"]), + qk_rope_head_dim=int(config["qk_rope_head_dim"]), + v_head_dim=int(config["v_head_dim"]), + kv_lora_rank=int(config["kv_lora_rank"]), + num_key_value_heads=int(config["num_key_value_heads"]), + num_routed_experts=int(config["n_routed_experts"]), + n_shared_experts=int(config["n_shared_experts"]), + num_experts_per_tok=int(config["num_experts_per_tok"]), vocab_size=int(config["vocab_size"]), - gradient_accumulation_steps=gradient_accumulation_steps, - ), - log=False, - ) + capacity_factor=float(config.get("capacity_factor", 1.5)), + gradient_accumulation_steps=1, + include_softmax=True, + ) + else: + total_tflops = calculate_tflops_training_per_device( + Config( + per_device_batch_size=batch_size, + max_target_length=sequence_length, + mlp_dim=int(config["intermediate_size"]), + emb_dim=int(config["hidden_size"]), + mlp_activations=["silu", "linear"], + num_experts=int(config.get("num_local_experts", 1)), + num_experts_per_tok=int(config.get("num_experts_per_tok", 1)), + num_query_heads=int(config["num_attention_heads"]), + num_kv_heads=int(config["num_key_value_heads"]), + head_dim=int(config["hidden_size"] / config["num_attention_heads"]), + num_decoder_layers=int(config["num_hidden_layers"]), + vocab_size=int(config["vocab_size"]), + gradient_accumulation_steps=gradient_accumulation_steps, + ), + log=True, + ) assert torch_dtype == "bfloat16", f"Unsupported dtype {torch_dtype}" diff --git a/torchprime/metrics/step_duration.py b/torchprime/metrics/step_duration.py index b6efd58e..75d6fc1b 100644 --- a/torchprime/metrics/step_duration.py +++ b/torchprime/metrics/step_duration.py @@ -155,6 +155,7 @@ def analyze_step_duration_from_pb(xspace: XSpace) -> float: # Confirm we have exactly one unique event name if len(unique_names) > 1: raise ValueError(f"Ambiguous event names found in XSpace: {unique_names}") + # print(f"Ambiguous event names found in XSpace: {unique_names}") inferred_event_name = max(unique_names) diff --git a/torchprime/rope/rope.py b/torchprime/rope/rope.py index 00940601..bb1f544d 100644 --- a/torchprime/rope/rope.py +++ b/torchprime/rope/rope.py @@ -7,6 +7,7 @@ from dataclasses import dataclass import torch +from omegaconf import DictConfig @dataclass(kw_only=True) @@ -72,3 +73,105 @@ def llama3_rope_frequencies( freqs = torch.where(is_medium_freq, smoothed_freqs, freqs) return freqs + + +def deepseek_v3_rope_init_fn(config: DictConfig) -> tuple["torch.Tensor", float]: + """ + copied from HF implementation `_compute_yarn_parameters` function, from + https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C5-L197C29 + + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://huggingface.co/papers/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + + assert hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, DictConfig) + assert config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) == "yarn" + base = config.rope_theta + partial_rotary_factor = ( + config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + ) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + dim = int(head_dim * partial_rotary_factor) + factor = config.rope_scaling["factor"] + attention_factor = config.rope_scaling.get("attention_factor") + mscale = config.rope_scaling.get("mscale") + mscale_all_dim = config.rope_scaling.get("mscale_all_dim") + + # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a + # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two + # values to compute the default attention scaling factor, instead of using `factor`. + if "original_max_position_embeddings" in config.rope_scaling: + original_max_position_embeddings = config.rope_scaling[ + "original_max_position_embeddings" + ] + factor = config.max_position_embeddings / original_max_position_embeddings + else: + original_max_position_embeddings = config.max_position_embeddings + + def get_mscale(scale, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if mscale and mscale_all_dim: + attention_factor = float( + get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim) + ) + else: + attention_factor = get_mscale(factor) + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): + """Find dimension range bounds based on rotations""" + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. + pos_freqs = base ** (torch.arange(0, dim, 2).to(dtype=torch.float) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, original_max_position_embeddings + ) + + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to( + dtype=torch.float + ) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + return inv_freq, attention_factor diff --git a/torchprime/torch_xla_models/configs/model/deepseek-v3-mini.yaml b/torchprime/torch_xla_models/configs/model/deepseek-v3-mini.yaml new file mode 100644 index 00000000..74924329 --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/deepseek-v3-mini.yaml @@ -0,0 +1,59 @@ +defaults: + - _self_ # refers to this config file + - sharding: deepseek-fsdp-tp # refers to sharding/deepseek-fsdp-tp.yaml + - remat: deepseek # refers to remat/deepseek.yaml + +model_id: deepseek-v3 +model_class: deepseek_v3.DeepseekV3ForCausalLM +tokenizer_name: deepseek-ai/deepseek-v3 +attention_kernel: null + +# Configuration automatically generated from HF +vocab_size: 129280 +max_position_embeddings: 4096 +hidden_size: 448 # 7168 // 16 +intermediate_size: 1152 # 18432 // 16 +moe_intermediate_size: 128 # 2048 // 16 +num_hidden_layers: 8 # from 61 +num_attention_heads: 8 # 128 // 16 +n_shared_experts: 1 +n_routed_experts: 16 # 256 // 16 +routed_scaling_factor: 2.5 +kv_lora_rank: 32 # 512 // 16 +q_lora_rank: 96 # 1536 // 16 +qk_rope_head_dim: 4 # 64 // 16 +v_head_dim: 8 # 128 // 16 +qk_nope_head_dim: 8 # 128 // 16 +qk_head_dim: 12 # 192 // 16 +head_dim: 4 # 64 // 16 +num_key_value_heads: 8 # 128 // 16 +n_group: 4 # from 8 +topk_group: 4 +num_experts_per_tok: 8 +first_k_dense_replace: 3 # from 3 +norm_topk_prob: true +rope_interleave: true +rope_scaling: + beta_fast: 32 + beta_slow: 1 + factor: 40 + mscale: 1.0 + mscale_all_dim: 1.0 + original_max_position_embeddings: 4096 + rope_type: "yarn" + type: "yarn" +hidden_act: silu +initializer_range: 0.02 +rms_norm_eps: 1.0e-06 +rope_theta: 10000 +attention_bias: false +attention_dropout: 0.0 +return_dict: true +output_hidden_states: false +output_attentions: false +torchscript: false +torch_dtype: bfloat16 +use_bfloat16: false +bos_token_id: 0 +pad_token_id: null +eos_token_id: 1 \ No newline at end of file diff --git a/torchprime/torch_xla_models/configs/model/deepseek-v3-shallow.yaml b/torchprime/torch_xla_models/configs/model/deepseek-v3-shallow.yaml new file mode 100644 index 00000000..4eaed42f --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/deepseek-v3-shallow.yaml @@ -0,0 +1,61 @@ +defaults: + - _self_ # refers to this config file + - sharding: deepseek-fsdp-tp-ep # refers to sharding/deepseek-fsdp-tp.yaml + - remat: deepseek # refers to remat/deepseek.yaml + +model_id: deepseek-v3 +model_class: deepseek_v3.DeepseekV3ForCausalLM +tokenizer_name: deepseek-ai/deepseek-v3 +# choose attention_kernel from: [splash_attention, null] # flash_attention does not work with DeepSeek V3 +attention_kernel: splash_attention +capacity_factor: 1.25 + +# Configuration automatically generated from HF +vocab_size: 129280 +max_position_embeddings: 4096 +hidden_size: 7168 +intermediate_size: 18432 +moe_intermediate_size: 2048 +num_hidden_layers: 4 # scale down from 61 +num_attention_heads: 128 +n_shared_experts: 1 +n_routed_experts: 256 +routed_scaling_factor: 2.5 +kv_lora_rank: 512 +q_lora_rank: 1536 +qk_rope_head_dim: 64 +v_head_dim: 128 +qk_nope_head_dim: 128 +qk_head_dim: 192 +head_dim: 64 +n_group: 8 +topk_group: 4 +num_experts_per_tok: 8 +first_k_dense_replace: 3 +norm_topk_prob: true +rope_interleave: true +rope_scaling: + beta_fast: 32 + beta_slow: 1 + factor: 40 + mscale: 1.0 + mscale_all_dim: 1.0 + original_max_position_embeddings: 4096 + rope_type: "yarn" + type: "yarn" +num_key_value_heads: 128 +hidden_act: silu +initializer_range: 0.02 +rms_norm_eps: 1.0e-06 +rope_theta: 10000 +attention_bias: false +attention_dropout: 0.0 +return_dict: true +output_hidden_states: false +output_attentions: false +torchscript: false +torch_dtype: bfloat16 +use_bfloat16: false +bos_token_id: 0 +pad_token_id: null +eos_token_id: 1 \ No newline at end of file diff --git a/torchprime/torch_xla_models/configs/model/deepseek-v3.yaml b/torchprime/torch_xla_models/configs/model/deepseek-v3.yaml new file mode 100644 index 00000000..97a1bdf4 --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/deepseek-v3.yaml @@ -0,0 +1,61 @@ +defaults: + - _self_ # refers to this config file + - sharding: deepseek-fsdp-tp-ep # refers to sharding/deepseek-fsdp-tp.yaml + - remat: deepseek # refers to remat/deepseek.yaml + +model_id: deepseek-v3 +model_class: deepseek_v3.DeepseekV3ForCausalLM +tokenizer_name: deepseek-ai/deepseek-v3 +# choose attention_kernel from: [splash_attention, null] # flash_attention does not work with DeepSeek V3 +attention_kernel: splash_attention +capacity_factor: 1.25 + +# Configuration automatically generated from HF +vocab_size: 129280 +max_position_embeddings: 4096 +hidden_size: 7168 +intermediate_size: 18432 +moe_intermediate_size: 2048 +num_hidden_layers: 61 +num_attention_heads: 128 +n_shared_experts: 1 +n_routed_experts: 256 +routed_scaling_factor: 2.5 +kv_lora_rank: 512 +q_lora_rank: 1536 +qk_rope_head_dim: 64 +v_head_dim: 128 +qk_nope_head_dim: 128 +qk_head_dim: 192 +head_dim: 64 +n_group: 8 +topk_group: 4 +num_experts_per_tok: 8 +first_k_dense_replace: 3 +norm_topk_prob: true +rope_interleave: true +rope_scaling: + beta_fast: 32 + beta_slow: 1 + factor: 40 + mscale: 1.0 + mscale_all_dim: 1.0 + original_max_position_embeddings: 4096 + rope_type: "yarn" + type: "yarn" +num_key_value_heads: 128 +hidden_act: silu +initializer_range: 0.02 +rms_norm_eps: 1.0e-06 +rope_theta: 10000 +attention_bias: false +attention_dropout: 0.0 +return_dict: true +output_hidden_states: false +output_attentions: false +torchscript: false +torch_dtype: bfloat16 +use_bfloat16: false +bos_token_id: 0 +pad_token_id: null +eos_token_id: 1 \ No newline at end of file diff --git a/torchprime/torch_xla_models/configs/model/remat/deepseek.yaml b/torchprime/torch_xla_models/configs/model/remat/deepseek.yaml new file mode 100644 index 00000000..49d986cc --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/remat/deepseek.yaml @@ -0,0 +1,5 @@ +activation_checkpoint_layers: + - DeepseekV3DecoderLayer + +optimization_barrier_layers: + - DeepseekV3DecoderLayer diff --git a/torchprime/torch_xla_models/configs/model/sharding/deepseek-fsdp-nomoe.yaml b/torchprime/torch_xla_models/configs/model/sharding/deepseek-fsdp-nomoe.yaml new file mode 100644 index 00000000..eea731eb --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/sharding/deepseek-fsdp-nomoe.yaml @@ -0,0 +1,33 @@ +# Weights +model.embed_tokens.weight: [fsdp, null] + +model.layers.*.self_attn.q_a_proj.weight: [fsdp, null] +model.layers.*.self_attn.q_a_layernorm.weight: [fsdp] +model.layers.*.self_attn.q_b_proj.weight: [fsdp, null] +model.layers.*.self_attn.kv_a_proj_with_mqa.weight: [fsdp, null] +model.layers.*.self_attn.kv_a_layernorm.weight: [fsdp] +model.layers.*.self_attn.kv_b_proj.weight: [null, fsdp] +model.layers.*.self_attn.o_proj.weight: [null, fsdp] + +model.layers.*.mlp.gate_proj.weight: [fsdp, null] +model.layers.*.mlp.up_proj.weight: [fsdp, null] +model.layers.*.mlp.down_proj.weight: [null, fsdp] + +# model.layers.*.mlp.gate.weight: [null, fsdp] +# model.layers.*.mlp.experts.*.gate_proj.weight: [fsdp, null] +# model.layers.*.mlp.experts.*.up_proj.weight: [fsdp, null] +# model.layers.*.mlp.experts.*.down_proj.weight: [null, fsdp] + +# model.layers.*.mlp.shared_experts.gate_proj.weight: [fsdp, null] +# model.layers.*.mlp.shared_experts.up_proj.weight: [fsdp, null] +# model.layers.*.mlp.shared_experts.down_proj.weight: [null, fsdp] + +model.layers.*.input_layernorm.weight: [fsdp] +model.layers.*.post_attention_layernorm.weight: [fsdp] + +model.norm.weight: [fsdp] +lm_head.weight: [fsdp, null] + +# Activations +model.layers.*: [[data, fsdp], null, null] +lm_head: [[data, fsdp], null, null] diff --git a/torchprime/torch_xla_models/configs/model/sharding/deepseek-fsdp-tp-ep.yaml b/torchprime/torch_xla_models/configs/model/sharding/deepseek-fsdp-tp-ep.yaml new file mode 100644 index 00000000..4bcfe7b9 --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/sharding/deepseek-fsdp-tp-ep.yaml @@ -0,0 +1,57 @@ +# Weights + +# vocab_size, hidden_size +model.embed_tokens.weight: [fsdp, tensor] + +# q_lora_rank, hidden_size +model.layers.*.self_attn.q_a_proj.weight: [fsdp, tensor] +# q_lora_rank +model.layers.*.self_attn.q_a_layernorm.weight: [fsdp] +# num_attention_heads * qk_head_dim, q_lora_rank +model.layers.*.self_attn.q_b_proj.weight: [fsdp, tensor] +# kv_lora_rank + qk_rope_head_dim, hidden_size +model.layers.*.self_attn.kv_a_proj_with_mqa.weight: [fsdp, tensor] +# kv_lora_rank +model.layers.*.self_attn.kv_a_layernorm.weight: [fsdp] +# num_attention_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank +model.layers.*.self_attn.kv_b_proj.weight: [tensor, fsdp] +# hidden_size, kv_lora_rank +model.layers.*.self_attn.o_proj.weight: [tensor, fsdp] + + +# intermediate_size, hidden_size +model.layers.*.mlp.gate_proj.weight: [fsdp, tensor] +# intermediate_size, hidden_size +model.layers.*.mlp.up_proj.weight: [fsdp, tensor] +# hidden_size, intermediate_size +model.layers.*.mlp.down_proj.weight: [tensor, fsdp] + +# n_routed_experts, hidden_size +model.layers.*.mlp.gate.weight: [expert, fsdp] +# n_routed_experts, hidden_size, moe_intermediate_size +model.layers.*.mlp.grouped.W_gate: [expert, tensor, fsdp] +# n_routed_experts, hidden_size, moe_intermediate_size +model.layers.*.mlp.grouped.W_up: [expert, tensor, fsdp] +# n_routed_experts, moe_intermediate_size, hidden_size +model.layers.*.mlp.grouped.W_down: [expert, fsdp, tensor] + +# moe_intermediate_size, hidden_size +model.layers.*.mlp.shared_experts.gate_proj.weight: [fsdp, tensor] +# moe_intermediate_size, hidden_size +model.layers.*.mlp.shared_experts.up_proj.weight: [fsdp, tensor] +# hidden_size, moe_intermediate_size +model.layers.*.mlp.shared_experts.down_proj.weight: [tensor, fsdp] + +# hidden_size +model.layers.*.input_layernorm.weight: [fsdp] +# hidden_size +model.layers.*.post_attention_layernorm.weight: [fsdp] + +# hidden_size +model.norm.weight: [fsdp] +# vocab_size, hidden_size +lm_head.weight: [fsdp, tensor] + +# Activations +model.layers.*: [[data, fsdp], null, tensor] +lm_head: [[data, fsdp], null, tensor] diff --git a/torchprime/torch_xla_models/configs/model/sharding/deepseek-fsdp-tp.yaml b/torchprime/torch_xla_models/configs/model/sharding/deepseek-fsdp-tp.yaml new file mode 100644 index 00000000..9f31c81a --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/sharding/deepseek-fsdp-tp.yaml @@ -0,0 +1,57 @@ +# Weights + +# vocab_size, hidden_size +model.embed_tokens.weight: [fsdp, tensor] + +# q_lora_rank, hidden_size +model.layers.*.self_attn.q_a_proj.weight: [fsdp, tensor] +# q_lora_rank +model.layers.*.self_attn.q_a_layernorm.weight: [fsdp] +# num_attention_heads * qk_head_dim, q_lora_rank +model.layers.*.self_attn.q_b_proj.weight: [fsdp, tensor] +# kv_lora_rank + qk_rope_head_dim, hidden_size +model.layers.*.self_attn.kv_a_proj_with_mqa.weight: [fsdp, tensor] +# kv_lora_rank +model.layers.*.self_attn.kv_a_layernorm.weight: [fsdp] +# num_attention_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank +model.layers.*.self_attn.kv_b_proj.weight: [tensor, fsdp] +# hidden_size, kv_lora_rank +model.layers.*.self_attn.o_proj.weight: [tensor, fsdp] + + +# intermediate_size, hidden_size +model.layers.*.mlp.gate_proj.weight: [fsdp, tensor] +# intermediate_size, hidden_size +model.layers.*.mlp.up_proj.weight: [fsdp, tensor] +# hidden_size, intermediate_size +model.layers.*.mlp.down_proj.weight: [tensor, fsdp] + +# n_routed_experts, hidden_size +model.layers.*.mlp.gate.weight: [null, fsdp] +# n_routed_experts, hidden_size, moe_intermediate_size +model.layers.*.mlp.grouped.W_gate: [null, tensor, fsdp] +# n_routed_experts, hidden_size, moe_intermediate_size +model.layers.*.mlp.grouped.W_up: [null, tensor, fsdp] +# n_routed_experts, moe_intermediate_size, hidden_size +model.layers.*.mlp.grouped.W_down: [null, fsdp, tensor] + +# moe_intermediate_size, hidden_size +model.layers.*.mlp.shared_experts.gate_proj.weight: [fsdp, tensor] +# moe_intermediate_size, hidden_size +model.layers.*.mlp.shared_experts.up_proj.weight: [fsdp, tensor] +# hidden_size, moe_intermediate_size +model.layers.*.mlp.shared_experts.down_proj.weight: [tensor, fsdp] + +# hidden_size +model.layers.*.input_layernorm.weight: [fsdp] +# hidden_size +model.layers.*.post_attention_layernorm.weight: [fsdp] + +# hidden_size +model.norm.weight: [fsdp] +# vocab_size, hidden_size +lm_head.weight: [fsdp, tensor] + +# Activations +model.layers.*: [[data, fsdp], null, tensor] +lm_head: [[data, fsdp], null, tensor] diff --git a/torchprime/torch_xla_models/configs/model/sharding/deepseek-fsdp.yaml b/torchprime/torch_xla_models/configs/model/sharding/deepseek-fsdp.yaml new file mode 100644 index 00000000..da8431b4 --- /dev/null +++ b/torchprime/torch_xla_models/configs/model/sharding/deepseek-fsdp.yaml @@ -0,0 +1,33 @@ +# Weights +model.embed_tokens.weight: [fsdp, null] + +model.layers.*.self_attn.q_a_proj.weight: [fsdp, null] +model.layers.*.self_attn.q_a_layernorm.weight: [fsdp] +model.layers.*.self_attn.q_b_proj.weight: [fsdp, null] +model.layers.*.self_attn.kv_a_proj_with_mqa.weight: [fsdp, null] +model.layers.*.self_attn.kv_a_layernorm.weight: [fsdp] +model.layers.*.self_attn.kv_b_proj.weight: [null, fsdp] +model.layers.*.self_attn.o_proj.weight: [null, fsdp] + +model.layers.*.mlp.gate_proj.weight: [fsdp, null] +model.layers.*.mlp.up_proj.weight: [fsdp, null] +model.layers.*.mlp.down_proj.weight: [null, fsdp] + +model.layers.*.mlp.gate.weight: [null, fsdp] +model.layers.*.mlp.experts.*.gate_proj.weight: [fsdp, null] +model.layers.*.mlp.experts.*.up_proj.weight: [fsdp, null] +model.layers.*.mlp.experts.*.down_proj.weight: [null, fsdp] + +model.layers.*.mlp.shared_experts.gate_proj.weight: [fsdp, null] +model.layers.*.mlp.shared_experts.up_proj.weight: [fsdp, null] +model.layers.*.mlp.shared_experts.down_proj.weight: [null, fsdp] + +model.layers.*.input_layernorm.weight: [fsdp] +model.layers.*.post_attention_layernorm.weight: [fsdp] + +model.norm.weight: [fsdp] +lm_head.weight: [fsdp, null] + +# Activations +model.layers.*: [[data, fsdp], null, null] +lm_head: [[data, fsdp], null, null] diff --git a/torchprime/torch_xla_models/model/deepseek_v3/__init__.py b/torchprime/torch_xla_models/model/deepseek_v3/__init__.py new file mode 100644 index 00000000..ba733a21 --- /dev/null +++ b/torchprime/torch_xla_models/model/deepseek_v3/__init__.py @@ -0,0 +1,3 @@ +from .model import DeepseekV3ForCausalLM, convert_hf_state_dict_for_grouped_moe + +__all__ = ["DeepseekV3ForCausalLM", "convert_hf_state_dict_for_grouped_moe"] # noqa: F401 diff --git a/torchprime/torch_xla_models/model/deepseek_v3/model.py b/torchprime/torch_xla_models/model/deepseek_v3/model.py new file mode 100644 index 00000000..af1a7cba --- /dev/null +++ b/torchprime/torch_xla_models/model/deepseek_v3/model.py @@ -0,0 +1,676 @@ +"""PyTorch/XLA Deepseek v3 model. + +Following the Deepseek v3 implementation from HF transformers +https://github.com/huggingface/transformers/blob/18a7c29ff8431193887e1065777e9cde29d46e53/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +""" + +from __future__ import annotations + +import math + +import torch +import torch.nn.functional as F +import torch_xla.debug.profiler as xp +from omegaconf import DictConfig +from torch import nn +from transformers.activations import ACT2FN +from transformers.utils import logging + +from torchprime.layers.sequential import HomogeneousSequential +from torchprime.rope.rope import deepseek_v3_rope_init_fn +from torchprime.torch_xla_models import offloading +from torchprime.torch_xla_models.attention import AttentionModule +from torchprime.torch_xla_models.loss import cross_entropy_loss +from torchprime.torch_xla_models.model.base_causal_lm import BaseCausalLM +from torchprime.torch_xla_models.model.llama.model import apply_rotary_pos_emb + +logger = logging.get_logger(__name__) +BF16 = torch.bfloat16 + + +class DeepseekV3RMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, dtype=BF16)) + self.variance_epsilon = eps + + @xp.trace_me("DeepseekV3RMSNorm") + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + # hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class DeepseekV3RotaryEmbedding(nn.Module): + inv_freq: nn.Buffer + + def __init__(self, config: DictConfig): + super().__init__() + self.config = config + inv_freq, self.attention_scaling = deepseek_v3_rope_init_fn(self.config) + self.register_buffer("inv_freq", inv_freq.to(BF16), persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + def forward(self, x: torch.Tensor, position_ids: torch.Tensor): + inv_freq_expanded = ( + self.inv_freq[None, :, None].to(BF16).expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].to(BF16) + + device_type = x.device.type + device_type = ( + device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.to(BF16) @ position_ids_expanded.to(BF16)).transpose( + 1, 2 + ) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_interleave( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor | None = None, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def yarn_get_mscale(scale: float = 1.0, mscale: float = 1.0) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekV3MLP(nn.Module): + def __init__( + self, + config: DictConfig, + hidden_size: int | None = None, + intermediate_size: int | None = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + @xp.trace_me("DeepseekV3MLP") + def forward(self, x: torch.Tensor) -> torch.Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class DeepseekV3TopkRouter(nn.Module): + def __init__(self, config: DictConfig): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, config.hidden_size), dtype=BF16) + ) + self.register_buffer( + "e_score_correction_bias", torch.zeros(self.n_routed_experts, dtype=BF16) + ) + + @torch.no_grad() + def get_topk_indices(self, scores: torch.Tensor) -> torch.Tensor: + scores_for_choice = scores.view( + -1, self.n_routed_experts + ) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + return topk_indices + + @xp.trace_me("DeepseekV3TopkRouter") + def forward(self, hidden_states: torch.Tensor): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.to(BF16), self.weight.to(BF16)) + scores = router_logits.sigmoid() + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + +class GroupedMoEWeights(nn.Module): + """Grouped expert weights that can be sharded along the expert dim (E).""" + + def __init__(self, E: int, D: int, H: int, dtype: torch.dtype): + super().__init__() + self.W_gate = nn.Parameter(torch.empty(E, D, H, dtype=dtype)) + self.W_up = nn.Parameter(torch.empty(E, D, H, dtype=dtype)) + self.W_down = nn.Parameter(torch.empty(E, H, D, dtype=dtype)) + nn.init.kaiming_uniform_(self.W_gate, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.W_up, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.W_down, a=math.sqrt(5)) + + +class DeepseekV3MoE(nn.Module): + """ + Mixture-of-Experts with grouped einsum over existing per-expert weights. + + XLA-friendly: + - No dynamic-shape ops (no masked_select/index_select/bincount/repeat_interleave) + - Uses sort + scatter_add_ (int32) + gather + einsum + index_add_ + - Capacity dropping without compaction (dropped -> dummy slot with weight=0) + Checkpoint-compatible: + - Keeps self.experts ModuleList with gate/up/down Linear weights and maps to grouped params + """ + + def __init__(self, config: DictConfig): + super().__init__() + self.config = config + self.E = config.n_routed_experts + self.K = config.num_experts_per_tok + self.D = config.hidden_size + self.I = config.moe_intermediate_size + self.capacity_factor = getattr(config, "capacity_factor", 1.25) + + # Router (unchanged keys) + self.gate = DeepseekV3TopkRouter(config) + + # # Experts (preserve parameter names/keys for checkpoint compatibility) + # self.experts = nn.ModuleList( + # [DeepseekV3MLP(config, intermediate_size=self.I) for _ in range(self.E)] + # ) + + # Grouped weights used in the hot path (shardable along E) + self.grouped = GroupedMoEWeights(self.E, self.D, self.I, dtype=BF16) + + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=self.I * config.n_shared_experts + ) + + self.act_fn = ACT2FN[config.hidden_act] + + # Optional static capacity: set config.static_capacity to a positive int to avoid recompiles + self.static_capacity = int(getattr(config, "static_capacity", 0)) + + @torch.no_grad() + def _pre_load_old_keys(self, state_dict, prefix: str): + """When loading, if old per-expert keys exist, copy them into grouped params.""" + has_old = any( + k.startswith(prefix + "experts.0.gate_proj.weight") + for k in state_dict.keys() # noqa: SIM118 + ) + if not has_old: + return + E = self.E + Wg = torch.stack( + [state_dict[f"{prefix}experts.{e}.gate_proj.weight"].t() for e in range(E)], dim=0 + ) + Wu = torch.stack( + [state_dict[f"{prefix}experts.{e}.up_proj.weight"].t() for e in range(E)], dim=0 + ) + Wd = torch.stack( + [state_dict[f"{prefix}experts.{e}.down_proj.weight"].t() for e in range(E)], dim=0 + ) + # Cast to grouped dtype + Wg = Wg.to(self.grouped.W_gate.dtype) + Wu = Wu.to(self.grouped.W_up.dtype) + Wd = Wd.to(self.grouped.W_down.dtype) + self.grouped.W_gate.copy_(Wg.contiguous()) + self.grouped.W_up.copy_(Wu.contiguous()) + self.grouped.W_down.copy_(Wd.contiguous()) + + @torch.no_grad() + def _post_state_dict_old_keys(self, state_dict, prefix: str): + """When saving, also write old per-expert keys so external tools remain compatible.""" + E = self.E + for e in range(E): + state_dict[f"{prefix}experts.{e}.gate_proj.weight"] = ( + self.grouped.W_gate[e].t().contiguous().to(BF16) + ) + state_dict[f"{prefix}experts.{e}.up_proj.weight"] = ( + self.grouped.W_up[e].t().contiguous().to(BF16) + ) + state_dict[f"{prefix}experts.{e}.down_proj.weight"] = ( + self.grouped.W_down[e].t().contiguous().to(BF16) + ) + + # ------------------------------ core MoE path ------------------------------ + + @torch.no_grad() + def _compute_capacity(self, T: int) -> int: + if self.static_capacity > 0: + return self.static_capacity + return int(math.ceil(self.capacity_factor * T / self.E)) + + def _grouped_weights(self, dtype: torch.dtype): + # Ensure einsum inputs match activation dtype (bf16 recommended on TPU) + return ( + self.grouped.W_gate.to(dtype), + self.grouped.W_up.to(dtype), + self.grouped.W_down.to(dtype), + ) + + @xp.trace_me("DeepseekV3MoE") + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, D = hidden_states.shape + assert D == self.D + device, dtype = hidden_states.device, hidden_states.dtype + T = B * S + E, K = self.E, self.K + + # Flatten tokens + x = hidden_states.reshape(T, D) + + # Router (cast back to bf16 if topk forced f32) + topk_idx, topk_w = self.gate(x) # [T,K], [T,K] + topk_w = topk_w.to(dtype) + + # Build flat arrays of length N=T*K + token_ids = ( + torch.arange(T, device=device, dtype=torch.long) + .view(T, 1) + .expand(T, K) + .reshape(-1) + ) # [N] + expert_ids = topk_idx.reshape(-1).to(torch.long) # [N] + weights = topk_w.reshape(-1) # [N] + + # Sort tokens by expert + expert_ids_sorted, sort_ix = torch.sort(expert_ids) # [N], [N] + token_ids = torch.gather(token_ids, 0, sort_ix) # [N] + weights = torch.gather(weights, 0, sort_ix) # [N] + + # Per-expert counts via scatter_add_ (int32 robust on XLA) + counts_i32 = torch.zeros(E, device=device, dtype=torch.int32) + ones_i32 = torch.ones_like(expert_ids_sorted, dtype=torch.int32) + counts_i32.scatter_add_(0, expert_ids_sorted.to(torch.int32), ones_i32) # [E] + counts = counts_i32.to(torch.long) # [E] + + # Start offset of each expert's segment + group_start = torch.cumsum( + torch.cat([counts.new_zeros(1), counts[:-1]], dim=0), dim=0 + ) # [E], long + + # Position within expert after sort + N = expert_ids_sorted.numel() + arangeN = torch.arange(N, device=device, dtype=torch.long) # [N] + offsets_rep = torch.gather(group_start, 0, expert_ids_sorted) # [N] + pos_in_exp = arangeN - offsets_rep # [N], long + + # Capacity & destination slot (dropped → expert's slot 0 with weight=0) + C = self._compute_capacity(T) + C_long = torch.tensor(C, device=device, dtype=torch.long) + valid = pos_in_exp < C_long # [N] bool + dest = expert_ids_sorted * C_long + torch.minimum(pos_in_exp, C_long - 1) # [N] + dest = torch.where( + valid, dest, expert_ids_sorted * C_long + torch.zeros_like(pos_in_exp) + ) # route dropped to slot 0 + + # Slot tables of length EC = E*C + EC = E * C + slots_token = torch.zeros(EC, device=device, dtype=torch.long) # token id per slot + slots_w = torch.zeros(EC, device=device, dtype=dtype) # gate weight per slot + slot_fill = torch.zeros( + EC, device=device, dtype=dtype + ) # 1.0 if slot filled else 0.0 + + valid_f = valid.to(dtype) + valid_l = valid.to(torch.long) + + # Unique mapping ensures no collisions among valid slots + slots_token.index_add_( + 0, dest, token_ids * valid_l + ) # int add; valid rows write token id + slots_w.index_add_(0, dest, weights * valid_f) # write gate weights at valid slots + slot_fill.index_add_(0, dest, valid_f) # 1.0 for valid slots + + # Gather packed inputs [E, C, D]; dummy slots point to token 0 (weight 0 → no contribution) + gather_idx = slots_token.view(-1, 1).expand(EC, D) # [EC, D] + X_packed = torch.gather(x, 0, gather_idx).view(E, C, D) # [E, C, D] + + # ---------- Grouped MLP via einsum ---------- + W_gate, W_up, W_down = self._grouped_weights(dtype) # [E,D,I], [E,D,I], [E,I,D] + # dims: e=experts, c=capacity, d=hidden, i=intermediate + G = torch.einsum("ecd,edi->eci", X_packed, W_gate) # [E, C, I] + U = torch.einsum("ecd,edi->eci", X_packed, W_up) # [E, C, I] + A = self.act_fn(G) * U # [E, C, I] + Y_packed = torch.einsum("eci,eid->ecd", A, W_down) # [E, C, D] + + # Apply per-slot gate weight (dropped → weight 0 → no contribution) + Y_flat = Y_packed.view(EC, D) * slots_w.unsqueeze(-1) # [EC, D] + + # One global scatter back to [T, D] + out = torch.zeros(T, D, device=device, dtype=dtype) + out.index_add_(0, slots_token, Y_flat) # [T, D] + + # Shared path + reshape + out = out.view(B, S, D) + self.shared_experts(hidden_states) + return out + + +class DeepseekV3Attention(nn.Module): + """Multi-headed latent attention.""" + + def __init__(self, config: DictConfig, layer_idx: int | None = None): + super().__init__() + self.config = config + self.attention_block = AttentionModule(config) + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = ( + config.attention_dropout + ) # this is not used in the current implementation + self.num_heads = config.num_attention_heads + self.rope_theta = config.rope_theta + ############# + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + ############# + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if config.q_lora_rank is None: + self.q_proj = nn.Linear( + config.hidden_size, self.num_heads * self.qk_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + config.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads * (config.qk_nope_head_dim + config.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * config.v_head_dim, config.hidden_size, bias=config.attention_bias + ) + + self.scaling = self.qk_head_dim ** (-0.5) + if config.rope_scaling is not None: + mscale_all_dim = config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale + + @xp.trace_me("DeepseekV3Attention") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + batch_size, seq_length = hidden_states.shape[:2] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = ( + batch_size, + seq_length, + -1, + self.config.qk_nope_head_dim + self.config.v_head_dim, + ) + + if self.config.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + else: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split( + q_states, [self.config.qk_nope_head_dim, self.config.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split( + compressed_kv, [self.config.kv_lora_rank, self.config.qk_rope_head_dim], dim=-1 + ) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split( + k_pass, [self.config.qk_nope_head_dim, self.config.v_head_dim], dim=-1 + ) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.config.qk_rope_head_dim) + cos, sin = position_embeddings + if self.config.rope_interleave: + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + attn_output = self.attention_block( + query_states, key_states, value_states, attention_mask + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, -1) + attn_output = self.o_proj(attn_output) + return attn_output + + +class DeepseekV3DecoderLayer(nn.Module): + def __init__(self, config: DictConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) + if layer_idx >= config.first_k_dense_replace: + self.mlp = DeepseekV3MoE(config) + else: + self.mlp = DeepseekV3MLP(config) + self.input_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + @xp.trace_me("DeepseekV3DecoderLayer") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + hidden_states = offloading.offload_name(hidden_states, "decoder_input") + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states, position_embeddings, attention_mask, position_ids + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class DeepseekV3Model(nn.Module): + def __init__(self, config: DictConfig): + super().__init__() + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = HomogeneousSequential( + *[ + DeepseekV3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) + + @xp.trace_me("DeepseekV3Model") + def forward( + self, input_ids: torch.LongTensor, attention_mask: torch.Tensor | None = None + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + seq_length = inputs_embeds.size(1) + position_ids = ( + torch.arange(seq_length, device=inputs_embeds.device).unsqueeze(0).to(BF16) + ) + + causal_mask = torch.triu( + torch.full((seq_length, seq_length), float("-inf"), device=inputs_embeds.device), + diagonal=1, + ) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + if attention_mask is not None: + causal_mask = causal_mask * attention_mask[:, None, None, :] + + position_embeddings = self.rotary_emb(inputs_embeds, position_ids) + hidden_states = self.layers( + inputs_embeds, + attention_mask=causal_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class DeepseekV3ForCausalLM(BaseCausalLM): + def __init__(self, config: DictConfig): + super().__init__() + self.config = config + self.model = DeepseekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.apply(self._init_weights) + + @xp.trace_me("DeepseekV3ForCausalLM") + def forward( + self, + input_ids: torch.LongTensor, + labels: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + hidden_states = self.model(input_ids=input_ids, attention_mask=attention_mask) + logits = self.lm_head(hidden_states) + # logits = logits.float() + if labels is None: + return logits, None + loss = cross_entropy_loss(logits, labels=labels, vocab_size=self.config.vocab_size) + return logits, loss + + +def convert_hf_state_dict_for_grouped_moe(hf_state_dict, config): + """ + Converts a Hugging Face state_dict with per-expert weights in-place + to use the grouped weight format. + + Args: + hf_state_dict (dict): The state_dict from the Hugging Face model. + config: The model configuration, used to get the number of experts. + + Returns: + dict: The modified state_dict. + """ + # Find all unique MoE layer prefixes (e.g., "model.layers.0.mlp.", "model.layers.1.mlp.", etc.) + moe_prefixes = set() + for key in hf_state_dict.keys(): # noqa: SIM118 + if "experts.0.gate_proj.weight" in key: + # Assumes key format is like '....experts.0.gate_proj.weight' + prefix = key.split("experts.0.gate_proj.weight")[0] + moe_prefixes.add(prefix) + + if not moe_prefixes: + print("No MoE layers with per-expert weights found to convert.") + return hf_state_dict + + E = config.n_routed_experts + + print(f"Found and converting {len(moe_prefixes)} MoE layers with {E} experts each...") + + for prefix in moe_prefixes: + # Pop all the old per-expert weights from the dictionary, transposing them + w_g_list = [ + hf_state_dict.pop(f"{prefix}experts.{e}.gate_proj.weight").t() for e in range(E) + ] + w_u_list = [ + hf_state_dict.pop(f"{prefix}experts.{e}.up_proj.weight").t() for e in range(E) + ] + w_d_list = [ + hf_state_dict.pop(f"{prefix}experts.{e}.down_proj.weight").t() for e in range(E) + ] + + # Stack them to create the new grouped tensors + Wg = torch.stack(w_g_list, dim=0) + Wu = torch.stack(w_u_list, dim=0) + Wd = torch.stack(w_d_list, dim=0) + + # Add the new grouped weight keys to the dictionary + hf_state_dict[f"{prefix}grouped.W_gate"] = Wg + hf_state_dict[f"{prefix}grouped.W_up"] = Wu + hf_state_dict[f"{prefix}grouped.W_down"] = Wd + + print(f" - Converted weights for prefix: {prefix}") + + return hf_state_dict diff --git a/torchprime/torch_xla_models/model/deepseek_v3/model_from_hf.py b/torchprime/torch_xla_models/model/deepseek_v3/model_from_hf.py new file mode 100644 index 00000000..bd0a1eb2 --- /dev/null +++ b/torchprime/torch_xla_models/model/deepseek_v3/model_from_hf.py @@ -0,0 +1,398 @@ +import math +from collections.abc import Callable + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging +from ..llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, + rotate_half, +) +from .configuration_deepseek_v3 import DeepseekV3Config + +logger = logging.get_logger(__name__) + + +class DeepseekV3RMSNorm(LlamaRMSNorm): + pass + + +class DeepseekV3RotaryEmbedding(LlamaRotaryEmbedding): + pass + + +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + r""" + TODO let's just use the original freqcis computation to not have the view + transpose + reshape! This is not optimized! + Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekV3MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class DeepseekV3TopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) + + @torch.no_grad() + def get_topk_indices(self, scores): + scores_for_choice = scores.view( + -1, self.n_routed_experts + ) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + return topk_indices + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32) + ) + scores = router_logits.sigmoid() + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + +class DeepseekV3MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.experts = nn.ModuleList( + [ + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(config.n_routed_experts) + ] + ) + self.gate = DeepseekV3TopkRouter(config) + self.shared_experts = DeepseekV3MLP( + config=config, + intermediate_size=config.moe_intermediate_size * config.n_shared_experts, + ) + + def moe( + self, + hidden_states: torch.Tensor, + topk_indices: torch.Tensor, + topk_weights: torch.Tensor, + ): + r""" + CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused + to not have to do a loop here (deepseek has 256 experts soooo yeah). + """ + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot( + topk_indices, num_classes=len(self.experts) + ) + expert_mask = expert_mask.permute(2, 0, 1) + + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) + + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, weighted_output) + + # in original deepseek, the output of the experts are gathered once we leave this module + # thus the moe module is itelsf an IsolatedParallel module + # and all expert are "local" meaning we shard but we don't gather + return final_hidden_states.type(hidden_states.dtype) + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + topk_indices, topk_weights = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view( + *orig_shape + ) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +class DeepseekV3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + config.hidden_size, self.num_heads * self.qk_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + config.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_value: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + else: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split( + q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split( + k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + cos, sin = position_embeddings + if self.config.rope_interleave: # support using interleaved weights for efficiency + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + if ( + self.config._attn_implementation == "flash_attention_2" + and self.qk_head_dim != self.v_head_dim + ): + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if ( + self.config._attn_implementation == "flash_attention_2" + and self.qk_head_dim != self.v_head_dim + ): + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class DeepseekV3DecoderLayer(LlamaDecoderLayer, nn.Module): + def __init__(self, config: DeepseekV3Config, layer_idx: int): + nn.Module().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) + + if layer_idx >= config.first_k_dense_replace: + self.mlp = DeepseekV3MoE(config) + else: + self.mlp = DeepseekV3MLP(config) + + self.input_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + +class DeepseekV3PreTrainedModel(LlamaPreTrainedModel): + def _init_weights(self, module): + LlamaPreTrainedModel._init_weights(module) + if isinstance(module, DeepseekV3TopkRouter): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + +class DeepseekV3Model(LlamaModel): + _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"] + + +class DeepseekV3ForCausalLM(LlamaForCausalLM): + pass + + +__all__ = [ + "DeepseekV3PreTrainedModel", + "DeepseekV3Model", + "DeepseekV3ForCausalLM", +] diff --git a/torchprime/torch_xla_models/model_rewriting/rematerialization_utils.py b/torchprime/torch_xla_models/model_rewriting/rematerialization_utils.py index eec70965..b22ef394 100644 --- a/torchprime/torch_xla_models/model_rewriting/rematerialization_utils.py +++ b/torchprime/torch_xla_models/model_rewriting/rematerialization_utils.py @@ -56,6 +56,7 @@ def add_activation_checkpointing_and_scan( ) layers_to_scan = remat_config.get("scan_layers", None) offload_tensors = remat_config.get("offload_tensors", []) + start_from_layer = config.model.get("first_k_dense_replace", None) # Checking preconditions and logging. if remat_classes: @@ -80,7 +81,7 @@ def maybe_checkpoint(mod: nn.Module, _name: str) -> nn.Module: return wrap_module(model, maybe_checkpoint) if remat_classes else model if not remat_classes: - return scan_layers.compile(model, layers_to_scan) + return scan_layers.compile(model, layers_to_scan, start_from_layer=start_from_layer) seq = model.get_submodule(layers_to_scan) assert isinstance(seq, HomogeneousSequential) @@ -95,7 +96,9 @@ def maybe_checkpoint(mod: nn.Module, _name: str) -> nn.Module: names_to_offload=offload_tensors, ) ) - return scan_layers.compile(model, layers_to_scan, partition_fn=partition_fn) + return scan_layers.compile( + model, layers_to_scan, partition_fn=partition_fn, start_from_layer=start_from_layer + ) def add_optimization_barriers(model: nn.Module, config: DictConfig) -> nn.Module: diff --git a/torchprime/torch_xla_models/scan_layers.py b/torchprime/torch_xla_models/scan_layers.py index e15fa669..4eefe97c 100644 --- a/torchprime/torch_xla_models/scan_layers.py +++ b/torchprime/torch_xla_models/scan_layers.py @@ -48,13 +48,32 @@ def compile_one_stack( def compile( - mod: nn.Module, sequential_to_scan: str, partition_fn=default_partition + mod: nn.Module, + sequential_to_scan: str, + start_from_layer: int | None = None, + partition_fn=default_partition, ) -> nn.Module: seq = mod.get_submodule(sequential_to_scan) if not isinstance(seq, HomogeneousSequential): raise ValueError(f"compile only supports HomogeneousSequential, got {type(seq)}") # Replace the submodule - mod.set_submodule( - sequential_to_scan, compile_one_stack(seq, partition_fn=partition_fn) - ) + if start_from_layer is None or start_from_layer == 0: + # Whole block is scanned + mod.set_submodule( + sequential_to_scan, compile_one_stack(seq, partition_fn=partition_fn) + ) + else: + # Split: prefix stays, tail gets scanned + prefix_layers = seq[:start_from_layer] + tail_layers = seq[start_from_layer:] + + # Compile the tail + scanned_tail = compile_one_stack( + HomogeneousSequential(*tail_layers), partition_fn=partition_fn + ) + + # Reconstruct full sequence + new_seq = HomogeneousSequential(*prefix_layers, *scanned_tail) + mod.set_submodule(sequential_to_scan, new_seq) + return mod diff --git a/torchprime/torch_xla_models/tests/test_deepseek_v3.py b/torchprime/torch_xla_models/tests/test_deepseek_v3.py new file mode 100644 index 00000000..dbf342c2 --- /dev/null +++ b/torchprime/torch_xla_models/tests/test_deepseek_v3.py @@ -0,0 +1,221 @@ +import copy +from dataclasses import dataclass + +import pytest +import torch +import torch_xla +from omegaconf import OmegaConf +from transformers import AutoConfig +from transformers import DeepseekV3ForCausalLM as HFDeepseekV3ForCausalLM + +from torchprime.torch_xla_models.model.deepseek_v3 import ( + DeepseekV3ForCausalLM, + convert_hf_state_dict_for_grouped_moe, +) + +MOE_START_FROM_LAYER = 2 # layer 0,1 dense layers and layer 2+ moe layers + + +@dataclass +class DeepseekFixture: + vocab_size: int + hf_model: HFDeepseekV3ForCausalLM + model: DeepseekV3ForCausalLM + + +def get_deepseek_v3_dummy() -> DeepseekFixture: + seed = 123 + torch.manual_seed(seed) + torch_xla.manual_seed(seed) + vocab_size = 64 + config = AutoConfig.from_pretrained( + "deepseek-ai/deepseek-v3", + ) + config.vocab_size = vocab_size + config.max_position_embeddings = vocab_size + config.first_k_dense_replace = MOE_START_FROM_LAYER + config.num_hidden_layers = 5 # from 61 + config.n_group = 4 # from 8 + + scale_factor = 32 + config.attention_kernel = "pytorch" + + config.hidden_size //= scale_factor + config.intermediate_size //= scale_factor + config.moe_intermediate_size //= scale_factor + config.num_attention_heads //= scale_factor + config.n_routed_experts //= scale_factor + config.kv_lora_rank //= scale_factor + config.q_lora_rank //= scale_factor + config.qk_rope_head_dim //= scale_factor + config.v_head_dim //= scale_factor + config.qk_nope_head_dim //= scale_factor + config.qk_head_dim //= scale_factor + config.head_dim //= scale_factor + config.num_key_value_heads //= scale_factor + config.capacity_factor = 10.0 + + tp_cfg = OmegaConf.create(config.to_dict()) + with torch.device("cpu"): + hf_model = HFDeepseekV3ForCausalLM(config) + hf_model.init_weights() + hf_dict = hf_model.state_dict() + + model = DeepseekV3ForCausalLM(tp_cfg) + converted_dict = convert_hf_state_dict_for_grouped_moe(hf_dict, model.config) + model.load_state_dict(converted_dict, strict=True) + + return DeepseekFixture(vocab_size, hf_model, model) + + +def noop(mod): + return mod + + +def scan_decoders(mod): + import torchprime.torch_xla_models.scan_layers + + return torchprime.torch_xla_models.scan_layers.compile( + mod, "model.layers", MOE_START_FROM_LAYER + ) + + +@pytest.mark.parametrize("transform", [noop, scan_decoders]) +def test_forward_our_model_against_hf_model(transform): + fixture = get_deepseek_v3_dummy() + device = torch_xla.device() + model_xla = copy.deepcopy(fixture.model).to(device) + model_xla = transform(model_xla) + hf_model_xla = copy.deepcopy(fixture.hf_model).to(device) + torch_xla.sync() + for input_size in [8, 16]: + input_ids = torch.randint(fixture.vocab_size, (2, input_size // 2)).to(device) + hf_output = hf_model_xla( + input_ids, labels=input_ids, attention_mask=torch.ones_like(input_ids) + ) + deepseek_xla_logits, deepseek_xla_loss = model_xla( + input_ids, labels=input_ids, attention_mask=torch.ones_like(input_ids) + ) + torch_xla.sync() + torch.testing.assert_close( + hf_output.logits, + deepseek_xla_logits, + atol=1e-2, + rtol=1e-6, + msg="logits are not equal", + ) + torch.testing.assert_close( + hf_output.loss, + deepseek_xla_loss, + atol=1e-2, + rtol=1e-6, + msg="loss is not equal", + ) + + +@pytest.mark.parametrize("transform", [noop, scan_decoders]) +def test_layers_by_layer_against_hf_model(transform): + fixture = get_deepseek_v3_dummy() + device = torch_xla.device() + model_xla = copy.deepcopy(fixture.model).to(device) + model_xla = transform(model_xla) + hf_model_xla = copy.deepcopy(fixture.hf_model).to(device) + + seq_len = 4 + input_ids = torch.randint(fixture.vocab_size, (2, seq_len)).to(device) + attention_mask = torch.ones_like(input_ids) + + inputs_embeds_xla = model_xla.model.embed_tokens(input_ids) + inputs_embeds_hf = hf_model_xla.model.embed_tokens(input_ids) + torch.testing.assert_close( + inputs_embeds_xla, + inputs_embeds_hf, + atol=1e-2, + rtol=1e-6, + msg="emb layer outputs not equal", + ) + + position_ids = torch.arange(seq_len, device=device).unsqueeze(0).float() + causal_mask = ( + torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1) + .unsqueeze(0) + .unsqueeze(0) + ) + causal_mask = causal_mask * attention_mask[:, None, None, :] + + pos_embeds_xla = model_xla.model.rotary_emb(inputs_embeds_xla, position_ids) + pos_embeds_hf = hf_model_xla.model.rotary_emb(inputs_embeds_hf, position_ids) + torch.testing.assert_close( + pos_embeds_xla[0], + pos_embeds_hf[0], + atol=1e-2, + rtol=1e-6, + msg="rotary_emb layer outputs not equal", + ) + torch.testing.assert_close( + pos_embeds_xla[1], + pos_embeds_hf[1], + atol=1e-2, + rtol=1e-6, + msg="rotary_emb layer outputs not equal", + ) + + hidden_xla = inputs_embeds_xla + hidden_hf = inputs_embeds_hf + for idx, (layer_xla, layer_hf) in enumerate( + zip(model_xla.model.layers, hf_model_xla.model.layers, strict=True) + ): + hidden_xla = layer_xla( + hidden_xla, + attention_mask=causal_mask, + position_ids=position_ids, + position_embeddings=pos_embeds_xla, + ) + hidden_hf = layer_hf( + hidden_hf, + attention_mask=causal_mask, + position_ids=position_ids, + position_embeddings=pos_embeds_hf, + )[0] + torch_xla.sync() + torch.testing.assert_close( + hidden_xla, + hidden_hf, + atol=1e-2, + rtol=1e-6, + msg=f"decoder layer {idx} outputs not equal", + ) + + +def test_forward_torch_xla_against_native_cpu(): + fixture = get_deepseek_v3_dummy() + input_size = 8 + device = torch.device("cpu") + input_ids = torch.randint(fixture.vocab_size, (2, input_size // 2)) + native_logits, native_loss = fixture.model( + input_ids, labels=input_ids, attention_mask=torch.ones_like(input_ids) + ) + + device = torch_xla.device() + input_ids = input_ids.to(device) + model_xla = copy.deepcopy(fixture.model).to(device) + torch_xla.sync() + + xla_logits, xla_loss = model_xla( + input_ids, labels=input_ids, attention_mask=torch.ones_like(input_ids) + ) + torch_xla.sync() + torch.testing.assert_close( + native_logits, + xla_logits.to("cpu"), + atol=1e-2, + rtol=1e-6, + msg="CPU run and XLA run logits are not equal", + ) + torch.testing.assert_close( + native_loss, + xla_loss.to("cpu"), + atol=1e-2, + rtol=1e-6, + msg="CPU run and XLA run loss is not equal", + )