From 77e406c9b9d978e8dfa1bc667a00742fb36a0001 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Thu, 6 Feb 2025 22:53:24 +0000 Subject: [PATCH 1/6] 16B deepseek running on v6e --- .../torchax_models/deepseek_v3/model.py | 255 ++++++++++++------ .../deepseek_v3/prefill_benchmark.py | 215 +++++++++++---- .../deepseek_v3/tests/test_prefill.py | 46 ++-- 3 files changed, 356 insertions(+), 160 deletions(-) diff --git a/torchprime/experimental/torchax_models/deepseek_v3/model.py b/torchprime/experimental/torchax_models/deepseek_v3/model.py index 57248bd4..1cf787c8 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/model.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/model.py @@ -2,8 +2,8 @@ from dataclasses import dataclass from typing import Literal +import jax import torch -import torch.distributed as dist import torch.nn.functional as F from torch import nn @@ -382,7 +382,7 @@ def __init__(self, args: ModelArgs): def forward( self, x: torch.Tensor, - start_pos: int, + input_pos: torch.Tensor, freqs_cis: torch.Tensor, mask: torch.Tensor | None, ): @@ -399,7 +399,6 @@ def forward( torch.Tensor: Output tensor with the same shape as the input. """ bsz, seqlen, _ = x.size() - end_pos = start_pos + seqlen q = self.wq(x) if self.q_lora_rank == 0 else self.wq_b(self.q_norm(self.wq_a(x))) q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) q_nope, q_pe = torch.split( @@ -417,12 +416,9 @@ def forward( ) k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) - self.k_cache[:bsz, start_pos:end_pos] = k - self.v_cache[:bsz, start_pos:end_pos] = v - scores = ( - torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) - * self.softmax_scale - ) + # self.k_cache[:bsz, start_pos:end_pos] = k + # self.v_cache[:bsz, start_pos:end_pos] = v + scores = torch.einsum("bshd,bthd->bsht", q, k) * self.softmax_scale else: wkv_b = ( self.wkv_b.weight @@ -431,19 +427,22 @@ def forward( ) wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, : self.qk_nope_head_dim]) - self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) - self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) + # self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) + # self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) + kv_cache = self.kv_norm(kv) + pe_cache = k_pe.squeeze(2) scores = ( - torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) - + torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos]) + torch.einsum("bshc,btc->bsht", q_nope, kv_cache) + + torch.einsum("bshr,btr->bsht", q_pe, pe_cache) ) * self.softmax_scale if mask is not None: scores += mask.unsqueeze(1) scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) if attn_impl == "naive": - x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) + x = torch.einsum("bsht,bthd->bshd", scores, v) else: - x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) + kv_cache = self.kv_norm(kv) + x = torch.einsum("bsht,btc->bshc", scores, kv_cache) x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim :]) x = self.wo(x.flatten(2)) return x @@ -590,71 +589,154 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) -class MoE(nn.Module): - """ - Mixture-of-Experts (MoE) module. - - Attributes: - dim (int): Dimensionality of input features. - n_routed_experts (int): Total number of experts in the model. - n_local_experts (int): Number of experts handled locally in distributed systems. - n_activated_experts (int): Number of experts activated for each input. - gate (nn.Module): Gating mechanism to route inputs to experts. - experts (nn.ModuleList): List of expert modules. - shared_experts (nn.Module): Shared experts applied to all inputs. - """ - - def __init__(self, args: ModelArgs): - """ - Initializes the MoE module. - - Args: - args (ModelArgs): Model arguments containing MoE parameters. - """ +# class MoE(nn.Module): +# """ +# Mixture-of-Experts (MoE) module. + +# Attributes: +# dim (int): Dimensionality of input features. +# n_routed_experts (int): Total number of experts in the model. +# n_local_experts (int): Number of experts handled locally in distributed systems. +# n_activated_experts (int): Number of experts activated for each input. +# gate (nn.Module): Gating mechanism to route inputs to experts. +# experts (nn.ModuleList): List of expert modules. +# shared_experts (nn.Module): Shared experts applied to all inputs. +# """ + +# def __init__(self, args: ModelArgs): +# """ +# Initializes the MoE module. + +# Args: +# args (ModelArgs): Model arguments containing MoE parameters. +# """ +# super().__init__() +# self.dim = args.dim +# assert args.n_routed_experts % world_size == 0 +# self.n_routed_experts = args.n_routed_experts +# self.n_local_experts = args.n_routed_experts // world_size +# self.n_activated_experts = args.n_activated_experts +# self.experts_start_idx = rank * self.n_local_experts +# self.experts_end_idx = self.experts_start_idx + self.n_local_experts +# self.gate = Gate(args) +# self.experts = nn.ModuleList( +# [ +# Expert(args.dim, args.moe_inter_dim) +# if self.experts_start_idx <= i < self.experts_end_idx +# else None +# for i in range(self.n_routed_experts) +# ] +# ) +# self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) + +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# """ +# Forward pass for the MoE module. + +# Args: +# x (torch.Tensor): Input tensor. + + +# Returns: +# torch.Tensor: Output tensor after expert routing and computation. +# """ +# shape = x.size() +# x = x.view(-1, self.dim) +# weights, indices = self.gate(x) +# y = torch.zeros_like(x) +# # counts = torch.bincount( +# # indices.flatten(), +# # minlength=self.n_routed_experts).tolist() +# # NOTE: we actually know exact lenght of counts here, +# # however torch.bincount does not take length as an args: +# counts = tx.interop.call_jax( +# jnp.bincount, +# indices.flatten(), +# minlength=self.n_routed_experts, +# length=self.n_routed_experts) +# for i in range(self.experts_start_idx, self.experts_end_idx): +# # if counts[i] == 0: +# # continue +# expert = self.experts[i] +# idx, top = torch.where(indices == i) +# y[idx] += expert(x[idx]) * weights[idx, top, None] +# z = self.shared_experts(x) +# if world_size > 1: +# dist.all_reduce(y) +# return (y + z).view(shape) +class ConditionalFeedForward(torch.nn.Module): + def __init__(self, config): super().__init__() - self.dim = args.dim - assert args.n_routed_experts % world_size == 0 - self.n_routed_experts = args.n_routed_experts - self.n_local_experts = args.n_routed_experts // world_size - self.n_activated_experts = args.n_activated_experts - self.experts_start_idx = rank * self.n_local_experts - self.experts_end_idx = self.experts_start_idx + self.n_local_experts - self.gate = Gate(args) - self.experts = nn.ModuleList( - [ - Expert(args.dim, args.moe_inter_dim) - if self.experts_start_idx <= i < self.experts_end_idx - else None - for i in range(self.n_routed_experts) - ] + # TODO(How to enable quantization?) + self.w1 = nn.Parameter( + torch.empty(config.n_routed_experts, config.moe_inter_dim, config.dim) + ) + self.w2 = nn.Parameter( + torch.empty(config.n_routed_experts, config.dim, config.moe_inter_dim) + ) + self.w3 = nn.Parameter( + torch.empty(config.n_routed_experts, config.moe_inter_dim, config.dim) ) - self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) + self.config = config + + def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor: + return self.forward_for_long_seq_len(x, expert_indices) + + def forward_for_long_seq_len(self, x, expert_indices): + seqlen = x.shape[0] + num_experts = self.w1.shape[0] + + # e = total num of exp = 8 + # t = seqlen + # o = config.imtermediate size + # i = config.dim + with jax.named_scope("conditional_ff"): + x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1)) + x3 = torch.einsum("ti, eoi-> teo", x, self.w3) + expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) + # e = 8; need to reduce to 2 + seq_indexes = torch.arange(seqlen, device=x.device).unsqueeze(1) + return expert_outs[seq_indexes, expert_indices] + + +class MoE(torch.nn.Module): + def __init__(self, model_args) -> None: + super().__init__() + self.dim = model_args.dim + self.model_args = model_args + # assert args.n_routed_experts % world_size == 0 + # self.n_routed_experts = args.n_routed_experts + # self.n_local_experts = args.n_routed_experts // world_size + # self.n_activated_experts = args.n_activated_experts + # self.experts_start_idx = rank * self.n_local_experts + # self.experts_end_idx = self.experts_start_idx + self.n_local_experts + self.gate = Gate(model_args) + # self.experts = nn.ModuleList( + # [ + # Expert(args.dim, args.moe_inter_dim) + # if self.experts_start_idx <= i < self.experts_end_idx + # else None + # for i in range(self.n_routed_experts) + # ] + # ) + self.shared_experts = MLP( + model_args.dim, model_args.n_shared_experts * model_args.moe_inter_dim + ) + self.cond_ffn = ConditionalFeedForward(model_args) def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass for the MoE module. - - Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output tensor after expert routing and computation. - """ - shape = x.size() + bsz, seq, hidden = x.shape + # [B, T, D], combine BT, for prefill B = 1, for decode, T = 1 x = x.view(-1, self.dim) + # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts + # x: [T, D] + scores = self.gate(x) # [T, E] weights, indices = self.gate(x) - y = torch.zeros_like(x) - counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() - for i in range(self.experts_start_idx, self.experts_end_idx): - if counts[i] == 0: - continue - expert = self.experts[i] - idx, top = torch.where(indices == i) - y[idx] += expert(x[idx]) * weights[idx, top, None] - z = self.shared_experts(x) - if world_size > 1: - dist.all_reduce(y) - return (y + z).view(shape) + expert_outs = self.cond_ffn(x, indices) + expert_outs = torch.einsum("tai,ta -> ti", expert_outs, weights) + # Changes back to [B, T, D] + expert_outs = expert_outs.reshape(bsz, seq, hidden) + return expert_outs class Block(nn.Module): @@ -687,7 +769,7 @@ def __init__(self, layer_id: int, args: ModelArgs): def forward( self, x: torch.Tensor, - start_pos: int, + input_pos: torch.Tensor, freqs_cis: torch.Tensor, mask: torch.Tensor | None, ) -> torch.Tensor: @@ -703,7 +785,7 @@ def forward( Returns: torch.Tensor: Output tensor after block computation. """ - x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask) + x = x + self.attn(self.attn_norm(x), input_pos, freqs_cis, mask) x = x + self.ffn(self.ffn_norm(x)) return x @@ -728,9 +810,8 @@ def __init__(self, args: ModelArgs): Args: args (ModelArgs): Model arguments containing transformer parameters. """ - global world_size, rank - world_size = dist.get_world_size() if dist.is_initialized() else 1 - rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = 1 + rank = 0 Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 super().__init__() self.max_seq_len = args.max_seq_len @@ -742,10 +823,10 @@ def __init__(self, args: ModelArgs): self.head = ColumnParallelLinear( args.dim, args.vocab_size, dtype=torch.get_default_dtype() ) - self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) + self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=True) @torch.inference_mode() - def forward(self, tokens: torch.Tensor, start_pos: int = 0): + def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor): """ Forward pass for the Transformer model. @@ -758,16 +839,12 @@ def forward(self, tokens: torch.Tensor, start_pos: int = 0): """ seqlen = tokens.size(1) h = self.embed(tokens) - freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + freqs_cis = self.freqs_cis[input_pos] mask = None - if seqlen > 1: - mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) + # if seqlen > 1: + # mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) for layer in self.layers: - h = layer(h, start_pos, freqs_cis, mask) + h = layer(h, input_pos, freqs_cis, mask) h = self.norm(h)[:, -1] logits = self.head(h) - if world_size > 1: - all_logits = [torch.empty_like(logits) for _ in range(world_size)] - dist.all_gather(all_logits, logits) - logits = torch.cat(all_logits, dim=-1) return logits diff --git a/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py b/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py index 110c69a1..66b7e140 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py @@ -1,71 +1,187 @@ import functools +import json import time import jax +import jax.numpy as jnp +import model as ds_model import torch import torchax import torchax.interop +import torchax.ops.mappings as tx_mappings +from jax.experimental.mesh_utils import create_device_mesh +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from model import ModelArgs, Transformer +from torchax import interop from torchax.interop import JittableModule -from .model import ( - ModelArgs, - Transformer, -) - -def single_device_compile(): - print("======= single_device_compile =======") +def _process_sharding_name(name): + """Replace integers in param name with *. + + Presumably all layers should have the same sharding. + """ + + def is_integer(t): + try: + int(t) + return True + # pylint: disable-next=all + except: # noqa: E722 + return False + + tokens = name.split(".") + for i, t in enumerate(tokens): + if is_integer(t): + tokens[i] = "*" + return ".".join(tokens) + + +def _get_sharding_sepc(sharding_map, name): + sharding_spec = sharding_map.get(name) + if sharding_spec is not None: + return sharding_spec + sharding_spec = sharding_map.get(_process_sharding_name(name)) + return sharding_spec + + +def make_weight_shard(weight_meta, slice_index): + weight_shard_meta = weight_meta[slice_index] + with torchax.default_env(): + return interop.jax_view( + torch.randn(weight_shard_meta.shape, dtype=weight_shard_meta.dtype) + ) + + +def make_cache_shard(weight_meta, slice_index): + weight_shard_meta = weight_meta[slice_index] + return jnp.zeros( + weight_shard_meta.shape, dtype=tx_mappings.t2j_dtype(weight_shard_meta.dtype) + ) + + +def create_sharded_weights(model, mesh, sharding_map, env): + res = {} + for name, weight_meta in model.state_dict().items(): + sharding_spec = _get_sharding_sepc(sharding_map, name) + if sharding_spec is None: + print("Skipping weight:", name) + continue + sharding = NamedSharding(mesh, P(*sharding_spec)) + res[name] = env.j2t_iso( + jax.make_array_from_callback( + weight_meta.shape, sharding, functools.partial(make_weight_shard, weight_meta) + ) + ) + return res + + +def create_sharded_kv_cache(cache_dict, mesh, env): + res = {} + # shard at num device + sharding = NamedSharding(mesh, P(None, None, name0, None)) + for name, weight_meta in cache_dict.items(): + if name.endswith("_cache"): + res[name] = env.j2t_iso( + jax.make_array_from_callback( + weight_meta.shape, sharding, functools.partial(make_cache_shard, weight_meta) + ) + ) + return res + + +name0 = "tp0" +# name1 = "tp1" +sharding_map_1d_tp = { + "embed.weight": (name0, None), + "layers.*.attn.wq.weight": (None, name0), + "layers.*.attn.wq.bias": (name0,), + "layers.*.attn.wkv_a.weight": (None, name0), + "layers.*.attn.kv_norm.weight": (name0,), + "layers.*.attn.wkv_b.weight": (name0, None), + "layers.*.attn.wkv_b.bias": (name0,), + "layers.*.attn.wo.weight": (name0, None), + "layers.*.attn.wo.bias": (name0, None), + "layers.0.ffn.w1.weight": (name0, None), + "layers.0.ffn.w1.bias": (name0,), + "layers.0.ffn.w2.weight": (None, name0), + "layers.0.ffn.w2.bias": (name0,), + "layers.0.ffn.w3.weight": (name0, None), + "layers.0.ffn.w3.bias": (name0,), + "layers.*.ffn.cond_ffn.w1": (None, name0, None), + "layers.*.ffn.cond_ffn.w2": (None, None, name0), + "layers.*.ffn.cond_ffn.w3": (None, name0, None), + "layers.*.ffn.gate.weight": (None, name0), + "layers.*.ffn.gate.bias": (name0,), + "layers.*.ffn.shared_experts.w1.weight": (name0, None), + "layers.*.ffn.shared_experts.w1.bias": (name0,), + "layers.*.ffn.shared_experts.w2.weight": (None, name0), + "layers.*.ffn.shared_experts.w2.bias": (name0,), + "layers.*.ffn.shared_experts.w3.weight": (name0, None), + "layers.*.ffn.shared_experts.w3.bias": (name0,), + "layers.*.attn_norm.weight": (name0,), + "layers.*.ffn_norm.weight": (name0,), + "norm.weight": (name0,), + "head.weight": (name0, None), + "head.bias": (name0,), + "freqs_cis": (), +} + + +def _replicate(x, env, mesh): + with jax.default_device(jax.devices("cpu")[0]): + xj = env.to_xla(x).jax() + xj = env.j2t_iso( + jax.make_array_from_callback(xj.shape, NamedSharding(mesh, P()), lambda a: xj) + ) + return xj + + +def main(config=None, seqlen=2048, batch_size=1): + config_dict = None + if config is not None: + with open(config) as f: + config_dict = json.load(f) + + print("======= multi_device =======") torch.set_default_dtype(torch.bfloat16) env = torchax.default_env() + config_dict = config_dict or {} + + env.config.use_torch_native_for_cpu_tensor = False + torch.manual_seed(42) torchax.enable_performance_mode() + torchax.enable_globally() + args = ModelArgs(**config_dict) - args = ModelArgs() + dev_array = create_device_mesh((len(jax.devices()),), allow_split_physical_axes=True) + mesh = Mesh(dev_array, (name0,)) - with torch.no_grad(), env: - x = torch.randint(0, args.vocab_size, (1, 2048)) - x = x.to("jax") + torch.set_default_device("meta") + with env, torch.device("meta"): model = Transformer(args) - model.to("jax") - model.embed = JittableModule(model.embed) - # for i in range(len(model.layers)): - # model.layers[i] = JittableModule(model.layers[i]) - model.norm = JittableModule(model.norm) - model.head = JittableModule(model.head) - for i in range(5): - step_start = time.perf_counter() - logits = model(x, 0) - jax.block_until_ready(torchax.tensor.t2j(logits)) - step_end = time.perf_counter() - print( - i, - "step latency: ", - step_end - step_start, - ) + jitted = JittableModule(model) + freqs_cis = ds_model.precompute_freqs_cis(args) + freqs_cis = _replicate(freqs_cis, env, mesh) + jitted.buffers["freqs_cis"] = freqs_cis + print(model) + caches_dict = create_sharded_kv_cache(jitted.buffers, mesh, env) + sharded_weights = create_sharded_weights(model, mesh, sharding_map_1d_tp, env) -def single_device_eager(): - print("======= single_device_eager =======") - torch.set_default_dtype(torch.bfloat16) - env = torchax.default_env() - torch.manual_seed(42) - torchax.enable_performance_mode() - - args = ModelArgs() - - with torch.no_grad(), env: - x = torch.randint(0, args.vocab_size, (1, 2048)) - x = x.to("jax") - model = Transformer(args) - model.to("jax") - weights = model.state_dict() - model_forward = functools.partial(torch.func.functional_call, model) - # model_forward = torchax.interop.jax_jit(model_forward) + jitted.params = sharded_weights + jitted.buffers.update(caches_dict) + with mesh: + x = torch.randint(0, args.vocab_size, (1, seqlen)) + x = _replicate(x, env, mesh) + input_pos = torch.arange(seqlen, device="jax") for i in range(5): step_start = time.perf_counter() - logits = model_forward(weights, (x, 0)) + logits = jitted(x, input_pos) jax.block_until_ready(torchax.tensor.t2j(logits)) step_end = time.perf_counter() print( @@ -75,15 +191,6 @@ def single_device_eager(): ) -def main(option="single_device_eager"): - if option == "single_device_eager": - single_device_eager() - elif option == "single_device_compile": - single_device_compile() - else: - raise Exception("Invalid option") - - if __name__ == "__main__": import fire diff --git a/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py b/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py index ac53ef8d..4ddeac21 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py @@ -1,23 +1,35 @@ -import pytest +import unittest +import torch +import torchax +import torchax.interop +from torchprime.experimental.torchax_models.deepseek import model as ds_model -# TODO(https://github.com/AI-Hypercomputer/torchprime/issues/75): Fix the failure on torch 2.6, -# then enable the test unconditionally. -@pytest.mark.deepseek -def test_single_device_compile(): - from torchprime.experimental.torchax_models.deepseek_v3.prefill_benchmark import ( - single_device_compile, - ) - single_device_compile() +class DeepseekModuleTest(unittest.TestCase): + def setUp(self): + torchax.enable_globally() -# TODO(https://github.com/AI-Hypercomputer/torchprime/issues/75): Fix the failure on torch 2.6, -# then enable the test unconditionally. -@pytest.mark.deepseek -def test_single_device_eager(): - from torchprime.experimental.torchax_models.deepseek_v3.prefill_benchmark import ( - single_device_eager, - ) + def tearDown(self): + torchax.disable_globally() - single_device_eager() + def test_moe_can_jit(self): + torch.manual_seed(42) + max_seq_len = 512 # 8192 + vocab_size = 128 # 32000 + n_layer = 1 + n_heads = 4 + dim = 8 + block_size = 16 # 2048 + with torch.no_grad(): + x = torch.ones((1, max_seq_len, 2048), dtype=torch.float32, device='jax') + model_args = ds_model.ModelArgs() + model = ds_model.MoE(model_args).to('jax') + + jitted = torchax.interop.JittableModule(model) + print(jitted(x)) + + +if __name__ == '__main__': + unittest.main() From 1598ab8a860ae1d92a3ed7fa915b8cb2bc72ab8e Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 7 Feb 2025 00:21:42 +0000 Subject: [PATCH 2/6] Use pytest format for testing --- .../deepseek_v3/tests/test_prefill.py | 51 ++++++++----------- 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py b/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py index 4ddeac21..cd2e096a 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py @@ -1,35 +1,28 @@ -import unittest +import pytest import torch import torchax import torchax.interop -from torchprime.experimental.torchax_models.deepseek import model as ds_model - - -class DeepseekModuleTest(unittest.TestCase): - - def setUp(self): - torchax.enable_globally() - - def tearDown(self): - torchax.disable_globally() - - def test_moe_can_jit(self): - torch.manual_seed(42) - max_seq_len = 512 # 8192 - vocab_size = 128 # 32000 - n_layer = 1 - n_heads = 4 - dim = 8 - block_size = 16 # 2048 - with torch.no_grad(): - x = torch.ones((1, max_seq_len, 2048), dtype=torch.float32, device='jax') - model_args = ds_model.ModelArgs() - model = ds_model.MoE(model_args).to('jax') - - jitted = torchax.interop.JittableModule(model) - print(jitted(x)) +from torchprime.experimental.torchax_models.deepseek_v3 import model as ds_model + + +@pytest.mark.deepseek +def test_moe_can_jit(): + torchax.enable_globally() + torch.manual_seed(42) + max_seq_len = 512 # 8192 + vocab_size = 128 # 32000 + n_layer = 1 + n_heads = 4 + dim = 8 + block_size = 16 # 2048 + with torch.no_grad(): + x = torch.ones((1, max_seq_len, 2048), dtype=torch.float32, device='jax') + model_args = ds_model.ModelArgs() + model = ds_model.MoE(model_args).to('jax') + + jitted = torchax.interop.JittableModule(model) + print(jitted(x)) + torchax.disable_globally() -if __name__ == '__main__': - unittest.main() From c134f9ffcc768f0836856fccff1a679658a7cb81 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 7 Feb 2025 00:26:29 +0000 Subject: [PATCH 3/6] remove old moe --- .../torchax_models/deepseek_v3/model.py | 75 ------------------- .../deepseek_v3/tests/test_prefill.py | 8 +- 2 files changed, 4 insertions(+), 79 deletions(-) diff --git a/torchprime/experimental/torchax_models/deepseek_v3/model.py b/torchprime/experimental/torchax_models/deepseek_v3/model.py index 1cf787c8..27475578 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/model.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/model.py @@ -589,81 +589,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) -# class MoE(nn.Module): -# """ -# Mixture-of-Experts (MoE) module. - -# Attributes: -# dim (int): Dimensionality of input features. -# n_routed_experts (int): Total number of experts in the model. -# n_local_experts (int): Number of experts handled locally in distributed systems. -# n_activated_experts (int): Number of experts activated for each input. -# gate (nn.Module): Gating mechanism to route inputs to experts. -# experts (nn.ModuleList): List of expert modules. -# shared_experts (nn.Module): Shared experts applied to all inputs. -# """ - -# def __init__(self, args: ModelArgs): -# """ -# Initializes the MoE module. - -# Args: -# args (ModelArgs): Model arguments containing MoE parameters. -# """ -# super().__init__() -# self.dim = args.dim -# assert args.n_routed_experts % world_size == 0 -# self.n_routed_experts = args.n_routed_experts -# self.n_local_experts = args.n_routed_experts // world_size -# self.n_activated_experts = args.n_activated_experts -# self.experts_start_idx = rank * self.n_local_experts -# self.experts_end_idx = self.experts_start_idx + self.n_local_experts -# self.gate = Gate(args) -# self.experts = nn.ModuleList( -# [ -# Expert(args.dim, args.moe_inter_dim) -# if self.experts_start_idx <= i < self.experts_end_idx -# else None -# for i in range(self.n_routed_experts) -# ] -# ) -# self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) - -# def forward(self, x: torch.Tensor) -> torch.Tensor: -# """ -# Forward pass for the MoE module. - -# Args: -# x (torch.Tensor): Input tensor. - - -# Returns: -# torch.Tensor: Output tensor after expert routing and computation. -# """ -# shape = x.size() -# x = x.view(-1, self.dim) -# weights, indices = self.gate(x) -# y = torch.zeros_like(x) -# # counts = torch.bincount( -# # indices.flatten(), -# # minlength=self.n_routed_experts).tolist() -# # NOTE: we actually know exact lenght of counts here, -# # however torch.bincount does not take length as an args: -# counts = tx.interop.call_jax( -# jnp.bincount, -# indices.flatten(), -# minlength=self.n_routed_experts, -# length=self.n_routed_experts) -# for i in range(self.experts_start_idx, self.experts_end_idx): -# # if counts[i] == 0: -# # continue -# expert = self.experts[i] -# idx, top = torch.where(indices == i) -# y[idx] += expert(x[idx]) * weights[idx, top, None] -# z = self.shared_experts(x) -# if world_size > 1: -# dist.all_reduce(y) -# return (y + z).view(shape) class ConditionalFeedForward(torch.nn.Module): def __init__(self, config): super().__init__() diff --git a/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py b/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py index cd2e096a..6708d277 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py @@ -1,13 +1,13 @@ import pytest -import torch -import torchax -import torchax.interop -from torchprime.experimental.torchax_models.deepseek_v3 import model as ds_model @pytest.mark.deepseek def test_moe_can_jit(): + import torch + import torchax + import torchax.interop + from torchprime.experimental.torchax_models.deepseek_v3 import model as ds_model torchax.enable_globally() torch.manual_seed(42) max_seq_len = 512 # 8192 From 55c8a49e4d8fcdf3bce5abe4a5d17d84fe152b0b Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 7 Feb 2025 03:03:43 +0000 Subject: [PATCH 4/6] ruff --- .../torchax_models/deepseek_v3/model.py | 8 +++----- .../deepseek_v3/tests/test_prefill.py | 14 ++++---------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/torchprime/experimental/torchax_models/deepseek_v3/model.py b/torchprime/experimental/torchax_models/deepseek_v3/model.py index 27475578..bb97585c 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/model.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/model.py @@ -609,7 +609,7 @@ def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor def forward_for_long_seq_len(self, x, expert_indices): seqlen = x.shape[0] - num_experts = self.w1.shape[0] + self.w1.shape[0] # e = total num of exp = 8 # t = seqlen @@ -655,7 +655,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.view(-1, self.dim) # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts # x: [T, D] - scores = self.gate(x) # [T, E] + self.gate(x) # [T, E] weights, indices = self.gate(x) expert_outs = self.cond_ffn(x, indices) expert_outs = torch.einsum("tai,ta -> ti", expert_outs, weights) @@ -735,8 +735,6 @@ def __init__(self, args: ModelArgs): Args: args (ModelArgs): Model arguments containing transformer parameters. """ - world_size = 1 - rank = 0 Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 super().__init__() self.max_seq_len = args.max_seq_len @@ -762,7 +760,7 @@ def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor): Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ - seqlen = tokens.size(1) + tokens.size(1) h = self.embed(tokens) freqs_cis = self.freqs_cis[input_pos] mask = None diff --git a/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py b/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py index 6708d277..abaebc51 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/tests/test_prefill.py @@ -1,28 +1,22 @@ import pytest - @pytest.mark.deepseek def test_moe_can_jit(): import torch import torchax import torchax.interop + from torchprime.experimental.torchax_models.deepseek_v3 import model as ds_model + torchax.enable_globally() torch.manual_seed(42) max_seq_len = 512 # 8192 - vocab_size = 128 # 32000 - n_layer = 1 - n_heads = 4 - dim = 8 - block_size = 16 # 2048 with torch.no_grad(): - x = torch.ones((1, max_seq_len, 2048), dtype=torch.float32, device='jax') + x = torch.ones((1, max_seq_len, 2048), dtype=torch.float32, device="jax") model_args = ds_model.ModelArgs() - model = ds_model.MoE(model_args).to('jax') + model = ds_model.MoE(model_args).to("jax") jitted = torchax.interop.JittableModule(model) print(jitted(x)) torchax.disable_globally() - - From a72e579340fcc0b9f9ce182108b97d0ff2f36032 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 7 Feb 2025 22:50:17 +0000 Subject: [PATCH 5/6] Add decode benchmark --- .../deepseek_v3/prefill_benchmark.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py b/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py index 66b7e140..f8e427c0 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py @@ -155,6 +155,7 @@ def main(config=None, seqlen=2048, batch_size=1): torchax.enable_performance_mode() torchax.enable_globally() args = ModelArgs(**config_dict) + args.max_batch_size = 1 dev_array = create_device_mesh((len(jax.devices()),), allow_split_physical_axes=True) mesh = Mesh(dev_array, (name0,)) @@ -190,6 +191,20 @@ def main(config=None, seqlen=2048, batch_size=1): step_end - step_start, ) + x = torch.randint(0, args.vocab_size, (1, 1)) + x = _replicate(x, env, mesh) + input_pos = torch.arange(2048, 2049, device="jax") + for i in range(5): + step_start = time.perf_counter() + logits = jitted(x, input_pos) + jax.block_until_ready(torchax.tensor.t2j(logits)) + step_end = time.perf_counter() + print( + i, + "decode step latency: ", + step_end - step_start, + ) + if __name__ == "__main__": import fire From 267e0e868c23f3b7752f6781266c25f09e578d43 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Mon, 10 Feb 2025 22:03:23 +0000 Subject: [PATCH 6/6] Update the rest, R1 runs on v5p-128 --- pyproject.toml | 1 + .../deepseek_v3/configs/config_671B.json | 2 +- .../torchax_models/deepseek_v3/model.py | 1 + .../deepseek_v3/prefill_benchmark.py | 29 ++++++++++++------- torchprime/launcher/Dockerfile | 3 ++ torchprime/launcher/buildpush.py | 2 +- 6 files changed, 25 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 667072cb..f9a4bc34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "tensorboard-plugin-profile==2.18.0", "tf_keras==2.18.0", "protobuf==4.25.5", + "fire", ] [project.optional-dependencies] diff --git a/torchprime/experimental/torchax_models/deepseek_v3/configs/config_671B.json b/torchprime/experimental/torchax_models/deepseek_v3/configs/config_671B.json index 48b5c719..38c76296 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/configs/config_671B.json +++ b/torchprime/experimental/torchax_models/deepseek_v3/configs/config_671B.json @@ -18,5 +18,5 @@ "qk_nope_head_dim": 128, "qk_rope_head_dim": 64, "v_head_dim": 128, - "dtype": "fp8" + "dtype": "bfloat16" } \ No newline at end of file diff --git a/torchprime/experimental/torchax_models/deepseek_v3/model.py b/torchprime/experimental/torchax_models/deepseek_v3/model.py index bb97585c..eb0b264b 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/model.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/model.py @@ -543,6 +543,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: else: group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) indices = group_scores.topk(self.topk_groups, dim=-1)[1] + print('i am here') mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True) scores = (scores * mask.unsqueeze(-1)).flatten(1) indices = torch.topk(scores, self.topk, dim=-1)[1] diff --git a/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py b/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py index f8e427c0..690c5fba 100644 --- a/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py +++ b/torchprime/experimental/torchax_models/deepseek_v3/prefill_benchmark.py @@ -66,7 +66,7 @@ def create_sharded_weights(model, mesh, sharding_map, env): for name, weight_meta in model.state_dict().items(): sharding_spec = _get_sharding_sepc(sharding_map, name) if sharding_spec is None: - print("Skipping weight:", name) + print("Skipping weight:", name, weight_meta.shape) continue sharding = NamedSharding(mesh, P(*sharding_spec)) res[name] = env.j2t_iso( @@ -97,18 +97,24 @@ def create_sharded_kv_cache(cache_dict, mesh, env): "embed.weight": (name0, None), "layers.*.attn.wq.weight": (None, name0), "layers.*.attn.wq.bias": (name0,), - "layers.*.attn.wkv_a.weight": (None, name0), + "layers.*.attn.wkv_a.weight": (None, None), "layers.*.attn.kv_norm.weight": (name0,), "layers.*.attn.wkv_b.weight": (name0, None), "layers.*.attn.wkv_b.bias": (name0,), "layers.*.attn.wo.weight": (name0, None), "layers.*.attn.wo.bias": (name0, None), - "layers.0.ffn.w1.weight": (name0, None), - "layers.0.ffn.w1.bias": (name0,), - "layers.0.ffn.w2.weight": (None, name0), - "layers.0.ffn.w2.bias": (name0,), - "layers.0.ffn.w3.weight": (name0, None), - "layers.0.ffn.w3.bias": (name0,), + + "layers.*.attn.wq_a.weight": (None, None), + "layers.*.attn.q_norm.weight": (), + "layers.*.attn.wq_b.weight": (name0, None), + "layers.*.attn.wq_b.bias": (name0,), + + "layers.*.ffn.w1.weight": (name0, None), + "layers.*.ffn.w1.bias": (name0,), + "layers.*.ffn.w2.weight": (None, name0), + "layers.*.ffn.w2.bias": (name0,), + "layers.*.ffn.w3.weight": (name0, None), + "layers.*.ffn.w3.bias": (name0,), "layers.*.ffn.cond_ffn.w1": (None, name0, None), "layers.*.ffn.cond_ffn.w2": (None, None, name0), "layers.*.ffn.cond_ffn.w3": (None, name0, None), @@ -154,6 +160,7 @@ def main(config=None, seqlen=2048, batch_size=1): torch.manual_seed(42) torchax.enable_performance_mode() torchax.enable_globally() + torchax.default_env().config.debug_print_each_op = True args = ModelArgs(**config_dict) args.max_batch_size = 1 @@ -177,9 +184,9 @@ def main(config=None, seqlen=2048, batch_size=1): jitted.buffers.update(caches_dict) with mesh: - x = torch.randint(0, args.vocab_size, (1, seqlen)) + x = torch.ones((1, 2048), dtype=torch.int32) x = _replicate(x, env, mesh) - input_pos = torch.arange(seqlen, device="jax") + input_pos = torch.arange(2048, device="jax") for i in range(5): step_start = time.perf_counter() logits = jitted(x, input_pos) @@ -191,7 +198,7 @@ def main(config=None, seqlen=2048, batch_size=1): step_end - step_start, ) - x = torch.randint(0, args.vocab_size, (1, 1)) + x = torch.ones((1, 1), dtype=torch.int32) x = _replicate(x, env, mesh) input_pos = torch.arange(2048, 2049, device="jax") for i in range(5): diff --git a/torchprime/launcher/Dockerfile b/torchprime/launcher/Dockerfile index f4678e7c..49a1b6a2 100644 --- a/torchprime/launcher/Dockerfile +++ b/torchprime/launcher/Dockerfile @@ -26,6 +26,7 @@ WORKDIR /workspaces # Install torchax RUN git clone https://github.com/pytorch/xla.git WORKDIR /workspaces/xla/torchax +RUN git checkout hanq_torchax1 RUN pip install torch_xla[pallas] \ -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html @@ -42,5 +43,7 @@ RUN pip install -e . COPY . /workspaces/torchprime # This should not install any packages. Only symlink the source code. RUN pip install --no-deps -e . +RUN pip install --force-reinstall --upgrade torch==2.5.1+cpu --index-url https://download.pytorch.org/whl/cpu +RUN pip uninstall torchvision -y ENV LIBTPU_INIT_ARGS "--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true" diff --git a/torchprime/launcher/buildpush.py b/torchprime/launcher/buildpush.py index 51746eca..f50dbff6 100755 --- a/torchprime/launcher/buildpush.py +++ b/torchprime/launcher/buildpush.py @@ -47,7 +47,7 @@ def buildpush( # Build, tag, and push Docker image try: _run( - f"{sudo_cmd} docker build --network=host --progress=auto -t {docker_tag} {context_dir} -f {docker_file}", + f"{sudo_cmd} docker build --no-cache --network=host --progress=auto -t {docker_tag} {context_dir} -f {docker_file}", ) _run( f"{sudo_cmd} docker tag {docker_tag} {docker_url}",