Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3453fdd
skip flash block sizes setting for cross attention.
entrpn Sep 16, 2025
c16045a
change sharding based on cross/self attention.
entrpn Sep 18, 2025
a9d9691
update sharding rules for attn.
entrpn Sep 19, 2025
695b95e
lint.
entrpn Sep 19, 2025
b185681
ring attention rules are added at front if not present to shard seque…
coolkp Sep 26, 2025
65da062
test fix
coolkp Sep 26, 2025
42cbb0e
Add dense padded attention kernel and use unsafe rng key for generation
coolkp Oct 20, 2025
920bda4
Update
coolkp Oct 22, 2025
4cbb943
Ignore history
coolkp Oct 22, 2025
94ecdca
remove file
coolkp Oct 22, 2025
54ae6e7
Flag for using segment ids and masking padding tokens in attention
coolkp Nov 10, 2025
0abc904
Tokamax splash attn
coolkp Nov 11, 2025
b0bc3a3
Flag for using same sequence sharding for self and cross
coolkp Nov 11, 2025
5182222
update requirements.txt
coolkp Nov 11, 2025
8e364d9
Merge branch 'main' into cross_self_attention_switch
coolkp Nov 11, 2025
225b79e
Delete splash_attn_benchmark.py
coolkp Nov 11, 2025
1d21a53
Delete padded_flash_attn.py
coolkp Nov 11, 2025
69d2a30
Merge main
coolkp Nov 11, 2025
0d97a52
Ruff format
coolkp Nov 11, 2025
be62d37
Ruff format
coolkp Nov 11, 2025
5498223
Ruff format
coolkp Nov 11, 2025
19fb249
Address comments
coolkp Nov 11, 2025
463baaf
Address comments
coolkp Nov 11, 2025
85cec45
Address comments
coolkp Nov 11, 2025
e082175
Fix pprint error, add description of attention configuration params
coolkp Nov 12, 2025
2bb97dc
Fix pprint error, add description of attention configuration params
coolkp Nov 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
- name: PyTest
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=65472" python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

Expand Down Expand Up @@ -98,6 +97,7 @@ celerybeat-schedule

# Environments
.env
.history
.venv
env/
venv/
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ftfy
tensorboard>=2.17.0
tensorboardx>=2.6.2.2
tensorboard-plugin-profile>=2.15.2
tokamax
Jinja2
scikit-image
parameterized
Expand Down
34 changes: 33 additions & 1 deletion src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
BlockSizes = splash_attention_kernel.BlockSizes

AxisNames = tuple[str, ...]

# Physical axis names for device meshes.
DATA = "data"
FSDP = "fsdp"
TENSOR = "tensor"
# Logical axis names for model parameters and activations.
BATCH = "activation_batch"
LENGTH = "activation_length"
KV_LENGTH = "activation_kv_length"
Expand All @@ -44,4 +48,32 @@
KEEP_2 = "activation_keep_2"
CONV_OUT = "activation_conv_out_channels"

# For setting self/cross attention independently in splash kernel
SELF_ATTN_HEAD = "activation_self_attn_heads"
SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length"
SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length"
CROSS_ATTN_HEAD = "activation_cross_attn_heads"
CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length"
CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length"


WAN_MODEL = "Wan2.1"

### Common axis rules for ring attention ###
RING_ATTENTION_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, FSDP],
[SELF_ATTN_KV_LENGTH, FSDP],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, FSDP],
[CROSS_ATTN_KV_LENGTH, FSDP],
]

SEQUENCE_PARALLEL_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, FSDP],
[SELF_ATTN_KV_LENGTH, None],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, FSDP],
[CROSS_ATTN_KV_LENGTH, None],
]
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ jit_initializers: True
from_pt: False
split_head_dim: True
attention: 'dot_product' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
flash_block_sizes: {}
# GroupNorm groups
norm_num_groups: 32
Expand Down
10 changes: 10 additions & 0 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ jit_initializers: True
from_pt: False
split_head_dim: True
attention: 'dot_product' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {}
# GroupNorm groups
norm_num_groups: 32
Expand Down
10 changes: 10 additions & 0 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {}
# to override default block sizes for flash attention
# flash_block_sizes:
Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

flash_block_sizes: {}
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev_multi_res.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True

#flash_block_sizes: {}
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
flash_block_sizes: {
"block_q" : 256,
"block_kv_compute" : 256,
Expand Down
46 changes: 34 additions & 12 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,28 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
flash_min_seq_length: 4096
flash_min_seq_length: 0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
dropout: 0.1

flash_block_sizes: {
"block_q" : 1024,
"block_kv_compute" : 256,
"block_kv" : 1024,
"block_q_dkv" : 1024,
"block_kv_dkv" : 1024,
"block_kv_dkv_compute" : 256,
"block_q_dq" : 1024,
"block_kv_dq" : 1024
"block_q" : 3024,
"block_kv_compute" : 1024,
"block_kv" : 2048,
"block_q_dkv" : 3024,
"block_kv_dkv" : 2048,
"block_kv_dkv_compute" : 2048,
"block_q_dq" : 3024,
"block_kv_dq" : 2048
}
# Use on v6e
# flash_block_sizes: {
Expand All @@ -80,11 +90,22 @@ flash_block_sizes: {
# "block_kv" : 2048,
# "block_q_dkv" : 3024,
# "block_kv_dkv" : 2048,
# "block_kv_dkv_compute" : 2048,
# "block_kv_dkv_compute" : 1024,
# "block_q_dq" : 3024,
# "block_kv_dq" : 2048,
# "use_fused_bwd_kernel": False,
# }
# Use on v5p
# flash_block_sizes: {
# "block_q" : 3024,
# "block_kv_compute" : 1024,
# "block_kv" : 2048,
# "block_q_dkv" : 1024,
# "block_kv_dkv" : 3072,
# "block_kv_dkv_compute" : 256,
# "block_q_dq" : 1024,
# "block_kv_dq" : 3072
# }
# GroupNorm groups
norm_num_groups: 32

Expand Down Expand Up @@ -145,8 +166,9 @@ mesh_axes: ['data', 'fsdp', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_self_attn_heads', ['fsdp', 'tensor']],
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
['activation_length', 'fsdp'],

['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
Expand Down Expand Up @@ -276,7 +298,7 @@ flow_shift: 3.0
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 30
fps: 24
fps: 16
save_final_checkpoint: False

# SDXL Lightning parameters
Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
flash_min_seq_length: 4096
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
dropout: 0.1

flash_block_sizes: {
Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ jit_initializers: True
from_pt: False
split_head_dim: True
attention: 'dot_product' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
flash_block_sizes: {}
# GroupNorm groups
norm_num_groups: 32
Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ jit_initializers: True
from_pt: False
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
mask_padding_tokens: True
# Maxdiffusion has 2 types of attention sharding strategies:
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
# in cross attention q.
attention_sharding_uniform: True
flash_block_sizes: {}
# GroupNorm groups
norm_num_groups: 32
Expand Down
9 changes: 9 additions & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def delete_file(file_path: str):


jax.config.update("jax_use_shardy_partitioner", True)
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
# TF allocates extraneous GPU memory when using TFDS data
# this leads to CUDA OOMs. WAR for now is to hide GPUs from TF
# tf.config.set_visible_devices([], "GPU")
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
max_logging.log("Enabling unsafe RNG bit generator for TPU SPMD.")
os.environ["LIBTPU_INIT_ARGS"] = (
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
)

def get_pipeline(model_name: str):
if model_name == "wan2.1":
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def get_flash_block_sizes(config):
"""Create custom flash attention BlockSizes."""
flash_block_sizes = None
if len(config.flash_block_sizes.keys()) > 0:
use_fused_bwd_kernel = config.flash_block_sizes.get("use_fused_bwd_kernel", False)
flash_block_sizes = splash_attention_kernel.BlockSizes(
block_q=config.flash_block_sizes["block_q"],
block_kv_compute=config.flash_block_sizes["block_kv_compute"],
Expand Down
Loading
Loading