Skip to content

Commit c02c4a6

Browse files
a-r-r-o-wDN6
andauthored
[refactor] Wan single file implementation (#11918)
* update * update * update * add coauthor Co-Authored-By: Dhruv Nair <[email protected]> * improve test * handle ip adapter params correctly * fix chroma qkv fusion test * fix fastercache implementation * remove set_attention_backend related code * fix more tests * fight more tests * add back set_attention_backend * update * update * make style * make fix-copies * make ip adapter processor compatible with attention dispatcher * refactor chroma as well * attnetion dispatcher support * remove transpose; fix rope shape * remove rmsnorm assert * minify and deprecate npu/xla processors * remove rmsnorm assert * minify and deprecate npu/xla processors * update * Update src/diffusers/models/transformers/transformer_wan.py --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent 6f3ac30 commit c02c4a6

File tree

2 files changed

+212
-75
lines changed

2 files changed

+212
-75
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 197 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
2525
from ...utils.torch_utils import maybe_allow_in_graph
26-
from ..attention import FeedForward
27-
from ..attention_processor import Attention
26+
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
27+
from ..attention_dispatch import dispatch_attention_fn
2828
from ..cache_utils import CacheMixin
2929
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
3030
from ..modeling_outputs import Transformer2DModelOutput
@@ -35,40 +35,67 @@
3535
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3636

3737

38-
class WanAttnProcessor2_0:
38+
def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
39+
# encoder_hidden_states is only passed for cross-attention
40+
if encoder_hidden_states is None:
41+
encoder_hidden_states = hidden_states
42+
43+
if attn.fused_projections:
44+
if attn.cross_attention_dim_head is None:
45+
# In self-attention layers, we can fuse the entire QKV projection into a single linear
46+
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
47+
else:
48+
# In cross-attention layers, we can only fuse the KV projections into a single linear
49+
query = attn.to_q(hidden_states)
50+
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
51+
else:
52+
query = attn.to_q(hidden_states)
53+
key = attn.to_k(encoder_hidden_states)
54+
value = attn.to_v(encoder_hidden_states)
55+
return query, key, value
56+
57+
58+
def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
59+
if attn.fused_projections:
60+
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
61+
else:
62+
key_img = attn.add_k_proj(encoder_hidden_states_img)
63+
value_img = attn.add_v_proj(encoder_hidden_states_img)
64+
return key_img, value_img
65+
66+
67+
class WanAttnProcessor:
68+
_attention_backend = None
69+
3970
def __init__(self):
4071
if not hasattr(F, "scaled_dot_product_attention"):
41-
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
72+
raise ImportError(
73+
"WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
74+
)
4275

4376
def __call__(
4477
self,
45-
attn: Attention,
78+
attn: "WanAttention",
4679
hidden_states: torch.Tensor,
4780
encoder_hidden_states: Optional[torch.Tensor] = None,
4881
attention_mask: Optional[torch.Tensor] = None,
49-
rotary_emb: Optional[torch.Tensor] = None,
82+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
5083
) -> torch.Tensor:
5184
encoder_hidden_states_img = None
5285
if attn.add_k_proj is not None:
5386
# 512 is the context length of the text encoder, hardcoded for now
5487
image_context_length = encoder_hidden_states.shape[1] - 512
5588
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
5689
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
57-
if encoder_hidden_states is None:
58-
encoder_hidden_states = hidden_states
5990

60-
query = attn.to_q(hidden_states)
61-
key = attn.to_k(encoder_hidden_states)
62-
value = attn.to_v(encoder_hidden_states)
91+
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
6392

64-
if attn.norm_q is not None:
65-
query = attn.norm_q(query)
66-
if attn.norm_k is not None:
67-
key = attn.norm_k(key)
93+
query = attn.norm_q(query)
94+
key = attn.norm_k(key)
6895

69-
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
70-
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
71-
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
96+
query = query.unflatten(2, (attn.heads, -1))
97+
key = key.unflatten(2, (attn.heads, -1))
98+
value = value.unflatten(2, (attn.heads, -1))
7299

73100
if rotary_emb is not None:
74101

@@ -77,8 +104,7 @@ def apply_rotary_emb(
77104
freqs_cos: torch.Tensor,
78105
freqs_sin: torch.Tensor,
79106
):
80-
x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
81-
x1, x2 = x[..., 0], x[..., 1]
107+
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
82108
cos = freqs_cos[..., 0::2]
83109
sin = freqs_sin[..., 1::2]
84110
out = torch.empty_like(hidden_states)
@@ -92,23 +118,34 @@ def apply_rotary_emb(
92118
# I2V task
93119
hidden_states_img = None
94120
if encoder_hidden_states_img is not None:
95-
key_img = attn.add_k_proj(encoder_hidden_states_img)
121+
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
96122
key_img = attn.norm_added_k(key_img)
97-
value_img = attn.add_v_proj(encoder_hidden_states_img)
98-
99-
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
100-
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
101123

102-
hidden_states_img = F.scaled_dot_product_attention(
103-
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
124+
key_img = key_img.unflatten(2, (attn.heads, -1))
125+
value_img = value_img.unflatten(2, (attn.heads, -1))
126+
127+
hidden_states_img = dispatch_attention_fn(
128+
query,
129+
key_img,
130+
value_img,
131+
attn_mask=None,
132+
dropout_p=0.0,
133+
is_causal=False,
134+
backend=self._attention_backend,
104135
)
105-
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
136+
hidden_states_img = hidden_states_img.flatten(2, 3)
106137
hidden_states_img = hidden_states_img.type_as(query)
107138

108-
hidden_states = F.scaled_dot_product_attention(
109-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
139+
hidden_states = dispatch_attention_fn(
140+
query,
141+
key,
142+
value,
143+
attn_mask=attention_mask,
144+
dropout_p=0.0,
145+
is_causal=False,
146+
backend=self._attention_backend,
110147
)
111-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
148+
hidden_states = hidden_states.flatten(2, 3)
112149
hidden_states = hidden_states.type_as(query)
113150

114151
if hidden_states_img is not None:
@@ -119,6 +156,119 @@ def apply_rotary_emb(
119156
return hidden_states
120157

121158

159+
class WanAttnProcessor2_0:
160+
def __new__(cls, *args, **kwargs):
161+
deprecation_message = (
162+
"The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
163+
"Please use WanAttnProcessor instead. "
164+
)
165+
deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
166+
return WanAttnProcessor(*args, **kwargs)
167+
168+
169+
class WanAttention(torch.nn.Module, AttentionModuleMixin):
170+
_default_processor_cls = WanAttnProcessor
171+
_available_processors = [WanAttnProcessor]
172+
173+
def __init__(
174+
self,
175+
dim: int,
176+
heads: int = 8,
177+
dim_head: int = 64,
178+
eps: float = 1e-5,
179+
dropout: float = 0.0,
180+
added_kv_proj_dim: Optional[int] = None,
181+
cross_attention_dim_head: Optional[int] = None,
182+
processor=None,
183+
):
184+
super().__init__()
185+
186+
self.inner_dim = dim_head * heads
187+
self.heads = heads
188+
self.added_kv_proj_dim = added_kv_proj_dim
189+
self.cross_attention_dim_head = cross_attention_dim_head
190+
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
191+
192+
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
193+
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
194+
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
195+
self.to_out = torch.nn.ModuleList(
196+
[
197+
torch.nn.Linear(self.inner_dim, dim, bias=True),
198+
torch.nn.Dropout(dropout),
199+
]
200+
)
201+
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
202+
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
203+
204+
self.add_k_proj = self.add_v_proj = None
205+
if added_kv_proj_dim is not None:
206+
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
207+
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
208+
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
209+
210+
self.set_processor(processor)
211+
212+
def fuse_projections(self):
213+
if getattr(self, "fused_projections", False):
214+
return
215+
216+
if self.cross_attention_dim_head is None:
217+
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
218+
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
219+
out_features, in_features = concatenated_weights.shape
220+
with torch.device("meta"):
221+
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
222+
self.to_qkv.load_state_dict(
223+
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
224+
)
225+
else:
226+
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
227+
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
228+
out_features, in_features = concatenated_weights.shape
229+
with torch.device("meta"):
230+
self.to_kv = nn.Linear(in_features, out_features, bias=True)
231+
self.to_kv.load_state_dict(
232+
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
233+
)
234+
235+
if self.added_kv_proj_dim is not None:
236+
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
237+
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
238+
out_features, in_features = concatenated_weights.shape
239+
with torch.device("meta"):
240+
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
241+
self.to_added_kv.load_state_dict(
242+
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
243+
)
244+
245+
self.fused_projections = True
246+
247+
@torch.no_grad()
248+
def unfuse_projections(self):
249+
if not getattr(self, "fused_projections", False):
250+
return
251+
252+
if hasattr(self, "to_qkv"):
253+
delattr(self, "to_qkv")
254+
if hasattr(self, "to_kv"):
255+
delattr(self, "to_kv")
256+
if hasattr(self, "to_added_kv"):
257+
delattr(self, "to_added_kv")
258+
259+
self.fused_projections = False
260+
261+
def forward(
262+
self,
263+
hidden_states: torch.Tensor,
264+
encoder_hidden_states: Optional[torch.Tensor] = None,
265+
attention_mask: Optional[torch.Tensor] = None,
266+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
267+
**kwargs,
268+
) -> torch.Tensor:
269+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
270+
271+
122272
class WanImageEmbedding(torch.nn.Module):
123273
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
124274
super().__init__()
@@ -247,8 +397,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
247397
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
248398
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
249399

250-
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
251-
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
400+
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
401+
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
252402

253403
return freqs_cos, freqs_sin
254404

@@ -269,33 +419,24 @@ def __init__(
269419

270420
# 1. Self-attention
271421
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
272-
self.attn1 = Attention(
273-
query_dim=dim,
422+
self.attn1 = WanAttention(
423+
dim=dim,
274424
heads=num_heads,
275-
kv_heads=num_heads,
276425
dim_head=dim // num_heads,
277-
qk_norm=qk_norm,
278426
eps=eps,
279-
bias=True,
280-
cross_attention_dim=None,
281-
out_bias=True,
282-
processor=WanAttnProcessor2_0(),
427+
cross_attention_dim_head=None,
428+
processor=WanAttnProcessor(),
283429
)
284430

285431
# 2. Cross-attention
286-
self.attn2 = Attention(
287-
query_dim=dim,
432+
self.attn2 = WanAttention(
433+
dim=dim,
288434
heads=num_heads,
289-
kv_heads=num_heads,
290435
dim_head=dim // num_heads,
291-
qk_norm=qk_norm,
292436
eps=eps,
293-
bias=True,
294-
cross_attention_dim=None,
295-
out_bias=True,
296437
added_kv_proj_dim=added_kv_proj_dim,
297-
added_proj_bias=True,
298-
processor=WanAttnProcessor2_0(),
438+
cross_attention_dim_head=dim // num_heads,
439+
processor=WanAttnProcessor(),
299440
)
300441
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
301442

@@ -332,12 +473,12 @@ def forward(
332473

333474
# 1. Self-attention
334475
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
335-
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
476+
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
336477
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
337478

338479
# 2. Cross-attention
339480
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
340-
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
481+
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
341482
hidden_states = hidden_states + attn_output
342483

343484
# 3. Feed-forward
@@ -350,7 +491,9 @@ def forward(
350491
return hidden_states
351492

352493

353-
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
494+
class WanTransformer3DModel(
495+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
496+
):
354497
r"""
355498
A Transformer model for video-like data used in the Wan model.
356499

0 commit comments

Comments
 (0)