Skip to content

Switch DeepSeekV3 to Use FlexAttention by Default #1610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 22, 2025

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Aug 21, 2025

Currently, the only available backend for SDPA for DeepSeekV3 is efficient attention kernel. For FlashAttentionV2 (what current SDPA supports), the V embedding dimension must be the same as Q and K. For cuDNN attention, it is complaining the head dimension is too large.

The reason for defaulting the attention to SDPA in TorchTitan is that FlexCP is not yet ready. However, the combination of SDPA + CP + DeepSeekV3 is also not functional.

This PR updates all DeepSeekV3 configurations to use FlexAttention, which significantly improves the overall performance. Document masking also contributes to MFU improvement, but the majority is from FlexAttention itself.

CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --training.steps=100 --parallelism.expert_parallel_degree=8

SDPA:

[rank0]:[titan] 2025-08-20 18:28:42,047 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 4096, total steps 100 (warmup 200)
[rank0]:[titan] 2025-08-20 18:28:42,047 - root - INFO - Training starts at step 1
[rank0]:[titan] 2025-08-20 18:29:04,053 - root - INFO - step:  1  loss: 12.0401  grad_norm:  1.7464  memory: 63.55GiB(66.89%)  tps: 1,416  tflops: 24.67  mfu: 2.49%
[rank0]:[titan] 2025-08-20 18:29:04,053 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-08-20 18:29:46,138 - root - INFO - step: 10  loss: 10.3087  grad_norm:  3.1896  memory: 78.14GiB(82.25%)  tps: 7,008  tflops: 122.12  mfu: 12.35%
[rank0]:[titan] 2025-08-20 18:30:33,628 - root - INFO - step: 20  loss:  8.7601  grad_norm:  2.5195  memory: 78.14GiB(82.25%)  tps: 6,900  tflops: 120.24  mfu: 12.16%
[rank0]:[titan] 2025-08-20 18:31:22,497 - root - INFO - step: 30  loss:  7.7450  grad_norm:  1.9296  memory: 78.14GiB(82.25%)  tps: 6,705  tflops: 116.85  mfu: 11.82%
[rank0]:[titan] 2025-08-20 18:32:19,709 - root - INFO - step: 40  loss:  6.9795  grad_norm:  0.6893  memory: 78.14GiB(82.25%)  tps: 5,728  tflops: 99.81  mfu: 10.09%
[rank0]:[titan] 2025-08-20 18:33:34,343 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds
[rank0]:[titan] 2025-08-20 18:33:43,863 - root - INFO - step: 50  loss:  6.8381  grad_norm:  1.1848  memory: 78.14GiB(82.25%)  tps: 3,894  tflops: 67.86  mfu: 6.86%
[rank0]:[titan] 2025-08-20 18:34:37,289 - root - INFO - step: 60  loss:  6.5727  grad_norm:  0.9871  memory: 78.14GiB(82.25%)  tps: 6,133  tflops: 106.88  mfu: 10.81%
[rank0]:[titan] 2025-08-20 18:35:27,959 - root - INFO - step: 70  loss:  6.5041  grad_norm:  1.5895  memory: 78.14GiB(82.25%)  tps: 6,467  tflops: 112.70  mfu: 11.40%
[rank0]:[titan] 2025-08-20 18:36:16,732 - root - INFO - step: 80  loss:  6.3179  grad_norm:  0.9556  memory: 78.14GiB(82.25%)  tps: 6,719  tflops: 117.08  mfu: 11.84%
[rank0]:[titan] 2025-08-20 18:37:05,604 - root - INFO - step: 90  loss:  6.2124  grad_norm:  0.8286  memory: 78.14GiB(82.25%)  tps: 6,705  tflops: 116.85  mfu: 11.81%
[rank0]:[titan] 2025-08-20 18:37:49,285 - root - INFO - [GC] Peforming periodical GC collection 0.04 seconds
[rank0]:[titan] 2025-08-20 18:37:54,361 - root - INFO - step: 100  loss:  6.2596  grad_norm:
1.5143  memory: 78.14GiB(82.25%)  tps: 6,721  tflops: 117.12  mfu: 11.84%
[rank0]:[titan] 2025-08-20 18:37:54,361 - root - INFO - Sleeping 2 seconds for other ranks to
complete
[rank0]:[titan] 2025-08-20 18:37:56,364 - root - INFO - Training completed
[rank0]:[titan] 2025-08-20 18:37:57,535 - root - INFO - Process group destroyed

FlexAttention (now)

[rank0]:/data/users/chienchin/mywork/pytorch/torch/__init__.py:1539: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /data/users/chienchin/mywork/pytorch/aten/src/ATen/Context.cpp:80.)
[rank0]:  return _C._get_float32_matmul_precision()
[rank0]:[titan] 2025-08-20 22:16:59,699 - root - INFO - step:  1  loss: 11.9984  grad_norm:  1.7288  memory: 63.55GiB(66.89%)  tps: 727  tflops: 12.67  mfu: 1.28%
[rank0]:[titan] 2025-08-20 22:16:59,699 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-08-20 22:17:32,228 - root - INFO - step: 10  loss: 10.3101  grad_norm:  2.9111  memory: 78.14GiB(82.25%)  tps: 9,066  tflops: 157.99  mfu: 15.97%
[rank0]:[titan] 2025-08-20 22:18:08,957 - root - INFO - step: 20  loss:  8.7431  grad_norm:  2.5391  memory: 78.14GiB(82.25%)  tps: 8,922  tflops: 155.47  mfu: 15.72%
[rank0]:[titan] 2025-08-20 22:18:46,981 - root - INFO - step: 30  loss:  7.7133  grad_norm:  1.7743  memory: 78.14GiB(82.25%)  tps: 8,618  tflops: 150.18  mfu: 15.19%
[rank0]:[titan] 2025-08-20 22:19:26,672 - root - INFO - step: 40  loss:  6.9643  grad_norm:  0.7227  memory: 78.14GiB(82.25%)  tps: 8,256  tflops: 143.88  mfu: 14.55%
[rank0]:[titan] 2025-08-20 22:20:01,975 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds
[rank0]:[titan] 2025-08-20 22:20:06,015 - root - INFO - step: 50  loss:  6.8046  grad_norm:  1.0556  memory: 78.14GiB(82.25%)  tps: 8,329  tflops: 145.15  mfu: 14.68%
[rank0]:[titan] 2025-08-20 22:20:45,784 - root - INFO - step: 60  loss:  6.5364  grad_norm:  1.7141  memory: 78.14GiB(82.25%)  tps: 8,240  tflops: 143.59  mfu: 14.52%
[rank0]:[titan] 2025-08-20 22:21:25,078 - root - INFO - step: 70  loss:  6.4709  grad_norm:  1.2385  memory: 78.14GiB(82.25%)  tps: 8,340  tflops: 145.33  mfu: 14.69%
[rank0]:[titan] 2025-08-20 22:22:03,088 - root - INFO - step: 80  loss:  6.2786  grad_norm:  2.2534  memory: 78.14GiB(82.25%)  tps: 8,621  tflops: 150.24  mfu: 15.19%
[rank0]:[titan] 2025-08-20 22:22:41,254 - root - INFO - step: 90  loss:  6.1441  grad_norm:  0.6878  memory: 78.14GiB(82.25%)  tps: 8,586  tflops: 149.62  mfu: 15.13%
[rank0]:[titan] 2025-08-20 22:23:15,059 - root - INFO - [GC] Peforming periodical GC collection 0.05 seconds
[rank0]:[titan] 2025-08-20 22:23:19,063 - root - INFO - step: 100  loss:  6.1348  grad_norm:
1.2875  memory: 78.14GiB(82.25%)  tps: 8,667  tflops: 151.04  mfu: 15.27%
[rank0]:[titan] 2025-08-20 22:23:19,064 - root - INFO - Sleeping 2 seconds for other ranks to
complete
[rank0]:[titan] 2025-08-20 22:23:21,065 - root - INFO - Training completed
[rank0]:[titan] 2025-08-20 22:23:22,436 - root - INFO - Process group destroyed

fegin added 2 commits August 20, 2025 16:25
With the current implementation DeepSeekV3 will not use FlashAttention nor cuDNN attention. It will use efficient attention, which has lower performance.

The motivation of defaulting the models to SDPA in TorchTitan is because FlexCP is not ready but SDPA + CP + DeepSeekV3 doesn't work either.

So this PR make all DeepSeekV3 configurations use FlexAttention, which gives us higher MFU.
@fegin fegin requested a review from drisspg August 21, 2025 05:32
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 21, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, the combination of SDPA + CP + DeepSeekV3 is also not functional.

I do want to understand this a bit more.
What stops MemEfficientAttention to work with CP? According to #1522 (comment), it is caused by precision issues when load balancing + AC are enabled.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, today PP doesn't work with FlexAttn block causal masking, because PP can't receive eos_id as a non-Tensor input (nor can it receive a mask function).
https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L433

Context: This regression is coming from a recent refactor #1424 to move eos_id out of ModelArgs (to remove dependency from model to tokenizer).

PP is indeed very important for DSV3 training at large scale, so we should make sure it runs. Created #1612 to track.

cc @H-Huang

@fegin
Copy link
Contributor Author

fegin commented Aug 21, 2025

What stops MemEfficientAttention to work with CP? According to #1522 (comment), it is caused by precision issues when load balancing + AC are enabled.

The precision issue is the hypothesis. But we could not confirm this. The only evidence is that if we convert Q, K, V to float32, it will work. I'm working with a researcher to see if his debugging tool can help confirm the hypothesis.

@fegin
Copy link
Contributor Author

fegin commented Aug 21, 2025

Created #1612 to track.

I have another PR that may fix the PP issue. Will submit it today.

@eqy
Copy link

eqy commented Aug 21, 2025

@fegin could you share some more details about the platform(s) you're testing? In theory cutting-edge cuDNN should support larger head dims incl. DeepSeekV3 but we have been conservative in enabling it due to a somewhat uneven support surface.

@fegin
Copy link
Contributor Author

fegin commented Aug 22, 2025

@eqy H100, CUDA 12.6, CUDNN 9.10.2.21.

@eqy
Copy link

eqy commented Aug 22, 2025

@fegin thanks, I think we can try enabling this case for H100 via pytorch/pytorch#161210

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@fegin fegin merged commit f738a03 into main Aug 22, 2025
7 checks passed
@fegin fegin deleted the chienchin/ds3_default_flex branch August 22, 2025 15:37
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Aug 22, 2025
reference: pytorch/torchtitan#1610

9.10 only for now, we would want to hold off on upgrading to either cuDNN frontend 1.14+/cuDNN 9.11+ due to some head-dim > 128 handling issues

Pull Request resolved: #161210
Approved by: https://github.com/Skylion007
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants