diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index b6c01c13c12f..8a18ea5f3e2a 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -21,10 +21,10 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward -from ..attention_processor import Attention +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput @@ -35,18 +35,51 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class WanAttnProcessor2_0: +def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if attn.cross_attention_dim_head is None: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor): + if attn.fused_projections: + key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) + else: + key_img = attn.add_k_proj(encoder_hidden_states_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + return key_img, value_img + + +class WanAttnProcessor: + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) def __call__( self, - attn: Attention, + attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: encoder_hidden_states_img = None if attn.add_k_proj is not None: @@ -54,21 +87,15 @@ def __call__( image_context_length = encoder_hidden_states.shape[1] - 512 encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] encoder_hidden_states = encoder_hidden_states[:, image_context_length:] - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + query = attn.norm_q(query) + key = attn.norm_k(key) - query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) if rotary_emb is not None: @@ -77,8 +104,7 @@ def apply_rotary_emb( freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): - x = hidden_states.view(*hidden_states.shape[:-1], -1, 2) - x1, x2 = x[..., 0], x[..., 1] + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) cos = freqs_cos[..., 0::2] sin = freqs_sin[..., 1::2] out = torch.empty_like(hidden_states) @@ -92,23 +118,34 @@ def apply_rotary_emb( # I2V task hidden_states_img = None if encoder_hidden_states_img is not None: - key_img = attn.add_k_proj(encoder_hidden_states_img) + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) key_img = attn.norm_added_k(key_img) - value_img = attn.add_v_proj(encoder_hidden_states_img) - - key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) - hidden_states_img = F.scaled_dot_product_attention( - query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) + + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) - hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) + hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) if hidden_states_img is not None: @@ -119,6 +156,119 @@ def apply_rotary_emb( return hidden_states +class WanAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. " + "Please use WanAttnProcessor instead. " + ) + deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) + return WanAttnProcessor(*args, **kwargs) + + +class WanAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = WanAttnProcessor + _available_processors = [WanAttnProcessor] + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: Optional[int] = None, + cross_attention_dim_head: Optional[int] = None, + processor=None, + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = torch.nn.ModuleList( + [ + torch.nn.Linear(self.inner_dim, dim, bias=True), + torch.nn.Dropout(dropout), + ] + ) + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + + self.set_processor(processor) + + def fuse_projections(self): + if getattr(self, "fused_projections", False): + return + + if self.cross_attention_dim_head is None: + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_qkv = nn.Linear(in_features, out_features, bias=True) + self.to_qkv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_kv = nn.Linear(in_features, out_features, bias=True) + self.to_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + if self.added_kv_proj_dim is not None: + concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_added_kv = nn.Linear(in_features, out_features, bias=True) + self.to_added_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + self.fused_projections = True + + @torch.no_grad() + def unfuse_projections(self): + if not getattr(self, "fused_projections", False): + return + + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + if hasattr(self, "to_added_kv"): + delattr(self, "to_added_kv") + + self.fused_projections = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs) + + class WanImageEmbedding(torch.nn.Module): def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): super().__init__() @@ -247,8 +397,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) - freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) - freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) + freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) return freqs_cos, freqs_sin @@ -269,33 +419,24 @@ def __init__( # 1. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) - self.attn1 = Attention( - query_dim=dim, + self.attn1 = WanAttention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, - processor=WanAttnProcessor2_0(), + cross_attention_dim_head=None, + processor=WanAttnProcessor(), ) # 2. Cross-attention - self.attn2 = Attention( - query_dim=dim, + self.attn2 = WanAttention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, added_kv_proj_dim=added_kv_proj_dim, - added_proj_bias=True, - processor=WanAttnProcessor2_0(), + cross_attention_dim_head=dim // num_heads, + processor=WanAttnProcessor(), ) self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() @@ -332,12 +473,12 @@ def forward( # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) - attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) hidden_states = hidden_states + attn_output # 3. Feed-forward @@ -350,7 +491,9 @@ def forward( return hidden_states -class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): +class WanTransformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): r""" A Transformer model for video-like data used in the Wan model. diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 1a6f2af59a87..e039d362193d 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -22,12 +22,17 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward -from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm -from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock +from .transformer_wan import ( + WanAttention, + WanAttnProcessor, + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, + WanTransformerBlock, +) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -55,33 +60,22 @@ def __init__( # 2. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) - self.attn1 = Attention( - query_dim=dim, + self.attn1 = WanAttention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, - processor=WanAttnProcessor2_0(), + processor=WanAttnProcessor(), ) # 3. Cross-attention - self.attn2 = Attention( - query_dim=dim, + self.attn2 = WanAttention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, added_kv_proj_dim=added_kv_proj_dim, - added_proj_bias=True, - processor=WanAttnProcessor2_0(), + processor=WanAttnProcessor(), ) self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() @@ -116,12 +110,12 @@ def forward( norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as( control_hidden_states ) - attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states) # 2. Cross-attention norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states) - attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) control_hidden_states = control_hidden_states + attn_output # 3. Feed-forward