21
21
22
22
from ...configuration_utils import ConfigMixin , register_to_config
23
23
from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
24
- from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
25
- from ..attention import FeedForward
26
- from ..attention_processor import Attention
24
+ from ...utils import USE_PEFT_BACKEND , deprecate , logging , scale_lora_layers , unscale_lora_layers
25
+ from ...utils .torch_utils import maybe_allow_in_graph
26
+ from ..attention import AttentionMixin , AttentionModuleMixin , FeedForward
27
+ from ..attention_dispatch import dispatch_attention_fn
27
28
from ..cache_utils import CacheMixin
28
29
from ..embeddings import PixArtAlphaTextProjection , TimestepEmbedding , Timesteps , get_1d_rotary_pos_embed
29
30
from ..modeling_outputs import Transformer2DModelOutput
34
35
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
35
36
36
37
37
- class WanAttnProcessor2_0 :
38
+ def _get_qkv_projections (attn : "WanAttention" , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ):
39
+ # encoder_hidden_states is only passed for cross-attention
40
+ if encoder_hidden_states is None :
41
+ encoder_hidden_states = hidden_states
42
+
43
+ if attn .fused_projections :
44
+ if attn .cross_attention_dim_head is None :
45
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
46
+ query , key , value = attn .to_qkv (hidden_states ).chunk (3 , dim = - 1 )
47
+ else :
48
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
49
+ query = attn .to_q (hidden_states )
50
+ key , value = attn .to_kv (encoder_hidden_states ).chunk (2 , dim = - 1 )
51
+ else :
52
+ query = attn .to_q (hidden_states )
53
+ key = attn .to_k (encoder_hidden_states )
54
+ value = attn .to_v (encoder_hidden_states )
55
+ return query , key , value
56
+
57
+
58
+ def _get_added_kv_projections (attn : "WanAttention" , encoder_hidden_states_img : torch .Tensor ):
59
+ if attn .fused_projections :
60
+ key_img , value_img = attn .to_added_kv (encoder_hidden_states_img ).chunk (2 , dim = - 1 )
61
+ else :
62
+ key_img = attn .add_k_proj (encoder_hidden_states_img )
63
+ value_img = attn .add_v_proj (encoder_hidden_states_img )
64
+ return key_img , value_img
65
+
66
+
67
+ class WanAttnProcessor :
68
+ _attention_backend = None
69
+
38
70
def __init__ (self ):
39
71
if not hasattr (F , "scaled_dot_product_attention" ):
40
- raise ImportError ("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
72
+ raise ImportError (
73
+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
74
+ )
41
75
42
76
def __call__ (
43
77
self ,
44
- attn : Attention ,
78
+ attn : "WanAttention" ,
45
79
hidden_states : torch .Tensor ,
46
80
encoder_hidden_states : Optional [torch .Tensor ] = None ,
47
81
attention_mask : Optional [torch .Tensor ] = None ,
48
- rotary_emb : Optional [torch .Tensor ] = None ,
82
+ rotary_emb : Optional [Tuple [ torch .Tensor , torch . Tensor ] ] = None ,
49
83
) -> torch .Tensor :
50
84
encoder_hidden_states_img = None
51
85
if attn .add_k_proj is not None :
52
86
# 512 is the context length of the text encoder, hardcoded for now
53
87
image_context_length = encoder_hidden_states .shape [1 ] - 512
54
88
encoder_hidden_states_img = encoder_hidden_states [:, :image_context_length ]
55
89
encoder_hidden_states = encoder_hidden_states [:, image_context_length :]
56
- if encoder_hidden_states is None :
57
- encoder_hidden_states = hidden_states
58
90
59
- query = attn .to_q (hidden_states )
60
- key = attn .to_k (encoder_hidden_states )
61
- value = attn .to_v (encoder_hidden_states )
91
+ query , key , value = _get_qkv_projections (attn , hidden_states , encoder_hidden_states )
62
92
63
- if attn .norm_q is not None :
64
- query = attn .norm_q (query )
65
- if attn .norm_k is not None :
66
- key = attn .norm_k (key )
93
+ query = attn .norm_q (query )
94
+ key = attn .norm_k (key )
67
95
68
- query = query .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
69
- key = key .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
70
- value = value .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
96
+ query = query .unflatten (2 , (attn .heads , - 1 ))
97
+ key = key .unflatten (2 , (attn .heads , - 1 ))
98
+ value = value .unflatten (2 , (attn .heads , - 1 ))
71
99
72
100
if rotary_emb is not None :
73
101
74
- def apply_rotary_emb (hidden_states : torch .Tensor , freqs : torch .Tensor ):
75
- dtype = torch .float32 if hidden_states .device .type == "mps" else torch .float64
76
- x_rotated = torch .view_as_complex (hidden_states .to (dtype ).unflatten (3 , (- 1 , 2 )))
77
- x_out = torch .view_as_real (x_rotated * freqs ).flatten (3 , 4 )
78
- return x_out .type_as (hidden_states )
102
+ def apply_rotary_emb (
103
+ hidden_states : torch .Tensor ,
104
+ freqs_cos : torch .Tensor ,
105
+ freqs_sin : torch .Tensor ,
106
+ ):
107
+ x1 , x2 = hidden_states .unflatten (- 1 , (- 1 , 2 )).unbind (- 1 )
108
+ cos = freqs_cos [..., 0 ::2 ]
109
+ sin = freqs_sin [..., 1 ::2 ]
110
+ out = torch .empty_like (hidden_states )
111
+ out [..., 0 ::2 ] = x1 * cos - x2 * sin
112
+ out [..., 1 ::2 ] = x1 * sin + x2 * cos
113
+ return out .type_as (hidden_states )
79
114
80
115
query = apply_rotary_emb (query , rotary_emb )
81
116
key = apply_rotary_emb (key , rotary_emb )
82
117
83
118
# I2V task
84
119
hidden_states_img = None
85
120
if encoder_hidden_states_img is not None :
86
- key_img = attn . add_k_proj ( encoder_hidden_states_img )
121
+ key_img , value_img = _get_added_kv_projections ( attn , encoder_hidden_states_img )
87
122
key_img = attn .norm_added_k (key_img )
88
- value_img = attn .add_v_proj (encoder_hidden_states_img )
89
-
90
- key_img = key_img .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
91
- value_img = value_img .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
92
123
93
- hidden_states_img = F .scaled_dot_product_attention (
94
- query , key_img , value_img , attn_mask = None , dropout_p = 0.0 , is_causal = False
124
+ key_img = key_img .unflatten (2 , (attn .heads , - 1 ))
125
+ value_img = value_img .unflatten (2 , (attn .heads , - 1 ))
126
+
127
+ hidden_states_img = dispatch_attention_fn (
128
+ query ,
129
+ key_img ,
130
+ value_img ,
131
+ attn_mask = None ,
132
+ dropout_p = 0.0 ,
133
+ is_causal = False ,
134
+ backend = self ._attention_backend ,
95
135
)
96
- hidden_states_img = hidden_states_img .transpose ( 1 , 2 ). flatten (2 , 3 )
136
+ hidden_states_img = hidden_states_img .flatten (2 , 3 )
97
137
hidden_states_img = hidden_states_img .type_as (query )
98
138
99
- hidden_states = F .scaled_dot_product_attention (
100
- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
139
+ hidden_states = dispatch_attention_fn (
140
+ query ,
141
+ key ,
142
+ value ,
143
+ attn_mask = attention_mask ,
144
+ dropout_p = 0.0 ,
145
+ is_causal = False ,
146
+ backend = self ._attention_backend ,
101
147
)
102
- hidden_states = hidden_states .transpose ( 1 , 2 ). flatten (2 , 3 )
148
+ hidden_states = hidden_states .flatten (2 , 3 )
103
149
hidden_states = hidden_states .type_as (query )
104
150
105
151
if hidden_states_img is not None :
@@ -110,6 +156,119 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
110
156
return hidden_states
111
157
112
158
159
+ class WanAttnProcessor2_0 :
160
+ def __new__ (cls , * args , ** kwargs ):
161
+ deprecation_message = (
162
+ "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
163
+ "Please use WanAttnProcessor instead. "
164
+ )
165
+ deprecate ("WanAttnProcessor2_0" , "1.0.0" , deprecation_message , standard_warn = False )
166
+ return WanAttnProcessor (* args , ** kwargs )
167
+
168
+
169
+ class WanAttention (torch .nn .Module , AttentionModuleMixin ):
170
+ _default_processor_cls = WanAttnProcessor
171
+ _available_processors = [WanAttnProcessor ]
172
+
173
+ def __init__ (
174
+ self ,
175
+ dim : int ,
176
+ heads : int = 8 ,
177
+ dim_head : int = 64 ,
178
+ eps : float = 1e-5 ,
179
+ dropout : float = 0.0 ,
180
+ added_kv_proj_dim : Optional [int ] = None ,
181
+ cross_attention_dim_head : Optional [int ] = None ,
182
+ processor = None ,
183
+ ):
184
+ super ().__init__ ()
185
+
186
+ self .inner_dim = dim_head * heads
187
+ self .heads = heads
188
+ self .added_kv_proj_dim = added_kv_proj_dim
189
+ self .cross_attention_dim_head = cross_attention_dim_head
190
+ self .kv_inner_dim = self .inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
191
+
192
+ self .to_q = torch .nn .Linear (dim , self .inner_dim , bias = True )
193
+ self .to_k = torch .nn .Linear (dim , self .kv_inner_dim , bias = True )
194
+ self .to_v = torch .nn .Linear (dim , self .kv_inner_dim , bias = True )
195
+ self .to_out = torch .nn .ModuleList (
196
+ [
197
+ torch .nn .Linear (self .inner_dim , dim , bias = True ),
198
+ torch .nn .Dropout (dropout ),
199
+ ]
200
+ )
201
+ self .norm_q = torch .nn .RMSNorm (dim_head * heads , eps = eps , elementwise_affine = True )
202
+ self .norm_k = torch .nn .RMSNorm (dim_head * heads , eps = eps , elementwise_affine = True )
203
+
204
+ self .add_k_proj = self .add_v_proj = None
205
+ if added_kv_proj_dim is not None :
206
+ self .add_k_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = True )
207
+ self .add_v_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = True )
208
+ self .norm_added_k = torch .nn .RMSNorm (dim_head * heads , eps = eps )
209
+
210
+ self .set_processor (processor )
211
+
212
+ def fuse_projections (self ):
213
+ if getattr (self , "fused_projections" , False ):
214
+ return
215
+
216
+ if self .cross_attention_dim_head is None :
217
+ concatenated_weights = torch .cat ([self .to_q .weight .data , self .to_k .weight .data , self .to_v .weight .data ])
218
+ concatenated_bias = torch .cat ([self .to_q .bias .data , self .to_k .bias .data , self .to_v .bias .data ])
219
+ out_features , in_features = concatenated_weights .shape
220
+ with torch .device ("meta" ):
221
+ self .to_qkv = nn .Linear (in_features , out_features , bias = True )
222
+ self .to_qkv .load_state_dict (
223
+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
224
+ )
225
+ else :
226
+ concatenated_weights = torch .cat ([self .to_k .weight .data , self .to_v .weight .data ])
227
+ concatenated_bias = torch .cat ([self .to_k .bias .data , self .to_v .bias .data ])
228
+ out_features , in_features = concatenated_weights .shape
229
+ with torch .device ("meta" ):
230
+ self .to_kv = nn .Linear (in_features , out_features , bias = True )
231
+ self .to_kv .load_state_dict (
232
+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
233
+ )
234
+
235
+ if self .added_kv_proj_dim is not None :
236
+ concatenated_weights = torch .cat ([self .add_k_proj .weight .data , self .add_v_proj .weight .data ])
237
+ concatenated_bias = torch .cat ([self .add_k_proj .bias .data , self .add_v_proj .bias .data ])
238
+ out_features , in_features = concatenated_weights .shape
239
+ with torch .device ("meta" ):
240
+ self .to_added_kv = nn .Linear (in_features , out_features , bias = True )
241
+ self .to_added_kv .load_state_dict (
242
+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
243
+ )
244
+
245
+ self .fused_projections = True
246
+
247
+ @torch .no_grad ()
248
+ def unfuse_projections (self ):
249
+ if not getattr (self , "fused_projections" , False ):
250
+ return
251
+
252
+ if hasattr (self , "to_qkv" ):
253
+ delattr (self , "to_qkv" )
254
+ if hasattr (self , "to_kv" ):
255
+ delattr (self , "to_kv" )
256
+ if hasattr (self , "to_added_kv" ):
257
+ delattr (self , "to_added_kv" )
258
+
259
+ self .fused_projections = False
260
+
261
+ def forward (
262
+ self ,
263
+ hidden_states : torch .Tensor ,
264
+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
265
+ attention_mask : Optional [torch .Tensor ] = None ,
266
+ rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
267
+ ** kwargs ,
268
+ ) -> torch .Tensor :
269
+ return self .processor (self , hidden_states , encoder_hidden_states , attention_mask , rotary_emb , ** kwargs )
270
+
271
+
113
272
class WanImageEmbedding (torch .nn .Module ):
114
273
def __init__ (self , in_features : int , out_features : int , pos_embed_seq_len = None ):
115
274
super ().__init__ ()
@@ -217,11 +376,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
217
376
dim = 1 ,
218
377
)
219
378
220
- freqs_f = freqs [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
221
- freqs_h = freqs [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
222
- freqs_w = freqs [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
223
- freqs = torch .cat ([freqs_f , freqs_h , freqs_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
224
- return freqs
379
+ freqs_cos = self .freqs_cos .split (split_sizes , dim = 1 )
380
+ freqs_sin = self .freqs_sin .split (split_sizes , dim = 1 )
381
+
382
+ freqs_cos_f = freqs_cos [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
383
+ freqs_cos_h = freqs_cos [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
384
+ freqs_cos_w = freqs_cos [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
385
+
386
+ freqs_sin_f = freqs_sin [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
387
+ freqs_sin_h = freqs_sin [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
388
+ freqs_sin_w = freqs_sin [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
389
+
390
+ freqs_cos = torch .cat ([freqs_cos_f , freqs_cos_h , freqs_cos_w ], dim = - 1 ).reshape (1 , ppf * pph * ppw , 1 , - 1 )
391
+ freqs_sin = torch .cat ([freqs_sin_f , freqs_sin_h , freqs_sin_w ], dim = - 1 ).reshape (1 , ppf * pph * ppw , 1 , - 1 )
392
+
393
+ return freqs_cos , freqs_sin
225
394
226
395
227
396
class WanTransformerBlock (nn .Module ):
@@ -239,33 +408,24 @@ def __init__(
239
408
240
409
# 1. Self-attention
241
410
self .norm1 = FP32LayerNorm (dim , eps , elementwise_affine = False )
242
- self .attn1 = Attention (
243
- query_dim = dim ,
411
+ self .attn1 = WanAttention (
412
+ dim = dim ,
244
413
heads = num_heads ,
245
- kv_heads = num_heads ,
246
414
dim_head = dim // num_heads ,
247
- qk_norm = qk_norm ,
248
415
eps = eps ,
249
- bias = True ,
250
- cross_attention_dim = None ,
251
- out_bias = True ,
252
- processor = WanAttnProcessor2_0 (),
416
+ cross_attention_dim_head = None ,
417
+ processor = WanAttnProcessor (),
253
418
)
254
419
255
420
# 2. Cross-attention
256
- self .attn2 = Attention (
257
- query_dim = dim ,
421
+ self .attn2 = WanAttention (
422
+ dim = dim ,
258
423
heads = num_heads ,
259
- kv_heads = num_heads ,
260
424
dim_head = dim // num_heads ,
261
- qk_norm = qk_norm ,
262
425
eps = eps ,
263
- bias = True ,
264
- cross_attention_dim = None ,
265
- out_bias = True ,
266
426
added_kv_proj_dim = added_kv_proj_dim ,
267
- added_proj_bias = True ,
268
- processor = WanAttnProcessor2_0 (),
427
+ cross_attention_dim_head = dim // num_heads ,
428
+ processor = WanAttnProcessor (),
269
429
)
270
430
self .norm2 = FP32LayerNorm (dim , eps , elementwise_affine = True ) if cross_attn_norm else nn .Identity ()
271
431
@@ -302,12 +462,12 @@ def forward(
302
462
303
463
# 1. Self-attention
304
464
norm_hidden_states = (self .norm1 (hidden_states .float ()) * (1 + scale_msa ) + shift_msa ).type_as (hidden_states )
305
- attn_output = self .attn1 (hidden_states = norm_hidden_states , rotary_emb = rotary_emb )
465
+ attn_output = self .attn1 (norm_hidden_states , None , None , rotary_emb )
306
466
hidden_states = (hidden_states .float () + attn_output * gate_msa ).type_as (hidden_states )
307
467
308
468
# 2. Cross-attention
309
469
norm_hidden_states = self .norm2 (hidden_states .float ()).type_as (hidden_states )
310
- attn_output = self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states )
470
+ attn_output = self .attn2 (norm_hidden_states , encoder_hidden_states , None , None )
311
471
hidden_states = hidden_states + attn_output
312
472
313
473
# 3. Feed-forward
@@ -320,7 +480,9 @@ def forward(
320
480
return hidden_states
321
481
322
482
323
- class WanTransformer3DModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin ):
483
+ class WanTransformer3DModel (
484
+ ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin , AttentionMixin
485
+ ):
324
486
r"""
325
487
A Transformer model for video-like data used in the Wan model.
326
488
0 commit comments