Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
42658fa
Add Support for Z-Image.
JerryWu-code Nov 23, 2025
3e74bb2
Reformatting with make style, black & isort.
JerryWu-code Nov 23, 2025
a4b89a0
Remove init, Modify import utils, Merge forward in transformers block…
JerryWu-code Nov 24, 2025
7df350d
modified main model forward, freqs_cis left
ChrisLiu6 Nov 24, 2025
1dd587b
Merge remote-tracking branch 'JerryWu-code/z-image-dev' into fork/Jer…
ChrisLiu6 Nov 24, 2025
aae03cf
refactored to add B dim
ChrisLiu6 Nov 24, 2025
21d8130
fixed stack issue
ChrisLiu6 Nov 24, 2025
e3dfa9e
fixed modulation bug
ChrisLiu6 Nov 24, 2025
a7fa731
fixed modulation bug
ChrisLiu6 Nov 24, 2025
1e0cefe
fix bug
ChrisLiu6 Nov 24, 2025
7adaae8
remove value_from_time_aware_config
ChrisLiu6 Nov 24, 2025
5b4c907
styling
ChrisLiu6 Nov 24, 2025
2bb39f4
Fix neg embed and devide / bug; Reuse pad zero tensor; Turn cat -> re…
JerryWu-code Nov 24, 2025
71e8049
Replace padding with pad_sequence; Add gradient checkpointing.
JerryWu-code Nov 24, 2025
fbf26b7
Fix flash_attn3 in dispatch attn backend by _flash_attn_forward, repl…
JerryWu-code Nov 24, 2025
6c0c059
Fix Docstring and Make Style.
JerryWu-code Nov 24, 2025
28685dd
Revert "Fix flash_attn3 in dispatch attn backend by _flash_attn_forwa…
ChrisLiu6 Nov 25, 2025
8e391b7
update z-image docstring
ChrisLiu6 Nov 25, 2025
3b22e84
Revert attention dispatcher
ChrisLiu6 Nov 25, 2025
3d1a7aa
update z-image docstring
ChrisLiu6 Nov 25, 2025
336c5ce
styling
ChrisLiu6 Nov 25, 2025
38a89ed
Recover attention_dispatch.py with its origin impl, later would speci…
JerryWu-code Nov 25, 2025
69d61e5
Fix prev bug, and support for prompt_embeds pass in args after prompt…
JerryWu-code Nov 25, 2025
549ad57
Merge branch 'z-image-dev-ql' into z-image-dev
JerryWu-code Nov 25, 2025
1dd8f3c
Remove einop dependency.
JerryWu-code Nov 25, 2025
2f2d8c3
Merge branch 'z-image-dev' into z-image
JerryWu-code Nov 25, 2025
a74a0c4
Merge remote-tracking branch 'origin/main' into z-image
JerryWu-code Nov 25, 2025
e49a1f9
remove redundant imports & make fix-copies
ChrisLiu6 Nov 25, 2025
1048d0a
fix import
ChrisLiu6 Nov 25, 2025
266e169
Support for num_images_per_prompt>1; Remove redundant unquote variables.
JerryWu-code Nov 25, 2025
12d2fb2
Fix bugs for num_images_per_prompt with actual batch.
JerryWu-code Nov 25, 2025
9a049f0
Add unit tests for Z-Image.
JerryWu-code Nov 25, 2025
c4e4a57
Refine unitest and skip for cases needed separate test env; Fix compa…
JerryWu-code Nov 25, 2025
6f2808b
Add clean env for test_save_load_float16 separ test; Add Note; Styling.
JerryWu-code Nov 25, 2025
e48060c
Merge current branch into ours for next pr compatibility.
JerryWu-code Nov 26, 2025
27a37cd
Merge branch 'main' into z-image
JerryWu-code Nov 26, 2025
aeed890
Update dtype mentioned by yiyi.
JerryWu-code Nov 26, 2025
e277137
Merge branch 'main' into z-image
JerryWu-code Nov 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/diffusers/models/transformers/transformer_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def timestep_embedding(t, dim, max_period=10000):

def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
weight_dtype = self.mlp[0].weight.dtype
if weight_dtype.is_floating_point:
t_freq = t_freq.to(weight_dtype)
t_emb = self.mlp(t_freq)
return t_emb


Expand Down Expand Up @@ -126,6 +129,10 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
dtype = query.dtype
query, key = query.to(dtype), key.to(dtype)

# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask[:, None, None, :]

# Compute joint attention
hidden_states = dispatch_attention_fn(
query,
Expand Down Expand Up @@ -306,6 +313,10 @@ def __call__(self, ids: torch.Tensor):
if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_cis[0].device != device:
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]

result = []
for i in range(len(self.axes_dims)):
Expand All @@ -317,6 +328,7 @@ def __call__(self, ids: torch.Tensor):
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["ZImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers

@register_to_config
def __init__(
Expand Down Expand Up @@ -553,8 +565,6 @@ def forward(
t = t * self.t_scale
t = self.t_embedder(t)

adaln_input = t

(
x,
cap_feats,
Expand All @@ -572,6 +582,9 @@ def forward(

x = torch.cat(x, dim=0)
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)

# Match t_embedder output dtype to x for layerwise casting compatibility
adaln_input = t.type_as(x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
Expand Down
39 changes: 14 additions & 25 deletions src/diffusers/pipelines/z_image/pipeline_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,16 @@ def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = self._encode_prompt(
prompt=prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
max_sequence_length=max_sequence_length,
)
Expand All @@ -193,8 +188,6 @@ def encode_prompt(
negative_prompt_embeds = self._encode_prompt(
prompt=negative_prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
)
Expand All @@ -206,12 +199,9 @@ def _encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
assert num_images_per_prompt == 1
device = device or self._execution_device

if prompt_embeds is not None:
Expand Down Expand Up @@ -417,8 +407,6 @@ def __call__(
f"Please adjust the width to a multiple of {vae_scale}."
)

assert self.dtype == torch.bfloat16
dtype = self.dtype
device = self._execution_device

self._guidance_scale = guidance_scale
Expand All @@ -434,10 +422,6 @@ def __call__(
else:
batch_size = len(prompt_embeds)

lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)

# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is not None and prompt is None:
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
Expand All @@ -455,11 +439,8 @@ def __call__(
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
dtype=dtype,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)

# 4. Prepare latent variables
Expand All @@ -475,6 +456,14 @@ def __call__(
generator,
latents,
)

# Repeat prompt_embeds for num_images_per_prompt
if num_images_per_prompt > 1:
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
if self.do_classifier_free_guidance and negative_prompt_embeds:
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]

actual_batch_size = batch_size * num_images_per_prompt
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)

# 5. Prepare timesteps
Expand Down Expand Up @@ -523,12 +512,12 @@ def __call__(
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0

if apply_cfg:
latents_typed = latents if latents.dtype == dtype else latents.to(dtype)
latents_typed = latents.to(self.transformer.dtype)
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
timestep_model_input = timestep.repeat(2)
else:
latent_model_input = latents if latents.dtype == dtype else latents.to(dtype)
latent_model_input = latents.to(self.transformer.dtype)
prompt_embeds_model_input = prompt_embeds
timestep_model_input = timestep

Expand All @@ -543,11 +532,11 @@ def __call__(

if apply_cfg:
# Perform CFG
pos_out = model_out_list[:batch_size]
neg_out = model_out_list[batch_size:]
pos_out = model_out_list[:actual_batch_size]
neg_out = model_out_list[actual_batch_size:]

noise_pred = []
for j in range(batch_size):
for j in range(actual_batch_size):
pos = pos_out[j].float()
neg = neg_out[j].float()

Expand Down Expand Up @@ -588,11 +577,11 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

latents = latents.to(dtype)
if output_type == "latent":
image = latents

else:
latents = latents.to(self.vae.dtype)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor

image = self.vae.decode(latents, return_dict=False)[0]
Expand Down
Loading