Skip to content

Commit 3e20f8a

Browse files
authored
Log weight hashes for DSv3 w/ pp vs w/o pp (#240)
stack-info: PR: #240, branch: xmfan/stack/18
1 parent 14366af commit 3e20f8a

File tree

5 files changed

+331
-129
lines changed

5 files changed

+331
-129
lines changed

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,5 @@ jobs:
4545
python examples/example_llama3.py
4646
python examples/example_dcp.py
4747
python examples/example_local_map.py
48-
python examples/example_ds3_local_map.py
4948
python examples/example_pp_graph_passes.py
49+
torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py

autoparallel/_testing/models/dsv3.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,9 +1529,11 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
15291529
)
15301530
self.model_args = model_args
15311531

1532-
def init_weights(self, buffer_device: torch.device | None = None) -> None:
1533-
_init_weights_tok_embeddings(self)
1534-
_init_weights_layers(self, buffer_device)
1532+
def init_weights(
1533+
self, buffer_device: torch.device | None = None, seed: int | None = None
1534+
) -> None:
1535+
_init_weights_tok_embeddings(self, seed)
1536+
_init_weights_layers(self, buffer_device, seed)
15351537
_init_weights_norm_and_output(self)
15361538

15371539
def forward(
@@ -1585,8 +1587,10 @@ def forward(self, h):
15851587
h = layer(h, self.freqs_cis)
15861588
return h
15871589

1588-
def init_weights(self, buffer_device: torch.device | None = None) -> None:
1589-
_init_weights_layers(self, buffer_device)
1590+
def init_weights(
1591+
self, buffer_device: torch.device | None = None, seed: int | None = None
1592+
) -> None:
1593+
_init_weights_layers(self, buffer_device, seed)
15901594

15911595

15921596
class DeepSeekV3Stage0(DeepSeekV3StageI):
@@ -1600,9 +1604,11 @@ def forward(self, tokens):
16001604
# torch.Size([1024, 1024, 2048])
16011605
return super().forward(h)
16021606

1603-
def init_weights(self, buffer_device: torch.device | None = None) -> None:
1604-
_init_weights_tok_embeddings(self)
1605-
super().init_weights(buffer_device=buffer_device)
1607+
def init_weights(
1608+
self, buffer_device: torch.device | None = None, seed: int | None = None
1609+
) -> None:
1610+
_init_weights_tok_embeddings(self, seed)
1611+
super().init_weights(buffer_device, seed)
16061612

16071613

16081614
class DeepSeekV3StageN(DeepSeekV3StageI):
@@ -1618,8 +1624,10 @@ def forward(self, h):
16181624
output = self.output(h) if self.output is not None else h
16191625
return output
16201626

1621-
def init_weights(self, buffer_device: torch.device | None = None) -> None:
1622-
super().init_weights(buffer_device=buffer_device)
1627+
def init_weights(
1628+
self, buffer_device: torch.device | None = None, seed: int | None = None
1629+
) -> None:
1630+
super().init_weights(buffer_device, seed)
16231631
_init_weights_norm_and_output(self)
16241632

16251633

@@ -1628,23 +1636,30 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
16281636
######################
16291637

16301638

1631-
def _init_weights_tok_embeddings(self: Union[DeepSeekV3Model, DeepSeekV3Stage0]):
1639+
def _init_weights_tok_embeddings(
1640+
self: Union[DeepSeekV3Model, DeepSeekV3Stage0], seed: int | None = None
1641+
):
1642+
if seed is not None:
1643+
torch.manual_seed(seed)
16321644
if self.tok_embeddings is not None:
16331645
nn.init.normal_(self.tok_embeddings.weight)
16341646

16351647

16361648
def _init_weights_layers(
16371649
self: Union[DeepSeekV3Model, DeepSeekV3StageI],
16381650
buffer_device: torch.device | None,
1651+
seed: int | None = None,
16391652
):
16401653
if buffer_device is None:
16411654
buffer_device = self.freqs_cis.device # type: ignore[assignment]
16421655
with torch.device(buffer_device): # type: ignore[arg-type]
16431656
self.freqs_cis = precompute_freqs_cis(self.model_args)
1644-
for layer in self.layers.values():
1657+
for i, layer in enumerate(self.layers.values()):
1658+
if seed is not None:
1659+
torch.manual_seed(seed)
16451660
if layer is not None:
16461661
assert isinstance(layer, TransformerBlock)
1647-
layer.init_weights(buffer_device=buffer_device) # type: ignore[arg-type]
1662+
layer.init_weights(buffer_device) # type: ignore[arg-type]
16481663

16491664

16501665
def _init_weights_norm_and_output(self: Union[DeepSeekV3Model, DeepSeekV3StageN]):

autoparallel/utils.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from pathlib import Path
67
from typing import Any, Iterable
78

89
import torch
@@ -377,3 +378,88 @@ def print_rank_by_rank(msg: Any):
377378
print(msg)
378379
print(f"{rank=} done")
379380
torch.distributed.barrier()
381+
382+
383+
def hash_tensor(t: torch.Tensor) -> str:
384+
if isinstance(t, torch.distributed.tensor.DTensor):
385+
t = t.to_local()
386+
return f"DTensor({hash_tensor(t)})"
387+
388+
if t.is_complex():
389+
return f"real={hash_tensor(t.real)}, imag={hash_tensor(t.imag)})"
390+
391+
return f"{torch.hash_tensor(t)}"
392+
393+
394+
class NumericsLogger:
395+
def __init__(self, base_dir: str):
396+
self.base = Path(base_dir)
397+
self.base.mkdir(parents=True, exist_ok=True)
398+
self.rank = torch.distributed.get_rank()
399+
self.dir = self._create_run_dir()
400+
401+
def _create_run_dir(self) -> Path:
402+
"""
403+
Find the next available integer directory name under base_dir.
404+
Example: base_dir/0, base_dir/1, base_dir/2, ...
405+
"""
406+
existing = [
407+
int(p.name) for p in self.base.iterdir() if p.is_dir() and p.name.isdigit()
408+
]
409+
next_id = (max(existing) + 1) if existing else 0
410+
run_dir = self.base / str(next_id)
411+
torch.distributed.barrier()
412+
if self.rank == 0:
413+
run_dir.mkdir()
414+
torch.distributed.barrier()
415+
return run_dir
416+
417+
def log_model_weights(self, parallel_mod):
418+
if self.rank == 0:
419+
path = self.dir / "weights.log"
420+
421+
logs = []
422+
for name, param in parallel_mod.named_parameters():
423+
logs.append(f"{name=} hash={hash_tensor(param)}")
424+
for name, buf in parallel_mod.named_buffers():
425+
logs.append(f"{name=} hash={hash_tensor(buf)}")
426+
427+
with open(path, "a") as f:
428+
f.write("\n".join(logs) + "\n")
429+
430+
print(f"Weight hashes written to {path}")
431+
432+
def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
433+
path = self.dir / "pp_weights.log"
434+
435+
torch.distributed.barrier()
436+
# First print the params of every stage
437+
for i in range(num_world_stages):
438+
if self.rank in ranks and i in stage_mods:
439+
param_logs = []
440+
real_params = dict(stage_mods[i].named_parameters())
441+
for name, _ in orig_mod.named_parameters():
442+
if name not in real_params:
443+
continue
444+
param = real_params[name]
445+
param_logs.append(f"{name=} hash={hash_tensor(param)}")
446+
with open(path, "a") as f:
447+
f.write("\n".join(param_logs) + "\n")
448+
torch.distributed.barrier()
449+
450+
# Then print the buffers of every stage
451+
for i in range(num_world_stages):
452+
if self.rank in ranks and i in stage_mods:
453+
buffer_logs = []
454+
real_buffers = dict(stage_mods[i].named_buffers())
455+
for name, _ in orig_mod.named_buffers():
456+
if name not in real_buffers:
457+
continue
458+
buffer = real_buffers[name]
459+
buffer_logs.append(f"{name=} hash={hash_tensor(buffer)}")
460+
with open(path, "a") as f:
461+
f.write("\n".join(buffer_logs) + "\n")
462+
torch.distributed.barrier()
463+
464+
if self.rank == 0:
465+
print(f"Weight hashes written to {path}")

0 commit comments

Comments
 (0)