From d7b9e42958e41f423e02aa3f1f4a724e69361225 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 04:47:53 +0200 Subject: [PATCH 01/26] update --- src/diffusers/models/attention.py | 477 +++++++++++++++++- src/diffusers/models/attention_processor.py | 248 ++------- .../models/transformers/transformer_flux.py | 457 ++++++++++++----- 3 files changed, 863 insertions(+), 319 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index ae51d3ab1349..b174cb093d58 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,23 +11,494 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn import torch.nn.functional as F -from torch import nn from ..utils import deprecate, logging +from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU -from .attention_processor import Attention, JointAttnProcessor2_0 +from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX +if is_xformers_available(): + import xformers as xops +else: + xops = None + + logger = logging.get_logger(__name__) +class AttentionMixin: + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + """ + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + for module in self.modules(): + if isinstance(module, AttentionModuleMixin): + module.fuse_projections() + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + """ + for module in self.modules(): + if isinstance(module, AttentionModuleMixin): + module.unfuse_projections() + + +class AttentionModuleMixin: + _default_processor_cls = None + _available_processors = [] + fused_projections = False + + def set_processor(self, processor: AttentionProcessor) -> None: + """ + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + """ + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: + """ + Set whether to use NPU flash attention from `torch_npu` or not. + + Args: + use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not. + """ + + if use_npu_flash_attention: + if not is_torch_npu_available(): + raise ImportError("torch_npu is not available") + + self.set_attention_backend("_native_npu") + + def set_use_xla_flash_attention( + self, + use_xla_flash_attention: bool, + partition_spec: Optional[Tuple[Optional[str], ...]] = None, + is_flux=False, + ) -> None: + """ + Set whether to use XLA flash attention from `torch_xla` or not. + + Args: + use_xla_flash_attention (`bool`): + Whether to use pallas flash attention kernel from `torch_xla` or not. + partition_spec (`Tuple[]`, *optional*): + Specify the partition specification if using SPMD. Otherwise None. + is_flux (`bool`, *optional*, defaults to `False`): + Whether the model is a Flux model. + """ + if use_xla_flash_attention: + if not is_torch_xla_available(): + raise ImportError("torch_xla is not available") + + self.set_attention_backend("_native_xla") + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + """ + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + if use_memory_efficient_attention_xformers: + if not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + if is_xformers_available(): + dtype = None + if attention_op is not None: + op_fw, op_bw = attention_op + dtype, *_ = op_fw.SUPPORTED_DTYPES + q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) + _ = xops.memory_efficient_attention(q, q, q) + except Exception as e: + raise e + + self.set_attention_backend("xformers") + + @torch.no_grad() + def fuse_projections(self): + """ + Fuse the query, key, and value projections into a single projection for efficiency. + """ + # Skip if already fused + if getattr(self, "fused_projections", False): + return + + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if hasattr(self, "is_cross_attention") and self.is_cross_attention: + # Fuse cross-attention key-value projections + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if hasattr(self, "use_bias") and self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + else: + # Fuse self-attention projections + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if hasattr(self, "use_bias") and self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + # Handle added projections for models like SD3, Flux, etc. + if ( + getattr(self, "add_q_proj", None) is not None + and getattr(self, "add_k_proj", None) is not None + and getattr(self, "add_v_proj", None) is not None + ): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype + ) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + + self.fused_projections = True + + @torch.no_grad() + def unfuse_projections(self): + """ + Unfuse the query, key, and value projections back to separate projections. + """ + # Skip if not fused + if not getattr(self, "fused_projections", False): + return + + # Remove fused projection layers + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + + if hasattr(self, "to_added_qkv"): + delattr(self, "to_added_qkv") + + self.fused_projections = False + + def set_attention_slice(self, slice_size: int) -> None: + """ + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + processor = None + + # Try to get a compatible processor for sliced attention + if slice_size is not None: + processor = self._get_compatible_processor("sliced") + + # If no processor was found or slice_size is None, use default processor + if processor is None: + processor = self.default_processor_cls() + + self.set_processor(processor) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + """ + Reshape the tensor for multi-head attention processing. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + """ + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): The attention mask to prepare. + target_length (`int`): The target length of the attention mask. + batch_size (`int`): The batch size for repeating the attention mask. + out_dim (`int`, *optional*, defaults to `3`): Output dimension. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + """ + Normalize the encoder hidden states. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): # "feed_forward_chunk_size" can be used to save memory if hidden_states.shape[chunk_dim] % chunk_size != 0: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4760cfd40b3c..2306bdbc9dbd 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2272,102 +2272,16 @@ def __call__( return hidden_states -class FluxAttnProcessor2_0: +class FluxAttnProcessor2_0_NPU: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + deprecation_message = ( + "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " + "alternative solution to use NPU Flash Attention will be provided in the future." ) + deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FluxAttnProcessor2_0_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU" @@ -2470,107 +2384,16 @@ def __call__( return hidden_states -class FusedFluxAttnProcessor2_0: +class FusedFluxAttnProcessor2_0_NPU: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - # `context` projections. - if encoder_hidden_states is not None: - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - ( - encoder_hidden_states_query_proj, - encoder_hidden_states_key_proj, - encoder_hidden_states_value_proj, - ) = torch.split(encoder_qkv, split_size, dim=-1) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + deprecation_message = ( + "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " + "alternative solution to use NPU Flash Attention will be provided in the future." ) + deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FusedFluxAttnProcessor2_0_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU" @@ -3459,6 +3282,12 @@ class XLAFluxFlashAttnProcessor2_0: """ def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None): + deprecation_message = ( + "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An " + "alternative solution to using XLA Flash Attention will be provided in the future." + ) + deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) + if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." @@ -5992,17 +5821,6 @@ def __init__(self): pass -class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0): - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead." - deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message) - super().__init__() - - class SanaLinearAttnProcessor2_0: r""" Processor for implementing scaled dot-product linear attention. @@ -6167,6 +5985,40 @@ def __call__( return hidden_states +class FluxAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`" + deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message) + + from .transformers.transformer_flux import FluxAttnProcessor + + return FluxAttnProcessor(*args, **kwargs) + + +class FluxSingleAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __new__(cls, *args, **kwargs): + deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead." + deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message) + + from .transformers.transformer_flux import FluxAttnProcessor + + return FluxAttnProcessor(*args, **kwargs) + + +class FusedFluxAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`" + deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message) + + from .transformers.transformer_flux import FluxAttnProcessor + + return FluxAttnProcessor(*args, **kwargs) + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 608ea7f6e5f0..8682bcbb603e 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -13,27 +13,26 @@ # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward -from ..attention_processor import ( - Attention, - AttentionProcessor, - FluxAttnProcessor2_0, - FluxAttnProcessor2_0_NPU, - FusedFluxAttnProcessor2_0, -) +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..cache_utils import CacheMixin -from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + FluxPosEmbed, + apply_rotary_emb, +) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -42,6 +41,322 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class FluxAttnProcessor: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def _get_projections(self, attn, hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_projections = None + if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"): + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + encoder_projections = (encoder_query, encoder_key, encoder_value) + + return query, key, value, encoder_projections + + def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None): + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + encoder_projections = None + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1) + encoder_projections = (encoder_query, encoder_key, encoder_value) + + return query, key, value, encoder_projections + + def get_qkv_projections(self, attn: AttentionModuleMixin, hidden_states, encoder_hidden_states=None): + if hasattr(attn, "to_qkv") and attn.fused_projections: + return self._get_fused_projections(attn, hidden_states, encoder_hidden_states) + return self._get_projections(attn, hidden_states, encoder_hidden_states) + + def __call__( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + query, key, value, encoder_projections = self.get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + if encoder_projections is not None: + encoder_query, encoder_key, encoder_value = encoder_projections + encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + # Concatenate for joint attention + query = torch.cat([encoder_query, query], dim=2) + key = torch.cat([encoder_key, key], dim=2) + value = torch.cat([encoder_value, value], dim=2) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FluxIPAdapterAttnProcessor(torch.nn.Module): + """Flux Attention processor for IP-Adapter.""" + + def __init__( + self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") + self.scale = scale + + self.to_k_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + self.to_v_ip = nn.ModuleList( + [ + nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) + for _ in range(len(num_tokens)) + ] + ) + + def __call__( + self, + attn: "FluxAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ip_hidden_states: Optional[List[torch.Tensor]] = None, + ip_adapter_masks: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + hidden_states_query_proj = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + hidden_states_query_proj = attn.norm_q(hidden_states_query_proj) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = torch.nn.functional(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # IP-adapter + ip_query = hidden_states_query_proj + ip_attn_output = torch.zeros_like(hidden_states) + + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = torch.nn.functional.scaled_dot_product_attention( + ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) + ip_attn_output += scale * current_ip_hidden_states + + return hidden_states, encoder_hidden_states, ip_attn_output + else: + return hidden_states + + +class FluxAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = FluxAttnProcessor + _available_processors = [ + FluxAttnProcessor, + FluxIPAdapterAttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + assert qk_norm == "rms_norm", "Flux uses RMSNorm" + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + @maybe_allow_in_graph class FluxSingleTransformerBlock(nn.Module): def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): @@ -54,6 +369,8 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) if is_torch_npu_available(): + from ..attention_processor import FluxAttnProcessor2_0_NPU + deprecation_message = ( "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " "should be set explicitly using the `set_attn_processor` method." @@ -61,11 +378,10 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, deprecate("npu_processor", "0.34.0", deprecation_message) processor = FluxAttnProcessor2_0_NPU() else: - processor = FluxAttnProcessor2_0() + processor = FluxAttnProcessor() - self.attn = Attention( + self.attn = FluxAttention( query_dim=dim, - cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, @@ -118,16 +434,15 @@ def __init__( self.norm1 = AdaLayerNormZero(dim) self.norm1_context = AdaLayerNormZero(dim) - self.attn = Attention( + self.attn = FluxAttention( query_dim=dim, - cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=False, bias=True, - processor=FluxAttnProcessor2_0(), + processor=FluxAttnProcessor(), qk_norm=qk_norm, eps=eps, ) @@ -152,6 +467,7 @@ def forward( encoder_hidden_states, emb=temb ) joint_attention_kwargs = joint_attention_kwargs or {} + # Attention. attention_outputs = self.attn( hidden_states=norm_hidden_states, @@ -180,7 +496,6 @@ def forward( hidden_states = hidden_states + ip_attn_output # Process attention outputs for the `encoder_hidden_states`. - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output @@ -196,7 +511,13 @@ def forward( class FluxTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + FluxTransformer2DLoadersMixin, + CacheMixin, + AttentionMixin, ): """ The Transformer model introduced in Flux. @@ -292,106 +613,6 @@ def __init__( self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedFluxAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - def forward( self, hidden_states: torch.Tensor, From 7e97e43efcad3ac918dd1cd9ef92523136f8da17 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 04:56:55 +0200 Subject: [PATCH 02/26] update --- src/diffusers/models/embeddings.py | 41 +++++-------------- .../models/transformers/transformer_flux.py | 33 ++++++++++++++- 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 4f268bfa018f..262e57a3a050 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1238,37 +1238,6 @@ def apply_1d_rope(tokens, pos, cos, sin): return x -class FluxPosEmbed(nn.Module): - # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 - def __init__(self, theta: int, axes_dim: List[int]): - super().__init__() - self.theta = theta - self.axes_dim = axes_dim - - def forward(self, ids: torch.Tensor) -> torch.Tensor: - n_axes = ids.shape[-1] - cos_out = [] - sin_out = [] - pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - for i in range(n_axes): - cos, sin = get_1d_rotary_pos_embed( - self.axes_dim[i], - pos[:, i], - theta=self.theta, - repeat_interleave_real=True, - use_real=True, - freqs_dtype=freqs_dtype, - ) - cos_out.append(cos) - sin_out.append(sin) - freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) - freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) - return freqs_cos, freqs_sin - - class TimestepEmbedding(nn.Module): def __init__( self, @@ -2619,3 +2588,13 @@ def forward(self, image_embeds: List[torch.Tensor]): projected_image_embeds.append(image_embed) return projected_image_embeds + + +class FluxPosEmbed(nn.Module): + def __new__(cls, *args, **kwargs): + deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`." + deprecate("FluxPosEmbed", "1.0.0", deprecation_message) + + from .transformers.transformer_flux import FluxPosEmbed + + return FluxPosEmbed(*args, **kwargs) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 8682bcbb603e..8218b2ae6e93 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -30,8 +30,8 @@ from ..embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, - FluxPosEmbed, apply_rotary_emb, + get_1d_rotary_pos_embed, ) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -510,6 +510,37 @@ def forward( return encoder_hidden_states, hidden_states +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + class FluxTransformer2DModel( ModelMixin, ConfigMixin, From aaf2df8bc5ac719d7200865ca2ce2fa5aeb72f05 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 07:35:03 +0200 Subject: [PATCH 03/26] update --- .../models/transformers/transformer_wan.py | 205 ++++++++++++++---- .../transformers/transformer_wan_vace.py | 32 ++- 2 files changed, 180 insertions(+), 57 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index bdb9201e62cf..f9b6f01c6b87 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -21,10 +21,9 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward -from ..attention_processor import Attention +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput @@ -35,18 +34,49 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class WanAttnProcessor2_0: +class WanAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def get_qkv_projections( + self, attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor + ): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if attn.cross_attention_dim_head is None: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + def get_added_kv_projections(self, attn: "WanAttention", encoder_hidden_states_img: torch.Tensor): + if attn.fused_projections: + key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) + else: + key_img = attn.add_k_proj(encoder_hidden_states_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + return key_img, value_img def __call__( self, - attn: Attention, + attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: encoder_hidden_states_img = None if attn.add_k_proj is not None: @@ -54,17 +84,11 @@ def __call__( image_context_length = encoder_hidden_states.shape[1] - 512 encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] encoder_hidden_states = encoder_hidden_states[:, image_context_length:] - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + query, key, value = self.get_qkv_projections(attn, hidden_states, encoder_hidden_states) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + query = attn.norm_q(query) + key = attn.norm_k(key) query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) @@ -92,9 +116,8 @@ def apply_rotary_emb( # I2V task hidden_states_img = None if encoder_hidden_states_img is not None: - key_img = attn.add_k_proj(encoder_hidden_states_img) + key_img, value_img = self.get_added_kv_projections(attn, encoder_hidden_states_img) key_img = attn.norm_added_k(key_img) - value_img = attn.add_v_proj(encoder_hidden_states_img) key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) @@ -119,6 +142,119 @@ def apply_rotary_emb( return hidden_states +class WanAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. " + "Please use WanAttnProcessor instead. " + ) + deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) + return WanAttnProcessor(*args, **kwargs) + + +class WanAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = WanAttnProcessor + _available_processors = [WanAttnProcessor] + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: Optional[int] = None, + cross_attention_dim_head: Optional[int] = None, + processor=None, + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = torch.nn.ModuleList( + [ + torch.nn.Linear(self.inner_dim, dim, bias=True), + torch.nn.Dropout(dropout), + ] + ) + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + + self.set_processor(processor) + + def fuse_projections(self): + if getattr(self, "fused_projections", False): + return + + if self.cross_attention_dim_head is None: + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_qkv = nn.Linear(in_features, out_features, bias=True) + self.to_qkv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_kv = nn.Linear(in_features, out_features, bias=True) + self.to_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + if self.added_kv_proj_dim is not None: + concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_added_kv = nn.Linear(in_features, out_features, bias=True) + self.to_added_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + self.fused_projections = True + + @torch.no_grad() + def unfuse_projections(self): + if not getattr(self, "fused_projections", False): + return + + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + if hasattr(self, "to_added_kv"): + delattr(self, "to_added_kv") + + self.fused_projections = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs) + + class WanImageEmbedding(torch.nn.Module): def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): super().__init__() @@ -266,33 +402,24 @@ def __init__( # 1. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) - self.attn1 = Attention( - query_dim=dim, + self.attn1 = WanAttention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, - processor=WanAttnProcessor2_0(), + cross_attention_dim_head=None, + processor=WanAttnProcessor(), ) # 2. Cross-attention - self.attn2 = Attention( - query_dim=dim, + self.attn2 = WanAttention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, added_kv_proj_dim=added_kv_proj_dim, - added_proj_bias=True, - processor=WanAttnProcessor2_0(), + cross_attention_dim_head=dim // num_heads, + processor=WanAttnProcessor(), ) self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() @@ -315,12 +442,12 @@ def forward( # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) - attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) hidden_states = hidden_states + attn_output # 3. Feed-forward @@ -333,7 +460,9 @@ def forward( return hidden_states -class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): +class WanTransformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): r""" A Transformer model for video-like data used in the Wan model. diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 1a6f2af59a87..934e12ca54f5 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -22,12 +22,17 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward -from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm -from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock +from .transformer_wan import ( + WanAttention, + WanAttnProcessor, + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, + WanTransformerBlock, +) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -55,33 +60,22 @@ def __init__( # 2. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) - self.attn1 = Attention( - query_dim=dim, + self.attn1 = WanAttention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, - processor=WanAttnProcessor2_0(), + processor=WanAttnProcessor(), ) # 3. Cross-attention - self.attn2 = Attention( - query_dim=dim, + self.attn2 = WanAttention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, added_kv_proj_dim=added_kv_proj_dim, - added_proj_bias=True, - processor=WanAttnProcessor2_0(), + processor=WanAttnProcessor(), ) self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() From ecabd2a46e8dbe6e719f3045e368c012a0f36ff4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 07:38:04 +0200 Subject: [PATCH 04/26] add coauthor Co-Authored-By: Dhruv Nair From ff21b7fe8b0e3d2c0bbd0341c8273a1f4bb62a7c Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 07:46:32 +0200 Subject: [PATCH 05/26] improve test --- tests/pipelines/flux/test_pipeline_flux.py | 11 ++++------- tests/pipelines/test_pipelines_common.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 0df0e028ff06..4541521c8941 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -28,8 +28,7 @@ FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, + check_qkv_fused_layers_exist, ) @@ -171,12 +170,10 @@ def test_fused_qkv_projections(self): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 13c25ccaa469..387eb6a614f9 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -37,6 +37,7 @@ from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin +from diffusers.models.attention import AttentionModuleMixin from diffusers.models.attention_processor import AttnProcessor from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel @@ -98,6 +99,20 @@ def check_qkv_fusion_processors_exist(model): return all(p.startswith("Fused") for p in proc_names) +def check_qkv_fused_layers_exist(model, layer_names): + is_fused_submodules = [] + for submodule in model.modules(): + if not isinstance(submodule, AttentionModuleMixin): + continue + is_fused_attribute_set = submodule.fused_projections + is_fused_layer = True + for layer in layer_names: + is_fused_layer = is_fused_layer and getattr(submodule, layer, None) is not None + is_fused = is_fused_attribute_set and is_fused_layer + is_fused_submodules.append(is_fused) + return all(is_fused_submodules) + + class SDFunctionTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. From b8f7fe61e1c5136cb8f88ee7ebe14c7b7c95fb13 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 08:21:47 +0200 Subject: [PATCH 06/26] handle ip adapter params correctly --- src/diffusers/loaders/ip_adapter.py | 9 +- src/diffusers/loaders/transformer_flux.py | 6 +- src/diffusers/models/attention_processor.py | 156 ++---------------- .../models/transformers/transformer_flux.py | 14 +- 4 files changed, 29 insertions(+), 156 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index e05d53687a24..dca4758ba038 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -40,8 +40,6 @@ from ..models.attention_processor import ( AttnProcessor, AttnProcessor2_0, - FluxAttnProcessor2_0, - FluxIPAdapterJointAttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor, @@ -867,6 +865,9 @@ def unload_ip_adapter(self): >>> ... ``` """ + # TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level + from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor + # remove CLIP image encoder if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: self.image_encoder = None @@ -886,9 +887,9 @@ def unload_ip_adapter(self): # restore original Transformer attention processors layers attn_procs = {} for name, value in self.transformer.attn_processors.items(): - attn_processor_class = FluxAttnProcessor2_0() + attn_processor_class = FluxAttnProcessor() attn_procs[name] = ( - attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__() + attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__() ) self.transformer.set_attn_processor(attn_procs) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index af03d09029c1..0873e8edd08f 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -87,9 +87,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us return image_projection def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): - from ..models.attention_processor import ( - FluxIPAdapterJointAttnProcessor2_0, - ) + from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor if low_cpu_mem_usage: if is_accelerate_available(): @@ -121,7 +119,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_ else: cross_attention_dim = self.config.joint_attention_dim hidden_size = self.inner_dim - attn_processor_class = FluxIPAdapterJointAttnProcessor2_0 + attn_processor_class = FluxIPAdapterAttnProcessor num_image_text_embeds = [] for state_dict in state_dicts: if "proj.weight" in state_dict["image_proj"]: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 2306bdbc9dbd..e64bd45eb42d 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2501,152 +2501,6 @@ def __call__( return hidden_states -class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module): - """Flux Attention processor for IP-Adapter.""" - - def __init__( - self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None - ): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - - if not isinstance(num_tokens, (tuple, list)): - num_tokens = [num_tokens] - - if not isinstance(scale, list): - scale = [scale] * len(num_tokens) - if len(scale) != len(num_tokens): - raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") - self.scale = scale - - self.to_k_ip = nn.ModuleList( - [ - nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) - for _ in range(len(num_tokens)) - ] - ) - self.to_v_ip = nn.ModuleList( - [ - nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) - for _ in range(len(num_tokens)) - ] - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ip_hidden_states: Optional[List[torch.Tensor]] = None, - ip_adapter_masks: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - hidden_states_query_proj = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - hidden_states_query_proj = attn.norm_q(hidden_states_query_proj) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - # IP-adapter - ip_query = hidden_states_query_proj - ip_attn_output = torch.zeros_like(hidden_states) - - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) - ip_attn_output += scale * current_ip_hidden_states - - return hidden_states, encoder_hidden_states, ip_attn_output - else: - return hidden_states - - class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -6019,6 +5873,16 @@ def __new__(cls, *args, **kwargs): return FluxAttnProcessor(*args, **kwargs) +class FluxIPAdapterJointAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`" + deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message) + + from .transformers.transformer_flux import FluxIPAdapterAttnProcessor + + return FluxIPAdapterAttnProcessor(*args, **kwargs) + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 8218b2ae6e93..7898bdb1f010 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import inspect from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -241,7 +241,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = torch.nn.functional(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -354,6 +356,14 @@ def forward( image_rotary_emb: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) From 0cda91d467636458bf77beb69cfa6ab62ceab324 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 07:51:58 +0200 Subject: [PATCH 07/26] fix chroma qkv fusion test --- tests/pipelines/chroma/test_pipeline_chroma.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py index fc5749f96cd8..5121a2b52d75 100644 --- a/tests/pipelines/chroma/test_pipeline_chroma.py +++ b/tests/pipelines/chroma/test_pipeline_chroma.py @@ -7,12 +7,7 @@ from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers.utils.testing_utils import torch_device -from ..test_pipelines_common import ( - FluxIPAdapterTesterMixin, - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist class ChromaPipelineFastTests( @@ -126,12 +121,10 @@ def test_fused_qkv_projections(self): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images From bc64f12c98acb5237634e4864660d72563416a06 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 08:01:42 +0200 Subject: [PATCH 08/26] fix fastercache implementation --- src/diffusers/hooks/faster_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py index 1be5e1436294..a6c250b50ca4 100644 --- a/src/diffusers/hooks/faster_cache.py +++ b/src/diffusers/hooks/faster_cache.py @@ -18,6 +18,7 @@ import torch +from ..models.attention import AttentionModuleMixin from ..models.attention_processor import Attention, MochiAttention from ..models.modeling_outputs import Transformer2DModelOutput from ..utils import logging @@ -567,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float: _apply_faster_cache_on_denoiser(module, config) for name, submodule in module.named_modules(): - if not isinstance(submodule, _ATTENTION_CLASSES): + if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): continue if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS): _apply_faster_cache_on_attention_class(name, submodule, config) From c72e72cc466defca19a7569a4b2fadd95e757a07 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 08:12:23 +0200 Subject: [PATCH 09/26] remove set_attention_backend related code --- src/diffusers/models/attention.py | 80 +------------------------------ 1 file changed, 2 insertions(+), 78 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b174cb093d58..f895fe1d3b66 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from ..utils import deprecate, logging -from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available +from ..utils.import_utils import is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0 @@ -161,82 +161,6 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce if not return_deprecated_lora: return self.processor - def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: - """ - Set whether to use NPU flash attention from `torch_npu` or not. - - Args: - use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not. - """ - - if use_npu_flash_attention: - if not is_torch_npu_available(): - raise ImportError("torch_npu is not available") - - self.set_attention_backend("_native_npu") - - def set_use_xla_flash_attention( - self, - use_xla_flash_attention: bool, - partition_spec: Optional[Tuple[Optional[str], ...]] = None, - is_flux=False, - ) -> None: - """ - Set whether to use XLA flash attention from `torch_xla` or not. - - Args: - use_xla_flash_attention (`bool`): - Whether to use pallas flash attention kernel from `torch_xla` or not. - partition_spec (`Tuple[]`, *optional*): - Specify the partition specification if using SPMD. Otherwise None. - is_flux (`bool`, *optional*, defaults to `False`): - Whether the model is a Flux model. - """ - if use_xla_flash_attention: - if not is_torch_xla_available(): - raise ImportError("torch_xla is not available") - - self.set_attention_backend("_native_xla") - - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ) -> None: - """ - Set whether to use memory efficient attention from `xformers` or not. - - Args: - use_memory_efficient_attention_xformers (`bool`): - Whether to use memory efficient attention from `xformers` or not. - attention_op (`Callable`, *optional*): - The attention operation to use. Defaults to `None` which uses the default attention operation from - `xformers`. - """ - if use_memory_efficient_attention_xformers: - if not is_xformers_available(): - raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers", - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - if is_xformers_available(): - dtype = None - if attention_op is not None: - op_fw, op_bw = attention_op - dtype, *_ = op_fw.SUPPORTED_DTYPES - q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) - _ = xops.memory_efficient_attention(q, q, q) - except Exception as e: - raise e - - self.set_attention_backend("xformers") - @torch.no_grad() def fuse_projections(self): """ From a0b276da538e456a48153d7b28027d807a103b08 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 08:30:26 +0200 Subject: [PATCH 10/26] fix more tests --- .../hooks/pyramid_attention_broadcast.py | 3 ++- .../chroma/test_pipeline_chroma_img2img.py | 15 ++++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index bbdd1c3f68d4..1c8787194196 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -18,6 +18,7 @@ import torch +from ..models.attention import AttentionModuleMixin from ..models.attention_processor import Attention, MochiAttention from ..utils import logging from .hooks import HookRegistry, ModelHook @@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt config.spatial_attention_block_skip_range = 2 for name, submodule in module.named_modules(): - if not isinstance(submodule, _ATTENTION_CLASSES): + if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB # cannot be applied to this layer. For custom layers, users can extend this functionality and implement # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. diff --git a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py index 02b20527b2f9..d518e1b7b8d1 100644 --- a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py +++ b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py @@ -8,12 +8,7 @@ from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers.utils.testing_utils import floats_tensor, torch_device -from ..test_pipelines_common import ( - FluxIPAdapterTesterMixin, - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist class ChromaImg2ImgPipelineFastTests( @@ -129,12 +124,10 @@ def test_fused_qkv_projections(self): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images From c1415207145b1417e0839fd39862b7eeb38bdaea Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 09:49:17 +0200 Subject: [PATCH 11/26] fight more tests --- .../models/transformers/transformer_flux.py | 2 +- .../test_controlnet_flux_img2img.py | 14 ++++---------- tests/pipelines/flux/test_pipeline_flux_control.py | 14 ++++---------- .../flux/test_pipeline_flux_control_inpaint.py | 14 ++++---------- 4 files changed, 13 insertions(+), 31 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 7898bdb1f010..706438569fca 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -242,7 +242,7 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb) hidden_states = torch.nn.functional.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py index 8d63619c402b..ab4cf3273489 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py @@ -16,11 +16,7 @@ ) from diffusers.utils.torch_utils import randn_tensor -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): @@ -170,12 +166,10 @@ def test_fused_qkv_projections(self): original_image_slice = image[0, -3:, -3:, -1] pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py index d8d0774e1e32..42283da6fd03 100644 --- a/tests/pipelines/flux/test_pipeline_flux_control.py +++ b/tests/pipelines/flux/test_pipeline_flux_control.py @@ -8,11 +8,7 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel from diffusers.utils.testing_utils import torch_device -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): @@ -140,12 +136,10 @@ def test_fused_qkv_projections(self): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py index a2f7c9171082..0abd08e37300 100644 --- a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py +++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py @@ -15,11 +15,7 @@ torch_device, ) -from ..test_pipelines_common import ( - PipelineTesterMixin, - check_qkv_fusion_matches_attn_procs_length, - check_qkv_fusion_processors_exist, -) +from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): @@ -134,12 +130,10 @@ def test_fused_qkv_projections(self): # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() - assert check_qkv_fusion_processors_exist(pipe.transformer), ( - "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) - assert check_qkv_fusion_matches_attn_procs_length( - pipe.transformer, pipe.transformer.original_attn_processors - ), "Something wrong with the attention processors concerning the fused QKV projections." inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images From 4dcd6729071306251ea9f49639938c0ea9be0672 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 12:29:28 +0200 Subject: [PATCH 12/26] add back set_attention_backend --- src/diffusers/models/attention.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b174cb093d58..2d5eaaa69122 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -160,6 +160,16 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce """ if not return_deprecated_lora: return self.processor + + def set_attention_backend(self, backend: str): + from .attention_dispatch import AttentionBackendName + + available_backends = {x.value for x in AttentionBackendName.__members__.values()} + if backend not in available_backends: + raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + + backend = AttentionBackendName(backend.lower()) + self.processor._attention_backend = backend def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: """ From 576da52f45b77bb73ec7ef355011e5e8bc572f80 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 11:43:48 +0200 Subject: [PATCH 13/26] update --- src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/attention_dispatch.py | 1155 +++++++++++++++++ src/diffusers/models/modeling_utils.py | 50 + .../models/transformers/transformer_flux.py | 104 +- src/diffusers/utils/__init__.py | 6 + src/diffusers/utils/constants.py | 2 + src/diffusers/utils/import_utils.py | 60 + 8 files changed, 1332 insertions(+), 51 deletions(-) create mode 100644 src/diffusers/models/attention_dispatch.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index bc81c24f7347..dd46e44991d3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -163,6 +163,7 @@ [ "AllegroTransformer3DModel", "AsymmetricAutoencoderKL", + "AttentionBackendName", "AuraFlowTransformer2DModel", "AutoencoderDC", "AutoencoderKL", @@ -237,6 +238,7 @@ "VQModel", "WanTransformer3DModel", "WanVACETransformer3DModel", + "attention_backend", ] ) _import_structure["modular_pipelines"].extend( @@ -809,6 +811,7 @@ from .models import ( AllegroTransformer3DModel, AsymmetricAutoencoderKL, + AttentionBackendName, AuraFlowTransformer2DModel, AutoencoderDC, AutoencoderKL, @@ -882,6 +885,7 @@ VQModel, WanTransformer3DModel, WanVACETransformer3DModel, + attention_backend, ) from .modular_pipelines import ( ComponentsManager, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 73903a627415..f019b35b0fe2 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -26,6 +26,7 @@ if is_torch_available(): _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] + _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"] _import_structure["auto_model"] = ["AutoModel"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] @@ -111,6 +112,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .adapter import MultiAdapter, T2IAdapter + from .attention_dispatch import AttentionBackendName, attention_backend from .auto_model import AutoModel from .autoencoders import ( AsymmetricAutoencoderKL, diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py new file mode 100644 index 000000000000..141a7fee858b --- /dev/null +++ b/src/diffusers/models/attention_dispatch.py @@ -0,0 +1,1155 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import functools +import inspect +import math +from enum import Enum +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch + +from ..utils import ( + get_logger, + is_flash_attn_3_available, + is_flash_attn_available, + is_flash_attn_version, + is_sageattention_available, + is_sageattention_version, + is_torch_npu_available, + is_torch_version, + is_torch_xla_available, + is_torch_xla_version, + is_xformers_available, + is_xformers_version, +) +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"): + from flash_attn import flash_attn_func, flash_attn_varlen_func +else: + logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.") + flash_attn_func = None + flash_attn_varlen_func = None + + +if is_flash_attn_3_available(): + from flash_attn_interface import flash_attn_func as flash_attn_3_func + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func +else: + flash_attn_3_func = None + flash_attn_3_varlen_func = None + + +if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"): + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp8_cuda, + sageattn_qk_int8_pv_fp8_cuda_sm90, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp16_triton, + sageattn_varlen, + ) +else: + logger.warning( + "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`." + ) + sageattn = None + sageattn_qk_int8_pv_fp16_cuda = None + sageattn_qk_int8_pv_fp16_triton = None + sageattn_qk_int8_pv_fp8_cuda = None + sageattn_qk_int8_pv_fp8_cuda_sm90 = None + sageattn_varlen = None + + +if is_torch_version(">=", "2.5.0"): + # We cannot import the flex_attention function from the package directly because it is expected (from the + # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the + # compiled function. + import torch.nn.attention.flex_attention as flex_attention + + +if is_torch_npu_available(): + from torch_npu import npu_fusion_attention +else: + npu_fusion_attention = None + + +if is_torch_xla_available() and is_torch_xla_version(">", "2.2"): + from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention +else: + xla_flash_attention = None + + +if is_xformers_available() and is_xformers_version(">=", "0.0.29"): + import xformers.ops as xops +else: + logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.") + xops = None + + +# TODO(aryan): Add support for the following: +# - Sage Attention++ +# - block sparse, radial and other attention methods +# - CP with sage attention, flex, xformers, other missing backends +# - Add support for normal and CP training with backends that don't support it yet + + +_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] +_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] +_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] + + +class AttentionBackendName(str, Enum): + # EAGER = "eager" + + # `flash-attn` + FLASH = "flash" + FLASH_VARLEN = "flash_varlen" + _FLASH_3 = "_flash_3" + _FLASH_VARLEN_3 = "_flash_varlen_3" + + # PyTorch native + FLEX = "flex" + NATIVE = "native" + _NATIVE_CUDNN = "_native_cudnn" + _NATIVE_EFFICIENT = "_native_efficient" + _NATIVE_FLASH = "_native_flash" + _NATIVE_MATH = "_native_math" + _NATIVE_NPU = "_native_npu" + _NATIVE_XLA = "_native_xla" + + # `sageattention` + SAGE = "sage" + SAGE_VARLEN = "sage_varlen" + _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" + _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" + _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" + _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" + # TODO: let's not add support for Sparge Attention now because it requires tuning per model + # We can look into supporting something "autotune"-ing in the future + # SPARGE = "sparge" + + # `xformers` + XFORMERS = "xformers" + + +class _AttentionBackendRegistry: + _backends = {} + _constraints = {} + _supported_arg_names = {} + _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) + _checks_enabled = DIFFUSERS_ATTN_CHECKS + + @classmethod + def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None): + logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") + + def decorator(func): + cls._backends[backend] = func + cls._constraints[backend] = constraints or [] + cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) + return func + + return decorator + + @classmethod + def get_active_backend(cls): + return cls._active_backend, cls._backends[cls._active_backend] + + @classmethod + def list_backends(cls): + return list(cls._backends.keys()) + + +@contextlib.contextmanager +def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE): + """ + Context manager to set the active attention backend. + """ + if backend not in _AttentionBackendRegistry._backends: + raise ValueError(f"Backend {backend} is not registered.") + + old_backend = _AttentionBackendRegistry._active_backend + _AttentionBackendRegistry._active_backend = backend + + try: + yield + finally: + _AttentionBackendRegistry._active_backend = old_backend + + +def dispatch_attention_fn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + *, + backend: Optional[AttentionBackendName] = None, +) -> torch.Tensor: + attention_kwargs = attention_kwargs or {} + + if backend is None: + # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment + # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager + backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend() + else: + backend_name = AttentionBackendName(backend) + backend_fn = _AttentionBackendRegistry._backends.get(backend_name) + + kwargs = { + "query": query, + "key": key, + "value": value, + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa, + **attention_kwargs, + } + + if _AttentionBackendRegistry._checks_enabled: + removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) + if removed_kwargs: + logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.") + for check in _AttentionBackendRegistry._constraints.get(backend_name): + check(**kwargs) + + kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} + return backend_fn(**kwargs) + + +# ===== Checks ===== +# A list of very simple functions to catch common errors quickly when debugging. + + +def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: + if attn_mask is not None and is_causal: + raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") + + +def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.device != key.device or query.device != value.device: + raise ValueError("Query, key, and value must be on the same device.") + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError("Query, key, and value must have the same dtype.") + + +def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device(query, key, value) + if query.device.type != "cuda": + raise ValueError("Query, key, and value must be on a CUDA device.") + + +def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: + def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device_cuda(query, key, value) + if torch.cuda.get_device_capability(query.device) < (major, minor): + raise ValueError( + f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." + ) + + return check_device_cuda + + +def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.dtype != key.dtype: + raise ValueError("Query and key must have the same dtype.") + if query.dtype != value.dtype: + raise ValueError("Query and value must have the same dtype.") + + +def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_qkv_dtype_match(query, key, value) + if query.dtype not in (torch.bfloat16, torch.float16): + raise ValueError("Query, key, and value must be either bfloat16 or float16.") + + +def _check_shape( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> None: + if query.shape[-1] != key.shape[-1]: + raise ValueError("Query and key must have the same last dimension.") + if query.shape[-2] != value.shape[-2]: + raise ValueError("Query and value must have the same second to last dimension.") + if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: + raise ValueError("Attention mask must match the key's second to last dimension.") + + +# ===== Helper functions ===== + + +@functools.lru_cache(maxsize=128) +def _prepare_for_flash_attn_or_sage_varlen_without_mask( + batch_size: int, + seq_len_q: int, + seq_len_kv: int, + device: Optional[torch.device] = None, +): + seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +def _prepare_for_flash_attn_or_sage_varlen_with_mask( + batch_size: int, + seq_len_q: int, + attn_mask: torch.Tensor, + device: Optional[torch.device] = None, +): + seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +def _prepare_for_flash_attn_or_sage_varlen( + batch_size: int, + seq_len_q: int, + seq_len_kv: int, + attn_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, +) -> None: + if attn_mask is None: + return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device) + return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device) + + +def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: + """ + Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in + FlashAttention/Sage varlen. + + Supports 1D to 4D shapes and common broadcasting patterns. + """ + if attn_mask.dtype != torch.bool: + raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") + + if attn_mask.ndim == 1: + # [seq_len_k] -> broadcast across batch + attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 2: + # [batch_size, seq_len_k]. Maybe broadcast across batch + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 3: + # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension + # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen. + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." + ) + attn_mask = attn_mask.any(dim=1) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 4: + # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K] + attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] + + else: + raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") + + if attn_mask.shape != (batch_size, seq_len_k): + raise ValueError( + f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" + ) + + return attn_mask + + +def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return q_idx >= kv_idx + + +# ===== torch op registrations ===== +# Registrations are required for fullgraph tracing compatibility + + +# TODO: library.custom_op and register_fake probably need version guards? +# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding +# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 +@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") +def _wrapped_flash_attn_3_original( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = flash_attn_3_func(query, key, value) + lse = lse.permute(0, 2, 1) + return out, lse + + +@torch.library.register_fake("flash_attn_3::_flash_attn_forward") +def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_len, num_heads, head_dim = query.shape + lse_shape = (batch_size, seq_len, num_heads) + return torch.empty_like(query), query.new_empty(lse_shape) + + +# ===== Attention backends ===== + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +) -> torch.Tensor: + out = flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH_VARLEN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = flash_attn_varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + out = out.unflatten(0, (batch_size, -1)) + + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_3, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention_3( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, +) -> torch.Tensor: + out, lse, *_ = flash_attn_3_func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + attention_chunk=0, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + ) + return (out, lse) if return_attn_probs else out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_VARLEN_3, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_varlen_attention_3( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out, lse, *_ = flash_attn_3_varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=None, + seqused_k=None, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + ) + out = out.unflatten(0, (batch_size, -1)) + + return (out, lse) if return_attn_probs else out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLEX, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _native_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + kernel_options: Optional[Dict[str, Any]] = None, +) -> torch.Tensor: + # TODO: should we LRU cache the block mask creation? + score_mod = None + block_mask = None + batch_size, seq_len_q, num_heads, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): + block_mask = attn_mask + elif is_causal: + block_mask = flex_attention.create_block_mask( + _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device + ) + elif torch.is_tensor(attn_mask): + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + + attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) + + if attn_mask.dtype == torch.bool: + # TODO: this probably does not work but verify! + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return attn_mask[batch_idx, head_idx, q_idx, kv_idx] + + block_mask = flex_attention.create_block_mask( + mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device + ) + else: + + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] + else: + raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = flex_attention.flex_attention( + query=query, + key=key, + value=value, + score_mod=score_mod, + block_mask=block_mask, + scale=scale, + enable_gqa=enable_gqa, + return_lse=return_lse, + kernel_options=kernel_options, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.NATIVE, + constraints=[_check_device, _check_shape], +) +def _native_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_CUDNN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_cudnn_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_EFFICIENT, + constraints=[_check_device, _check_shape], +) +def _native_efficient_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_FLASH, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=None, # not supported + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_MATH, + constraints=[_check_device, _check_shape], +) +def _native_math_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_NPU, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_npu_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, +) -> torch.Tensor: + return npu_fusion_attention( + query, + key, + value, + query.size(2), # num_heads + input_layout="BSND", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tokens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + + +# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853 +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_XLA, + constraints=[_check_device, _check_shape], +) +def _native_xla_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + query = query / math.sqrt(query.shape[-1]) + out = xla_flash_attention( + q=query, + k=key, + v=value, + causal=is_causal, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE_VARLEN, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + is_causal: bool = False, + scale: Optional[float] = None, + smooth_k: bool = True, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, seq_len_q, _, _ = query.shape + _, seq_len_kv, _, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = sageattn_varlen( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + ) + out = out.unflatten(0, (batch_size, -1)) + + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +) +def _sage_qk_int8_pv_fp8_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +) +def _sage_qk_int8_pv_fp8_cuda_sm90_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], +) +def _sage_qk_int8_pv_fp16_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_cuda( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], +) +def _sage_qk_int8_pv_fp16_triton_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_triton( + q=query, + k=key, + v=value, + tensor_layout="NHD", + quantization_backend=quantization_backend, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName.XFORMERS, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _xformers_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + batch_size, seq_len_q, num_heads_q, _ = query.shape + _, seq_len_kv, num_heads_kv, _ = key.shape + + if is_causal: + attn_mask = xops.LowerTriangularMask() + elif attn_mask is not None: + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + elif attn_mask.ndim != 4: + raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + + if enable_gqa: + if num_heads_q % num_heads_kv != 0: + raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") + num_heads_per_group = num_heads_q // num_heads_kv + query = query.unflatten(2, (num_heads_kv, -1)) + key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + + out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) + + if enable_gqa: + out = out.flatten(2, 3) + + return out diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d7b2136b4afc..4918fae91d8a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -606,6 +606,56 @@ def enable_group_offload( offload_to_disk_path=offload_to_disk_path, ) + def set_attention_backend(self, backend: str) -> None: + """ + Set the attention backend for the model. + + Args: + backend (`str`): + The name of the backend to set. Must be one of the available backends defined in + `AttentionBackendName`. Available backends can be found in + `diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product + attention as backend. + """ + from .attention import AttentionModuleMixin + from .attention_dispatch import AttentionBackendName + + # TODO: the following will not be required when everything is refactored to AttentionModuleMixin + from .attention_processor import Attention, MochiAttention + + backend = backend.lower() + available_backends = {x.value for x in AttentionBackendName.__members__.values()} + if backend not in available_backends: + raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + + backend = AttentionBackendName(backend) + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + processor._attention_backend = backend + + def reset_attention_backend(self) -> None: + """ + Resets the attention backend for the model. Following calls to `forward` will use the environment default or + the torch native scaled dot product attention. + """ + from .attention import AttentionModuleMixin + from .attention_processor import Attention, MochiAttention + + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + processor._attention_backend = None + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 706438569fca..2d4bc172a721 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -26,6 +26,7 @@ from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, @@ -42,6 +43,8 @@ class FluxAttnProcessor: + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") @@ -51,31 +54,25 @@ def _get_projections(self, attn, hidden_states, encoder_hidden_states=None): key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - encoder_projections = None - if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"): + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) - encoder_projections = (encoder_query, encoder_key, encoder_value) - return query, key, value, encoder_projections + return query, key, value, encoder_query, encoder_key, encoder_value def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None): - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) - encoder_projections = None + encoder_query = encoder_key = encoder_value = (None,) if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1) - encoder_projections = (encoder_query, encoder_key, encoder_value) + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) - return query, key, value, encoder_projections + return query, key, value, encoder_query, encoder_key, encoder_value def get_qkv_projections(self, attn: AttentionModuleMixin, hidden_states, encoder_hidden_states=None): - if hasattr(attn, "to_qkv") and attn.fused_projections: + if attn.fused_projections: return self._get_fused_projections(attn, hidden_states, encoder_hidden_states) return self._get_projections(attn, hidden_states, encoder_hidden_states) @@ -87,53 +84,43 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - query, key, value, encoder_projections = self.get_qkv_projections(attn, hidden_states, encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads + query, key, value, encoder_query, encoder_key, encoder_value = self.get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + query = attn.norm_q(query) + key = attn.norm_k(key) - if encoder_projections is not None: - encoder_query, encoder_key, encoder_value = encoder_projections - encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) - if attn.norm_added_q is not None: - encoder_query = attn.norm_added_q(encoder_query) - if attn.norm_added_k is not None: - encoder_key = attn.norm_added_k(encoder_key) + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) - # Concatenate for joint attention - query = torch.cat([encoder_query, query], dim=2) - key = torch.cat([encoder_key, key], dim=2) - value = torch.cat([encoder_value, value], dim=2) + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = dispatch_attention_fn( + query, key, value, attn_mask=attention_mask, backend=self._attention_backend + ) + hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 ) - hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) @@ -146,6 +133,8 @@ def __call__( class FluxIPAdapterAttnProcessor(torch.nn.Module): """Flux Attention processor for IP-Adapter.""" + _attention_backend = None + def __init__( self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None ): @@ -241,8 +230,14 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -273,8 +268,14 @@ def __call__( ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = torch.nn.functional.scaled_dot_product_attention( - ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + current_ip_hidden_states = dispatch_attention_fn( + ip_query, + ip_key, + ip_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim @@ -323,6 +324,7 @@ def __init__( self.context_pre_only = context_pre_only self.pre_only = pre_only self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim self.added_proj_bias = added_proj_bias self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 2df05cb8eb36..cadcedb98a14 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -67,6 +67,9 @@ is_bitsandbytes_version, is_bs4_available, is_cosmos_guardrail_available, + is_flash_attn_3_available, + is_flash_attn_available, + is_flash_attn_version, is_flax_available, is_ftfy_available, is_gguf_available, @@ -90,6 +93,8 @@ is_peft_version, is_pytorch_retinaface_available, is_safetensors_available, + is_sageattention_available, + is_sageattention_version, is_scipy_available, is_sentencepiece_available, is_tensorboard_available, @@ -108,6 +113,7 @@ is_unidecode_available, is_wandb_available, is_xformers_available, + is_xformers_version, requires_backends, ) from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 7c04287d33ed..f8f04cc03abd 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -41,6 +41,8 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DIFFUSERS_REQUEST_TIMEOUT = 60 +DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") +DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f12e9de33172..a27c2da648f4 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -220,6 +220,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _better_profanity_available, _better_profanity_version = _is_package_available("better_profanity") _nltk_available, _nltk_version = _is_package_available("nltk") _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail") +_sageattention_available, _sageattention_version = _is_package_available("sageattention") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") +_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") def is_torch_available(): @@ -378,6 +381,18 @@ def is_hpu_available(): return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch")) +def is_sageattention_available(): + return _sageattention_available + + +def is_flash_attn_available(): + return _flash_attn_available + + +def is_flash_attn_3_available(): + return _flash_attn_3_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -804,6 +819,51 @@ def is_optimum_quanto_version(operation: str, version: str): return compare_versions(parse(_optimum_quanto_version), operation, version) +def is_xformers_version(operation: str, version: str): + """ + Compares the current xformers version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _xformers_available: + return False + return compare_versions(parse(_xformers_version), operation, version) + + +def is_sageattention_version(operation: str, version: str): + """ + Compares the current sageattention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _sageattention_available: + return False + return compare_versions(parse(_sageattention_version), operation, version) + + +def is_flash_attn_version(operation: str, version: str): + """ + Compares the current flash-attention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _flash_attn_available: + return False + return compare_versions(parse(_flash_attn_version), operation, version) + + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects From e909b7355fcd7055334e245e532e79349df79a92 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 12:25:51 +0200 Subject: [PATCH 14/26] update --- src/diffusers/models/embeddings.py | 12 ++++++++++-- .../models/transformers/transformer_flux.py | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 262e57a3a050..4d3d246e4815 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1176,6 +1176,7 @@ def apply_rotary_emb( freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, + sequence_dim: int = 2, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings @@ -1193,8 +1194,15 @@ def apply_rotary_emb( """ if use_real: cos, sin = freqs_cis # [S, D] - cos = cos[None, None] - sin = sin[None, None] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 2d4bc172a721..e5b45bbcd583 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -108,8 +108,8 @@ def __call__( value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, backend=self._attention_backend From 1e7217f82daa10d55b799419300bc746a9aebfc1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 12:30:34 +0200 Subject: [PATCH 15/26] make style --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2d5eaaa69122..c720b379551f 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -160,7 +160,7 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce """ if not return_deprecated_lora: return self.processor - + def set_attention_backend(self, backend: str): from .attention_dispatch import AttentionBackendName From 4f52e3499c68c5f618f70362053947625415a061 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 12:30:45 +0200 Subject: [PATCH 16/26] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 247769306b53..4ec3a3234c96 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -258,6 +258,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AttentionBackendName(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AuraFlowTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1353,6 +1368,10 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +def attention_backend(*args, **kwargs): + requires_backends(attention_backend, ["torch"]) + + class ComponentsManager(metaclass=DummyObject): _backends = ["torch"] From d9c1683b07a46cbfcb0c845a1facca7681157910 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 13:30:52 +0200 Subject: [PATCH 17/26] make ip adapter processor compatible with attention dispatcher --- .../models/transformers/transformer_flux.py | 142 ++++++++---------- 1 file changed, 59 insertions(+), 83 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index e5b45bbcd583..7640d8d13cdc 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -42,39 +42,42 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class FluxAttnProcessor: - _attention_backend = None +def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) - def _get_projections(self, attn, hidden_states, encoder_hidden_states=None): - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) + return query, key, value, encoder_query, encoder_key, encoder_value - encoder_query = encoder_key = encoder_value = None - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - encoder_query = attn.add_q_proj(encoder_hidden_states) - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_value = attn.add_v_proj(encoder_hidden_states) - return query, key, value, encoder_query, encoder_key, encoder_value +def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) - def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None): - query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) - encoder_query = encoder_key = encoder_value = (None,) - if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): - encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + return query, key, value, encoder_query, encoder_key, encoder_value - return query, key, value, encoder_query, encoder_key, encoder_value - def get_qkv_projections(self, attn: AttentionModuleMixin, hidden_states, encoder_hidden_states=None): - if attn.fused_projections: - return self._get_fused_projections(attn, hidden_states, encoder_hidden_states) - return self._get_projections(attn, hidden_states, encoder_hidden_states) +def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class FluxAttnProcessor: + _attention_backend = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") def __call__( self, @@ -84,7 +87,7 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - query, key, value, encoder_query, encoder_key, encoder_value = self.get_qkv_projections( + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( attn, hidden_states, encoder_hidden_states ) @@ -180,55 +183,35 @@ def __call__( ip_hidden_states: Optional[List[torch.Tensor]] = None, ip_adapter_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + batch_size = hidden_states.shape[0] - # `sample` projections. - hidden_states_query_proj = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) - hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) - if attn.norm_q is not None: - hidden_states_query_proj = attn.norm_q(hidden_states_query_proj) - if attn.norm_k is not None: - key = attn.norm_k(key) + query = attn.norm_q(query) + key = attn.norm_k(key) + ip_query = query - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) hidden_states = dispatch_attention_fn( query, @@ -239,23 +222,18 @@ def __call__( is_causal=False, backend=self._attention_backend, ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 ) - - # linear proj hidden_states = attn.to_out[0](hidden_states) - # dropout hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) # IP-adapter - ip_query = hidden_states_query_proj ip_attn_output = torch.zeros_like(hidden_states) for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( @@ -264,10 +242,9 @@ def __call__( ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 + ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) + ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) + current_ip_hidden_states = dispatch_attention_fn( ip_query, ip_key, @@ -277,9 +254,7 @@ def __call__( is_causal=False, backend=self._attention_backend, ) - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) + current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) ip_attn_output += scale * current_ip_hidden_states @@ -316,6 +291,7 @@ def __init__( super().__init__() assert qk_norm == "rms_norm", "Flux uses RMSNorm" + self.head_dim = dim_head self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.query_dim = query_dim self.use_bias = bias From a73cb396bacfd4aa36242d85a4868ba7d7e62b5c Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 13:53:27 +0200 Subject: [PATCH 18/26] refactor chroma as well --- .../models/transformers/transformer_chroma.py | 130 ++---------------- 1 file changed, 15 insertions(+), 115 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 0f6dd677ac5c..bf4df9df93d1 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -24,19 +24,13 @@ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward -from ..attention_processor import ( - Attention, - AttentionProcessor, - FluxAttnProcessor2_0, - FluxAttnProcessor2_0_NPU, - FusedFluxAttnProcessor2_0, -) +from ..attention import AttentionMixin, FeedForward from ..cache_utils import CacheMixin from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm +from .transformer_flux import FluxAttention, FluxAttnProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -223,6 +217,8 @@ def __init__( self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) if is_torch_npu_available(): + from ..attention_processor import FluxAttnProcessor2_0_NPU + deprecation_message = ( "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " "should be set explicitly using the `set_attn_processor` method." @@ -230,11 +226,10 @@ def __init__( deprecate("npu_processor", "0.34.0", deprecation_message) processor = FluxAttnProcessor2_0_NPU() else: - processor = FluxAttnProcessor2_0() + processor = FluxAttnProcessor() - self.attn = Attention( + self.attn = FluxAttention( query_dim=dim, - cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, @@ -292,16 +287,15 @@ def __init__( self.norm1 = ChromaAdaLayerNormZeroPruned(dim) self.norm1_context = ChromaAdaLayerNormZeroPruned(dim) - self.attn = Attention( + self.attn = FluxAttention( query_dim=dim, - cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=False, bias=True, - processor=FluxAttnProcessor2_0(), + processor=FluxAttnProcessor(), qk_norm=qk_norm, eps=eps, ) @@ -376,7 +370,13 @@ def forward( class ChromaTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + FluxTransformer2DLoadersMixin, + CacheMixin, + AttentionMixin, ): """ The Transformer model introduced in Flux, modified for Chroma. @@ -475,106 +475,6 @@ def __init__( self.gradient_checkpointing = False - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - self.set_attn_processor(FusedFluxAttnProcessor2_0()) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - def forward( self, hidden_states: torch.Tensor, From 4940b21dbd82771e7971ea286d83e3d09a341b66 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 13:58:00 +0200 Subject: [PATCH 19/26] attnetion dispatcher support --- src/diffusers/models/attention.py | 4 +- .../models/transformers/transformer_wan.py | 85 +++++++++++-------- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index ba560ec39d09..c720b379551f 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from ..utils import deprecate, logging -from ..utils.import_utils import is_xformers_available +from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0 diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index f9b6f01c6b87..7163a157cdc6 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -24,6 +24,7 @@ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput @@ -34,42 +35,44 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if attn.cross_attention_dim_head is None: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor): + if attn.fused_projections: + key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) + else: + key_img = attn.add_k_proj(encoder_hidden_states_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + return key_img, value_img + + class WanAttnProcessor: + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." ) - def get_qkv_projections( - self, attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor - ): - # encoder_hidden_states is only passed for cross-attention - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - if attn.fused_projections: - if attn.cross_attention_dim_head is None: - # In self-attention layers, we can fuse the entire QKV projection into a single linear - query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) - else: - # In cross-attention layers, we can only fuse the KV projections into a single linear - query = attn.to_q(hidden_states) - key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) - else: - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - return query, key, value - - def get_added_kv_projections(self, attn: "WanAttention", encoder_hidden_states_img: torch.Tensor): - if attn.fused_projections: - key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) - else: - key_img = attn.add_k_proj(encoder_hidden_states_img) - value_img = attn.add_v_proj(encoder_hidden_states_img) - return key_img, value_img - def __call__( self, attn: "WanAttention", @@ -85,7 +88,7 @@ def __call__( encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] encoder_hidden_states = encoder_hidden_states[:, image_context_length:] - query, key, value = self.get_qkv_projections(attn, hidden_states, encoder_hidden_states) + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) query = attn.norm_q(query) key = attn.norm_k(key) @@ -116,20 +119,32 @@ def apply_rotary_emb( # I2V task hidden_states_img = None if encoder_hidden_states_img is not None: - key_img, value_img = self.get_added_kv_projections(attn, encoder_hidden_states_img) + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) key_img = attn.norm_added_k(key_img) key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) - hidden_states_img = F.scaled_dot_product_attention( - query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.type_as(query) From 5b8927aa97294f75eb72822be706f36511d5d65e Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Jul 2025 14:08:49 +0200 Subject: [PATCH 20/26] remove transpose; fix rope shape --- .../models/transformers/transformer_wan.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 7163a157cdc6..6960522cff80 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -93,9 +93,9 @@ def __call__( query = attn.norm_q(query) key = attn.norm_k(key) - query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) if rotary_emb is not None: @@ -104,8 +104,7 @@ def apply_rotary_emb( freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): - x = hidden_states.view(*hidden_states.shape[:-1], -1, 2) - x1, x2 = x[..., 0], x[..., 1] + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) cos = freqs_cos[..., 0::2] sin = freqs_sin[..., 1::2] out = torch.empty_like(hidden_states) @@ -122,8 +121,8 @@ def apply_rotary_emb( key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) key_img = attn.norm_added_k(key_img) - key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) hidden_states_img = dispatch_attention_fn( query, @@ -134,7 +133,7 @@ def apply_rotary_emb( is_causal=False, backend=self._attention_backend, ) - hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) + hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) hidden_states = dispatch_attention_fn( @@ -146,7 +145,7 @@ def apply_rotary_emb( is_causal=False, backend=self._attention_backend, ) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) if hidden_states_img is not None: @@ -395,8 +394,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) - freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) - freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) + freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) return freqs_cos, freqs_sin From 90cf9679e75cc65bec01e0775a3d4c9c54957272 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 12:39:30 +0200 Subject: [PATCH 21/26] remove rmsnorm assert --- src/diffusers/models/transformers/transformer_chroma.py | 2 -- src/diffusers/models/transformers/transformer_flux.py | 4 ---- 2 files changed, 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index bf4df9df93d1..5823ae9d3da6 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -235,7 +235,6 @@ def __init__( out_dim=dim, bias=True, processor=processor, - qk_norm="rms_norm", eps=1e-6, pre_only=True, ) @@ -296,7 +295,6 @@ def __init__( context_pre_only=False, bias=True, processor=FluxAttnProcessor(), - qk_norm=qk_norm, eps=eps, ) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 7640d8d13cdc..9080cd508de4 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -277,7 +277,6 @@ def __init__( dim_head: int = 64, dropout: float = 0.0, bias: bool = False, - qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, added_proj_bias: Optional[bool] = True, out_bias: bool = True, @@ -289,7 +288,6 @@ def __init__( processor=None, ): super().__init__() - assert qk_norm == "rms_norm", "Flux uses RMSNorm" self.head_dim = dim_head self.inner_dim = out_dim if out_dim is not None else dim_head * heads @@ -375,7 +373,6 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, out_dim=dim, bias=True, processor=processor, - qk_norm="rms_norm", eps=1e-6, pre_only=True, ) @@ -431,7 +428,6 @@ def __init__( context_pre_only=False, bias=True, processor=FluxAttnProcessor(), - qk_norm=qk_norm, eps=eps, ) From 91a7b97f86d374ba8d3a355e9c4d049efc187164 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 12:50:54 +0200 Subject: [PATCH 22/26] minify and deprecate npu/xla processors --- src/diffusers/models/attention_processor.py | 396 +++----------------- 1 file changed, 61 insertions(+), 335 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e64bd45eb42d..990245de1742 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2272,235 +2272,6 @@ def __call__( return hidden_states -class FluxAttnProcessor2_0_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - deprecation_message = ( - "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " - "alternative solution to use NPU Flash Attention will be provided in the future." - ) - deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU" - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if query.dtype in (torch.float16, torch.bfloat16): - hidden_states = torch_npu.npu_fusion_attention( - query, - key, - value, - attn.heads, - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FusedFluxAttnProcessor2_0_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - deprecation_message = ( - "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " - "alternative solution to use NPU Flash Attention will be provided in the future." - ) - deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU" - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - # `context` projections. - if encoder_hidden_states is not None: - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - ( - encoder_hidden_states_query_proj, - encoder_hidden_states_key_proj, - encoder_hidden_states_value_proj, - ) = torch.split(encoder_qkv, split_size, dim=-1) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if query.dtype in (torch.float16, torch.bfloat16): - hidden_states = torch_npu.npu_fusion_attention( - query, - key, - value, - attn.heads, - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -3130,112 +2901,6 @@ def __call__( return hidden_states -class XLAFluxFlashAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. - """ - - def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None): - deprecation_message = ( - "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An " - "alternative solution to using XLA Flash Attention will be provided in the future." - ) - deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - if is_torch_xla_version("<", "2.3"): - raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") - if is_spmd() and is_torch_xla_version("<", "2.4"): - raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") - self.partition_spec = partition_spec - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - query /= math.sqrt(head_dim) - hidden_states = flash_attention(query, key, value, causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - class MochiVaeAttnProcessor2_0: r""" Attention processor used in Mochi VAE. @@ -5883,6 +5548,67 @@ def __new__(cls, *args, **kwargs): return FluxIPAdapterAttnProcessor(*args, **kwargs) +class FluxAttnProcessor2_0_NPU: + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " + "alternative solution to use NPU Flash Attention will be provided in the future." + ) + deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) + + from .transformers.transformer_flux import FluxAttnProcessor + + processor = FluxAttnProcessor() + processor._attention_backend = "_native_npu" + return processor + + +class FusedFluxAttnProcessor2_0_NPU: + def __new__(self): + deprecation_message = ( + "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " + "alternative solution to use NPU Flash Attention will be provided in the future." + ) + deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) + + from .transformers.transformer_flux import FluxAttnProcessor + + processor = FluxAttnProcessor() + processor._attention_backend = "_fused_npu" + return processor + + +class XLAFluxFlashAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. + """ + + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An " + "alternative solution to using XLA Flash Attention will be provided in the future." + ) + deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) + + if is_torch_xla_version("<", "2.3"): + raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") + if is_spmd() and is_torch_xla_version("<", "2.4"): + raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") + + from .transformers.transformer_flux import FluxAttnProcessor + + if len(args) > 0 or kwargs.get("partition_spec", None) is not None: + deprecation_message = ( + "partition_spec was not used in the processor implementation when it was added. Passing it " + "is a no-op and support for it will be removed." + ) + deprecate("partition_spec", "1.0.0", deprecation_message) + + processor = FluxAttnProcessor(*args, **kwargs) + processor._attention_backend = "_native_xla" + return processor + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, From 1e6b1c51a8691c144adb50f2c9d21c53adf5d5e2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 12:39:30 +0200 Subject: [PATCH 23/26] remove rmsnorm assert --- src/diffusers/models/transformers/transformer_chroma.py | 2 -- src/diffusers/models/transformers/transformer_flux.py | 4 ---- 2 files changed, 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index bf4df9df93d1..5823ae9d3da6 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -235,7 +235,6 @@ def __init__( out_dim=dim, bias=True, processor=processor, - qk_norm="rms_norm", eps=1e-6, pre_only=True, ) @@ -296,7 +295,6 @@ def __init__( context_pre_only=False, bias=True, processor=FluxAttnProcessor(), - qk_norm=qk_norm, eps=eps, ) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 7640d8d13cdc..9080cd508de4 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -277,7 +277,6 @@ def __init__( dim_head: int = 64, dropout: float = 0.0, bias: bool = False, - qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, added_proj_bias: Optional[bool] = True, out_bias: bool = True, @@ -289,7 +288,6 @@ def __init__( processor=None, ): super().__init__() - assert qk_norm == "rms_norm", "Flux uses RMSNorm" self.head_dim = dim_head self.inner_dim = out_dim if out_dim is not None else dim_head * heads @@ -375,7 +373,6 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, out_dim=dim, bias=True, processor=processor, - qk_norm="rms_norm", eps=1e-6, pre_only=True, ) @@ -431,7 +428,6 @@ def __init__( context_pre_only=False, bias=True, processor=FluxAttnProcessor(), - qk_norm=qk_norm, eps=eps, ) From 251bb619250900c60273621a071bef22b1971dca Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 12:50:54 +0200 Subject: [PATCH 24/26] minify and deprecate npu/xla processors --- src/diffusers/models/attention_processor.py | 396 +++----------------- 1 file changed, 61 insertions(+), 335 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e64bd45eb42d..990245de1742 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2272,235 +2272,6 @@ def __call__( return hidden_states -class FluxAttnProcessor2_0_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - deprecation_message = ( - "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " - "alternative solution to use NPU Flash Attention will be provided in the future." - ) - deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU" - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if query.dtype in (torch.float16, torch.bfloat16): - hidden_states = torch_npu.npu_fusion_attention( - query, - key, - value, - attn.heads, - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FusedFluxAttnProcessor2_0_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - deprecation_message = ( - "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " - "alternative solution to use NPU Flash Attention will be provided in the future." - ) - deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU" - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - # `context` projections. - if encoder_hidden_states is not None: - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - ( - encoder_hidden_states_query_proj, - encoder_hidden_states_key_proj, - encoder_hidden_states_value_proj, - ) = torch.split(encoder_qkv, split_size, dim=-1) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if query.dtype in (torch.float16, torch.bfloat16): - hidden_states = torch_npu.npu_fusion_attention( - query, - key, - value, - attn.heads, - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -3130,112 +2901,6 @@ def __call__( return hidden_states -class XLAFluxFlashAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. - """ - - def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None): - deprecation_message = ( - "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An " - "alternative solution to using XLA Flash Attention will be provided in the future." - ) - deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - if is_torch_xla_version("<", "2.3"): - raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") - if is_spmd() and is_torch_xla_version("<", "2.4"): - raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") - self.partition_spec = partition_spec - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - query /= math.sqrt(head_dim) - hidden_states = flash_attention(query, key, value, causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - class MochiVaeAttnProcessor2_0: r""" Attention processor used in Mochi VAE. @@ -5883,6 +5548,67 @@ def __new__(cls, *args, **kwargs): return FluxIPAdapterAttnProcessor(*args, **kwargs) +class FluxAttnProcessor2_0_NPU: + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " + "alternative solution to use NPU Flash Attention will be provided in the future." + ) + deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) + + from .transformers.transformer_flux import FluxAttnProcessor + + processor = FluxAttnProcessor() + processor._attention_backend = "_native_npu" + return processor + + +class FusedFluxAttnProcessor2_0_NPU: + def __new__(self): + deprecation_message = ( + "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " + "alternative solution to use NPU Flash Attention will be provided in the future." + ) + deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) + + from .transformers.transformer_flux import FluxAttnProcessor + + processor = FluxAttnProcessor() + processor._attention_backend = "_fused_npu" + return processor + + +class XLAFluxFlashAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. + """ + + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An " + "alternative solution to using XLA Flash Attention will be provided in the future." + ) + deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) + + if is_torch_xla_version("<", "2.3"): + raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") + if is_spmd() and is_torch_xla_version("<", "2.4"): + raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") + + from .transformers.transformer_flux import FluxAttnProcessor + + if len(args) > 0 or kwargs.get("partition_spec", None) is not None: + deprecation_message = ( + "partition_spec was not used in the processor implementation when it was added. Passing it " + "is a no-op and support for it will be removed." + ) + deprecate("partition_spec", "1.0.0", deprecation_message) + + processor = FluxAttnProcessor(*args, **kwargs) + processor._attention_backend = "_native_xla" + return processor + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, From 6067995c484d297d027efd8b927e87c228d8097e Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 23 Jul 2025 14:38:05 +0200 Subject: [PATCH 25/26] update --- src/diffusers/models/transformers/transformer_wan.py | 2 +- src/diffusers/models/transformers/transformer_wan_vace.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 6960522cff80..143abf142d04 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -261,7 +261,7 @@ def unfuse_projections(self): def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 934e12ca54f5..e039d362193d 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -110,12 +110,12 @@ def forward( norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as( control_hidden_states ) - attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states) # 2. Cross-attention norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states) - attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) control_hidden_states = control_hidden_states + attn_output # 3. Feed-forward From 3568e9145510ebce3482825b6f7dd7552dc09e9b Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Jul 2025 09:46:55 +0530 Subject: [PATCH 26/26] Update src/diffusers/models/transformers/transformer_wan.py --- src/diffusers/models/transformers/transformer_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index ea3539076f80..8a18ea5f3e2a 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -70,7 +70,7 @@ class WanAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( - "WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." ) def __call__(