Skip to content

Commit e51da27

Browse files
authored
Merge branch 'main' into docs-device-agnostic
2 parents fe5bb63 + 555b6cc commit e51da27

File tree

19 files changed

+1384
-110
lines changed

19 files changed

+1384
-110
lines changed

docs/source/en/api/pipelines/qwenimage.md

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616

1717
Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.
1818

19-
Check out the model card [here](https://huggingface.co/Qwen/Qwen-Image) to learn more.
19+
Qwen-Image comes in the following variants:
20+
21+
| model type | model id |
22+
|:----------:|:--------:|
23+
| Qwen-Image | [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) |
24+
| Qwen-Image-Edit | [`Qwen/Qwen-Image-Edit`](https://huggingface.co/Qwen/Qwen-Image-Edit) |
2025

2126
<Tip>
2227

@@ -84,16 +89,18 @@ image.save("qwen_fewsteps.png")
8489

8590
</details>
8691

92+
<Tip>
93+
94+
The `guidance_scale` parameter in the pipeline is there to support future guidance-distilled models when they come up. Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance, please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should enable classifier-free guidance computations.
95+
96+
</Tip>
97+
8798
## QwenImagePipeline
8899

89100
[[autodoc]] QwenImagePipeline
90101
- all
91102
- __call__
92103

93-
## QwenImagePipelineOutput
94-
95-
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
96-
97104
## QwenImageImg2ImgPipeline
98105

99106
[[autodoc]] QwenImageImg2ImgPipeline
@@ -105,3 +112,13 @@ image.save("qwen_fewsteps.png")
105112
[[autodoc]] QwenImageInpaintPipeline
106113
- all
107114
- __call__
115+
116+
## QwenImageEditPipeline
117+
118+
[[autodoc]] QwenImageEditPipeline
119+
- all
120+
- __call__
121+
122+
## QwenImagePipelineOutput
123+
124+
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput

examples/dreambooth/README_qwen.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ Now, we can launch training using:
7575
```bash
7676
export MODEL_NAME="Qwen/Qwen-Image"
7777
export INSTANCE_DIR="dog"
78-
export OUTPUT_DIR="trained-sana-lora"
78+
export OUTPUT_DIR="trained-qwenimage-lora"
7979

80-
accelerate launch train_dreambooth_lora_sana.py \
80+
accelerate launch train_dreambooth_lora_qwenimage.py \
8181
--pretrained_model_name_or_path=$MODEL_NAME \
8282
--instance_data_dir=$INSTANCE_DIR \
8383
--output_dir=$OUTPUT_DIR \

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@
489489
"PixArtAlphaPipeline",
490490
"PixArtSigmaPAGPipeline",
491491
"PixArtSigmaPipeline",
492+
"QwenImageEditPipeline",
492493
"QwenImageImg2ImgPipeline",
493494
"QwenImageInpaintPipeline",
494495
"QwenImagePipeline",
@@ -1123,6 +1124,7 @@
11231124
PixArtAlphaPipeline,
11241125
PixArtSigmaPAGPipeline,
11251126
PixArtSigmaPipeline,
1127+
QwenImageEditPipeline,
11261128
QwenImageImg2ImgPipeline,
11271129
QwenImageInpaintPipeline,
11281130
QwenImagePipeline,

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,6 +2080,74 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
20802080

20812081

20822082
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
2083+
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
2084+
if has_lora_unet:
2085+
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
2086+
2087+
def convert_key(key: str) -> str:
2088+
prefix = "transformer_blocks"
2089+
if "." in key:
2090+
base, suffix = key.rsplit(".", 1)
2091+
else:
2092+
base, suffix = key, ""
2093+
2094+
start = f"{prefix}_"
2095+
rest = base[len(start) :]
2096+
2097+
if "." in rest:
2098+
head, tail = rest.split(".", 1)
2099+
tail = "." + tail
2100+
else:
2101+
head, tail = rest, ""
2102+
2103+
# Protected n-grams that must keep their internal underscores
2104+
protected = {
2105+
# pairs
2106+
("to", "q"),
2107+
("to", "k"),
2108+
("to", "v"),
2109+
("to", "out"),
2110+
("add", "q"),
2111+
("add", "k"),
2112+
("add", "v"),
2113+
("txt", "mlp"),
2114+
("img", "mlp"),
2115+
("txt", "mod"),
2116+
("img", "mod"),
2117+
# triplets
2118+
("add", "q", "proj"),
2119+
("add", "k", "proj"),
2120+
("add", "v", "proj"),
2121+
("to", "add", "out"),
2122+
}
2123+
2124+
prot_by_len = {}
2125+
for ng in protected:
2126+
prot_by_len.setdefault(len(ng), set()).add(ng)
2127+
2128+
parts = head.split("_")
2129+
merged = []
2130+
i = 0
2131+
lengths_desc = sorted(prot_by_len.keys(), reverse=True)
2132+
2133+
while i < len(parts):
2134+
matched = False
2135+
for L in lengths_desc:
2136+
if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
2137+
merged.append("_".join(parts[i : i + L]))
2138+
i += L
2139+
matched = True
2140+
break
2141+
if not matched:
2142+
merged.append(parts[i])
2143+
i += 1
2144+
2145+
head_converted = ".".join(merged)
2146+
converted_base = f"{prefix}.{head_converted}{tail}"
2147+
return converted_base + (("." + suffix) if suffix else "")
2148+
2149+
state_dict = {convert_key(k): v for k, v in state_dict.items()}
2150+
20832151
converted_state_dict = {}
20842152
all_keys = list(state_dict.keys())
20852153
down_key = ".lora_down.weight"

src/diffusers/loaders/lora_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6643,7 +6643,8 @@ def lora_state_dict(
66436643
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
66446644

66456645
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
6646-
if has_alphas_in_sd:
6646+
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
6647+
if has_alphas_in_sd or has_lora_unet:
66476648
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
66486649

66496650
out = (state_dict, metadata) if return_lora_metadata else state_dict

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def __init__(
299299
act_fn: Union[str, Tuple[str]] = "silu",
300300
upsample_block_type: str = "pixel_shuffle",
301301
in_shortcut: bool = True,
302+
conv_act_fn: str = "relu",
302303
):
303304
super().__init__()
304305

@@ -349,7 +350,7 @@ def __init__(
349350
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
350351

351352
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
352-
self.conv_act = nn.ReLU()
353+
self.conv_act = get_activation(conv_act_fn)
353354
self.conv_out = None
354355

355356
if layers_per_block[0] > 0:
@@ -414,6 +415,12 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
414415
The normalization type(s) to use in the decoder.
415416
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
416417
The activation function(s) to use in the decoder.
418+
encoder_out_shortcut (`bool`, defaults to `True`):
419+
Whether to use shortcut at the end of the encoder.
420+
decoder_in_shortcut (`bool`, defaults to `True`):
421+
Whether to use shortcut at the beginning of the decoder.
422+
decoder_conv_act_fn (`str`, defaults to `"relu"`):
423+
The activation function to use at the end of the decoder.
417424
scaling_factor (`float`, defaults to `1.0`):
418425
The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
419426
space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
@@ -441,6 +448,9 @@ def __init__(
441448
downsample_block_type: str = "pixel_unshuffle",
442449
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
443450
decoder_act_fns: Union[str, Tuple[str]] = "silu",
451+
encoder_out_shortcut: bool = True,
452+
decoder_in_shortcut: bool = True,
453+
decoder_conv_act_fn: str = "relu",
444454
scaling_factor: float = 1.0,
445455
) -> None:
446456
super().__init__()
@@ -454,6 +464,7 @@ def __init__(
454464
layers_per_block=encoder_layers_per_block,
455465
qkv_multiscales=encoder_qkv_multiscales,
456466
downsample_block_type=downsample_block_type,
467+
out_shortcut=encoder_out_shortcut,
457468
)
458469
self.decoder = Decoder(
459470
in_channels=in_channels,
@@ -466,6 +477,8 @@ def __init__(
466477
norm_type=decoder_norm_types,
467478
act_fn=decoder_act_fns,
468479
upsample_block_type=upsample_block_type,
480+
in_shortcut=decoder_in_shortcut,
481+
conv_act_fn=decoder_conv_act_fn,
469482
)
470483

471484
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)

src/diffusers/models/model_loading_utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -726,23 +726,29 @@ def _caching_allocator_warmup(
726726
very large margin.
727727
"""
728728
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
729-
# Remove disk and cpu devices, and cast to proper torch.device
729+
730+
# Keep only accelerator devices
730731
accelerator_device_map = {
731732
param: torch.device(device)
732733
for param, device in expanded_device_map.items()
733734
if str(device) not in ["cpu", "disk"]
734735
}
735-
total_byte_count = defaultdict(lambda: 0)
736+
if not accelerator_device_map:
737+
return
738+
739+
elements_per_device = defaultdict(int)
736740
for param_name, device in accelerator_device_map.items():
737741
try:
738-
param = model.get_parameter(param_name)
742+
p = model.get_parameter(param_name)
739743
except AttributeError:
740-
param = model.get_buffer(param_name)
741-
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
742-
param_byte_count = param.numel() * param.element_size()
744+
try:
745+
p = model.get_buffer(param_name)
746+
except AttributeError:
747+
raise AttributeError(f"Parameter or buffer with name={param_name} not found in model")
743748
# TODO: account for TP when needed.
744-
total_byte_count[device] += param_byte_count
749+
elements_per_device[device] += p.numel()
745750

746751
# This will kick off the caching allocator to avoid having to Malloc afterwards
747-
for device, byte_count in total_byte_count.items():
748-
_ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False)
752+
for device, elem_count in elements_per_device.items():
753+
warmup_elems = max(1, elem_count // factor)
754+
_ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
2929
from ..modeling_outputs import Transformer2DModelOutput
3030
from ..modeling_utils import ModelMixin
31-
from ..normalization import AdaLayerNormContinuous
31+
from ..normalization import LayerNorm, RMSNorm
3232

3333

3434
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -584,6 +584,38 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
584584
return (freqs.cos(), freqs.sin())
585585

586586

587+
class CogView4AdaLayerNormContinuous(nn.Module):
588+
"""
589+
CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
590+
Linear on conditioning embedding.
591+
"""
592+
593+
def __init__(
594+
self,
595+
embedding_dim: int,
596+
conditioning_embedding_dim: int,
597+
elementwise_affine: bool = True,
598+
eps: float = 1e-5,
599+
bias: bool = True,
600+
norm_type: str = "layer_norm",
601+
):
602+
super().__init__()
603+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
604+
if norm_type == "layer_norm":
605+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
606+
elif norm_type == "rms_norm":
607+
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
608+
else:
609+
raise ValueError(f"unknown norm_type {norm_type}")
610+
611+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
612+
# *** NO SiLU here ***
613+
emb = self.linear(conditioning_embedding.to(x.dtype))
614+
scale, shift = torch.chunk(emb, 2, dim=1)
615+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
616+
return x
617+
618+
587619
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
588620
r"""
589621
Args:
@@ -666,7 +698,7 @@ def __init__(
666698
)
667699

668700
# 4. Output projection
669-
self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
701+
self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
670702
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
671703

672704
self.gradient_checkpointing = False

0 commit comments

Comments
 (0)