Skip to content

Commit e6d4612

Browse files
Support unittest for Z-image ⚡️ (#12715)
* Add Support for Z-Image. * Reformatting with make style, black & isort. * Remove init, Modify import utils, Merge forward in transformers block, Remove once func in pipeline. * modified main model forward, freqs_cis left * refactored to add B dim * fixed stack issue * fixed modulation bug * fixed modulation bug * fix bug * remove value_from_time_aware_config * styling * Fix neg embed and devide / bug; Reuse pad zero tensor; Turn cat -> repeat; Add hint for attn processor. * Replace padding with pad_sequence; Add gradient checkpointing. * Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that. * Fix Docstring and Make Style. * Revert "Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, replace its origin implement; Add DocString in pipeline for that." This reverts commit fbf26b7. * update z-image docstring * Revert attention dispatcher * update z-image docstring * styling * Recover attention_dispatch.py with its origin impl, later would special commit for fa3 compatibility. * Fix prev bug, and support for prompt_embeds pass in args after prompt pre-encode as List of torch Tensor. * Remove einop dependency. * remove redundant imports & make fix-copies * fix import * Support for num_images_per_prompt>1; Remove redundant unquote variables. * Fix bugs for num_images_per_prompt with actual batch. * Add unit tests for Z-Image. * Refine unitest and skip for cases needed separate test env; Fix compatibility with unitest in model, mostly precision formating. * Add clean env for test_save_load_float16 separ test; Add Note; Styling. * Update dtype mentioned by yiyi. --------- Co-authored-by: liudongyang <[email protected]>
1 parent a88a7b4 commit e6d4612

File tree

3 files changed

+336
-28
lines changed

3 files changed

+336
-28
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ def timestep_embedding(t, dim, max_period=10000):
6969

7070
def forward(self, t):
7171
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
72-
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
72+
weight_dtype = self.mlp[0].weight.dtype
73+
if weight_dtype.is_floating_point:
74+
t_freq = t_freq.to(weight_dtype)
75+
t_emb = self.mlp(t_freq)
7376
return t_emb
7477

7578

@@ -126,6 +129,10 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
126129
dtype = query.dtype
127130
query, key = query.to(dtype), key.to(dtype)
128131

132+
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
133+
if attention_mask is not None and attention_mask.ndim == 2:
134+
attention_mask = attention_mask[:, None, None, :]
135+
129136
# Compute joint attention
130137
hidden_states = dispatch_attention_fn(
131138
query,
@@ -306,6 +313,10 @@ def __call__(self, ids: torch.Tensor):
306313
if self.freqs_cis is None:
307314
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
308315
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
316+
else:
317+
# Ensure freqs_cis are on the same device as ids
318+
if self.freqs_cis[0].device != device:
319+
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
309320

310321
result = []
311322
for i in range(len(self.axes_dims)):
@@ -317,6 +328,7 @@ def __call__(self, ids: torch.Tensor):
317328
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
318329
_supports_gradient_checkpointing = True
319330
_no_split_modules = ["ZImageTransformerBlock"]
331+
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers
320332

321333
@register_to_config
322334
def __init__(
@@ -553,8 +565,6 @@ def forward(
553565
t = t * self.t_scale
554566
t = self.t_embedder(t)
555567

556-
adaln_input = t
557-
558568
(
559569
x,
560570
cap_feats,
@@ -572,6 +582,9 @@ def forward(
572582

573583
x = torch.cat(x, dim=0)
574584
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
585+
586+
# Match t_embedder output dtype to x for layerwise casting compatibility
587+
adaln_input = t.type_as(x)
575588
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
576589
x = list(x.split(x_item_seqlens, dim=0))
577590
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -165,21 +165,16 @@ def encode_prompt(
165165
self,
166166
prompt: Union[str, List[str]],
167167
device: Optional[torch.device] = None,
168-
dtype: Optional[torch.dtype] = None,
169-
num_images_per_prompt: int = 1,
170168
do_classifier_free_guidance: bool = True,
171169
negative_prompt: Optional[Union[str, List[str]]] = None,
172170
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
173171
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
174172
max_sequence_length: int = 512,
175-
lora_scale: Optional[float] = None,
176173
):
177174
prompt = [prompt] if isinstance(prompt, str) else prompt
178175
prompt_embeds = self._encode_prompt(
179176
prompt=prompt,
180177
device=device,
181-
dtype=dtype,
182-
num_images_per_prompt=num_images_per_prompt,
183178
prompt_embeds=prompt_embeds,
184179
max_sequence_length=max_sequence_length,
185180
)
@@ -193,8 +188,6 @@ def encode_prompt(
193188
negative_prompt_embeds = self._encode_prompt(
194189
prompt=negative_prompt,
195190
device=device,
196-
dtype=dtype,
197-
num_images_per_prompt=num_images_per_prompt,
198191
prompt_embeds=negative_prompt_embeds,
199192
max_sequence_length=max_sequence_length,
200193
)
@@ -206,12 +199,9 @@ def _encode_prompt(
206199
self,
207200
prompt: Union[str, List[str]],
208201
device: Optional[torch.device] = None,
209-
dtype: Optional[torch.dtype] = None,
210-
num_images_per_prompt: int = 1,
211202
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
212203
max_sequence_length: int = 512,
213204
) -> List[torch.FloatTensor]:
214-
assert num_images_per_prompt == 1
215205
device = device or self._execution_device
216206

217207
if prompt_embeds is not None:
@@ -417,8 +407,6 @@ def __call__(
417407
f"Please adjust the width to a multiple of {vae_scale}."
418408
)
419409

420-
assert self.dtype == torch.bfloat16
421-
dtype = self.dtype
422410
device = self._execution_device
423411

424412
self._guidance_scale = guidance_scale
@@ -434,10 +422,6 @@ def __call__(
434422
else:
435423
batch_size = len(prompt_embeds)
436424

437-
lora_scale = (
438-
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
439-
)
440-
441425
# If prompt_embeds is provided and prompt is None, skip encoding
442426
if prompt_embeds is not None and prompt is None:
443427
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -455,11 +439,8 @@ def __call__(
455439
do_classifier_free_guidance=self.do_classifier_free_guidance,
456440
prompt_embeds=prompt_embeds,
457441
negative_prompt_embeds=negative_prompt_embeds,
458-
dtype=dtype,
459442
device=device,
460-
num_images_per_prompt=num_images_per_prompt,
461443
max_sequence_length=max_sequence_length,
462-
lora_scale=lora_scale,
463444
)
464445

465446
# 4. Prepare latent variables
@@ -475,6 +456,14 @@ def __call__(
475456
generator,
476457
latents,
477458
)
459+
460+
# Repeat prompt_embeds for num_images_per_prompt
461+
if num_images_per_prompt > 1:
462+
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
463+
if self.do_classifier_free_guidance and negative_prompt_embeds:
464+
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
465+
466+
actual_batch_size = batch_size * num_images_per_prompt
478467
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
479468

480469
# 5. Prepare timesteps
@@ -523,12 +512,12 @@ def __call__(
523512
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
524513

525514
if apply_cfg:
526-
latents_typed = latents if latents.dtype == dtype else latents.to(dtype)
515+
latents_typed = latents.to(self.transformer.dtype)
527516
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
528517
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
529518
timestep_model_input = timestep.repeat(2)
530519
else:
531-
latent_model_input = latents if latents.dtype == dtype else latents.to(dtype)
520+
latent_model_input = latents.to(self.transformer.dtype)
532521
prompt_embeds_model_input = prompt_embeds
533522
timestep_model_input = timestep
534523

@@ -543,11 +532,11 @@ def __call__(
543532

544533
if apply_cfg:
545534
# Perform CFG
546-
pos_out = model_out_list[:batch_size]
547-
neg_out = model_out_list[batch_size:]
535+
pos_out = model_out_list[:actual_batch_size]
536+
neg_out = model_out_list[actual_batch_size:]
548537

549538
noise_pred = []
550-
for j in range(batch_size):
539+
for j in range(actual_batch_size):
551540
pos = pos_out[j].float()
552541
neg = neg_out[j].float()
553542

@@ -588,11 +577,11 @@ def __call__(
588577
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
589578
progress_bar.update()
590579

591-
latents = latents.to(dtype)
592580
if output_type == "latent":
593581
image = latents
594582

595583
else:
584+
latents = latents.to(self.vae.dtype)
596585
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
597586

598587
image = self.vae.decode(latents, return_dict=False)[0]

0 commit comments

Comments
 (0)