Skip to content

Commit f738a03

Browse files
authored
Switch DeepSeekV3 to Use FlexAttention by Default (#1610)
Currently, the only available backend for SDPA for DeepSeekV3 is efficient attention kernel. For FlashAttentionV2 (what current SDPA supports), the V embedding dimension must be the same as Q and K. For cuDNN attention, it is complaining the head dimension is too large. The reason for defaulting the attention to SDPA in TorchTitan is that FlexCP is not yet ready. However, the combination of SDPA + CP + DeepSeekV3 is also not functional. This PR updates all DeepSeekV3 configurations to use FlexAttention, which significantly improves the overall performance. **Document masking also contributes to MFU improvement, but the majority is from FlexAttention itself**. ``` CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --training.steps=100 --parallelism.expert_parallel_degree=8 ``` SDPA: ``` [rank0]:[titan] 2025-08-20 18:28:42,047 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 4096, total steps 100 (warmup 200) [rank0]:[titan] 2025-08-20 18:28:42,047 - root - INFO - Training starts at step 1 [rank0]:[titan] 2025-08-20 18:29:04,053 - root - INFO - step: 1 loss: 12.0401 grad_norm: 1.7464 memory: 63.55GiB(66.89%) tps: 1,416 tflops: 24.67 mfu: 2.49% [rank0]:[titan] 2025-08-20 18:29:04,053 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-08-20 18:29:46,138 - root - INFO - step: 10 loss: 10.3087 grad_norm: 3.1896 memory: 78.14GiB(82.25%) tps: 7,008 tflops: 122.12 mfu: 12.35% [rank0]:[titan] 2025-08-20 18:30:33,628 - root - INFO - step: 20 loss: 8.7601 grad_norm: 2.5195 memory: 78.14GiB(82.25%) tps: 6,900 tflops: 120.24 mfu: 12.16% [rank0]:[titan] 2025-08-20 18:31:22,497 - root - INFO - step: 30 loss: 7.7450 grad_norm: 1.9296 memory: 78.14GiB(82.25%) tps: 6,705 tflops: 116.85 mfu: 11.82% [rank0]:[titan] 2025-08-20 18:32:19,709 - root - INFO - step: 40 loss: 6.9795 grad_norm: 0.6893 memory: 78.14GiB(82.25%) tps: 5,728 tflops: 99.81 mfu: 10.09% [rank0]:[titan] 2025-08-20 18:33:34,343 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds [rank0]:[titan] 2025-08-20 18:33:43,863 - root - INFO - step: 50 loss: 6.8381 grad_norm: 1.1848 memory: 78.14GiB(82.25%) tps: 3,894 tflops: 67.86 mfu: 6.86% [rank0]:[titan] 2025-08-20 18:34:37,289 - root - INFO - step: 60 loss: 6.5727 grad_norm: 0.9871 memory: 78.14GiB(82.25%) tps: 6,133 tflops: 106.88 mfu: 10.81% [rank0]:[titan] 2025-08-20 18:35:27,959 - root - INFO - step: 70 loss: 6.5041 grad_norm: 1.5895 memory: 78.14GiB(82.25%) tps: 6,467 tflops: 112.70 mfu: 11.40% [rank0]:[titan] 2025-08-20 18:36:16,732 - root - INFO - step: 80 loss: 6.3179 grad_norm: 0.9556 memory: 78.14GiB(82.25%) tps: 6,719 tflops: 117.08 mfu: 11.84% [rank0]:[titan] 2025-08-20 18:37:05,604 - root - INFO - step: 90 loss: 6.2124 grad_norm: 0.8286 memory: 78.14GiB(82.25%) tps: 6,705 tflops: 116.85 mfu: 11.81% [rank0]:[titan] 2025-08-20 18:37:49,285 - root - INFO - [GC] Peforming periodical GC collection 0.04 seconds [rank0]:[titan] 2025-08-20 18:37:54,361 - root - INFO - step: 100 loss: 6.2596 grad_norm: 1.5143 memory: 78.14GiB(82.25%) tps: 6,721 tflops: 117.12 mfu: 11.84% [rank0]:[titan] 2025-08-20 18:37:54,361 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-08-20 18:37:56,364 - root - INFO - Training completed [rank0]:[titan] 2025-08-20 18:37:57,535 - root - INFO - Process group destroyed ``` FlexAttention (now) ``` [rank0]:/data/users/chienchin/mywork/pytorch/torch/__init__.py:1539: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /data/users/chienchin/mywork/pytorch/aten/src/ATen/Context.cpp:80.) [rank0]: return _C._get_float32_matmul_precision() [rank0]:[titan] 2025-08-20 22:16:59,699 - root - INFO - step: 1 loss: 11.9984 grad_norm: 1.7288 memory: 63.55GiB(66.89%) tps: 727 tflops: 12.67 mfu: 1.28% [rank0]:[titan] 2025-08-20 22:16:59,699 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-08-20 22:17:32,228 - root - INFO - step: 10 loss: 10.3101 grad_norm: 2.9111 memory: 78.14GiB(82.25%) tps: 9,066 tflops: 157.99 mfu: 15.97% [rank0]:[titan] 2025-08-20 22:18:08,957 - root - INFO - step: 20 loss: 8.7431 grad_norm: 2.5391 memory: 78.14GiB(82.25%) tps: 8,922 tflops: 155.47 mfu: 15.72% [rank0]:[titan] 2025-08-20 22:18:46,981 - root - INFO - step: 30 loss: 7.7133 grad_norm: 1.7743 memory: 78.14GiB(82.25%) tps: 8,618 tflops: 150.18 mfu: 15.19% [rank0]:[titan] 2025-08-20 22:19:26,672 - root - INFO - step: 40 loss: 6.9643 grad_norm: 0.7227 memory: 78.14GiB(82.25%) tps: 8,256 tflops: 143.88 mfu: 14.55% [rank0]:[titan] 2025-08-20 22:20:01,975 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds [rank0]:[titan] 2025-08-20 22:20:06,015 - root - INFO - step: 50 loss: 6.8046 grad_norm: 1.0556 memory: 78.14GiB(82.25%) tps: 8,329 tflops: 145.15 mfu: 14.68% [rank0]:[titan] 2025-08-20 22:20:45,784 - root - INFO - step: 60 loss: 6.5364 grad_norm: 1.7141 memory: 78.14GiB(82.25%) tps: 8,240 tflops: 143.59 mfu: 14.52% [rank0]:[titan] 2025-08-20 22:21:25,078 - root - INFO - step: 70 loss: 6.4709 grad_norm: 1.2385 memory: 78.14GiB(82.25%) tps: 8,340 tflops: 145.33 mfu: 14.69% [rank0]:[titan] 2025-08-20 22:22:03,088 - root - INFO - step: 80 loss: 6.2786 grad_norm: 2.2534 memory: 78.14GiB(82.25%) tps: 8,621 tflops: 150.24 mfu: 15.19% [rank0]:[titan] 2025-08-20 22:22:41,254 - root - INFO - step: 90 loss: 6.1441 grad_norm: 0.6878 memory: 78.14GiB(82.25%) tps: 8,586 tflops: 149.62 mfu: 15.13% [rank0]:[titan] 2025-08-20 22:23:15,059 - root - INFO - [GC] Peforming periodical GC collection 0.05 seconds [rank0]:[titan] 2025-08-20 22:23:19,063 - root - INFO - step: 100 loss: 6.1348 grad_norm: 1.2875 memory: 78.14GiB(82.25%) tps: 8,667 tflops: 151.04 mfu: 15.27% [rank0]:[titan] 2025-08-20 22:23:19,064 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-08-20 22:23:21,065 - root - INFO - Training completed [rank0]:[titan] 2025-08-20 22:23:22,436 - root - INFO - Process group destroyed ```
1 parent 8a749c6 commit f738a03

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@
100100
qk_rope_head_dim=64,
101101
v_head_dim=128,
102102
mscale=0.70,
103+
use_flex_attn=True,
104+
attn_mask_type="block_causal",
103105
),
104106
"236B": DeepSeekV3ModelArgs(
105107
vocab_size=102400,
@@ -125,6 +127,8 @@
125127
qk_nope_head_dim=128,
126128
qk_rope_head_dim=64,
127129
v_head_dim=128,
130+
use_flex_attn=True,
131+
attn_mask_type="block_causal",
128132
),
129133
"671B": DeepSeekV3ModelArgs(
130134
vocab_size=129280,
@@ -150,6 +154,8 @@
150154
qk_nope_head_dim=128,
151155
qk_rope_head_dim=64,
152156
v_head_dim=128,
157+
use_flex_attn=True,
158+
attn_mask_type="block_causal",
153159
),
154160
}
155161

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def apply_tp(
243243
# the result of max, since the absolute maximum is
244244
# used to compute the scaling factor for quantization.
245245
torch.ops.aten.max.default,
246+
torch._higher_order_ops.flex_attention,
246247
}
247248

248249

0 commit comments

Comments
 (0)