diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index f81db35341..1c3d2b19d2 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -100,6 +100,8 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, + use_flex_attn=True, + attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( vocab_size=102400, @@ -125,6 +127,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, + use_flex_attn=True, + attn_mask_type="block_causal", ), "671B": DeepSeekV3ModelArgs( vocab_size=129280, @@ -150,6 +154,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, + use_flex_attn=True, + attn_mask_type="block_causal", ), } diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 6d9bf60c11..eaffe8c827 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -240,6 +240,7 @@ def apply_tp( # the result of max, since the absolute maximum is # used to compute the scaling factor for quantization. torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, }