diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 2d06124282d1..79149fb76067 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Genmo team and The HuggingFace Team. +# Copyright 2025 The Lightricks team and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,19 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import math from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn -import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward -from ..attention_processor import Attention +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput @@ -37,20 +37,30 @@ class LTXVideoAttentionProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`" + deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message) + + return LTXVideoAttnProcessor(*args, **kwargs) + + +class LTXVideoAttnProcessor: r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector. + Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX + model. It applies a normalization layer and rotary embedding on the query and key vector. """ + _attention_backend = None + def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." ) def __call__( self, - attn: Attention, + attn: "LTXAttention", hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -78,14 +88,20 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.to_out[0](hidden_states) @@ -93,6 +109,70 @@ def __call__( return hidden_states +class LTXAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = LTXVideoAttnProcessor + _available_processors = [LTXVideoAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + kv_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + qk_norm: str = "rms_norm_across_heads", + processor=None, + ): + super().__init__() + if qk_norm != "rms_norm_across_heads": + raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") + + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = query_dim + self.heads = heads + + norm_eps = 1e-5 + norm_elementwise_affine = True + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + class LTXVideoRotaryPosEmbed(nn.Module): def __init__( self, @@ -231,7 +311,7 @@ def __init__( super().__init__() self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) - self.attn1 = Attention( + self.attn1 = LTXAttention( query_dim=dim, heads=num_attention_heads, kv_heads=num_attention_heads, @@ -240,11 +320,10 @@ def __init__( cross_attention_dim=None, out_bias=attention_out_bias, qk_norm=qk_norm, - processor=LTXVideoAttentionProcessor2_0(), ) self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) - self.attn2 = Attention( + self.attn2 = LTXAttention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, @@ -253,7 +332,6 @@ def __init__( bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, - processor=LTXVideoAttentionProcessor2_0(), ) self.ff = FeedForward(dim, activation_fn=activation_fn) @@ -299,7 +377,9 @@ def forward( @maybe_allow_in_graph -class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin): +class LTXVideoTransformer3DModel( + ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin +): r""" A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).