Skip to content

Conversation

@JerryWu-code
Copy link
Contributor

@JerryWu-code JerryWu-code commented Nov 26, 2025

What does this PR do?

This PR adds unittest for Z-image Series⚡️ as discussed in #12703 (comment). Z-Image-Turbo, the distillation variant of our Z-Image, could generate 1K resolution photorealistic photo while excels in complex en/zh text rendering within 1-second in H800/H100 cards in bf16-precision.

Z-Image is a powerful and highly efficient 6B-parameter image generation model that is friendly for consumer-grade hardware, with strong capabilities in photorealistic image generation, accurate rendering of both complex Chinese and English text, and robust adherence to bilingual instructions.

The technical report and Z-Image-Turbo checkpoint will be released very soon !!!

Thanks for the support of @yiyixuxu.

Fixes # (issue)

  • Fix bugs for num_images_per_prompt with actual batch.
  • Refine unitest and skip for cases needed separate test env; Fix compatibility with unitest in model, mostly precision formating.
  • Add clean environment for test_save_load_float16 separate environment test and add notes for that; Styling.
  • Merge remote main branch for easy integration.

JerryWu-code and others added 30 commits November 23, 2025 19:54
…ryWu-code/z-image

# Conflicts:
#	src/diffusers/models/transformers/transformer_z_image.py
…ace its origin implement; Add DocString in pipeline for that.
…rd, replace its origin implement; Add DocString in pipeline for that."

This reverts commit fbf26b7.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks so much for the PR! I left some small suggestions:)
let me know what you think

t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
weight_dtype = self.mlp[0].weight.dtype
if weight_dtype in [torch.float32, torch.float16, torch.bfloat16]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if weight_dtype in [torch.float32, torch.float16, torch.bfloat16]:
if weight_dtype.is_floating_point:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, initially change this for compatible with precision autocasting, but yeah "is_floating_point" works ~

assert self.dtype == torch.bfloat16
dtype = self.dtype
# assert self.dtype == torch.bfloat16
dtype = self.dtype if hasattr(self, "dtype") and self.dtype is not None else torch.float32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we usually don't use self.dtype , this is the logic behind it https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_utils.py#L578

you can see that it is not very useful when components can sometimes have different dtype
so instead, we try to use specific dtype at each step, e.g. you will see a lot of patterns like this
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py#L257

prompt_embeds = encode_prompt(..., dtype= self.text_encoder.dtype)

or

latents = prepare_latents(...,  dtype=torch.float32)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect !! 😊 We've changed to trying to get dtype of the components of pipeline as your mentioned in the first format. Already styling may ready to ort merge ❤️

@JerryWu-code
Copy link
Contributor Author

Hi yiyi, this commit e277137 is ready to merged !! 🚀🚀🚀 Thanks 😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants