diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index be4b86321..013b4b537 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -48,6 +48,12 @@ def check_qaic_sdk(): QEFFCommonLoader, ) from QEfficient.compile.compile_helper import compile + + # Imports for the diffusers + from QEfficient.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import QEFFStableDiffusionPipeline + from QEfficient.diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion3 import ( + QEFFStableDiffusion3Pipeline, + ) from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv from QEfficient.peft import QEffAutoPeftModelForCausalLM @@ -67,6 +73,8 @@ def check_qaic_sdk(): "QEFFAutoModelForImageTextToText", "QEFFAutoModelForSpeechSeq2Seq", "QEFFCommonLoader", + "QEFFStableDiffusionPipeline", + "QEFFStableDiffusion3Pipeline", ] else: diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index d9d6823ae..8c51ffdd8 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -22,7 +22,7 @@ from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile from QEfficient.generation.cloud_infer import QAICInferenceSession -from QEfficient.utils import constants, create_json, dump_qconfig, generate_mdp_partition_config, load_json +from QEfficient.utils import constants, create_json, generate_mdp_partition_config, load_json from QEfficient.utils.cache import QEFF_HOME, to_hashable logger = logging.getLogger(__name__) @@ -179,7 +179,8 @@ def _export( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=constants.ONNX_EXPORT_OPSET, + opset_version=17, + # verbose=True, **export_kwargs, ) logger.info("Pytorch export successful") @@ -213,7 +214,7 @@ def _export( self.onnx_path = onnx_path return onnx_path - @dump_qconfig + # @dump_qconfig def _compile( self, onnx_path: Optional[str] = None, @@ -352,6 +353,7 @@ def _compile( command.append(f"-aic-binary-dir={qpc_path}") logger.info(f"Running compiler: {' '.join(command)}") + print(command) try: subprocess.run(command, capture_output=True, check=True) except subprocess.CalledProcessError as e: diff --git a/QEfficient/diffusers/README.md b/QEfficient/diffusers/README.md new file mode 100644 index 000000000..088108461 --- /dev/null +++ b/QEfficient/diffusers/README.md @@ -0,0 +1,110 @@ + +
+ + +# **Diffusion Models on Qualcomm Cloud AI 100** + + +
+ +### 🎨 **Experience the Future of AI Image Generation** + +* Optimized for Qualcomm Cloud AI 100* + +Sample Output + +**Generated with**: `stabilityai/stable-diffusion-3.5-large` • `"A girl laughing"` • 28 steps • 2.0 guidance scale • ⚡ + + + +
+ + + +[![Diffusers](https://img.shields.io/badge/Diffusers-0.31.0-orange.svg)](https://github.com/huggingface/diffusers) +
+ +--- + +## ✨ Overview + +QEfficient Diffusers brings the power of state-of-the-art diffusion models to Qualcomm Cloud AI 100 hardware for text-to-image generation. Built on top of the popular HuggingFace Diffusers library, our optimized pipeline provides seamless inference on Qualcomm Cloud AI 100 hardware. + +## 🛠️ Installation + +### Prerequisites + +Ensure you have Python 3.8+ and the required dependencies: + +```bash +# Create Python virtual environment (Recommended Python 3.10) +sudo apt install python3.10-venv +python3.10 -m venv qeff_env +source qeff_env/bin/activate +pip install -U pip +``` + +### Install QEfficient + +```bash +# Install from GitHub (includes diffusers support) +pip install git+https://github.com/quic/efficient-transformers + +# Or build from source +git clone https://github.com/quic/efficient-transformers.git +cd efficient-transformers +pip install build wheel +python -m build --wheel --outdir dist +pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl +``` + +### Install Diffusers Dependencies + +```bash +# Install diffusers optional dependencies +pip install "QEfficient[diffusers]" +``` + +--- + +## 🎯 Supported Models + +### Stable Diffusion 3.x Series +- ✅ [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) +- ✅ [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo) +--- + + +## 📚 Examples + +Check out our comprehensive examples in the [`examples/diffusers/`](../../examples/diffusers/) directory: + +--- + +## 🤝 Contributing + +We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING.md) for details. + +### Development Setup + +```bash +git clone https://github.com/quic/efficient-transformers.git +cd efficient-transformers +pip install -e ".[diffusers,test]" +``` + +--- + +## 🙏 Acknowledgments + +- **HuggingFace Diffusers**: For the excellent foundation library +- **Stability AI**: For the amazing Stable Diffusion models +--- + +## 📞 Support + +- 📖 **Documentation**: [https://quic.github.io/efficient-transformers/](https://quic.github.io/efficient-transformers/) +- 🐛 **Issues**: [GitHub Issues](https://github.com/quic/efficient-transformers/issues) + +--- + diff --git a/QEfficient/diffusers/__init__.py b/QEfficient/diffusers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/QEfficient/diffusers/models/__init__.py b/QEfficient/diffusers/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/QEfficient/diffusers/models/attention.py b/QEfficient/diffusers/models/attention.py new file mode 100644 index 000000000..3c9cc268d --- /dev/null +++ b/QEfficient/diffusers/models/attention.py @@ -0,0 +1,75 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import torch +from diffusers.models.attention import JointTransformerBlock, _chunked_feed_forward + + +class QEffJointTransformerBlock(JointTransformerBlock): + def forward( + self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + ): + if self.use_dual_attention: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( + hidden_states, emb=temb + ) + else: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + if self.context_pre_only: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + if self.use_dual_attention: + attn_output2 = self.attn2(hidden_states=norm_hidden_states2) + attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 + hidden_states = hidden_states + attn_output2 + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + # ff_output = self.ff(norm_hidden_states) + ff_output = self.ff(norm_hidden_states, block_size=4096) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + if self.context_pre_only: + encoder_hidden_states = None + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + context_ff_output = _chunked_feed_forward( + self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + # context_ff_output = self.ff_context(norm_encoder_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states, block_size=333) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return encoder_hidden_states, hidden_states diff --git a/QEfficient/diffusers/models/attention_processor.py b/QEfficient/diffusers/models/attention_processor.py new file mode 100644 index 000000000..01954e55e --- /dev/null +++ b/QEfficient/diffusers/models/attention_processor.py @@ -0,0 +1,155 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from typing import Optional + +import torch +from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0 + + +class QEffAttention(Attention): + def __qeff_init__(self): + processor = QEffJointAttnProcessor2_0() + self.processor = processor + processor.query_block_size = 64 + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key, + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + +class QEffJointAttnProcessor2_0(JointAttnProcessor2_0): + def __call__( + self, + attn: QEffAttention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + + query = query.reshape(-1, query.shape[-2], query.shape[-1]) + key = key.reshape(-1, key.shape[-2], key.shape[-1]) + value = value.reshape(-1, value.shape[-2], value.shape[-1]) + + # pre-transpose the key + key = key.transpose(-1, -2) + if query.size(-2) != value.size(-2): # cross-attention, use regular attention + # QKV done in single block + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + else: # self-attention, use blocked attention + # QKV done with block-attention (a la FlashAttentionV2) + query_block_size = self.query_block_size + query_seq_len = query.size(-2) + num_blocks = (query_seq_len + query_block_size - 1) // query_block_size + for qidx in range(num_blocks): + query_block = query[:, qidx * query_block_size : (qidx + 1) * query_block_size, :] + attention_probs = attn.get_attention_scores(query_block, key, attention_mask) + hidden_states_block = torch.bmm(attention_probs, value) + if qidx == 0: + hidden_states = hidden_states_block + else: + hidden_states = torch.cat((hidden_states, hidden_states_block), -2) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states diff --git a/QEfficient/diffusers/models/autoencoders/__init__.py b/QEfficient/diffusers/models/autoencoders/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/models/autoencoders/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/models/autoencoders/autoencoder_kl.py b/QEfficient/diffusers/models/autoencoders/autoencoder_kl.py new file mode 100644 index 000000000..c652f07d2 --- /dev/null +++ b/QEfficient/diffusers/models/autoencoders/autoencoder_kl.py @@ -0,0 +1,31 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import torch +from diffusers import AutoencoderKL + + +class QEffAutoencoderKL(AutoencoderKL): + def encode(self, x: torch.Tensor, return_dict: bool = True): + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + return h diff --git a/QEfficient/diffusers/models/pytorch_transforms.py b/QEfficient/diffusers/models/pytorch_transforms.py new file mode 100644 index 000000000..a3ab3939c --- /dev/null +++ b/QEfficient/diffusers/models/pytorch_transforms.py @@ -0,0 +1,42 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from typing import Tuple + +from diffusers.models.attention import JointTransformerBlock +from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0 +from diffusers.models.normalization import RMSNorm +from torch import nn + +from QEfficient.base.pytorch_transforms import ModuleMappingTransform +from QEfficient.customop.rms_norm import CustomRMSNormAIC +from QEfficient.diffusers.models.attention import QEffJointTransformerBlock +from QEfficient.diffusers.models.attention_processor import ( + QEffAttention, + QEffJointAttnProcessor2_0, +) + + +class CustomOpsTransform(ModuleMappingTransform): + _module_mapping = {RMSNorm: CustomRMSNormAIC} + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + model, transformed = super().apply(model) + return model, transformed + + +class AttentionTransform(ModuleMappingTransform): + _module_mapping = { + Attention: QEffAttention, + JointAttnProcessor2_0: QEffJointAttnProcessor2_0, + JointTransformerBlock: QEffJointTransformerBlock, + } + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + model, transformed = super().apply(model) + return model, transformed diff --git a/QEfficient/diffusers/pipelines/__init__.py b/QEfficient/diffusers/pipelines/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/pipeline_utils.py b/QEfficient/diffusers/pipelines/pipeline_utils.py new file mode 100644 index 000000000..ce6c4fba6 --- /dev/null +++ b/QEfficient/diffusers/pipelines/pipeline_utils.py @@ -0,0 +1,437 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import copy +import hashlib + +import torch +import torch.nn as nn + +from QEfficient.base.modeling_qeff import QEFFBaseModel +from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform +from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform, CustomOpsTransform +from QEfficient.transformers.models.pytorch_transforms import ( + T5ModelTransform, +) +from QEfficient.utils import constants +from QEfficient.utils.cache import to_hashable + + +class QEffTextEncoder(QEFFBaseModel): + _pytorch_transforms = [CustomOpsTransform, T5ModelTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + """ + QEffTextEncoder is a wrapper class for text encoder models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle text encoder models (like T5EncoderModel) with specific + transformations and optimizations for efficient inference on Qualcomm AI hardware. + """ + + def __init__(self, model: nn.modules): + super().__init__(model) + self.model = copy.deepcopy(model) + + def get_onnx_config(self): + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len = self.tokenizer.model_max_length + + example_inputs = { + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + } + + dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}} + + output_names = ["pooler_output", "last_hidden_state"] + if self.model.__class__.__name__ == "T5EncoderModel": + output_names = ["last_hidden_state"] + else: + example_inputs["output_hidden_states"] = (True,) + + return example_inputs, dynamic_axes, output_names + + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) + + def get_specializations( + self, + batch_size: int, + seq_len: int, + ): + specializations = [ + {"batch_size": batch_size, "seq_len": seq_len}, + ] + + return specializations + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + + @property + def get_model_config(self) -> dict: + return self.model.model.vision_model.config.__dict__ + + +class QEffUNet(QEFFBaseModel): + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + """ + QEffUNet is a wrapper class for UNet models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle UNet models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. It is commonly used in diffusion models for image + generation tasks. + """ + + def __init__(self, model: nn.modules): + super().__init__(model.unet) + self.model = model.unet + + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(dict(self.model.config))) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + + @property + def get_model_config(self) -> dict: + return self.model.model.vision_model.config.__dict__ + + +class QEffVAE(QEFFBaseModel): + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + """ + QEffVAE is a wrapper class for Variational Autoencoder (VAE) models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle VAE models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. VAE models are commonly used in diffusion pipelines + for encoding images to latent space and decoding latent representations back to images. + """ + + def __init__(self, model: nn.modules, type: str): + super().__init__(model.vae) + self.model = copy.deepcopy(model.vae) + self.type = type + + def get_onnx_config(self): + # VAE decode + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + example_inputs = { + "latent_sample": torch.randn(bs, 16, 64, 64), + "return_dict": False, + } + + output_names = ["sample"] + + dynamic_axes = { + "latent_sample": {0: "batch_size", 1: "channels", 2: "height", 3: "width"}, + } + return example_inputs, dynamic_axes, output_names + + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) + + def get_specializations( + self, + batch_size: int, + ): + sepcializations = [ + { + "batch_size": batch_size, + "channels": 16, + "height": 128, + "width": 128, + } + ] + return sepcializations + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(dict(self.model.config))) + mhash.update(to_hashable(self._transform_names())) + mhash.update(to_hashable(self.type)) + mhash = mhash.hexdigest()[:16] + return mhash + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + + @property + def get_model_config(self) -> dict: + return self.model.model.vision_model.config.__dict__ + + +class QEffSafetyChecker(QEFFBaseModel): + _pytorch_transforms = [CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + """ + QEffSafetyChecker is a wrapper class for safety checker models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle safety checker models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. Safety checker models are commonly used in diffusion pipelines + to filter out potentially harmful or inappropriate generated content. + """ + + def __init__(self, model: nn.modules): + super().__init__(model.vae) + self.model = model.safety_checker + + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname + + @property + def get_model_config(self) -> dict: + return self.model.model.vision_model.config.__dict__ + + +class QEffSD3Transformer2DModel(QEFFBaseModel): + _pytorch_transforms = [AttentionTransform, CustomOpsTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + """ + QEffSD3Transformer2DModel is a wrapper class for Stable Diffusion 3 Transformer2D models that provides ONNX export and compilation capabilities. + + This class extends QEFFBaseModel to handle SD3 Transformer2D models with specific transformations and optimizations + for efficient inference on Qualcomm AI hardware. It is designed for the newer Stable Diffusion 3 architecture + that uses transformer-based diffusion models instead of traditional UNet architectures. + """ + + def __init__(self, model: nn.modules): + super().__init__(model) + self.model = model + + def get_onnx_config(self): + example_inputs = { + "hidden_states": torch.randn( + 2, + self.model.config.in_channels, + self.model.config.sample_size, + self.model.config.sample_size, + ), + "encoder_hidden_states": torch.randn(2, 333, self.model.config.joint_attention_dim), + "pooled_projections": torch.randn(2, self.model.config.pooled_projection_dim), + "timestep": torch.randint(0, 20, (2,), dtype=torch.int64), + } + + output_names = ["output"] + + dynamic_axes = { + "hidden_states": {0: "batch_size", 1: "latent_channels", 2: "latent_height", 3: "latent_width"}, + "encoder_hidden_states": {0: "batch_size", 1: "seq_len"}, + "pooled_projections": {0: "batch_size"}, + "timestep": {0: "steps"}, + "output": {0: "batch_size", 1: "latent_channels", 2: "latent_height", 3: "latent_width"}, + } + return example_inputs, dynamic_axes, output_names + + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) + + def get_specializations( + self, + batch_size: int, + seq_len: int, + ): + specializations = [ + { + "batch_size": 2 * batch_size, + "latent_channels": 16, + "latent_height": self.model.config.sample_size, + "latent_width": self.model.config.sample_size, + "seq_len": seq_len, + "steps": 1, + } + ] + + return specializations + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + **compiler_options, + ) -> str: + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + **compiler_options, + ) + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(dict(self.model.config))) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + @property + def model_name(self) -> str: + mname = self.model.__class__.__name__ + if mname.startswith("QEff") or mname.startswith("QEFF"): + mname = mname[4:] + return mname diff --git a/QEfficient/diffusers/pipelines/stable_diffusion/__init__.py b/QEfficient/diffusers/pipelines/stable_diffusion/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/stable_diffusion/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/QEfficient/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py new file mode 100644 index 000000000..7f14f47d7 --- /dev/null +++ b/QEfficient/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -0,0 +1,481 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import os +from typing import List, Optional, Union + +import numpy as np +import torch +from diffusers import StableDiffusionPipeline +from diffusers.image_processor import VaeImageProcessor + +from QEfficient.diffusers.pipelines.pipeline_utils import QEffSafetyChecker, QEffTextEncoder, QEffUNet, QEffVAE +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils import constants + + +class QEFFStableDiffusionPipeline(StableDiffusionPipeline): + _hf_auto_class = StableDiffusionPipeline + + def __init__(self, model, *args, **kwargs): + # super().__init__(*args, **kwargs) + self.tokenizer = model.tokenizer + self.scheduler = model.scheduler + self.feature_extractor = model.feature_extractor + + self.text_encoder = QEffTextEncoder(model) + self.unet = QEffUNet(model) + + # VAE Encoder + self.vae_encoder = QEffVAE(model, "encoder") + self.vae_encoder.model.forward = lambda sample, return_dict: self.vae_encoder.model.encode(sample, return_dict) + + # VAE Decoder + self.vae_decoder = QEffVAE(model, "decoder") + self.vae_decoder.model.forward = lambda latent_sample, return_dict: self.vae_decoder.model.decode( + latent_sample, return_dict + ) + + # Saftey Checker + self.safety_checker = QEffSafetyChecker(model) + self.safety_checker.model.forward = model.safety_checker.forward_onnx + + self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.model.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + kwargs.update({"attn_implementation": "eager"}) + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float32, **kwargs) + model.to("cpu") + return cls(model, pretrained_model_name_or_path) + + def export(self, export_dir: Optional[str] = None) -> str: + """ + Exports the model to ``ONNX`` format using ``torch.onnx.export``. + + ``Optional`` Args: + :export_dir (str, optional): The directory path to store ONNX-graph. + + Returns: + :str: Path of the generated ``ONNX`` graph. + """ + + # Text encoder export + + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len = self.tokenizer.model_max_length + + example_inputs = { + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int32), + # "attention_mask": torch.ones((bs, seq_len), dtype=bool), + } + + dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}} + + output_names = ["last_hidden_state", "pooler_output"] + + # self.text_encoder.model.set_attn_processor(AttnProcessor()) + + # config = self.text_encoder.model.text_model.config + # for layer in self.text_encoder.model.text_model.encoder.layers: + # layer.self_attn = CLIPAttention(config) + + self.text_encoder_onnx_path = self.text_encoder.export( + example_inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + ) + + # UNET Export + + print("###################### Text Encoder Exported #####################") + + unet_example_input = { + "sample": torch.randn( + bs, self.unet.model.in_channels, self.unet.model.config.sample_size, self.unet.model.config.sample_size + ), + "timestep": torch.tensor([1]), + "encoder_hidden_states": torch.randn(bs, seq_len, self.unet.model.config.cross_attention_dim), + "return_dict": False, + } + + output_names = ["out_sample"] + + dynamic_axes = { + "sample": {0: "batch_size", 1: "channels", 2: "height", 3: "width"}, + "timestep": {0: "batch_size"}, + "encoder_hidden_states": {0: "batch_size", 1: "seq_len"}, + } + # self.unet.model.set_attn_processor(AttnProcessor()) + + self.unet_onnx_path = self.unet.export( + unet_example_input, + output_names, + dynamic_axes, + export_dir=export_dir, + ) + + print("###################### UNet Exported #####################") + + vae_encoder_input = { + "sample": torch.randn(bs, 3, 512, 512), + "return_dict": False, + } + + output_names = ["latent_sample"] + + dynamic_axes = { + "sample": {0: "batch_size", 1: "channels", 2: "height", 3: "width"}, + } + + # self.vae_encoder.model.set_attn_processor(AttnProcessor()) + + self.vae_encoder_onnx_path = self.vae_encoder.export( + vae_encoder_input, + output_names, + dynamic_axes, + export_dir=None, + ) + + print("###################### VAE Encoder Exported #####################") + + vae_decoder_input = { + "latent_sample": torch.randn(bs, 4, 64, 64), + "return_dict": False, + } + + output_names = ["sample"] + + dynamic_axes = { + "latent_sample": {0: "batch_size", 1: "channels", 2: "height", 3: "width"}, + } + + # self.vae_decoder.model.set_attn_processor(AttnProcessor()) + + self.vae_decoder_onnx_path = self.vae_decoder.export( + vae_decoder_input, + output_names, + dynamic_axes, + export_dir=None, + ) + + print("###################### VAE Decoder Exported #####################") + + saftey_checker_input = {"clip_input": torch.randn(bs, 3, 224, 224), "images": torch.randn(bs, 3, 512, 512)} + output_names = ["out_images", "has_nsfw_concepts"] + + dynamic_axes = { + "clip_input": {0: "batch_size", 1: "channels", 2: "clip_height", 3: "clip_width"}, + "images": {0: "batch_size", 1: "channels", 2: "height", 3: "width"}, + } + + # self.safety_checker.model.set_attn_processor(AttnProcessor()) + + # for layer in self.safety_checker.model.vision_model.vision_model.encoder.layers: + # config = self.safety_checker.model.config.vision_config + # layer.self_attn = CLIPAttention(config) + # Replace with eager version + + self.safety_checker_onnx_path = self.safety_checker.export( + saftey_checker_input, + output_names, + dynamic_axes, + export_dir=None, + ) + + print("###################### Safety Checker Exported #####################") + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + seq_len: Union[int, List[int]] = 32, + batch_size: int = 1, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + **compiler_options, + ) -> str: + # Compile text_encoder + + # Make specilization + + seq_len = self.tokenizer.model_max_length + + specializations = [ + {"batch_size": batch_size, "seq_len": seq_len}, + ] + + self.text_encoder_compile_path = self.text_encoder._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + **compiler_options, + ) + + print("###################### Text Encoder Compiled #####################") + + # Compile unet + + specializations = [ + { + "batch_size": batch_size, + "channels": 4, + "height": self.unet.model.config.sample_size, + "width": self.unet.model.config.sample_size, + "seq_len": seq_len, + } + ] + + self.compiled_unet_path = self.unet._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + **compiler_options, + ) + + print("###################### Unet Compiled #####################") + + # Compile vae_encoder + + # encoder_specializations = [ + # { + # "batch_size": batch_size, + # "channels": self.vae_encoder.model.config.in_channels, + # "height": self.vae_encoder.model.config.sample_size, + # "width": self.vae_encoder.model.config.sample_size, + # } + # ] + + # self.vae_encoder_compile_path=self.vae_encoder._compile( + # onnx_path, + # compile_dir, + # compile_only=True, + # specializations=encoder_specializations, + # convert_to_fp16=True, + # ) + + print("###################### VAE Encoder Compiled #####################") + + # compile vae decoder + + # decoder_sepcializations = [ + # { + # "batch_size": batch_size, + # "channels": 4, + # "height": self.vae_decoder.model.config.sample_size, + # "width": self.vae_decoder.model.config.sample_size, + # } + # ] + + # self.vae_decoder_compile_path=self.vae_decoder._compile( + # onnx_path, + # compile_dir, + # compile_only=True, + # specializations=decoder_sepcializations, + # convert_to_fp16=True, + # ) + + # TODO: Add support of comilation for now it will run on host + + print("###################### VAE Decoder Compiled #####################") + + # compile safety check + + safety_check_specializations = [ + { + "batch_size": batch_size, + "channels": 3, + "height": 512, + "width": 512, + "clip_height": 224, + "clip_width": 224, + } + ] + + self.compiled_safety_checker_path = self.safety_checker._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=safety_check_specializations, + convert_to_fp16=True, + ) + + print("###################### Safety Checker Compiled #####################") + + # def generate() + + @property + def model_name(self) -> str: + pass + + @property + def model_hash(self) -> str: + pass + + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + device_ids: List[int] = [0], + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + **kwargs, + ): + # # Get output for text_encoder + if self.text_encoder.qpc_session is None: + self.text_encoder.qpc_session = QAICInferenceSession(str(self.text_encoder_compile_path), device_ids) + + # Dynamic switching to closest seq_Len based on input_ids_len + + # find the inputs/attention mask shape for which qpc compiled + bs, compield_inputs_shape = self.text_encoder.qpc_session.bindings[0].dims + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="np", + ) + text_encoder_output = { + "last_hidden_state": np.random.rand(bs, 77, 768).astype(np.float32), + "pooler_output": np.random.rand(bs, 768).astype(np.float32), + } + self.text_encoder.qpc_session.set_buffers(text_encoder_output) + ## Testing with the ORT output ## + + import onnxruntime as ort + + ort_session = ort.InferenceSession(str(self.text_encoder.onnx_path)) + + onnx_inputs = {k: v for k, v in text_inputs.items() if k in [i.name for i in ort_session.get_inputs()]} + + onnx_inputs["input_ids"] = onnx_inputs["input_ids"].astype(np.int32) + + ort_outputs = ort_session.run(None, onnx_inputs) + text_inputs_pt = {k: torch.from_numpy(v) for k, v in onnx_inputs.items()} + + pt_output = self.text_encoder.model(**text_inputs_pt) + mad = torch.mean(torch.abs(pt_output[0] - torch.tensor(ort_outputs[0]))) + print("CLIP: MAD onnx vs pytorch", mad) + + self.text_encoder.qpc_session.set_buffers(text_encoder_output) + ai100_output = self.text_encoder.qpc_session.run(onnx_inputs) + mad_ai100_onnnx = np.mean(np.abs(ai100_output["last_hidden_state"] - ort_outputs[0])) + + print("CLIP: MAD ai100 vs onnx", mad_ai100_onnnx) + + ai100_output = ai100_output["last_hidden_state"] + + ## CLIP done here + # 4. Prepare timesteps + + from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps, sigmas) + timesteps = timesteps.numpy() + # 5. Prepare latent variables + # 0. Default height and width to unet + # timesteps = timesteps.astype(np.float32) + + width = height = self.unet.model.config.sample_size + height, width = height * self.vae_scale_factor, width * self.vae_scale_factor + + num_channels_latents = self.unet.model.config.in_channels + latents = self.prepare_latents( + bs, + num_channels_latents, + height, + width, + torch.float32, + generator, + latents, + ) + + # Load qpc + self.unet_qpc_session = QAICInferenceSession(str(self.compiled_unet_path), [1]) + + unet_output = {"out_sample": np.random.rand(bs, 4, 64, 64).astype(np.float32)} + self.unet_qpc_session.set_buffers(unet_output) + + # 3. Denoising loop + for t in timesteps: + latent_input = latents + latent_input = self.scheduler.scale_model_input(latent_input, t) + noise_pred = self.unet_qpc_session.run( + {"encoder_hidden_states": ai100_output, "timestep": t, "sample": latent_input.numpy()} + ) + latents = self.scheduler.step(noise_pred["out_sample"], t, latents).prev_sample + + # VAE decode step + # TODO: Add QPC for VAE decode + image = self.vae_decoder.model(latents / self.vae_decoder.model.config.scaling_factor, return_dict=False)[0] + + # Saftey check + + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image.detach(), output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt") + + self.safety_checker_session = QAICInferenceSession(str(self.compiled_safety_checker_path), [2]) + + safety_checker_output = { + "out_images": np.random.rand(1, 3, 512, 512).astype(np.float32), + "has_nsfw_concepts": np.bool_(1), + } + self.safety_checker_session.set_buffers(safety_checker_output) + + checker_output = self.safety_checker_session.run( + {"clip_input": safety_checker_input["pixel_values"].numpy(), "images": image.detach().numpy()} + ) + + has_nsfw_concept = checker_output["has_nsfw_concepts"].astype("bool") + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + image = self.image_processor.postprocess(image.detach(), output_type=output_type, do_denormalize=do_denormalize) + + # self.maybe_free_model_hooks() + + from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/QEfficient/diffusers/pipelines/stable_diffusion_3/__init__.py b/QEfficient/diffusers/pipelines/stable_diffusion_3/__init__.py new file mode 100644 index 000000000..75daf1953 --- /dev/null +++ b/QEfficient/diffusers/pipelines/stable_diffusion_3/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- diff --git a/QEfficient/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion3.py b/QEfficient/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion3.py new file mode 100644 index 000000000..ceb197feb --- /dev/null +++ b/QEfficient/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion3.py @@ -0,0 +1,915 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import os +from typing import Any, Callable, Dict, List, Optional, Union +from venv import logger + +import numpy as np +import torch +from diffusers import StableDiffusion3Pipeline +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps +from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput + +from QEfficient.diffusers.pipelines.pipeline_utils import QEffSD3Transformer2DModel, QEffTextEncoder, QEffVAE +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.utils import constants + + +class QEFFStableDiffusion3Pipeline(StableDiffusion3Pipeline): + _hf_auto_class = StableDiffusion3Pipeline + """ + A QEfficient-optimized Stable Diffusion 3 pipeline, inheriting from `diffusers.StableDiffusion3Pipeline`. + + This class integrates QEfficient components (e.g., optimized models for text encoder, + transformer, and VAE) to enhance performance, particularly for deployment on Qualcomm AI hardware. + It provides methods for text-to-image generation leveraging these optimized components. + """ + + def __init__(self, model, *args, **kwargs): + self.text_encoder = QEffTextEncoder(model.text_encoder) + self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2) + self.text_encoder_3 = QEffTextEncoder(model.text_encoder_3) + self.transformer = QEffSD3Transformer2DModel(model.transformer) + self.vae_decode = QEffVAE(model, "decoder") + + self.tokenizer = model.tokenizer + self.text_encoder.tokenizer = model.tokenizer + self.text_encoder_2.tokenizer = model.tokenizer_2 + self.text_encoder_3.tokenizer = model.tokenizer_3 + self.tokenizer_max_length = model.tokenizer_max_length + self.scheduler = model.scheduler + + self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode( + latent_sample, return_dict + ) + + self.vae_scale_factor = ( + 2 ** (len(model.vae.config.block_out_channels) - 1) if getattr(model, "vae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=model.vae_scale_factor) + + self.t_max_length = ( + model.tokenizer.model_max_length if hasattr(model, "tokenizer") and model.tokenizer is not None else 77 + ) + self.default_sample_size = ( + model.transformer.config.sample_size + if hasattr(model, "transformer") and model.transformer is not None + else 128 + ) + self.patch_size = ( + model.transformer.config.patch_size + if hasattr(model, "transformer") and model.transformer is not None + else 2 + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + """ + Instantiate a QEFFStableDiffusion3Pipeline from pretrained Diffusers models. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + The path to the pretrained model or its name. + **kwargs (additional keyword arguments): + Additional arguments that can be passed to the underlying `StableDiffusion3Pipeline.from_pretrained` + method. + """ + model = cls._hf_auto_class.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch.float32, + **kwargs, + ) + model.to("cpu") + return cls(model, pretrained_model_name_or_path) + + def export(self, export_dir: Optional[str] = None) -> str: + """ + Exports the model to ``ONNX`` format using ``torch.onnx.export``. + + ``Optional`` Args: + :export_dir (str, optional): The directory path to store ONNX-graph. + + Returns: + :str: Path of the generated ``ONNX`` graph. + """ + + # text_encoder + example_inputs_text_encoder, dynamic_axes_text_encoder, output_names_text_encoder = ( + self.text_encoder.get_onnx_config() + ) + + for i in range(0, 13): + output_names_text_encoder.append("hidden_states_" + str(i)) + self.text_encoder.export( + inputs=example_inputs_text_encoder, + output_names=output_names_text_encoder, + dynamic_axes=dynamic_axes_text_encoder, + export_dir=export_dir, + ) + + # text_encoder_2 + example_inputs_text_encoder_2, dynamic_axes_text_encoder_2, output_names_text_encoder_2 = ( + self.text_encoder_2.get_onnx_config() + ) + + for i in range(0, 33): + output_names_text_encoder_2.append("hidden_states_" + str(i)) + + self.text_encoder_2.export( + inputs=example_inputs_text_encoder_2, + output_names=output_names_text_encoder_2, + dynamic_axes=dynamic_axes_text_encoder_2, + export_dir=export_dir, + ) + + # t5_text_encoder + example_inputs_text_encoder_3, dynamic_axes_text_encoder_3, output_names_text_encoder_3 = ( + self.text_encoder_3.get_onnx_config() + ) + + with torch.no_grad(): + prev_sf = 1 + for i in range(len(self.text_encoder_3.model.encoder.block)): + wosf = constants.WO_SFS[i] + self.text_encoder_3.model.encoder.block[i].layer[0].SelfAttention.o.weight *= 1 / wosf + self.text_encoder_3.model.encoder.block[i].layer[0].scaling_factor *= prev_sf / wosf + self.text_encoder_3.model.encoder.block[i].layer[1].DenseReluDense.wo.weight *= 1 / wosf + prev_sf = wosf + + self.text_encoder_3.export( + inputs=example_inputs_text_encoder_3, + output_names=output_names_text_encoder_3, + dynamic_axes=dynamic_axes_text_encoder_3, + export_dir=export_dir, + ) + + # transformers + example_inputs_transformer, dynamic_axes_transformer, output_names_transformer = ( + self.transformer.get_onnx_config() + ) + + self.transformer.export( + inputs=example_inputs_transformer, + output_names=output_names_transformer, + dynamic_axes=dynamic_axes_transformer, + export_dir=export_dir, + ) + + # vae + example_inputs_vae, dynamic_axes_vae, output_names_vae = self.vae_decode.get_onnx_config() + + self.vae_decoder_onnx_path = self.vae_decode.export( + example_inputs_vae, + output_names_vae, + dynamic_axes_vae, + export_dir=export_dir, + ) + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + seq_len: Union[int, List[int]] = 32, + batch_size: int = 1, + num_devices_text_encoder: int = 1, + num_devices_transformer: int = 4, + num_devices_vae_decoder: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + **compiler_options, + ) -> str: + """ + Compiles the ONNX graphs of the different model components for deployment on Qualcomm AI hardware. + + This method takes the ONNX paths of the text encoders, transformer, and VAE decoder, + and compiles them into an optimized format for inference. + + Args: + onnx_path (`str`, *optional*): + The base directory where ONNX files were exported. If None, it assumes the ONNX + paths are already set as attributes (e.g., `self.text_encoder_onnx_path`). + This parameter is currently not fully utilized as individual ONNX paths are derived + from the `export` method. + compile_dir (`str`, *optional*): + The directory path to store the compiled artifacts. If None, a default location + might be used by the underlying compilation process. + seq_len (`Union[int, List[int]]`, *optional*, defaults to 32): + The sequence length(s) to use for compiling the text encoders. Can be a single + integer or a list of integers for multiple sequence lengths. + batch_size (`int`, *optional*, defaults to 1): + The batch size to use for compilation. + num_devices_text_encoder (`int`, *optional*, defaults to 1): + The number of AI devices to deploy the text encoder models on. + num_devices_transformer (`int`, *optional*, defaults to 4): + The number of AI devices to deploy the transformer model on. + num_devices_vae_decoder (`int`, *optional*, defaults to 1): + The number of AI devices to deploy the VAE decoder model on. + num_cores (`int`, *optional*, defaults to 16): + The number of cores to use for compilation. This argument is currently marked + as `FIXME: Make this mandatory arg`. + mxfp6_matmul (`bool`, *optional*, defaults to `False`): + If `True`, enables mixed-precision floating-point 6-bit matrix multiplication + optimization during compilation. + **compiler_options: + Additional keyword arguments to pass to the underlying compiler. + + Returns: + `str`: A message indicating the compilation status or path to compiled artifacts. + (Note: The current implementation might need to return specific paths for each compiled model). + """ + if any( + path is None + for path in [ + self.text_encoder.onnx_path, + self.text_encoder_2.onnx_path, + self.text_encoder_3.onnx_path, + self.transformer.onnx_path, + self.vae_decode.onnx_path, + ] + ): + self.export() + + # text_encoder + specializations_text_encoder = self.text_encoder.get_specializations( + batch_size, self.tokenizer.model_max_length + ) + + self.text_encoder_compile_path = self.text_encoder._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations_text_encoder, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices_text_encoder, + aic_num_cores=num_cores, + **compiler_options, + ) + + # text encoder 2 + specializations_text_encoder_2 = self.text_encoder_2.get_specializations( + batch_size, self.tokenizer.model_max_length + ) + + self.text_encoder_2_compile_path = self.text_encoder_2._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations_text_encoder_2, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices_text_encoder, + aic_num_cores=num_cores, + **compiler_options, + ) + + # text_encoder 3 + specializations_text_encoder_3 = self.text_encoder_3.get_specializations(batch_size, 256) + + self.text_encoder_3_compile_path = self.text_encoder_3._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations_text_encoder_3, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices_text_encoder, + aic_num_cores=num_cores, + **compiler_options, + ) + + # transformer + specializations_transformer = self.transformer.get_specializations(batch_size, 333) + + compiler_options = {"mos": 1, "ols": 2} + self.trasformers_compile_path = self.transformer._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations_transformer, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices_transformer, + aic_num_cores=num_cores, + **compiler_options, + ) + + # vae + specializations_vae = self.vae_decode.get_specializations(batch_size) + + self.vae_decoder_compile_path = self.vae_decode._compile( + onnx_path, + compile_dir, + compile_only=True, + specializations=specializations_vae, + convert_to_fp16=True, + mdp_ts_num_devices=num_devices_vae_decoder, + ) + + def _get_clip_prompt_embeds( + self, + text_encoder, + tokenizer, + clip_index: bool, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int] = 1, + clip_skip: Optional[int] = None, + device_ids: List[int] = None, + ): + """ + Get CLIP prompt embeddings for a given text encoder and tokenizer. + + Args: + text_encoder: The QEffTextEncoder instance to use for encoding. + tokenizer: The tokenizer to use for text preprocessing. + clip_index (int): Index of the CLIP model (0 or 1) to determine embedding dimensions and hidden state range. + prompt (Union[str, List[str]]): The input prompt(s) to encode. + num_images_per_prompt (Optional[int], defaults to 1): Number of images to generate per prompt. + clip_skip (Optional[int], optional): Number of layers to skip from the end when extracting hidden states. + device_ids (List[int], optional): List of device IDs to use for inference. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - prompt_embd_text_encoder: The prompt embeddings from the text encoder. + - pooled_prompt_embeds_text_encoder: The pooled prompt embeddings. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + # to pick correct hidden_state range for each clip model + hidden_state_range = 33 if clip_index else 13 + + # choose embed_dim based on the clip model index. + embed_dim = 1280 if clip_index else 768 + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + + if text_encoder.qpc_session is None: + text_encoder.qpc_session = QAICInferenceSession(text_encoder.qpc_path, device_ids=device_ids) + + text_encoder_output = { + "pooler_output": np.random.rand(batch_size, embed_dim).astype(np.int32), + "last_hidden_state": np.random.rand(batch_size, self.tokenizer_max_length, embed_dim).astype(np.int32), + } + + for i in range(0, hidden_state_range): + text_encoder_output[f"hidden_states_{i}"] = np.random.rand( + batch_size, self.tokenizer_max_length, embed_dim + ).astype(np.int32) + text_encoder.qpc_session.set_buffers(text_encoder_output) + + aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)} + aic_embeddings = text_encoder.qpc_session.run(aic_text_input) + aic_text_encoder_emb = aic_embeddings["pooler_output"] + + ## [TEMP] CHECK ACC ## + # prompt_embeds_pytorch = text_encoder.model(text_input_ids, output_hidden_states=True) + # pt_pooled_embed = prompt_embeds_pytorch[0].detach().numpy() + # mad = np.mean(np.abs(pt_pooled_embed - aic_text_encoder_emb)) + # print(f"CLIP text encoder {clip_index}- pooled embed MAD: ", mad) + ### END CHECK ACC ## + + pooled_prompt_embeds = torch.tensor(aic_text_encoder_emb) + if clip_skip is None: + prompt_embd_text_encoder = torch.tensor(aic_embeddings[list(aic_embeddings.keys())[-2]]) + else: + prompt_embd_text_encoder = torch.tensor(aic_embeddings[list(aic_embeddings.keys())[-(clip_skip + 2)]]) + _, seq_len, _ = prompt_embd_text_encoder.shape + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embd_text_encoder = prompt_embd_text_encoder.repeat(1, num_images_per_prompt, 1) + prompt_embd_text_encoder = prompt_embd_text_encoder.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds_text_encoder = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds_text_encoder = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embd_text_encoder, pooled_prompt_embeds_text_encoder + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + Get T5 prompt embeddings for the given prompt(s). + + Args: + prompt (Union[str, List[str]], optional): The input prompt(s) to encode. + num_images_per_prompt (int, defaults to 1): Number of images to generate per prompt. + max_sequence_length (int, defaults to 256): Maximum sequence length for tokenization. + device (Optional[torch.device], optional): The device to place tensors on. + dtype (Optional[torch.dtype], optional): The data type for tensors. + + Returns: + torch.Tensor: The T5 prompt embeddings with shape (batch_size * num_images_per_prompt, seq_len, hidden_size). + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.text_encoder_3.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.text_encoder_3.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.text_encoder_3.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + if self.text_encoder_3.qpc_session is None: + self.text_encoder_3.qpc_session = QAICInferenceSession(str(self.text_encoder_3_compile_path)) + + aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)} + prompt_embeds = torch.tensor(self.text_encoder_3.qpc_session.run(aic_text_input)["last_hidden_state"]) + + # AIC Testing + # prompt_embeds_torch = self.text_encoder_3.model(text_input_ids.to(device))[0] + # mad = torch.abs(prompt_embeds - aic_embeddings).mean() + # print("Clip text-encoder-3 Pytorch vs AI 100:", mad) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + device_ids: List[int] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, + max_sequence_length: int = 256, + ): + """ + Encode the given prompts into text embeddings using the three text encoders (CLIP and T5). + + This method processes prompts through multiple text encoders to generate embeddings suitable + for Stable Diffusion 3 generation. It handles both positive and negative prompts for + classifier-free guidance. + + Args: + prompt (Union[str, List[str]]): The primary prompt(s) to encode. + prompt_2 (Union[str, List[str]]): The secondary prompt(s) for the second CLIP encoder. + prompt_3 (Union[str, List[str]]): The tertiary prompt(s) for the T5 encoder. + device_ids (List[int], optional): List of device IDs to use for inference. + num_images_per_prompt (int, defaults to 1): Number of images to generate per prompt. + do_classifier_free_guidance (bool, defaults to True): Whether to use classifier-free guidance. + negative_prompt (Optional[Union[str, List[str]]], optional): The negative prompt(s) to encode. + negative_prompt_2 (Optional[Union[str, List[str]]], optional): The negative prompt(s) for the second CLIP encoder. + negative_prompt_3 (Optional[Union[str, List[str]]], optional): The negative prompt(s) for the T5 encoder. + prompt_embeds (Optional[torch.FloatTensor], optional): Pre-computed prompt embeddings. + negative_prompt_embeds (Optional[torch.FloatTensor], optional): Pre-computed negative prompt embeddings. + pooled_prompt_embeds (Optional[torch.FloatTensor], optional): Pre-computed pooled prompt embeddings. + negative_pooled_prompt_embeds (Optional[torch.FloatTensor], optional): Pre-computed negative pooled prompt embeddings. + clip_skip (Optional[int], optional): Number of layers to skip from the end when extracting CLIP hidden states. + max_sequence_length (int, defaults to 256): Maximum sequence length for T5 tokenization. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - prompt_embeds: The combined prompt embeddings from all encoders. + - negative_prompt_embeds: The combined negative prompt embeddings (if classifier-free guidance is enabled). + - pooled_prompt_embeds: The pooled prompt embeddings from CLIP encoders. + - negative_pooled_prompt_embeds: The pooled negative prompt embeddings (if classifier-free guidance is enabled). + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + self.text_encoder, + self.text_encoder.tokenizer, + clip_index=0, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + device_ids=device_ids, + ) + + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + self.text_encoder_2, + self.text_encoder_2.tokenizer, + clip_index=1, + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + device_ids=device_ids, + ) + + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + self.text_encoder, + self.text_encoder.tokenizer, + clip_index=0, + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + device_ids=device_ids, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + self.text_encoder_2, + self.text_encoder_2.tokenizer, + clip_index=1, + prompt=negative_prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + device_ids=device_ids, + ) + + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + """ + Generate images from text prompts using the QEfficient-optimized Stable Diffusion 3 pipeline. + + This method performs text-to-image generation by encoding the input prompts through multiple + text encoders, running the diffusion process with the transformer model, and decoding the + final latents to images using the VAE decoder. All components are optimized for Qualcomm AI hardware. + + Args: + prompt (Union[str, List[str]], optional): The primary text prompt(s) to guide image generation. + prompt_2 (Optional[Union[str, List[str]]], optional): Secondary prompt(s) for the second CLIP encoder. + If None, defaults to `prompt`. + prompt_3 (Optional[Union[str, List[str]]], optional): Tertiary prompt(s) for the T5 encoder. + If None, defaults to `prompt`. + height (Optional[int], optional): Height of the generated image in pixels. If None, uses default + sample size scaled by VAE scale factor. + width (Optional[int], optional): Width of the generated image in pixels. If None, uses default + sample size scaled by VAE scale factor. + num_inference_steps (int, defaults to 28): Number of denoising steps during generation. + timesteps (List[int], optional): Custom timesteps to use for denoising. If provided, overrides + `num_inference_steps`. + guidance_scale (float, defaults to 7.0): Guidance scale for classifier-free guidance. Higher values + result in images more closely linked to the prompt at the expense of lower image quality. + negative_prompt (Optional[Union[str, List[str]]], optional): Negative prompt(s) to guide what not + to include in image generation. + negative_prompt_2 (Optional[Union[str, List[str]]], optional): Negative prompt(s) for the second + CLIP encoder. + negative_prompt_3 (Optional[Union[str, List[str]]], optional): Negative prompt(s) for the T5 encoder. + num_images_per_prompt (Optional[int], defaults to 1): Number of images to generate per prompt. + generator (Optional[Union[torch.Generator, List[torch.Generator]]], optional): Random number + generator(s) for reproducible generation. + latents (Optional[torch.FloatTensor], optional): Pre-generated noisy latents sampled from a Gaussian + distribution to be used as inputs for image generation. + prompt_embeds (Optional[torch.FloatTensor], optional): Pre-generated text embeddings. Can be used + to easily tweak text inputs (prompt weighting). + negative_prompt_embeds (Optional[torch.FloatTensor], optional): Pre-generated negative text embeddings. + pooled_prompt_embeds (Optional[torch.FloatTensor], optional): Pre-generated pooled text embeddings. + negative_pooled_prompt_embeds (Optional[torch.FloatTensor], optional): Pre-generated negative pooled + text embeddings. + output_type (Optional[str], defaults to "pil"): Output format of the generated images. Choose between + "pil", "np", "pt", or "latent". + return_dict (bool, defaults to True): Whether to return a `StableDiffusion3PipelineOutput` instead + of a plain tuple. + joint_attention_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments to pass + to the joint attention layers. + clip_skip (Optional[int], optional): Number of layers to skip from the end when extracting CLIP + hidden states. + callback_on_step_end (Optional[Callable[[int, int, Dict], None]], optional): Callback function + called at the end of each denoising step. + callback_on_step_end_tensor_inputs (List[str], defaults to ["latents"]): List of tensor inputs + to pass to the callback function. + max_sequence_length (int, defaults to 256): Maximum sequence length for T5 text encoder tokenization. + + Returns: + Union[StableDiffusion3PipelineOutput, Tuple]: If `return_dict` is True, returns a + `StableDiffusion3PipelineOutput` object containing the generated images. Otherwise, + returns a tuple with the generated images. + + Examples: + ```python + # Basic text-to-image generation + from QEfficient import QEFFStableDiffusion3Pipeline + + pipeline = QEFFStableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-large") + pipeline.compile(num_devices_text_encoder=1, num_devices_transformer=4, num_devices_vae_decoder=1) + + # NOTE: guidance_scale <=1 is not supported + image = pipeline("A girl laughing", num_inference_steps=28, guidance_scale=2.0).images[0] + image.save("girl_laughing.png") + ``` + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + device = "cpu" + + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.model.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + ###### AIC related changes of transformers ###### + if self.transformer.qpc_session is None: + self.transformer.qpc_session = QAICInferenceSession(str(self.transformer.qpc_path)) + + output_buffer = { + "output": np.random.rand( + 2 * batch_size, num_channels_latents, self.default_sample_size, self.default_sample_size + ).astype(np.int32), + } + + self.transformer.qpc_session.set_buffers(output_buffer) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + timestep = np.array([t], dtype=np.int64) + + # noise_pred_torch = self.transformer.model( + # hidden_states=latent_model_input, + # timestep=torch.tensor(timestep), + # encoder_hidden_states=prompt_embeds, + # pooled_projections=pooled_prompt_embeds, + # joint_attention_kwargs=self.joint_attention_kwargs, + # return_dict=False, + # )[0] + + noise_pred = self.transformer.qpc_session.run( + { + "encoder_hidden_states": prompt_embeds.detach().numpy(), + "pooled_projections": pooled_prompt_embeds.numpy(), + "timestep": timestep, + "hidden_states": latent_model_input.numpy(), + } + ) + + # ###### ACCURACY TESTING ####### + # mad=np.mean(np.abs(noise_pred_torch.detach().numpy()-noise_pred['output'])) + # print("transfromer model MAD:", mad) + + noise_pred = torch.tensor(noise_pred["output"]) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = ( + latents / self.vae_decode.model.config.scaling_factor + ) + self.vae_decode.model.config.shift_factor + + # image_torch = self.vae_decode.model(latents, return_dict=False)[0] + + vae_session = QAICInferenceSession(str(self.vae_decoder_compile_path)) + + output_buffer = { + "sample": np.random.rand( + batch_size, 3, self.vae_decode.model.config.sample_size, self.vae_decode.model.config.sample_size + ).astype(np.int32) + } + + vae_session.set_buffers(output_buffer) + inputs = {"latent_sample": latents.numpy()} + image = vae_session.run(inputs) + # mad= np.mean(np.abs(image['sample']-image_torch.detach().numpy())) + # print("VAE mad: ",mad) + image = self.image_processor.postprocess(torch.tensor(image["sample"]), output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/QEfficient/generation/cloud_infer.py b/QEfficient/generation/cloud_infer.py index 8519d824c..8fe1c0868 100644 --- a/QEfficient/generation/cloud_infer.py +++ b/QEfficient/generation/cloud_infer.py @@ -84,7 +84,7 @@ def __init__( self.binding_index_map = {binding.name: binding.index for binding in self.bindings} # Create and load Program prog_properties = qaicrt.QAicProgramProperties() - prog_properties.SubmitRetryTimeoutMs = 60_000 + prog_properties.SubmitRetryTimeoutMs = 60_00000 if device_ids and len(device_ids) > 1: prog_properties.devMapping = ":".join(map(str, device_ids)) self.program = qaicrt.Program(self.context, None, qpc, prog_properties) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index ca74c0ddd..6719396c0 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -142,6 +142,13 @@ Starcoder2ForCausalLM, Starcoder2Model, ) +from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5LayerCrossAttention, + T5LayerFF, + T5LayerNorm, + T5LayerSelfAttention, +) from transformers.models.whisper.modeling_whisper import ( WhisperAttention, WhisperDecoder, @@ -309,6 +316,13 @@ QEffStarcoder2ForCausalLM, QEffStarcoder2Model, ) +from QEfficient.transformers.models.t5.modeling_t5 import ( + QEffT5Attention, + QEffT5LayerCrossAttention, + QEffT5LayerFF, + QEffT5LayerNorm, + QEffT5LayerSelfAttention, +) from QEfficient.transformers.models.whisper.modeling_whisper import ( QEffWhisperAttention, QEffWhisperDecoder, @@ -617,6 +631,22 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): _match_class_replace_method = {} +class T5ModelTransform(ModuleMappingTransform): + # supported architectures + _module_mapping = { + T5LayerFF: QEffT5LayerFF, + T5LayerSelfAttention: QEffT5LayerSelfAttention, + T5LayerCrossAttention: QEffT5LayerCrossAttention, + T5Attention: QEffT5Attention, + T5LayerNorm: QEffT5LayerNorm, + } + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + model, transformed = super().apply(model) + return model, transformed + + class PoolingTransform: """ Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output. diff --git a/QEfficient/transformers/models/t5/__init__.py b/QEfficient/transformers/models/t5/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/QEfficient/transformers/models/t5/modeling_t5.py b/QEfficient/transformers/models/t5/modeling_t5.py new file mode 100644 index 000000000..9ba5869d7 --- /dev/null +++ b/QEfficient/transformers/models/t5/modeling_t5.py @@ -0,0 +1,217 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import torch.nn as nn +from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5LayerCrossAttention, + T5LayerFF, + T5LayerNorm, + T5LayerSelfAttention, +) + + +class QEffT5LayerNorm(T5LayerNorm): + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + div_first = hidden_states * torch.rsqrt(torch.tensor(hidden_states.shape[-1], dtype=torch.float32)) + variance = div_first.pow(2).sum(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class QEffT5LayerFF(T5LayerFF): + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states * 1.0 + self.dropout(forwarded_states) + return hidden_states + + +class QEffT5Attention(T5Attention): + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + # Original line: position_bias = position_bias[:, :, -seq_length:, :] + if past_key_value is not None: # This block is where the patch applies + # position_bias = position_bias[:, :, -hidden_states.size(1) :, :] # Original line (commented in patch) + position_bias = position_bias[:, :, -1:, :] # Added by patch + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, past_key_value, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class QEffT5LayerSelfAttention(T5LayerSelfAttention): + def __qeff_init__(self): + self.scaling_factor = 1.0 + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states * 1.0 + self.dropout(attention_output[0]) # Modified by patch + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class QEffT5LayerCrossAttention(T5LayerCrossAttention): + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + cache_position=cache_position, + ) + layer_output = hidden_states * 1.0 + self.dropout(attention_output[0]) # Modified by patch + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 50f36ea32..e458fe5b2 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -68,7 +68,7 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80 ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99 -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 17 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] @@ -103,6 +103,35 @@ def get_models_dir(): GEMMA3_MAX_POSITION_EMBEDDINGS = 32768 +# wo_sfs: weight output scaling factors (used to normalize T5 encoder output weights before export) +WO_SFS = [ + 61, + 203, + 398, + 615, + 845, + 1190, + 1402, + 2242, + 1875, + 2393, + 3845, + 3213, + 3922, + 4429, + 5020, + 5623, + 6439, + 6206, + 5165, + 4593, + 2802, + 2618, + 1891, + 1419, +] + + class Constants: # Export Constants. SEQ_LEN = 32 diff --git a/docs/image/girl_laughing.png b/docs/image/girl_laughing.png new file mode 100644 index 000000000..f3ad34a7a Binary files /dev/null and b/docs/image/girl_laughing.png differ diff --git a/examples/diffusers/stable_diffusion_3/stable_diffusion_35_example.py b/examples/diffusers/stable_diffusion_3/stable_diffusion_35_example.py new file mode 100644 index 000000000..ad90605b0 --- /dev/null +++ b/examples/diffusers/stable_diffusion_3/stable_diffusion_35_example.py @@ -0,0 +1,15 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from QEfficient import QEFFStableDiffusion3Pipeline + +pipeline = QEFFStableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-large-turbo") +pipeline.compile(num_devices_text_encoder=1, num_devices_transformer=4, num_devices_vae_decoder=1) + +# NOTE: guidance_scale <=1 is not supported +image = pipeline("A girl laughing", num_inference_steps=28, guidance_scale=2.0).images[0] +image.save("girl_laughing_turbo.png") diff --git a/pyproject.toml b/pyproject.toml index 479736c22..bf439548a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,18 +39,18 @@ dependencies = [ "fire", "py7zr", "torchmetrics==1.7.0", - "torch==2.4.1; platform_machine=='aarch64'", + "torch==2.7.1; platform_machine=='aarch64'", # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11 "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'", - "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'", - "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'", + "torch@https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_machine=='x86_64'", + "torch@https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_machine=='x86_64'", ] [project.optional-dependencies] test = ["pytest","pytest-mock"] docs = ["Sphinx==7.1.2","sphinx-rtd-theme==2.0.0","myst-parser==3.0.1","sphinx-multiversion"] quality = ["black", "ruff", "hf_doc_builder@git+https://github.com/huggingface/doc-builder.git"] - +diffusers = ["diffusers== 0.31.0"] [build-system] requires = ["setuptools>=62.0.0"] build-backend = "setuptools.build_meta" @@ -71,4 +71,4 @@ target-version = "py310" [tool.pytest.ini_options] addopts = "-W ignore -s -v" junit_logging = "all" -doctest_optionflags = "NUMBER NORMALIZE_WHITESPACE ELLIPSIS" +doctest_optionflags = "NUMBER NORMALIZE_WHITESPACE ELLIPSIS" \ No newline at end of file