diff --git a/docs/diffusers/_toctree.yml b/docs/diffusers/_toctree.yml index c8a8b720dd..5569dcad2b 100644 --- a/docs/diffusers/_toctree.yml +++ b/docs/diffusers/_toctree.yml @@ -233,6 +233,8 @@ title: ControlNetUnionModel - local: api/models/controlnet_flux title: FluxControlNetModel + - local: api/models/hidream_image_transformer + title: HiDreamImageTransformer2DModel - local: api/models/controlnet_hunyuandit title: HunyuanDiT2DControlNetModel - local: api/models/controlnet_sana @@ -419,6 +421,8 @@ title: Flux - local: api/pipelines/control_flux_inpaint title: FluxControlInpaint + - local: api/pipelines/hidream + title: HiDream-I1 - local: api/pipelines/framepack title: Framepack - local: api/pipelines/hunyuandit diff --git a/docs/diffusers/api/models/hidream_image_transformer.md b/docs/diffusers/api/models/hidream_image_transformer.md new file mode 100644 index 0000000000..fe73cd43a8 --- /dev/null +++ b/docs/diffusers/api/models/hidream_image_transformer.md @@ -0,0 +1,27 @@ + + +# HiDreamImageTransformer2DModel + +A Transformer model for image-like data from [HiDream-I1](https://huggingface.co/HiDream-ai). + +The model can be loaded with the following code snippet. + +```python +import mindspore +from mindone.diffusers import HiDreamImageTransformer2DModel + +transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", mindspore_dtype=mindspore.bfloat16) +``` + +::: mindone.diffusers.HiDreamImageTransformer2DModel + +::: mindone.diffusers.models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/diffusers/api/pipelines/hidream.md b/docs/diffusers/api/pipelines/hidream.md new file mode 100644 index 0000000000..8c84347bea --- /dev/null +++ b/docs/diffusers/api/pipelines/hidream.md @@ -0,0 +1,37 @@ + + +# HiDreamImage + +[HiDream-I1](https://huggingface.co/HiDream-ai) by HiDream.ai + +!!! tip + + Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + +## Available models + +The following models are available for the [`HiDreamImagePipeline`](text-to-image) pipeline: + +| Model name | Description | +|:---|:---| +| [`HiDream-ai/HiDream-I1-Full`](https://huggingface.co/HiDream-ai/HiDream-I1-Full) | - | +| [`HiDream-ai/HiDream-I1-Dev`](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) | - | +| [`HiDream-ai/HiDream-I1-Fast`](https://huggingface.co/HiDream-ai/HiDream-I1-Fast) | - | + + +::: mindone.diffusers.HiDreamImagePipeline + +::: mindone.diffusers.pipelines.hidream_image.pipeline_output.HiDreamImagePipelineOutput diff --git a/mindone/diffusers/__init__.py b/mindone/diffusers/__init__.py index ae45d11ca0..425242abe0 100644 --- a/mindone/diffusers/__init__.py +++ b/mindone/diffusers/__init__.py @@ -85,6 +85,7 @@ "FluxControlNetModel", "FluxMultiControlNetModel", "FluxTransformer2DModel", + "HiDreamImageTransformer2DModel", "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", @@ -197,6 +198,7 @@ "FluxKontextInpaintPipeline", "FluxPipeline", "FluxPriorReduxPipeline", + "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", @@ -457,6 +459,7 @@ FluxControlNetModel, FluxMultiControlNetModel, FluxTransformer2DModel, + HiDreamImageTransformer2DModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, @@ -580,6 +583,7 @@ FluxKontextPipeline, FluxPipeline, FluxPriorReduxPipeline, + HiDreamImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, HunyuanDiTPipeline, diff --git a/mindone/diffusers/models/__init__.py b/mindone/diffusers/models/__init__.py index f1c501b423..b5d234e5af 100644 --- a/mindone/diffusers/models/__init__.py +++ b/mindone/diffusers/models/__init__.py @@ -75,6 +75,7 @@ "transformers.transformer_cosmos": ["CosmosTransformer3DModel"], "transformers.transformer_easyanimate": ["EasyAnimateTransformer3DModel"], "transformers.transformer_flux": ["FluxTransformer2DModel"], + "transformers.transformer_hidream_image": ["HiDreamImageTransformer2DModel"], "transformers.transformer_hunyuan_video": ["HunyuanVideoTransformer3DModel"], "transformers.transformer_hunyuan_video_framepack": ["HunyuanVideoFramepackTransformer3DModel"], "transformers.transformer_ltx": ["LTXVideoTransformer3DModel"], @@ -151,6 +152,7 @@ DualTransformer2DModel, EasyAnimateTransformer3DModel, FluxTransformer2DModel, + HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, diff --git a/mindone/diffusers/models/transformers/__init__.py b/mindone/diffusers/models/transformers/__init__.py index 5c04faffc7..b65ecf1bbe 100644 --- a/mindone/diffusers/models/transformers/__init__.py +++ b/mindone/diffusers/models/transformers/__init__.py @@ -21,6 +21,7 @@ from .transformer_cosmos import CosmosTransformer3DModel from .transformer_easyanimate import EasyAnimateTransformer3DModel from .transformer_flux import FluxTransformer2DModel +from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel diff --git a/mindone/diffusers/models/transformers/transformer_hidream_image.py b/mindone/diffusers/models/transformers/transformer_hidream_image.py new file mode 100644 index 0000000000..36c0116f02 --- /dev/null +++ b/mindone/diffusers/models/transformers/transformer_hidream_image.py @@ -0,0 +1,960 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import mindspore as ms +from mindspore import mint, nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...models.modeling_outputs import Transformer2DModelOutput +from ...models.modeling_utils import ModelMixin +from ...utils import deprecate, logging +from ..attention import Attention +from ..embeddings import TimestepEmbedding, Timesteps +from ..normalization import RMSNorm + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HiDreamImageFeedForwardSwiGLU(nn.Cell): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = mint.nn.Linear(dim, hidden_dim, bias=False) + self.w2 = mint.nn.Linear(hidden_dim, dim, bias=False) + self.w3 = mint.nn.Linear(dim, hidden_dim, bias=False) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + return self.w2(mint.functional.silu(self.w1(x)) * self.w3(x)) + + +class HiDreamImagePooledEmbed(nn.Cell): + def __init__(self, text_emb_dim, hidden_size): + super().__init__() + self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size) + + def construct(self, pooled_embed: ms.Tensor) -> ms.Tensor: + return self.pooled_embedder(pooled_embed) + + +class HiDreamImageTimestepEmbed(nn.Cell): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) + + def construct(self, timesteps: ms.Tensor, wdtype: Optional[ms.Type] = None): + t_emb = self.time_proj(timesteps).to(dtype=wdtype) + t_emb = self.timestep_embedder(t_emb) + return t_emb + + +class HiDreamImageOutEmbed(nn.Cell): + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = mint.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = mint.nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.SequentialCell( + mint.nn.SiLU(), mint.nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def construct(self, hidden_states: ms.Tensor, temb: ms.Tensor) -> ms.Tensor: + shift, scale = self.adaLN_modulation(temb).chunk(2, dim=1) + hidden_states = self.norm_final(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + hidden_states = self.linear(hidden_states) + return hidden_states + + +class HiDreamImagePatchEmbed(nn.Cell): + def __init__( + self, + patch_size=2, + in_channels=4, + out_channels=1024, + ): + super().__init__() + self.patch_size = patch_size + self.out_channels = out_channels + self.proj = mint.nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True) + + def construct(self, latent): + latent = self.proj(latent) + return latent + + +def rope(pos: ms.Tensor, dim: int, theta: int) -> ms.Tensor: + assert dim % 2 == 0, "The dimension must be even." + + # TODO check here, notes: hf use ms.float32 if npu else + # dtype = ms.float32 if (is_mps or is_npu) else ms.float64 + dtype = ms.float32 + + scale = mint.arange(0, dim, 2, dtype=dtype) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = mint.einsum("...n,d->...nd", pos, omega) + cos_out = mint.cos(out) + sin_out = mint.sin(out) + + stacked_out = mint.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) + return out.float() + + +class HiDreamImageEmbedND(nn.Cell): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def construct(self, ids: ms.Tensor) -> ms.Tensor: + n_axes = ids.shape[-1] + emb = mint.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(2) + + +def apply_rope(xq: ms.Tensor, xk: ms.Tensor, freqs_cis: ms.Tensor) -> Tuple[ms.Tensor, ms.Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +class HiDreamAttention(Attention): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + upcast_softmax: bool = False, + scale_qk: bool = True, + eps: float = 1e-5, + processor=None, + out_dim: int = None, + single: bool = False, + ): + super(Attention, self).__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.out_dim = out_dim if out_dim is not None else query_dim + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + self.single = single + + self.to_q = mint.nn.Linear(query_dim, self.inner_dim) + self.to_k = mint.nn.Linear(self.inner_dim, self.inner_dim) + self.to_v = mint.nn.Linear(self.inner_dim, self.inner_dim) + self.to_out = mint.nn.Linear(self.inner_dim, self.out_dim) + self.q_rms_norm = RMSNorm(self.inner_dim, eps) + self.k_rms_norm = RMSNorm(self.inner_dim, eps) + + if not single: + self.to_q_t = mint.nn.Linear(query_dim, self.inner_dim) + self.to_k_t = mint.nn.Linear(self.inner_dim, self.inner_dim) + self.to_v_t = mint.nn.Linear(self.inner_dim, self.inner_dim) + self.to_out_t = mint.nn.Linear(self.inner_dim, self.out_dim) + self.q_rms_norm_t = RMSNorm(self.inner_dim, eps) + self.k_rms_norm_t = RMSNorm(self.inner_dim, eps) + + self.set_processor(processor) + + def scaled_dot_product_attention( + self, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + ): + if query.dtype in (ms.float16, ms.bfloat16): + return self.flash_attention_op(query, key, value, attn_mask, keep_prob=1 - dropout_p, scale=scale) + else: + return self.flash_attention_op( + query.to(ms.float16), + key.to(ms.float16), + value.to(ms.float16), + attn_mask, + keep_prob=1 - dropout_p, + scale=scale, + ).to(query.dtype) + + def flash_attention_op( + self, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + keep_prob: float = 1.0, + scale: Optional[float] = None, + ): + # For most scenarios, qkv has been processed into a BNSD layout before sdp + input_layout = "BNSD" + head_num = self.heads + + # In case qkv is 3-dim after `head_to_batch_dim` + if query.ndim == 3: + input_layout = "BSH" + head_num = 1 + + # process `attn_mask` as logic is different between PyTorch and Mindspore + # In MindSpore, False indicates retention and True indicates discard, in PyTorch it is the opposite + if attn_mask is not None: + attn_mask = mint.logical_not(attn_mask) if attn_mask.dtype == ms.bool_ else attn_mask.bool() + attn_mask = mint.broadcast_to( + attn_mask, (attn_mask.shape[0], attn_mask.shape[1], query.shape[-2], key.shape[-2]) + )[:, :1, :, :] + + return ops.operations.nn_ops.FlashAttentionScore( + head_num=head_num, keep_prob=keep_prob, scale_value=scale or self.scale, input_layout=input_layout + )(query, key, value, None, None, None, attn_mask)[3] + + def construct( + self, + norm_hidden_states: ms.Tensor, + hidden_states_masks: ms.Tensor = None, + norm_encoder_hidden_states: ms.Tensor = None, + image_rotary_emb: ms.Tensor = None, + ) -> ms.Tensor: + return self.processor( + self, + hidden_states=norm_hidden_states, + hidden_states_masks=hidden_states_masks, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + +class HiDreamAttnProcessor: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __call__( + self, + attn: HiDreamAttention, + hidden_states: ms.Tensor, + hidden_states_masks: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + image_rotary_emb: ms.Tensor = None, + *args, + **kwargs, + ) -> ms.Tensor: + dtype = hidden_states.dtype + batch_size = hidden_states.shape[0] + + query_i = attn.q_rms_norm(attn.to_q(hidden_states)).to(dtype=dtype) + key_i = attn.k_rms_norm(attn.to_k(hidden_states)).to(dtype=dtype) + value_i = attn.to_v(hidden_states) + + inner_dim = key_i.shape[-1] + head_dim = inner_dim // attn.heads + + query_i = query_i.view(batch_size, -1, attn.heads, head_dim) + key_i = key_i.view(batch_size, -1, attn.heads, head_dim) + value_i = value_i.view(batch_size, -1, attn.heads, head_dim) + if hidden_states_masks is not None: + key_i = key_i * hidden_states_masks.view(batch_size, -1, 1, 1) + + if not attn.single: + query_t = attn.q_rms_norm_t(attn.to_q_t(encoder_hidden_states)).to(dtype=dtype) + key_t = attn.k_rms_norm_t(attn.to_k_t(encoder_hidden_states)).to(dtype=dtype) + value_t = attn.to_v_t(encoder_hidden_states) + + query_t = query_t.view(batch_size, -1, attn.heads, head_dim) + key_t = key_t.view(batch_size, -1, attn.heads, head_dim) + value_t = value_t.view(batch_size, -1, attn.heads, head_dim) + + num_image_tokens = query_i.shape[1] + num_text_tokens = query_t.shape[1] + query = mint.cat([query_i, query_t], dim=1) + key = mint.cat([key_i, key_t], dim=1) + value = mint.cat([value_i, value_t], dim=1) + else: + query = query_i + key = key_i + value = value_i + + if query.shape[-1] == image_rotary_emb.shape[-3] * 2: + query, key = apply_rope(query, key, image_rotary_emb) + + else: + query_1, query_2 = query.chunk(2, dim=-1) + key_1, key_2 = key.chunk(2, dim=-1) + query_1, key_1 = apply_rope(query_1, key_1, image_rotary_emb) + query = mint.cat([query_1, query_2], dim=-1) + key = mint.cat([key_1, key_2], dim=-1) + + hidden_states = attn.scaled_dot_product_attention( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if not attn.single: + hidden_states_i, hidden_states_t = mint.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) + hidden_states_i = attn.to_out(hidden_states_i) + hidden_states_t = attn.to_out_t(hidden_states_t) + return hidden_states_i, hidden_states_t + else: + hidden_states = attn.to_out(hidden_states) + return hidden_states + + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MoEGate(nn.Cell): + def __init__( + self, + embed_dim, + num_routed_experts=4, + num_activated_experts=2, + aux_loss_alpha=0.01, + _force_inference_output=False, + ): + super().__init__() + self.top_k = num_activated_experts + self.n_routed_experts = num_routed_experts + + self.scoring_func = "softmax" + self.alpha = aux_loss_alpha + self.seq_aux = False + + # topk selection algorithm + self.norm_topk_prob = False + self.gating_dim = embed_dim + self.weight = ms.Parameter(mint.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5) + + self._force_inference_output = _force_inference_output + + def construct(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + # compute gating score + hidden_states = hidden_states.view(-1, h) + logits = mint.functional.linear(hidden_states, self.weight, None) + if self.scoring_func == "softmax": + scores = logits.softmax(axis=-1) + else: + raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") + + # select top-k experts + topk_weight, topk_idx = mint.topk(scores, k=self.top_k, dim=-1, sorted=False) + + # norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + + # expert-level computation auxiliary loss + if self.training and self.alpha > 0.0 and not self._force_inference_output: + scores_for_aux = scores + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = mint.zeros((bsz, self.n_routed_experts)) + ce.scatter_add_(1, topk_idx_for_aux_loss, mint.ones(bsz, seq_len * aux_topk)).div_( + seq_len * aux_topk / self.n_routed_experts + ) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha + else: + mask_ce = mint.functional.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) + ce = mask_ce.float().mean(0) + + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MOEFeedForwardSwiGLU(nn.Cell): + def __init__( + self, + dim: int, + hidden_dim: int, + num_routed_experts: int, + num_activated_experts: int, + _force_inference_output: bool = False, + ): + super().__init__() + self.shared_experts = HiDreamImageFeedForwardSwiGLU(dim, hidden_dim // 2) + self.experts = nn.CellList([HiDreamImageFeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]) + self._force_inference_output = _force_inference_output + self.gate = MoEGate( + embed_dim=dim, + num_routed_experts=num_routed_experts, + num_activated_experts=num_activated_experts, + _force_inference_output=_force_inference_output, + ) + self.num_activated_experts = num_activated_experts + + def construct(self, x): + wtype = x.dtype + identity = x + orig_shape = x.shape + topk_idx, topk_weight, aux_loss = self.gate(x) + x = x.view(-1, x.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training and not self._force_inference_output: + x = x.repeat_interleave(self.num_activated_experts, dim=0) + y = mint.empty_like(x, dtype=wtype) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape).to(dtype=wtype) + # y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + y = y + self.shared_experts(identity) + return y + + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = mint.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().numpy().cumsum(0) + token_idxs = idxs // self.num_activated_experts + for i, end_idx in enumerate(tokens_per_expert): + end_idx = int(end_idx) + start_idx = 0 if i == 0 else int(tokens_per_expert[i - 1]) + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # for fp16 and other dtype + expert_cache = expert_cache.to(expert_out.dtype) + # FIXME: mindspore lacks tensor.scatter_reduce_, use an scatter_add_ instead, which is safe here. + # expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum") + expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out) + return expert_cache + + +class TextProjection(nn.Cell): + def __init__(self, in_features, hidden_size): + super().__init__() + self.linear = mint.nn.Linear(in_features=in_features, out_features=hidden_size, bias=False) + + def construct(self, caption): + hidden_states = self.linear(caption) + return hidden_states + + +class HiDreamImageSingleTransformerBlock(nn.Cell): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + _force_inference_output: bool = False, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.SequentialCell(mint.nn.SiLU(), mint.nn.Linear(dim, 6 * dim, bias=True)) + + # 1. Attention + self.norm1_i = mint.nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor=HiDreamAttnProcessor(), + single=True, + ) + + # 3. Feed-forward + self.norm3_i = mint.nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim=dim, + hidden_dim=4 * dim, + num_routed_experts=num_routed_experts, + num_activated_experts=num_activated_experts, + _force_inference_output=_force_inference_output, + ) + else: + self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim) + + def construct( + self, + hidden_states: ms.Tensor, + hidden_states_masks: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + temb: Optional[ms.Tensor] = None, + image_rotary_emb: ms.Tensor = None, + ) -> ms.Tensor: + wtype = hidden_states.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = self.adaLN_modulation(temb)[ + :, None + ].chunk(6, dim=-1) + + # 1. MM-Attention + norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype) + norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i + attn_output_i = self.attn1( + norm_hidden_states, + hidden_states_masks, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = gate_msa_i * attn_output_i + hidden_states + + # 2. Feed-forward + norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i + ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype)) + hidden_states = ff_output_i + hidden_states + return hidden_states + + +class HiDreamImageTransformerBlock(nn.Cell): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + _force_inference_output: bool = False, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.SequentialCell(mint.nn.SiLU(), mint.nn.Linear(dim, 12 * dim, bias=True)) + + # 1. Attention + self.norm1_i = mint.nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False) + self.norm1_t = mint.nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor=HiDreamAttnProcessor(), + single=False, + ) + + # 3. Feed-forward + self.norm3_i = mint.nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim=dim, + hidden_dim=4 * dim, + num_routed_experts=num_routed_experts, + num_activated_experts=num_activated_experts, + _force_inference_output=_force_inference_output, + ) + else: + self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim) + self.norm3_t = mint.nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False) + self.ff_t = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim) + + def construct( + self, + hidden_states: ms.Tensor, + hidden_states_masks: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + temb: Optional[ms.Tensor] = None, + image_rotary_emb: ms.Tensor = None, + ) -> ms.Tensor: + wtype = hidden_states.dtype + ( + shift_msa_i, + scale_msa_i, + gate_msa_i, + shift_mlp_i, + scale_mlp_i, + gate_mlp_i, + shift_msa_t, + scale_msa_t, + gate_msa_t, + shift_mlp_t, + scale_mlp_t, + gate_mlp_t, + ) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1) + + # 1. MM-Attention + norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype) + norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i + norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(dtype=wtype) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t + + attn_output_i, attn_output_t = self.attn1( + norm_hidden_states, + hidden_states_masks, + norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = gate_msa_i * attn_output_i + hidden_states + encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states + + # 2. Feed-forward + norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i + norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t + + ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states) + ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states) + hidden_states = ff_output_i + hidden_states + encoder_hidden_states = ff_output_t + encoder_hidden_states + return hidden_states, encoder_hidden_states + + +class HiDreamBlock(nn.Cell): + def __init__(self, block: Union[HiDreamImageTransformerBlock, HiDreamImageSingleTransformerBlock]): + super().__init__() + self.block = block + + def construct( + self, + hidden_states: ms.Tensor, + hidden_states_masks: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + temb: Optional[ms.Tensor] = None, + image_rotary_emb: ms.Tensor = None, + ) -> ms.Tensor: + return self.block( + hidden_states=hidden_states, + hidden_states_masks=hidden_states_masks, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + +class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + _no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: Optional[int] = None, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 16, + num_single_layers: int = 32, + attention_head_dim: int = 128, + num_attention_heads: int = 20, + caption_channels: List[int] = None, + text_emb_dim: int = 2048, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + axes_dims_rope: Tuple[int, int] = (32, 32), + max_resolution: Tuple[int, int] = (128, 128), + llama_layers: List[int] = None, + force_inference_output: bool = False, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim) + self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim) + self.x_embedder = HiDreamImagePatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + out_channels=self.inner_dim, + ) + self.pe_embedder = HiDreamImageEmbedND(theta=10000, axes_dim=axes_dims_rope) + + self.double_stream_blocks = nn.CellList( + [ + HiDreamBlock( + HiDreamImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_routed_experts=num_routed_experts, + num_activated_experts=num_activated_experts, + _force_inference_output=force_inference_output, + ) + ) + for _ in range(num_layers) + ] + ) + + self.single_stream_blocks = nn.CellList( + [ + HiDreamBlock( + HiDreamImageSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_routed_experts=num_routed_experts, + num_activated_experts=num_activated_experts, + _force_inference_output=force_inference_output, + ) + ) + for _ in range(num_single_layers) + ] + ) + + self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels) + + caption_channels = [caption_channels[1]] * (num_layers + num_single_layers) + [caption_channels[0]] + caption_projection = [] + for caption_channel in caption_channels: + caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim)) + self.caption_projection = nn.CellList(caption_projection) + self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) + + self.gradient_checkpointing = False + + self.patch_size = self.config.patch_size + self.force_inference_output = self.config.force_inference_output + self.llama_layers = self.config.llama_layers + + def unpatchify(self, x: ms.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[ms.Tensor]: + if is_training and not self.force_inference_output: + B, S, F = x.shape + C = F // (self.patch_size * self.patch_size) + x = ( + x.reshape((B, S, self.patch_size, self.patch_size, C)) + .permute(0, 4, 1, 2, 3) + .reshape((B, C, S, self.patch_size * self.patch_size)) + ) + else: + x_arr = [] + p1 = self.patch_size + p2 = self.patch_size + for i, img_size in enumerate(img_sizes): + pH, pW = img_size + t = x[i, : pH * pW].reshape((1, pH, pW, -1)) + F_token = t.shape[-1] + C = F_token // (p1 * p2) + t = t.reshape((1, pH, pW, p1, p2, C)) + t = t.permute(0, 5, 1, 3, 2, 4) + t = t.reshape((1, C, pH * p1, pW * p2)) + x_arr.append(t) + x = mint.cat(x_arr, dim=0) + return x + + def patchify(self, hidden_states): + batch_size, channels, height, width = hidden_states.shape + patch_size = self.patch_size + patch_height, patch_width = height // patch_size, width // patch_size + dtype = hidden_states.dtype + + # create img_sizes + img_sizes = ms.tensor([patch_height, patch_width], dtype=ms.int64).reshape(-1) + img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1) + + # create hidden_states_masks + if hidden_states.shape[-2] != hidden_states.shape[-1]: + hidden_states_masks = mint.zeros((batch_size, self.max_seq), dtype=dtype) + hidden_states_masks[:, : patch_height * patch_width] = 1.0 + else: + hidden_states_masks = None + + # create img_ids + img_ids = mint.zeros((patch_height, patch_width, 3)) + row_indices = mint.arange(patch_height)[:, None] + col_indices = mint.arange(patch_width)[None, :] + img_ids[..., 1] = img_ids[..., 1] + row_indices + img_ids[..., 2] = img_ids[..., 2] + col_indices + img_ids = img_ids.reshape(patch_height * patch_width, -1) + + if hidden_states.shape[-2] != hidden_states.shape[-1]: + # Handle non-square latents + img_ids_pad = mint.zeros((self.max_seq, 3)) + img_ids_pad[: patch_height * patch_width, :] = img_ids + img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1) + else: + img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1) + + # patchify hidden_states + if hidden_states.shape[-2] != hidden_states.shape[-1]: + # Handle non-square latents + out = mint.zeros( + (batch_size, channels, self.max_seq, patch_size * patch_size), + dtype=dtype, + ) + hidden_states = hidden_states.reshape( + batch_size, channels, patch_height, patch_size, patch_width, patch_size + ) + hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5) + hidden_states = hidden_states.reshape( + batch_size, channels, patch_height * patch_width, patch_size * patch_size + ) + out[:, :, 0 : patch_height * patch_width] = hidden_states + hidden_states = out + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch_size, self.max_seq, patch_size * patch_size * channels + ) + + else: + # Handle square latents + hidden_states = hidden_states.reshape( + batch_size, channels, patch_height, patch_size, patch_width, patch_size + ) + hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1) + hidden_states = hidden_states.reshape( + batch_size, patch_height * patch_width, patch_size * patch_size * channels + ) + + return hidden_states, hidden_states_masks, img_sizes, img_ids + + def construct( + self, + hidden_states: ms.Tensor, + timesteps: ms.Tensor = None, + encoder_hidden_states_t5: ms.Tensor = None, + encoder_hidden_states_llama3: ms.Tensor = None, + pooled_embeds: ms.Tensor = None, + img_ids: Optional[ms.Tensor] = None, + img_sizes: Optional[List[Tuple[int, int]]] = None, + hidden_states_masks: Optional[ms.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + **kwargs, + ): + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + + if encoder_hidden_states is not None: + deprecation_message = "The `encoder_hidden_states` argument is deprecated. \ + Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead." + deprecate("encoder_hidden_states", "0.35.0", deprecation_message) + encoder_hidden_states_t5 = encoder_hidden_states[0] + encoder_hidden_states_llama3 = encoder_hidden_states[1] + + if img_ids is not None and img_sizes is not None and hidden_states_masks is None: + deprecation_message = ( + "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored." + ) + deprecate("img_ids", "0.35.0", deprecation_message) + + if hidden_states_masks is not None and (img_ids is None or img_sizes is None): + raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.") + elif hidden_states_masks is not None and hidden_states.ndim != 3: + raise ValueError( + "if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape \ + (batch_size, patch_height * patch_width, patch_size * patch_size * channels)" + ) + + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective.") + + # spatial forward + batch_size = hidden_states.shape[0] + hidden_states_type = hidden_states.dtype + + # Patchify the input + if hidden_states_masks is None: + hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states) + + # Embed the hidden states + hidden_states = self.x_embedder(hidden_states) + + # 0. time + timesteps = self.t_embedder(timesteps, hidden_states_type) + p_embedder = self.p_embedder(pooled_embeds) + temb = timesteps + p_embedder + + encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.llama_layers] + + if self.caption_projection is not None: + new_encoder_hidden_states = [] + for i, enc_hidden_state in enumerate(encoder_hidden_states): + enc_hidden_state = self.caption_projection[i](enc_hidden_state) + enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) + new_encoder_hidden_states.append(enc_hidden_state) + encoder_hidden_states = new_encoder_hidden_states + encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5) + encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1]) + encoder_hidden_states.append(encoder_hidden_states_t5) + + txt_ids = mint.zeros( + ( + batch_size, + encoder_hidden_states[-1].shape[1] + + encoder_hidden_states[-2].shape[1] + + encoder_hidden_states[0].shape[1], + 3, + ), + dtype=img_ids.dtype, + ) + ids = mint.cat((img_ids, txt_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids) + + # 2. Blocks + block_id = 0 + initial_encoder_hidden_states = mint.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) + initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] + for bid, block in enumerate(self.double_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + cur_encoder_hidden_states = mint.cat( + [initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1 + ) + + hidden_states, initial_encoder_hidden_states = block( + hidden_states=hidden_states, + hidden_states_masks=hidden_states_masks, + encoder_hidden_states=cur_encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] + block_id += 1 + + image_tokens_seq_len = hidden_states.shape[1] + hidden_states = mint.cat([hidden_states, initial_encoder_hidden_states], dim=1) + hidden_states_seq_len = hidden_states.shape[1] + if hidden_states_masks is not None: + encoder_attention_mask_ones = mint.ones( + (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), + dtype=hidden_states_masks.dtype, + ) + hidden_states_masks = mint.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1) + + for bid, block in enumerate(self.single_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + hidden_states = mint.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) + + hidden_states = block( + hidden_states=hidden_states, + hidden_states_masks=hidden_states_masks, + encoder_hidden_states=None, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states[:, :hidden_states_seq_len] + block_id += 1 + + hidden_states = hidden_states[:, :image_tokens_seq_len, ...] + output = self.final_layer(hidden_states, temb) + output = self.unpatchify(output, img_sizes, self.training) + if hidden_states_masks is not None: + hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len] + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/mindone/diffusers/pipelines/__init__.py b/mindone/diffusers/pipelines/__init__.py index 8522a4b379..c5630cf838 100644 --- a/mindone/diffusers/pipelines/__init__.py +++ b/mindone/diffusers/pipelines/__init__.py @@ -100,6 +100,7 @@ "FluxKontextPipeline", "FluxKontextInpaintPipeline", ], + "hidream_image": ["HiDreamImagePipeline"], "hunyuandit": ["HunyuanDiTPipeline"], "hunyuan_video": [ "HunyuanVideoPipeline", @@ -343,6 +344,7 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) + from .hidream_image import HiDreamImagePipeline from .hunyuan_video import ( HunyuanSkyreelsImageToVideoPipeline, HunyuanVideoFramepackPipeline, diff --git a/mindone/diffusers/pipelines/hidream_image/__init__.py b/mindone/diffusers/pipelines/hidream_image/__init__.py new file mode 100644 index 0000000000..cfdcf1ba82 --- /dev/null +++ b/mindone/diffusers/pipelines/hidream_image/__init__.py @@ -0,0 +1,20 @@ +from typing import TYPE_CHECKING + +from ...utils import _LazyModule + +_import_structure = { + "pipeline_output": ["HiDreamImagePipelineOutput"], + "pipeline_hidream_image": ["HiDreamImagePipeline"], +} + +if TYPE_CHECKING: + from .pipeline_hidream_image import HiDreamImagePipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/mindone/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/mindone/diffusers/pipelines/hidream_image/pipeline_hidream_image.py new file mode 100644 index 0000000000..3d25c384ba --- /dev/null +++ b/mindone/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -0,0 +1,982 @@ +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +from transformers import CLIPTokenizer, PreTrainedTokenizerFast, T5Tokenizer + +import mindspore as ms +from mindspore import mint + +from mindone.transformers import CLIPTextModelWithProjection, LlamaForCausalLM, T5EncoderModel + +from ...image_processor import VaeImageProcessor +from ...loaders import HiDreamImageLoraLoaderMixin +from ...models import AutoencoderKL, HiDreamImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler +from ...utils import deprecate, logging +from ...utils.mindspore_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HiDreamImagePipelineOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import mindspore + >>> import numpy as np + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import LlamaForCausalLM + >>> from mindone.diffusers import HiDreamImagePipeline + + + >>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( + ... "meta-llama/Meta-Llama-3.1-8B-Instruct", + ... output_hidden_states=True, + ... output_attentions=True, + ... mindspore_dtype=mindspore.bfloat16, + ... ) + + >>> pipe = HiDreamImagePipeline.from_pretrained( + ... "HiDream-ai/HiDream-I1-Full", + ... tokenizer_4=tokenizer_4, + ... text_encoder_4=text_encoder_4, + ... mindspore_dtype=mindspore.bfloat16, + ... ) + + >>> image = pipe( + ... 'A cat holding a sign that says "Hi-Dreams.ai".', + ... height=1024, + ... width=1024, + ... guidance_scale=5.0, + ... num_inference_steps=50, + ... generator=np.random.Generator(np.random.PCG64(42)), + ... )[0][0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin): + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5Tokenizer, + text_encoder_4: LlamaForCausalLM, + tokenizer_4: PreTrainedTokenizerFast, + transformer: HiDreamImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + text_encoder_4=text_encoder_4, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + tokenizer_4=tokenizer_4, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + if getattr(self, "tokenizer_4", None) is not None: + self.tokenizer_4.pad_token = self.tokenizer_4.eos_token + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 128, + dtype: Optional[ms.Type] = None, + ): + dtype = dtype or self.text_encoder_3.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_3.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="np").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not np.array_equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer_3.batch_decode( + untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(ms.tensor(text_input_ids), attention_mask=ms.tensor(attention_mask))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype) + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + tokenizer, + text_encoder, + prompt: Union[str, List[str]], + max_sequence_length: int = 128, + dtype: Optional[ms.Type] = None, + ): + dtype = dtype or text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=min(max_sequence_length, 218), + truncation=True, + return_tensors="np", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="np").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not np.array_equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {218} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(ms.tensor(text_input_ids), output_hidden_states=True) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.to(dtype=dtype) + return prompt_embeds + + def _get_llama3_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 128, + dtype: Optional[ms.Type] = None, + ): + dtype = dtype or self.text_encoder_4.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_4( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="np").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not np.array_equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer_4.batch_decode( + untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}" + ) + + outputs = self.text_encoder_4( + ms.tensor(text_input_ids), + attention_mask=ms.tensor(attention_mask), + output_hidden_states=True, + output_attentions=True, + ) + + # prompt_embeds = outputs.hidden_states[1:] + prompt_embeds = outputs[1][1:] + prompt_embeds = mint.stack(prompt_embeds, dim=0) + return prompt_embeds + + def encode_prompt( + self, + prompt: Optional[Union[str, List[str]]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + prompt_4: Optional[Union[str, List[str]]] = None, + dtype: Optional[ms.Type] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + negative_prompt_4: Optional[Union[str, List[str]]] = None, + prompt_embeds_t5: Optional[List[ms.Tensor]] = None, + prompt_embeds_llama3: Optional[List[ms.Tensor]] = None, + negative_prompt_embeds_t5: Optional[List[ms.Tensor]] = None, + negative_prompt_embeds_llama3: Optional[List[ms.Tensor]] = None, + pooled_prompt_embeds: Optional[ms.Tensor] = None, + negative_pooled_prompt_embeds: Optional[ms.Tensor] = None, + max_sequence_length: int = 128, + lora_scale: Optional[float] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = pooled_prompt_embeds.shape[0] + + if pooled_prompt_embeds is None: + pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, self.text_encoder, prompt, max_sequence_length, dtype + ) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if len(negative_prompt) > 1 and len(negative_prompt) != batch_size: + raise ValueError(f"negative_prompt must be of length 1 or {batch_size}") + + negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, dtype + ) + + if negative_pooled_prompt_embeds_1.shape[0] == 1 and batch_size > 1: + negative_pooled_prompt_embeds_1 = negative_pooled_prompt_embeds_1.repeat(batch_size, 1) + + if pooled_prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + if len(prompt_2) > 1 and len(prompt_2) != batch_size: + raise ValueError(f"prompt_2 must be of length 1 or {batch_size}") + + pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( + self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, dtype + ) + + if pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1: + pooled_prompt_embeds_2 = pooled_prompt_embeds_2.repeat(batch_size, 1) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + + if len(negative_prompt_2) > 1 and len(negative_prompt_2) != batch_size: + raise ValueError(f"negative_prompt_2 must be of length 1 or {batch_size}") + + negative_pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( + self.tokenizer_2, self.text_encoder_2, negative_prompt_2, max_sequence_length, dtype + ) + + if negative_pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1: + negative_pooled_prompt_embeds_2 = negative_pooled_prompt_embeds_2.repeat(batch_size, 1) + + if pooled_prompt_embeds is None: + pooled_prompt_embeds = mint.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_pooled_prompt_embeds = mint.cat( + [negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2], dim=-1 + ) + + if prompt_embeds_t5 is None: + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + if len(prompt_3) > 1 and len(prompt_3) != batch_size: + raise ValueError(f"prompt_3 must be of length 1 or {batch_size}") + + prompt_embeds_t5 = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, dtype) + + if prompt_embeds_t5.shape[0] == 1 and batch_size > 1: + prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1) + + if do_classifier_free_guidance and negative_prompt_embeds_t5 is None: + negative_prompt_3 = negative_prompt_3 or negative_prompt + negative_prompt_3 = [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + + if len(negative_prompt_3) > 1 and len(negative_prompt_3) != batch_size: + raise ValueError(f"negative_prompt_3 must be of length 1 or {batch_size}") + + negative_prompt_embeds_t5 = self._get_t5_prompt_embeds(negative_prompt_3, max_sequence_length, dtype) + + if negative_prompt_embeds_t5.shape[0] == 1 and batch_size > 1: + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1) + + if prompt_embeds_llama3 is None: + prompt_4 = prompt_4 or prompt + prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 + + if len(prompt_4) > 1 and len(prompt_4) != batch_size: + raise ValueError(f"prompt_4 must be of length 1 or {batch_size}") + + prompt_embeds_llama3 = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, dtype) + + if prompt_embeds_llama3.shape[0] == 1 and batch_size > 1: + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + + if do_classifier_free_guidance and negative_prompt_embeds_llama3 is None: + negative_prompt_4 = negative_prompt_4 or negative_prompt + negative_prompt_4 = [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 + + if len(negative_prompt_4) > 1 and len(negative_prompt_4) != batch_size: + raise ValueError(f"negative_prompt_4 must be of length 1 or {batch_size}") + + negative_prompt_embeds_llama3 = self._get_llama3_prompt_embeds( + negative_prompt_4, max_sequence_length, dtype + ) + + if negative_prompt_embeds_llama3.shape[0] == 1 and batch_size > 1: + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + + # duplicate pooled_prompt_embeds for each generation per prompt + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + # duplicate t5_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len, _ = prompt_embeds_t5.shape + if bs_embed == 1 and batch_size > 1: + prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate prompt_embeds_t5 of batch size {bs_embed}") + prompt_embeds_t5 = prompt_embeds_t5.repeat(1, num_images_per_prompt, 1) + prompt_embeds_t5 = prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1) + + # duplicate llama3_prompt_embeds for batch_size and num_images_per_prompt + _, bs_embed, seq_len, dim = prompt_embeds_llama3.shape + if bs_embed == 1 and batch_size > 1: + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate prompt_embeds_llama3 of batch size {bs_embed}") + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1) + prompt_embeds_llama3 = prompt_embeds_llama3.view(-1, batch_size * num_images_per_prompt, seq_len, dim) + + if do_classifier_free_guidance: + # duplicate negative_pooled_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len = negative_pooled_prompt_embeds.shape + if bs_embed == 1 and batch_size > 1: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_pooled_prompt_embeds of batch size {bs_embed}") + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + # duplicate negative_t5_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len, _ = negative_prompt_embeds_t5.shape + if bs_embed == 1 and batch_size > 1: + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_prompt_embeds_t5 of batch size {bs_embed}") + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1) + + # duplicate negative_prompt_embeds_llama3 for batch_size and num_images_per_prompt + _, bs_embed, seq_len, dim = negative_prompt_embeds_llama3.shape + if bs_embed == 1 and batch_size > 1: + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_prompt_embeds_llama3 of batch size {bs_embed}") + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1) + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.view( + -1, batch_size * num_images_per_prompt, seq_len, dim + ) + + return ( + prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + prompt_4, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + negative_prompt_4=None, + prompt_embeds_t5=None, + prompt_embeds_llama3=None, + negative_prompt_embeds_t5=None, + negative_prompt_embeds_llama3=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs},\ + but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds_t5 is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_3} and `prompt_embeds_t5`: {prompt_embeds_t5}. Please make sure to" + " only forward one of the two." + ) + elif prompt_4 is not None and prompt_embeds_llama3 is not None: + raise ValueError( + f"Cannot forward both `prompt_4`: {prompt_4} and `prompt_embeds_llama3`: {prompt_embeds_llama3}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and pooled_prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `pooled_prompt_embeds`. Cannot leave both `prompt` and `pooled_prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_t5 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_t5`. Cannot leave both `prompt` and `prompt_embeds_t5` undefined." + ) + elif prompt is None and prompt_embeds_llama3 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_llama3`. Cannot leave both `prompt` and `prompt_embeds_llama3` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + elif prompt_4 is not None and (not isinstance(prompt_4, str) and not isinstance(prompt_4, list)): + raise ValueError(f"`prompt_4` has to be of type `str` or `list` but is {type(prompt_4)}") + + if negative_prompt is not None and negative_pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_pooled_prompt_embeds`:" + f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_pooled_prompt_embeds`:" + f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds_t5 is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds_t5`:" + f" {negative_prompt_embeds_t5}. Please make sure to only forward one of the two." + ) + elif negative_prompt_4 is not None and negative_prompt_embeds_llama3 is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_4`: {negative_prompt_4} and `negative_prompt_embeds_llama3`:" + f" {negative_prompt_embeds_llama3}. Please make sure to only forward one of the two." + ) + + if pooled_prompt_embeds is not None and negative_pooled_prompt_embeds is not None: + if pooled_prompt_embeds.shape != negative_pooled_prompt_embeds.shape: + raise ValueError( + "`pooled_prompt_embeds` and `negative_pooled_prompt_embeds` must have the same shape when passed directly, but" + f" got: `pooled_prompt_embeds` {pooled_prompt_embeds.shape} != `negative_pooled_prompt_embeds`" + f" {negative_pooled_prompt_embeds.shape}." + ) + if prompt_embeds_t5 is not None and negative_prompt_embeds_t5 is not None: + if prompt_embeds_t5.shape != negative_prompt_embeds_t5.shape: + raise ValueError( + "`prompt_embeds_t5` and `negative_prompt_embeds_t5` must have the same shape when passed directly, but" + f" got: `prompt_embeds_t5` {prompt_embeds_t5.shape} != `negative_prompt_embeds_t5`" + f" {negative_prompt_embeds_t5.shape}." + ) + if prompt_embeds_llama3 is not None and negative_prompt_embeds_llama3 is not None: + if prompt_embeds_llama3.shape != negative_prompt_embeds_llama3.shape: + raise ValueError( + "`prompt_embeds_llama3` and `negative_prompt_embeds_llama3` must have the same shape when passed directly, but" + f" got: `prompt_embeds_llama3` {prompt_embeds_llama3.shape} != `negative_prompt_embeds_llama3`" + f" {negative_prompt_embeds_llama3.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + prompt_4: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + negative_prompt_4: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[ms.Tensor] = None, + prompt_embeds_t5: Optional[ms.Tensor] = None, + prompt_embeds_llama3: Optional[ms.Tensor] = None, + negative_prompt_embeds_t5: Optional[ms.Tensor] = None, + negative_prompt_embeds_llama3: Optional[ms.Tensor] = None, + pooled_prompt_embeds: Optional[ms.Tensor] = None, + negative_pooled_prompt_embeds: Optional[ms.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead. + prompt_4 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_4` and `text_encoder_4`. If not defined, `prompt` is + will be used instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_4 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_4` and + `text_encoder_4`. If not defined, `negative_prompt` is used in all the text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + One or a list of [np.random.Generator(s)](https://numpy.org/doc/stable/reference/random/generator.html) + to make generation deterministic. + latents (`ms.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`ms.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 128): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.hidream_image.HiDreamImagePipelineOutput`] or `tuple`: + [`~pipelines.hidream_image.HiDreamImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated. images. + """ + + prompt_embeds = kwargs.get("prompt_embeds", None) + negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None) + + if prompt_embeds is not None: + deprecation_message = "The `prompt_embeds` argument is deprecated. \ + Please use `prompt_embeds_t5` and `prompt_embeds_llama3` instead." + deprecate("prompt_embeds", "0.35.0", deprecation_message) + prompt_embeds_t5 = prompt_embeds[0] + prompt_embeds_llama3 = prompt_embeds[1] + + if negative_prompt_embeds is not None: + deprecation_message = "The `negative_prompt_embeds` argument is deprecated. \ + Please use `negative_prompt_embeds_t5` and `negative_prompt_embeds_llama3` instead." + deprecate("negative_prompt_embeds", "0.35.0", deprecation_message) + negative_prompt_embeds_t5 = negative_prompt_embeds[0] + negative_prompt_embeds_llama3 = negative_prompt_embeds[1] + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + division = self.vae_scale_factor * 2 + S_max = (self.default_sample_size * self.vae_scale_factor) ** 2 + scale = S_max / (width * height) + scale = math.sqrt(scale) + width, height = int(width * scale // division * division), int(height * scale // division * division) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + prompt_4, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + negative_prompt_4=negative_prompt_4, + prompt_embeds_t5=prompt_embeds_t5, + prompt_embeds_llama3=prompt_embeds_llama3, + negative_prompt_embeds_t5=negative_prompt_embeds_t5, + negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + elif pooled_prompt_embeds is not None: + batch_size = pooled_prompt_embeds.shape[0] + + # 3. Encode prompt + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + ( + prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + prompt_4=prompt_4, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + negative_prompt_4=negative_prompt_4, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds_t5=prompt_embeds_t5, + prompt_embeds_llama3=prompt_embeds_llama3, + negative_prompt_embeds_t5=negative_prompt_embeds_t5, + negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + prompt_embeds_t5 = mint.cat([negative_prompt_embeds_t5, prompt_embeds_t5], dim=0) + prompt_embeds_llama3 = mint.cat([negative_prompt_embeds_llama3, prompt_embeds_llama3], dim=1) + pooled_prompt_embeds = mint.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + pooled_prompt_embeds.dtype, + generator, + latents, + ) + + # 5. Prepare timesteps + mu = calculate_shift(self.transformer.max_seq) + scheduler_kwargs = {"mu": mu} + if isinstance(self.scheduler, UniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps) # , shift=math.exp(mu)) + timesteps = self.scheduler.timesteps + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = mint.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.broadcast_to((latent_model_input.shape[0],)) # .to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timesteps=timestep, + encoder_hidden_states_t5=prompt_embeds_t5, + encoder_hidden_states_llama3=prompt_embeds_llama3, + pooled_embeds=pooled_prompt_embeds, + return_dict=False, + )[0] + noise_pred = -noise_pred + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds_t5 = callback_outputs.pop("prompt_embeds_t5", prompt_embeds_t5) + prompt_embeds_llama3 = callback_outputs.pop("prompt_embeds_llama3", prompt_embeds_llama3) + pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return HiDreamImagePipelineOutput(images=image) diff --git a/mindone/diffusers/pipelines/hidream_image/pipeline_output.py b/mindone/diffusers/pipelines/hidream_image/pipeline_output.py new file mode 100644 index 0000000000..1890a8a3f5 --- /dev/null +++ b/mindone/diffusers/pipelines/hidream_image/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class HiDreamImagePipelineOutput(BaseOutput): + """ + Output class for HiDreamImage pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/tests/diffusers_tests/modules/modules_test_cases.py b/tests/diffusers_tests/modules/modules_test_cases.py index af03cff8bd..01e2c5f211 100644 --- a/tests/diffusers_tests/modules/modules_test_cases.py +++ b/tests/diffusers_tests/modules/modules_test_cases.py @@ -1012,6 +1012,41 @@ ] +HIDREAM_IMAGE_TRANSFORMER2D_CASES = [ + [ + "HiDreamImageTransformer2DModel", + "diffusers.models.transformers.transformer_hidream_image.HiDreamImageTransformer2DModel", + "mindone.diffusers.models.transformers.transformer_hidream_image.HiDreamImageTransformer2DModel", + (), + { + "patch_size": 2, + "in_channels": 4, + "out_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 4, + "caption_channels": [8, 4], + "text_emb_dim": 8, + "num_routed_experts": 2, + "num_activated_experts": 2, + "axes_dims_rope": (4, 2, 2), + "max_resolution": (32, 32), + "llama_layers": (0, 1), + "force_inference_output": True, + }, + (), + { + "hidden_states": np.random.default_rng(42).standard_normal((2, 4, 32, 32)), + "encoder_hidden_states_t5": np.random.default_rng(42).standard_normal((2, 8, 8)), + "encoder_hidden_states_llama3": np.random.default_rng(42).standard_normal((2, 2, 8, 4)), + "pooled_embeds": np.random.default_rng(42).standard_normal((2, 8)), + "timesteps": np.random.default_rng(42).integers(0, 1000, size=(2,)), + }, + ], +] + + HUNYUAN_VIDEO_TRANSFORMER3D_CASES = [ [ "HunyuanVideoTransformer3DModel", @@ -1531,6 +1566,7 @@ + COGVIEW3PLUS_TRANSFORMER2D_CASES + CONSISID_TRANSFORMER3D_CASES + DIT_TRANSFORMER2D_CASES + + HIDREAM_IMAGE_TRANSFORMER2D_CASES + HUNYUAN_VIDEO_TRANSFORMER3D_CASES + HUNYUAN_VIDEO_FRAMEPACK_TRANSFORMER3D_CASES + LTX_VIDEO_TRANSFORMER3D_CASES diff --git a/tests/diffusers_tests/pipelines/hidream/__init__.py b/tests/diffusers_tests/pipelines/hidream/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/diffusers_tests/pipelines/hidream/test_pipeline_hidream.py b/tests/diffusers_tests/pipelines/hidream/test_pipeline_hidream.py new file mode 100644 index 0000000000..2a841a1718 --- /dev/null +++ b/tests/diffusers_tests/pipelines/hidream/test_pipeline_hidream.py @@ -0,0 +1,218 @@ +import unittest + +import numpy as np +import torch +from ddt import data, ddt, unpack +from transformers import CLIPTextConfig + +import mindspore as ms + +from mindone.diffusers.utils.testing_utils import load_downloaded_numpy_from_hf_hub, slow # noqa F401 + +from ..pipeline_test_utils import ( + THRESHOLD_FP16, + THRESHOLD_FP32, + PipelineTesterMixin, + get_module, + get_pipeline_components, +) + +test_cases = [ + {"dtype": "float32"}, + {"dtype": "float16"}, + {"dtype": "bfloat16"}, +] + + +@ddt +class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_config = [ + [ + "transformer", + "diffusers.models.transformers.transformer_hidream_image.HiDreamImageTransformer2DModel", + "mindone.diffusers.models.transformers.transformer_hidream_image.HiDreamImageTransformer2DModel", + dict( + patch_size=2, + in_channels=4, + out_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=8, + num_attention_heads=4, + caption_channels=[32, 16], + text_emb_dim=64, + num_routed_experts=4, + num_activated_experts=2, + axes_dims_rope=(4, 2, 2), + max_resolution=(32, 32), + llama_layers=(0, 1), + ), + ], + [ + "scheduler", + "diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler", + "mindone.diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler", + dict(), + ], + [ + "vae", + "diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL", + "mindone.diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL", + dict( + shift_factor=0.1159, + scaling_factor=1.3611, + ), + ], + [ + "text_encoder", + "transformers.models.clip.modeling_clip.CLIPTextModelWithProjection", + "mindone.transformers.models.clip.modeling_clip.CLIPTextModelWithProjection", + dict( + config=CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + max_position_embeddings=128, + ), + ), + ], + [ + "text_encoder_2", + "transformers.models.clip.modeling_clip.CLIPTextModelWithProjection", + "mindone.transformers.models.clip.modeling_clip.CLIPTextModelWithProjection", + dict( + config=CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + max_position_embeddings=128, + ), + ), + ], + [ + "text_encoder_3", + "transformers.models.t5.modeling_t5.T5EncoderModel", + "mindone.transformers.models.t5.modeling_t5.T5EncoderModel", + dict( + pretrained_model_name_or_path="hf-internal-testing/tiny-random-t5", + revision="refs/pr/1", + ), + ], + [ + "text_encoder_4", + "transformers.models.llama.modeling_llama.LlamaForCausalLM", + "mindone.transformers.models.llama.modeling_llama.LlamaForCausalLM", + dict( + pretrained_model_name_or_path="hf-internal-testing/tiny-random-LlamaForCausalLM", + ), + ], + [ + "tokenizer", + "transformers.models.clip.tokenization_clip.CLIPTokenizer", + "transformers.models.clip.tokenization_clip.CLIPTokenizer", + dict( + pretrained_model_name_or_path="hf-internal-testing/tiny-random-clip", + ), + ], + [ + "tokenizer_2", + "transformers.models.clip.tokenization_clip.CLIPTokenizer", + "transformers.models.clip.tokenization_clip.CLIPTokenizer", + dict( + pretrained_model_name_or_path="hf-internal-testing/tiny-random-clip", + ), + ], + [ + "tokenizer_3", + "transformers.models.auto.tokenization_auto.AutoTokenizer", + "transformers.models.auto.tokenization_auto.AutoTokenizer", + dict( + pretrained_model_name_or_path="hf-internal-testing/tiny-random-t5", + revision="refs/pr/1", + ), + ], + [ + "tokenizer_4", + "transformers.models.auto.tokenization_auto.AutoTokenizer", + "transformers.models.auto.tokenization_auto.AutoTokenizer", + dict( + pretrained_model_name_or_path="hf-internal-testing/tiny-random-LlamaForCausalLM", + ), + ], + ] + + def get_dummy_components(self): + components = { + key: None + for key in [ + "scheduler", + "text_encoder", + "text_encoder_2", + "text_encoder_3", + "text_encoder_4", + "tokenizer", + "tokenizer_2", + "tokenizer_3", + "tokenizer_4", + "transformer", + "vae", + ] + } + + return get_pipeline_components(components, self.pipeline_config) + + def get_dummy_inputs(self, seed=0): + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + } + return inputs + + @data(*test_cases) + @unpack + def test_hidream(self, dtype): + pt_components, ms_components = self.get_dummy_components() + pt_pipe_cls = get_module("diffusers.pipelines.hidream_image.HiDreamImagePipeline") + ms_pipe_cls = get_module("mindone.diffusers.pipelines.hidream_image.HiDreamImagePipeline") + + pt_pipe = pt_pipe_cls(**pt_components) + ms_pipe = ms_pipe_cls(**ms_components) + + pt_pipe.text_encoder_4.generation_config.pad_token_id = 1 + ms_pipe.text_encoder_4.generation_config.pad_token_id = 1 + pt_pipe.transformer.eval() + + ms_dtype, pt_dtype = getattr(ms, dtype), getattr(torch, dtype) + pt_pipe = pt_pipe.to(pt_dtype) + ms_pipe = ms_pipe.to(ms_dtype) + + inputs = self.get_dummy_inputs() + + torch.manual_seed(0) + pt_image = pt_pipe(**inputs) + torch.manual_seed(0) + ms_image = ms_pipe(**inputs) + + pt_image_slice = pt_image.images[0, -3:, -3:, -1] + ms_image_slice = ms_image[0][0, -3:, -3:, -1] + + threshold = THRESHOLD_FP32 if dtype == "float32" else THRESHOLD_FP16 + assert np.linalg.norm(pt_image_slice - ms_image_slice) / np.linalg.norm(pt_image_slice) < threshold