From cefbf1265e95dec656200b0bee7f74f6f12c9832 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Wed, 13 Aug 2025 13:58:53 -0700 Subject: [PATCH] [Do Not Land] Debug for SDPA + CP nan issue in DeepSeekV3 [ghstack-poisoned] --- torchtitan/models/deepseek_v3/__init__.py | 7 +++++-- .../models/deepseek_v3/train_configs/deepseek_v3_16b.toml | 6 +++--- torchtitan/models/llama3/infra/parallelize.py | 2 +- torchtitan/train.py | 3 ++- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 8243a0a84a..1038b909f9 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -73,9 +73,12 @@ dim=2048, inter_dim=10944, moe_inter_dim=1408, - n_layers=27, + # n_layers=27, + n_layers=1, n_dense_layers=1, - n_heads=16, + # n_heads=16, + # n_heads=1, # n_heads=2 reproduces the nan error + n_heads=2, n_routed_experts=64, n_shared_experts=2, n_activated_experts=6, diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 4f646c8d0f..7fc5f98c68 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -13,7 +13,7 @@ enable_memory_snapshot = false save_memory_snapshot_folder = "memory_snapshot" [metrics] -log_freq = 10 +log_freq = 1 disable_color_printing = false enable_tensorboard = false save_tb_folder = "tb" @@ -37,10 +37,10 @@ decay_type = "cosine" min_lr_factor = 0.1 [training] -local_batch_size = 8 +local_batch_size = 1 seq_len = 4096 max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 2 compile = false dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 6d9bf60c11..7d7de61e19 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -253,7 +253,7 @@ def _apply_ac_to_transformer_block( ) if ac_config.mode == "full": - return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + return ptd_checkpoint_wrapper(module, preserve_rng_state=False, debug=True) assert ac_config.mode == "selective", f"{ac_config.mode}" use_op_sac = ac_config.selective_ac_option == "op" diff --git a/torchtitan/train.py b/torchtitan/train.py index 0955bbb2cb..8ac48c2e25 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,9 +11,9 @@ from typing import Any, Generator, Iterable, Optional import torch -from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module +from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderStopIteration from torchtitan.components.ft import FTManager, maybe_semi_sync_training @@ -448,6 +448,7 @@ def forward_backward_step( loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred + logger.info("backward") loss.backward() return loss