Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ jobs:
python examples/example_llama3.py
python examples/example_dcp.py
python examples/example_local_map.py
python examples/example_ds3_local_map.py
python examples/example_pp_graph_passes.py
torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py
41 changes: 28 additions & 13 deletions autoparallel/_testing/models/dsv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,9 +1529,11 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
)
self.model_args = model_args

def init_weights(self, buffer_device: torch.device | None = None) -> None:
_init_weights_tok_embeddings(self)
_init_weights_layers(self, buffer_device)
def init_weights(
self, buffer_device: torch.device | None = None, seed: int | None = None
) -> None:
_init_weights_tok_embeddings(self, seed)
_init_weights_layers(self, buffer_device, seed)
_init_weights_norm_and_output(self)

def forward(
Expand Down Expand Up @@ -1585,8 +1587,10 @@ def forward(self, h):
h = layer(h, self.freqs_cis)
return h

def init_weights(self, buffer_device: torch.device | None = None) -> None:
_init_weights_layers(self, buffer_device)
def init_weights(
self, buffer_device: torch.device | None = None, seed: int | None = None
) -> None:
_init_weights_layers(self, buffer_device, seed)


class DeepSeekV3Stage0(DeepSeekV3StageI):
Expand All @@ -1600,9 +1604,11 @@ def forward(self, tokens):
# torch.Size([1024, 1024, 2048])
return super().forward(h)

def init_weights(self, buffer_device: torch.device | None = None) -> None:
_init_weights_tok_embeddings(self)
super().init_weights(buffer_device=buffer_device)
def init_weights(
self, buffer_device: torch.device | None = None, seed: int | None = None
) -> None:
_init_weights_tok_embeddings(self, seed)
super().init_weights(buffer_device, seed)


class DeepSeekV3StageN(DeepSeekV3StageI):
Expand All @@ -1618,8 +1624,10 @@ def forward(self, h):
output = self.output(h) if self.output is not None else h
return output

def init_weights(self, buffer_device: torch.device | None = None) -> None:
super().init_weights(buffer_device=buffer_device)
def init_weights(
self, buffer_device: torch.device | None = None, seed: int | None = None
) -> None:
super().init_weights(buffer_device, seed)
_init_weights_norm_and_output(self)


Expand All @@ -1628,23 +1636,30 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
######################


def _init_weights_tok_embeddings(self: Union[DeepSeekV3Model, DeepSeekV3Stage0]):
def _init_weights_tok_embeddings(
self: Union[DeepSeekV3Model, DeepSeekV3Stage0], seed: int | None = None
):
if seed is not None:
torch.manual_seed(seed)
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)


def _init_weights_layers(
self: Union[DeepSeekV3Model, DeepSeekV3StageI],
buffer_device: torch.device | None,
seed: int | None = None,
):
if buffer_device is None:
buffer_device = self.freqs_cis.device # type: ignore[assignment]
with torch.device(buffer_device): # type: ignore[arg-type]
self.freqs_cis = precompute_freqs_cis(self.model_args)
for layer in self.layers.values():
for i, layer in enumerate(self.layers.values()):
if seed is not None:
torch.manual_seed(seed)
if layer is not None:
assert isinstance(layer, TransformerBlock)
layer.init_weights(buffer_device=buffer_device) # type: ignore[arg-type]
layer.init_weights(buffer_device) # type: ignore[arg-type]


def _init_weights_norm_and_output(self: Union[DeepSeekV3Model, DeepSeekV3StageN]):
Expand Down
86 changes: 86 additions & 0 deletions autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import Any, Iterable

import torch
Expand Down Expand Up @@ -377,3 +378,88 @@ def print_rank_by_rank(msg: Any):
print(msg)
print(f"{rank=} done")
torch.distributed.barrier()


def hash_tensor(t: torch.Tensor) -> str:
if isinstance(t, torch.distributed.tensor.DTensor):
t = t.to_local()
return f"DTensor({hash_tensor(t)})"

if t.is_complex():
return f"real={hash_tensor(t.real)}, imag={hash_tensor(t.imag)})"

return f"{torch.hash_tensor(t)}"


class NumericsLogger:
def __init__(self, base_dir: str):
self.base = Path(base_dir)
self.base.mkdir(parents=True, exist_ok=True)
self.rank = torch.distributed.get_rank()
self.dir = self._create_run_dir()

def _create_run_dir(self) -> Path:
"""
Find the next available integer directory name under base_dir.
Example: base_dir/0, base_dir/1, base_dir/2, ...
"""
existing = [
int(p.name) for p in self.base.iterdir() if p.is_dir() and p.name.isdigit()
]
next_id = (max(existing) + 1) if existing else 0
run_dir = self.base / str(next_id)
torch.distributed.barrier()
if self.rank == 0:
run_dir.mkdir()
torch.distributed.barrier()
return run_dir

def log_model_weights(self, parallel_mod):
if self.rank == 0:
path = self.dir / "weights.log"

logs = []
for name, param in parallel_mod.named_parameters():
logs.append(f"{name=} hash={hash_tensor(param)}")
for name, buf in parallel_mod.named_buffers():
logs.append(f"{name=} hash={hash_tensor(buf)}")

with open(path, "a") as f:
f.write("\n".join(logs) + "\n")

print(f"Weight hashes written to {path}")

def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
path = self.dir / "pp_weights.log"

torch.distributed.barrier()
# First print the params of every stage
for i in range(num_world_stages):
if self.rank in ranks and i in stage_mods:
param_logs = []
real_params = dict(stage_mods[i].named_parameters())
for name, _ in orig_mod.named_parameters():
if name not in real_params:
continue
param = real_params[name]
param_logs.append(f"{name=} hash={hash_tensor(param)}")
with open(path, "a") as f:
f.write("\n".join(param_logs) + "\n")
torch.distributed.barrier()

# Then print the buffers of every stage
for i in range(num_world_stages):
if self.rank in ranks and i in stage_mods:
buffer_logs = []
real_buffers = dict(stage_mods[i].named_buffers())
for name, _ in orig_mod.named_buffers():
if name not in real_buffers:
continue
buffer = real_buffers[name]
buffer_logs.append(f"{name=} hash={hash_tensor(buffer)}")
with open(path, "a") as f:
f.write("\n".join(buffer_logs) + "\n")
torch.distributed.barrier()

if self.rank == 0:
print(f"Weight hashes written to {path}")
Loading