Skip to content

Onboard deepseek v3 #350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
27 changes: 26 additions & 1 deletion .github/workflows/e2e_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
llama-3-8b-ddp-fsdp-name: ${{ steps.run-llama-3-8b-ddp-fsdp.outputs.name }}
llama-3-8b-fsdp-cp-name: ${{ steps.run-llama-3-8b-fsdp-cp.outputs.name }}
mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }}
ds-v3-shallow-name: ${{ steps.run-ds-v3-shallow.outputs.name }}
artifact-dir: ${{ steps.artifacts.outputs.artifact_dir }}
steps:
- name: Record artifact dir
Expand Down Expand Up @@ -286,6 +287,29 @@ jobs:
ici_mesh.fsdp=4 \
profile_start_step=3

- name: Run Deepseek v3 Shallow
id: run-ds-v3-shallow
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py ds-v3-shallow)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
model=deepseek-v3-shallow \
model.attention_kernel=splash_attention \
dataset=wikitext \
dataset.block_size=1024 \
task=train \
task.global_batch_size=4 \
task.lr_scheduler.type=constant \
task.max_steps=15 \
ici_mesh.fsdp=4 \
profile_start_step=3

# Load reference step times
load-benchmarks:
name: Load reference step times
Expand Down Expand Up @@ -336,7 +360,8 @@ jobs:
matrix.config.benchmark == 'llama-3-8b-sft' && needs.tp-run.outputs.llama-3-8b-sft-name ||
matrix.config.benchmark == 'llama-3-8b-2-slice' && needs.tp-run.outputs.llama-3-8b-2-slice-name ||
matrix.config.benchmark == 'llama-3-8b-ddp-fsdp' && needs.tp-run.outputs.llama-3-8b-ddp-fsdp-name ||
matrix.config.benchmark == 'llama-3-8b-fsdp-cp' && needs.tp-run.outputs.llama-3-8b-fsdp-cp-name
matrix.config.benchmark == 'llama-3-8b-fsdp-cp' && needs.tp-run.outputs.llama-3-8b-fsdp-cp-name ||
matrix.config.benchmark == 'ds-v3-shallow' && needs.tp-run.outputs.ds-v3-shallow-name
}}
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
step_time_lower_bound: ${{ matrix.config.lower_bound }}
Expand Down
7 changes: 7 additions & 0 deletions e2e_testing/step_time_bounds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ benchmarks:
confidence_interval: 0.02426
average: 1.6172
sample_size: 51
ds-v3-shallow:
name: Deepseek v3 Shallow # dummy number
step_time_lower_bound: 1.0
step_time_upper_bound: 11.0
confidence_interval: 5.0
average: 6.0
sample_size: 10
metadata:
query_start: '2025-07-01T00:00:00-07:00'
query_end: '2025-07-23T12:20:48-07:00'
Expand Down
11 changes: 11 additions & 0 deletions e2e_testing/update_step_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ def match_llama_3_8b_fsdp_cp(row):
)


def match_ds_v3_debug(row):
config = json.loads(row.configs_framework)
return (
row.run_id.startswith("ds-v3-shallow")
and config["ici_mesh"]["fsdp"] == 4
and config["ici_mesh"]["tensor"] == 1
)


BENCHMARKS = {
"Llama 3.0 8B": match_llama3_8b,
"Llama 3.0 8B (@assume_pure)": match_llama3_8b_pure_mlp,
Expand All @@ -133,6 +142,7 @@ def match_llama_3_8b_fsdp_cp(row):
"Llama 3.0 8B SFT": match_llama_3_8b_sft,
"Llama 3.0 8B (ddp + fsdp)": match_llama_3_8b_ddp_fsdp,
"Llama 3.0 8B (fsdp + cp)": match_llama_3_8b_fsdp_cp,
"Deepseek v3 Debug Model": match_ds_v3_debug,
}

STEP_ID_MAPPING = {
Expand All @@ -146,6 +156,7 @@ def match_llama_3_8b_fsdp_cp(row):
"Llama 3.0 8B SFT": "llama-3-8b-sft",
"Llama 3.0 8B (ddp + fsdp)": "llama-3-8b-ddp-fsdp",
"Llama 3.0 8B (fsdp + cp)": "llama-3-8b-fsdp-cp",
"Deepseek v3 Debug Model": "ds-v3-shallow",
}
"""Mapping from the benchmark name to the ID of the E2E test step used in GitHub Actions."""

Expand Down
134 changes: 133 additions & 1 deletion torchprime/metrics/mfu.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,111 @@ def calculate_tflops_training_per_device(config: Config, log=True):
return total_tflops


# ---------------------------------------------------------------------------
# DeepSeek-v3 FLOPs model (BF16, MLA FFN, gated-MoE, two-stage KV projection)
# ---------------------------------------------------------------------------
def calculate_tflops_training_per_device_deepseek(
*,
per_device_batch_size: int,
seq_len: int,
hidden_size: int,
intermediate_size: int,
moe_intermediate_size: int,
num_hidden_layers: int,
first_k_dense_replace: int,
num_attention_heads: int,
qk_head_dim: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
kv_lora_rank: int,
num_key_value_heads: int,
num_routed_experts: int,
n_shared_experts: int,
num_experts_per_tok: int,
vocab_size: int,
gradient_accumulation_steps: int = 1,
include_softmax: bool = False,
) -> float:
"""
Per-device TFLOPs *per optimizer step* for DeepSeek-v3 training.

Assumptions
-----------
• BF16 / FP16 → 2 FLOPs per MAC
• MLA FFN (3 linears + gating multiply)
• MoE begins after `first_k_dense_replace`
• One shared-expert FFN path in every MoE layer
• Optional soft-max term (set include_softmax=True for >~5 % extra)
"""

# -------------------------------------------------------- constants ----
B, L, H = per_device_batch_size, seq_len, hidden_size
L_dense = first_k_dense_replace
L_moe = num_hidden_layers - L_dense
tokens = B * L
fwd_bwd = 3 # forward + backward factor
BF16 = 2 # FLOPs per MAC in bf16/fp16

# -------------------------------------------------------------- FFNs ---
# Dense MLA FFN (first L_dense layers)
ffn_dense_flops = 3 * H * intermediate_size * BF16 + intermediate_size
ffn_dense_flops *= tokens * L_dense

# Gating linear in every MoE layer
moe_gate_flops = 2 * H * num_routed_experts * tokens * L_moe

# Per-expert MLA FFN (K experts/token)
moe_ffn_tok = 3 * H * moe_intermediate_size * BF16 + moe_intermediate_size
moe_ffn_flops = moe_ffn_tok * tokens * num_experts_per_tok * L_moe

# Shared-expert MLA FFN (runs on *all* tokens in every MoE layer)
M_shared = moe_intermediate_size * n_shared_experts
shared_ffn_tok = 3 * H * M_shared * BF16 + M_shared
shared_ffn_flops = shared_ffn_tok * tokens * L_moe

total_ffn_flops = ffn_dense_flops + moe_gate_flops + moe_ffn_flops + shared_ffn_flops

# ------------------------------------------------------- projections ---
q_proj_flops = 2 * H * num_attention_heads * qk_head_dim * tokens
kv_a_flops = 2 * H * (kv_lora_rank + qk_rope_head_dim) * tokens
kv_b_out_dim = num_attention_heads * (qk_nope_head_dim + v_head_dim)
kv_b_flops = 2 * kv_lora_rank * kv_b_out_dim * tokens
o_proj_flops = 2 * H * num_attention_heads * v_head_dim * tokens

proj_flops_layer = q_proj_flops + kv_a_flops + kv_b_flops + o_proj_flops
proj_flops_total = proj_flops_layer * num_hidden_layers

# ---------------------------------------------------- attention core ---
attn_qk = 2 * num_attention_heads * qk_head_dim * L * L * B
attn_av = 2 * num_attention_heads * v_head_dim * L * L * B
attn_core_layer = attn_qk + attn_av

softmax_flops_layer = 4 * B * L * L * num_attention_heads if include_softmax else 0

attn_core_total = (attn_core_layer + softmax_flops_layer) * num_hidden_layers

# --------------------------------------------- embedding / lm-head ----
embed_flops = 2 * H * vocab_size * tokens # embedding + lm_head

# ------------------------------------------------ aggregate numbers ---
trainable = (total_ffn_flops + proj_flops_total + embed_flops) * fwd_bwd
attention = attn_core_total * fwd_bwd
total = (trainable + attention) * gradient_accumulation_steps
tflops = total / 1e12

# ----------------------------------------------------- quick report ---
print(f"[DeepSeek-v3] TFLOPs/device/step : {tflops:>.2f}")
print(f" • FFNs (dense+MoE+shared) : {total_ffn_flops * fwd_bwd / 1e12:>.2f}")
print(f" • Attn projections : {proj_flops_total * fwd_bwd / 1e12:>.2f}")
print(
f" • Attn QK/AV{' + softmax' if include_softmax else ''} : {attention / 1e12:>.2f}"
)
print(f" • Embed + LM head : {embed_flops * fwd_bwd / 1e12:>.2f}")

return tflops


def compute_mfu(
config: dict,
batch_size: int,
Expand Down Expand Up @@ -180,9 +285,36 @@ def compute_mfu(
vocab_size=int(config["vocab_size"]),
gradient_accumulation_steps=gradient_accumulation_steps,
),
log=False,
log=True,
)

try:
total_tflops_deepseek = calculate_tflops_training_per_device_deepseek(
per_device_batch_size=batch_size,
seq_len=sequence_length,
hidden_size=int(config["hidden_size"]),
intermediate_size=int(config["intermediate_size"]),
moe_intermediate_size=int(config["moe_intermediate_size"]),
num_hidden_layers=int(config["num_hidden_layers"]),
first_k_dense_replace=int(config["first_k_dense_replace"]),
num_attention_heads=int(config["num_attention_heads"]),
qk_head_dim=int(config["qk_head_dim"]),
qk_nope_head_dim=int(config["qk_nope_head_dim"]),
qk_rope_head_dim=int(config["qk_rope_head_dim"]),
v_head_dim=int(config["v_head_dim"]),
kv_lora_rank=int(config["kv_lora_rank"]),
num_key_value_heads=int(config["num_key_value_heads"]),
num_routed_experts=int(config["n_routed_experts"]),
n_shared_experts=int(config["n_shared_experts"]),
num_experts_per_tok=int(config["num_experts_per_tok"]),
vocab_size=int(config["vocab_size"]),
gradient_accumulation_steps=1,
include_softmax=True,
)
total_tflops = total_tflops_deepseek
except Exception as e:
print(f"Error occurred while calculating TFLOPs: {e}")

assert torch_dtype == "bfloat16", f"Unsupported dtype {torch_dtype}"

chip_count_per_slice, tflops_per_chip = get_num_chips_and_tflops_per_chip(tpu_name)
Expand Down
1 change: 1 addition & 0 deletions torchprime/metrics/step_duration.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def analyze_step_duration_from_pb(xspace: XSpace) -> float:
# Confirm we have exactly one unique event name
if len(unique_names) > 1:
raise ValueError(f"Ambiguous event names found in XSpace: {unique_names}")
# print(f"Ambiguous event names found in XSpace: {unique_names}")

inferred_event_name = max(unique_names)

Expand Down
103 changes: 103 additions & 0 deletions torchprime/rope/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass

import torch
from omegaconf import DictConfig


@dataclass(kw_only=True)
Expand Down Expand Up @@ -72,3 +73,105 @@ def llama3_rope_frequencies(
freqs = torch.where(is_medium_freq, smoothed_freqs, freqs)

return freqs


def deepseek_v3_rope_init_fn(config: DictConfig) -> tuple["torch.Tensor", float]:
"""
copied from HF implementation `_compute_yarn_parameters` function, from
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C5-L197C29

Computes the inverse frequencies with NTK scaling. Please refer to the
[original paper](https://huggingface.co/papers/2309.00071)
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin.
"""

assert hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, DictConfig)
assert config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) == "yarn"
base = config.rope_theta
partial_rotary_factor = (
config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
)
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
dim = int(head_dim * partial_rotary_factor)
factor = config.rope_scaling["factor"]
attention_factor = config.rope_scaling.get("attention_factor")
mscale = config.rope_scaling.get("mscale")
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")

# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if "original_max_position_embeddings" in config.rope_scaling:
original_max_position_embeddings = config.rope_scaling[
"original_max_position_embeddings"
]
factor = config.max_position_embeddings / original_max_position_embeddings
else:
original_max_position_embeddings = config.max_position_embeddings

def get_mscale(scale, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0

# Sets the attention factor as suggested in the paper
if attention_factor is None:
if mscale and mscale_all_dim:
attention_factor = float(
get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)
)
else:
attention_factor = get_mscale(factor)

# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
beta_fast = config.rope_scaling.get("beta_fast") or 32
beta_slow = config.rope_scaling.get("beta_slow") or 1

# Compute the inverse frequencies
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
"""Inverse dimension formula to find the dimension based on the number of rotations"""
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)

def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
"""Find dimension range bounds based on rotations"""
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1)

def linear_ramp_factor(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity

linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func

# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs = base ** (torch.arange(0, dim, 2).to(dtype=torch.float) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)

low, high = find_correction_range(
beta_fast, beta_slow, dim, base, original_max_position_embeddings
)

# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(
dtype=torch.float
)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
)
return inv_freq, attention_factor
Loading
Loading