Skip to content

Commit 215f947

Browse files
a-r-r-o-wDN6
authored andcommitted
[refactor] Wan single file implementation (huggingface#11918)
* update * update * update * add coauthor Co-Authored-By: Dhruv Nair <[email protected]> * improve test * handle ip adapter params correctly * fix chroma qkv fusion test * fix fastercache implementation * remove set_attention_backend related code * fix more tests * fight more tests * add back set_attention_backend * update * update * make style * make fix-copies * make ip adapter processor compatible with attention dispatcher * refactor chroma as well * attnetion dispatcher support * remove transpose; fix rope shape * remove rmsnorm assert * minify and deprecate npu/xla processors * remove rmsnorm assert * minify and deprecate npu/xla processors * update * Update src/diffusers/models/transformers/transformer_wan.py --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent ba84995 commit 215f947

File tree

2 files changed

+237
-81
lines changed

2 files changed

+237
-81
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 222 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25-
from ..attention import FeedForward
26-
from ..attention_processor import Attention
24+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
25+
from ...utils.torch_utils import maybe_allow_in_graph
26+
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
27+
from ..attention_dispatch import dispatch_attention_fn
2728
from ..cache_utils import CacheMixin
2829
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
2930
from ..modeling_outputs import Transformer2DModelOutput
@@ -34,72 +35,117 @@
3435
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3536

3637

37-
class WanAttnProcessor2_0:
38+
def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
39+
# encoder_hidden_states is only passed for cross-attention
40+
if encoder_hidden_states is None:
41+
encoder_hidden_states = hidden_states
42+
43+
if attn.fused_projections:
44+
if attn.cross_attention_dim_head is None:
45+
# In self-attention layers, we can fuse the entire QKV projection into a single linear
46+
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
47+
else:
48+
# In cross-attention layers, we can only fuse the KV projections into a single linear
49+
query = attn.to_q(hidden_states)
50+
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
51+
else:
52+
query = attn.to_q(hidden_states)
53+
key = attn.to_k(encoder_hidden_states)
54+
value = attn.to_v(encoder_hidden_states)
55+
return query, key, value
56+
57+
58+
def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
59+
if attn.fused_projections:
60+
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
61+
else:
62+
key_img = attn.add_k_proj(encoder_hidden_states_img)
63+
value_img = attn.add_v_proj(encoder_hidden_states_img)
64+
return key_img, value_img
65+
66+
67+
class WanAttnProcessor:
68+
_attention_backend = None
69+
3870
def __init__(self):
3971
if not hasattr(F, "scaled_dot_product_attention"):
40-
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
72+
raise ImportError(
73+
"WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
74+
)
4175

4276
def __call__(
4377
self,
44-
attn: Attention,
78+
attn: "WanAttention",
4579
hidden_states: torch.Tensor,
4680
encoder_hidden_states: Optional[torch.Tensor] = None,
4781
attention_mask: Optional[torch.Tensor] = None,
48-
rotary_emb: Optional[torch.Tensor] = None,
82+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
4983
) -> torch.Tensor:
5084
encoder_hidden_states_img = None
5185
if attn.add_k_proj is not None:
5286
# 512 is the context length of the text encoder, hardcoded for now
5387
image_context_length = encoder_hidden_states.shape[1] - 512
5488
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
5589
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
56-
if encoder_hidden_states is None:
57-
encoder_hidden_states = hidden_states
5890

59-
query = attn.to_q(hidden_states)
60-
key = attn.to_k(encoder_hidden_states)
61-
value = attn.to_v(encoder_hidden_states)
91+
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
6292

63-
if attn.norm_q is not None:
64-
query = attn.norm_q(query)
65-
if attn.norm_k is not None:
66-
key = attn.norm_k(key)
93+
query = attn.norm_q(query)
94+
key = attn.norm_k(key)
6795

68-
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
69-
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
70-
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
96+
query = query.unflatten(2, (attn.heads, -1))
97+
key = key.unflatten(2, (attn.heads, -1))
98+
value = value.unflatten(2, (attn.heads, -1))
7199

72100
if rotary_emb is not None:
73101

74-
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
75-
dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
76-
x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
77-
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
78-
return x_out.type_as(hidden_states)
102+
def apply_rotary_emb(
103+
hidden_states: torch.Tensor,
104+
freqs_cos: torch.Tensor,
105+
freqs_sin: torch.Tensor,
106+
):
107+
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
108+
cos = freqs_cos[..., 0::2]
109+
sin = freqs_sin[..., 1::2]
110+
out = torch.empty_like(hidden_states)
111+
out[..., 0::2] = x1 * cos - x2 * sin
112+
out[..., 1::2] = x1 * sin + x2 * cos
113+
return out.type_as(hidden_states)
79114

80115
query = apply_rotary_emb(query, rotary_emb)
81116
key = apply_rotary_emb(key, rotary_emb)
82117

83118
# I2V task
84119
hidden_states_img = None
85120
if encoder_hidden_states_img is not None:
86-
key_img = attn.add_k_proj(encoder_hidden_states_img)
121+
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
87122
key_img = attn.norm_added_k(key_img)
88-
value_img = attn.add_v_proj(encoder_hidden_states_img)
89-
90-
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
91-
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
92123

93-
hidden_states_img = F.scaled_dot_product_attention(
94-
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
124+
key_img = key_img.unflatten(2, (attn.heads, -1))
125+
value_img = value_img.unflatten(2, (attn.heads, -1))
126+
127+
hidden_states_img = dispatch_attention_fn(
128+
query,
129+
key_img,
130+
value_img,
131+
attn_mask=None,
132+
dropout_p=0.0,
133+
is_causal=False,
134+
backend=self._attention_backend,
95135
)
96-
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
136+
hidden_states_img = hidden_states_img.flatten(2, 3)
97137
hidden_states_img = hidden_states_img.type_as(query)
98138

99-
hidden_states = F.scaled_dot_product_attention(
100-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
139+
hidden_states = dispatch_attention_fn(
140+
query,
141+
key,
142+
value,
143+
attn_mask=attention_mask,
144+
dropout_p=0.0,
145+
is_causal=False,
146+
backend=self._attention_backend,
101147
)
102-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
148+
hidden_states = hidden_states.flatten(2, 3)
103149
hidden_states = hidden_states.type_as(query)
104150

105151
if hidden_states_img is not None:
@@ -110,6 +156,119 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
110156
return hidden_states
111157

112158

159+
class WanAttnProcessor2_0:
160+
def __new__(cls, *args, **kwargs):
161+
deprecation_message = (
162+
"The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
163+
"Please use WanAttnProcessor instead. "
164+
)
165+
deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
166+
return WanAttnProcessor(*args, **kwargs)
167+
168+
169+
class WanAttention(torch.nn.Module, AttentionModuleMixin):
170+
_default_processor_cls = WanAttnProcessor
171+
_available_processors = [WanAttnProcessor]
172+
173+
def __init__(
174+
self,
175+
dim: int,
176+
heads: int = 8,
177+
dim_head: int = 64,
178+
eps: float = 1e-5,
179+
dropout: float = 0.0,
180+
added_kv_proj_dim: Optional[int] = None,
181+
cross_attention_dim_head: Optional[int] = None,
182+
processor=None,
183+
):
184+
super().__init__()
185+
186+
self.inner_dim = dim_head * heads
187+
self.heads = heads
188+
self.added_kv_proj_dim = added_kv_proj_dim
189+
self.cross_attention_dim_head = cross_attention_dim_head
190+
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
191+
192+
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
193+
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
194+
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
195+
self.to_out = torch.nn.ModuleList(
196+
[
197+
torch.nn.Linear(self.inner_dim, dim, bias=True),
198+
torch.nn.Dropout(dropout),
199+
]
200+
)
201+
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
202+
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
203+
204+
self.add_k_proj = self.add_v_proj = None
205+
if added_kv_proj_dim is not None:
206+
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
207+
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
208+
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
209+
210+
self.set_processor(processor)
211+
212+
def fuse_projections(self):
213+
if getattr(self, "fused_projections", False):
214+
return
215+
216+
if self.cross_attention_dim_head is None:
217+
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
218+
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
219+
out_features, in_features = concatenated_weights.shape
220+
with torch.device("meta"):
221+
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
222+
self.to_qkv.load_state_dict(
223+
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
224+
)
225+
else:
226+
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
227+
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
228+
out_features, in_features = concatenated_weights.shape
229+
with torch.device("meta"):
230+
self.to_kv = nn.Linear(in_features, out_features, bias=True)
231+
self.to_kv.load_state_dict(
232+
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
233+
)
234+
235+
if self.added_kv_proj_dim is not None:
236+
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
237+
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
238+
out_features, in_features = concatenated_weights.shape
239+
with torch.device("meta"):
240+
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
241+
self.to_added_kv.load_state_dict(
242+
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
243+
)
244+
245+
self.fused_projections = True
246+
247+
@torch.no_grad()
248+
def unfuse_projections(self):
249+
if not getattr(self, "fused_projections", False):
250+
return
251+
252+
if hasattr(self, "to_qkv"):
253+
delattr(self, "to_qkv")
254+
if hasattr(self, "to_kv"):
255+
delattr(self, "to_kv")
256+
if hasattr(self, "to_added_kv"):
257+
delattr(self, "to_added_kv")
258+
259+
self.fused_projections = False
260+
261+
def forward(
262+
self,
263+
hidden_states: torch.Tensor,
264+
encoder_hidden_states: Optional[torch.Tensor] = None,
265+
attention_mask: Optional[torch.Tensor] = None,
266+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
267+
**kwargs,
268+
) -> torch.Tensor:
269+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
270+
271+
113272
class WanImageEmbedding(torch.nn.Module):
114273
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
115274
super().__init__()
@@ -217,11 +376,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
217376
dim=1,
218377
)
219378

220-
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
221-
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
222-
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
223-
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
224-
return freqs
379+
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
380+
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
381+
382+
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
383+
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
384+
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
385+
386+
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
387+
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
388+
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
389+
390+
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
391+
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
392+
393+
return freqs_cos, freqs_sin
225394

226395

227396
class WanTransformerBlock(nn.Module):
@@ -239,33 +408,24 @@ def __init__(
239408

240409
# 1. Self-attention
241410
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
242-
self.attn1 = Attention(
243-
query_dim=dim,
411+
self.attn1 = WanAttention(
412+
dim=dim,
244413
heads=num_heads,
245-
kv_heads=num_heads,
246414
dim_head=dim // num_heads,
247-
qk_norm=qk_norm,
248415
eps=eps,
249-
bias=True,
250-
cross_attention_dim=None,
251-
out_bias=True,
252-
processor=WanAttnProcessor2_0(),
416+
cross_attention_dim_head=None,
417+
processor=WanAttnProcessor(),
253418
)
254419

255420
# 2. Cross-attention
256-
self.attn2 = Attention(
257-
query_dim=dim,
421+
self.attn2 = WanAttention(
422+
dim=dim,
258423
heads=num_heads,
259-
kv_heads=num_heads,
260424
dim_head=dim // num_heads,
261-
qk_norm=qk_norm,
262425
eps=eps,
263-
bias=True,
264-
cross_attention_dim=None,
265-
out_bias=True,
266426
added_kv_proj_dim=added_kv_proj_dim,
267-
added_proj_bias=True,
268-
processor=WanAttnProcessor2_0(),
427+
cross_attention_dim_head=dim // num_heads,
428+
processor=WanAttnProcessor(),
269429
)
270430
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
271431

@@ -302,12 +462,12 @@ def forward(
302462

303463
# 1. Self-attention
304464
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
305-
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
465+
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
306466
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
307467

308468
# 2. Cross-attention
309469
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
310-
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
470+
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
311471
hidden_states = hidden_states + attn_output
312472

313473
# 3. Feed-forward
@@ -320,7 +480,9 @@ def forward(
320480
return hidden_states
321481

322482

323-
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
483+
class WanTransformer3DModel(
484+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
485+
):
324486
r"""
325487
A Transformer model for video-like data used in the Wan model.
326488

0 commit comments

Comments
 (0)