Skip to content

Commit ab0e6e8

Browse files
committed
Working with cleaned code
Signed-off-by: Amit Raj <[email protected]>
1 parent 6856ffc commit ab0e6e8

File tree

12 files changed

+673
-966
lines changed

12 files changed

+673
-966
lines changed

QEfficient/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,17 @@ def check_qaic_sdk():
4848
QEFFCommonLoader,
4949
)
5050
from QEfficient.compile.compile_helper import compile
51+
52+
# Imports for the diffusers
53+
from QEfficient.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import QEFFStableDiffusionPipeline
54+
from QEfficient.diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion3 import (
55+
QEFFStableDiffusion3Pipeline,
56+
)
5157
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
5258
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
5359
from QEfficient.peft import QEffAutoPeftModelForCausalLM
5460
from QEfficient.transformers.transform import transform
55-
56-
57-
# Imports for the diffusers
58-
59-
from QEfficient.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import QEFFStableDiffusionPipeline
60-
from QEfficient.diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion3 import QEFFStableDiffusion3Pipeline
61+
6162
# Users can use QEfficient.export for exporting models to ONNX
6263
export = qualcomm_efficient_converter
6364

QEfficient/base/modeling_qeff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from QEfficient.base.pytorch_transforms import PytorchTransform
2323
from QEfficient.compile.qnn_compiler import compile as qnn_compile
2424
from QEfficient.generation.cloud_infer import QAICInferenceSession
25-
from QEfficient.utils import constants, create_json, dump_qconfig, generate_mdp_partition_config, load_json
25+
from QEfficient.utils import constants, create_json, generate_mdp_partition_config, load_json
2626
from QEfficient.utils.cache import QEFF_HOME, to_hashable
2727

2828
logger = logging.getLogger(__name__)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from diffusers.models.attention import JointTransformerBlock, _chunked_feed_forward
2+
import torch
3+
import torch as nn
4+
from QEfficient.diffusers.models.attention_processor import QEffJointAttnProcessor2_0
5+
from QEfficient.diffusers.models.attention_processor import QEffAttention
6+
from typing import Optional
7+
8+
9+
class QEffJointTransformerBlock(JointTransformerBlock):
10+
11+
def forward(
12+
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
13+
):
14+
if self.use_dual_attention:
15+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
16+
hidden_states, emb=temb
17+
)
18+
else:
19+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
20+
21+
if self.context_pre_only:
22+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
23+
else:
24+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
25+
encoder_hidden_states, emb=temb
26+
)
27+
28+
# Attention.
29+
attn_output, context_attn_output = self.attn(
30+
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
31+
)
32+
33+
# Process attention outputs for the `hidden_states`.
34+
attn_output = gate_msa.unsqueeze(1) * attn_output
35+
hidden_states = hidden_states + attn_output
36+
37+
if self.use_dual_attention:
38+
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
39+
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
40+
hidden_states = hidden_states + attn_output2
41+
42+
norm_hidden_states = self.norm2(hidden_states)
43+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
44+
if self._chunk_size is not None:
45+
# "feed_forward_chunk_size" can be used to save memory
46+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
47+
else:
48+
#ff_output = self.ff(norm_hidden_states)
49+
ff_output = self.ff(norm_hidden_states, block_size=4096)
50+
ff_output = gate_mlp.unsqueeze(1) * ff_output
51+
52+
hidden_states = hidden_states + ff_output
53+
54+
# Process attention outputs for the `encoder_hidden_states`.
55+
if self.context_pre_only:
56+
encoder_hidden_states = None
57+
else:
58+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
59+
encoder_hidden_states = encoder_hidden_states + context_attn_output
60+
61+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
62+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
63+
if self._chunk_size is not None:
64+
# "feed_forward_chunk_size" can be used to save memory
65+
context_ff_output = _chunked_feed_forward(
66+
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
67+
)
68+
else:
69+
#context_ff_output = self.ff_context(norm_encoder_hidden_states)
70+
context_ff_output = self.ff_context(norm_encoder_hidden_states, block_size=333)
71+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
72+
73+
return encoder_hidden_states, hidden_states
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from diffusers.models.attention_processor import Attention
2+
import torch
3+
from typing import Optional
4+
import torch as nn
5+
from diffusers.models.attention_processor import JointAttnProcessor2_0
6+
7+
class QEffAttention(Attention):
8+
9+
def __qeff_init__(self):
10+
processor=QEffJointAttnProcessor2_0()
11+
self.processor=processor
12+
processor.query_block_size = 64
13+
14+
def get_attention_scores(
15+
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
16+
) -> torch.Tensor:
17+
dtype = query.dtype
18+
if self.upcast_attention:
19+
query = query.float()
20+
key = key.float()
21+
22+
if attention_mask is None:
23+
baddbmm_input = torch.empty(
24+
query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device
25+
)
26+
beta = 0
27+
else:
28+
baddbmm_input = attention_mask
29+
beta = 1
30+
31+
attention_scores = torch.baddbmm(
32+
baddbmm_input,
33+
query,
34+
key,
35+
beta=beta,
36+
alpha=self.scale,
37+
)
38+
del baddbmm_input
39+
40+
if self.upcast_softmax:
41+
attention_scores = attention_scores.float()
42+
43+
attention_probs = attention_scores.softmax(dim=-1)
44+
del attention_scores
45+
46+
attention_probs = attention_probs.to(dtype)
47+
48+
return attention_probs
49+
50+
class QEffJointAttnProcessor2_0(JointAttnProcessor2_0):
51+
52+
def __call__(
53+
self,
54+
attn: QEffAttention,
55+
hidden_states: torch.FloatTensor,
56+
encoder_hidden_states: torch.FloatTensor = None,
57+
attention_mask: Optional[torch.FloatTensor] = None,
58+
*args,
59+
**kwargs,
60+
) -> torch.FloatTensor:
61+
residual = hidden_states
62+
63+
batch_size = hidden_states.shape[0]
64+
65+
# `sample` projections.
66+
query = attn.to_q(hidden_states)
67+
key = attn.to_k(hidden_states)
68+
value = attn.to_v(hidden_states)
69+
70+
inner_dim = key.shape[-1]
71+
head_dim = inner_dim // attn.heads
72+
73+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
74+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
75+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
76+
77+
if attn.norm_q is not None:
78+
query = attn.norm_q(query)
79+
if attn.norm_k is not None:
80+
key = attn.norm_k(key)
81+
82+
# `context` projections.
83+
if encoder_hidden_states is not None:
84+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
85+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
86+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
87+
88+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
89+
batch_size, -1, attn.heads, head_dim
90+
).transpose(1, 2)
91+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
92+
batch_size, -1, attn.heads, head_dim
93+
).transpose(1, 2)
94+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
95+
batch_size, -1, attn.heads, head_dim
96+
).transpose(1, 2)
97+
98+
if attn.norm_added_q is not None:
99+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
100+
if attn.norm_added_k is not None:
101+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
102+
103+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
104+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
105+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
106+
107+
query = query.reshape(-1, query.shape[-2], query.shape[-1])
108+
key = key.reshape(-1, key.shape[-2], key.shape[-1])
109+
value = value.reshape(-1, value.shape[-2], value.shape[-1])
110+
111+
# pre-transpose the key
112+
key = key.transpose(-1, -2)
113+
if query.size(-2) != value.size(-2): # cross-attention, use regular attention
114+
# QKV done in single block
115+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
116+
hidden_states = torch.bmm(attention_probs, value)
117+
else: # self-attention, use blocked attention
118+
# QKV done with block-attention (a la FlashAttentionV2)
119+
print(f"{query.shape = }, {key.shape = }, {value.shape = }")
120+
query_block_size = self.query_block_size
121+
query_seq_len = query.size(-2)
122+
num_blocks = (query_seq_len + query_block_size - 1) // query_block_size
123+
for qidx in range(num_blocks):
124+
query_block = query[:,qidx*query_block_size:(qidx+1)*query_block_size,:]
125+
attention_probs = attn.get_attention_scores(query_block, key, attention_mask)
126+
hidden_states_block = torch.bmm(attention_probs, value)
127+
if qidx == 0:
128+
hidden_states = hidden_states_block
129+
else:
130+
hidden_states = torch.cat((hidden_states, hidden_states_block), -2)
131+
hidden_states = attn.batch_to_head_dim(hidden_states)
132+
133+
if encoder_hidden_states is not None:
134+
# Split the attention outputs.
135+
hidden_states, encoder_hidden_states = (
136+
hidden_states[:, : residual.shape[1]],
137+
hidden_states[:, residual.shape[1] :],
138+
)
139+
if not attn.context_pre_only:
140+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
141+
142+
# linear proj
143+
hidden_states = attn.to_out[0](hidden_states)
144+
# dropout
145+
hidden_states = attn.to_out[1](hidden_states)
146+
147+
if encoder_hidden_states is not None:
148+
return hidden_states, encoder_hidden_states
149+
else:
150+
return hidden_states

QEfficient/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,15 @@
55
#
66
# ----------------------------------------------------------------------------
77

8+
import torch
9+
810
from diffusers import AutoencoderKL
911
from diffusers.utils.accelerate_utils import apply_forward_hook
10-
import torch
1112

1213

1314
class QEffAutoencoderKL(AutoencoderKL):
14-
1515
@apply_forward_hook
16-
def encode(
17-
self, x: torch.Tensor, return_dict: bool = True
18-
):
16+
def encode(self, x: torch.Tensor, return_dict: bool = True):
1917
"""
2018
Encode a batch of images into latents.
2119
@@ -34,4 +32,3 @@ def encode(
3432
else:
3533
h = self._encode(x)
3634
return h
37-

QEfficient/diffusers/models/pytorch_transforms.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,33 @@
55
#
66
# -----------------------------------------------------------------------------
77
from typing import Tuple
8-
from diffusers import AutoencoderKL
9-
from QEfficient.diffusers.models.autoencoders.autoencoder_kl import QEffAutoencoderKL
10-
from QEfficient.base.pytorch_transforms import ModuleMappingTransform
8+
119
from torch import nn
10+
from QEfficient.customop import CustomRMSNormAIC
11+
12+
13+
from diffusers import AutoencoderKL
14+
from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ExternalModuleMapperTransform
15+
from diffusers.models.attention import JointTransformerBlock
16+
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
17+
1218

19+
from QEfficient.diffusers.models.attention_processor import QEffAttention, QEffJointAttnProcessor2_0, JointAttnProcessor2_0
20+
from QEfficient.diffusers.models.attention import QEffJointTransformerBlock
21+
22+
class CustomOpsTransform(ModuleMappingTransform):
23+
_module_mapping = {
24+
}
1325

14-
class AutoencoderKLTransform(ModuleMappingTransform):
15-
"""Transforms a Diffusers AutoencoderKL model to a QEfficientAutoencoderKL model."""
1626

27+
class AttentionTransform(ModuleMappingTransform):
1728
_module_mapping = {
18-
AutoencoderKL: QEffAutoencoderKL,
19-
}
29+
Attention: QEffAttention,
30+
JointAttnProcessor2_0: QEffJointAttnProcessor2_0,
31+
JointTransformerBlock: QEffJointTransformerBlock
32+
}
33+
2034
@classmethod
2135
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
2236
model, transformed = super().apply(model)
23-
return model, transformed
37+
return model, transformed

QEfficient/diffusers/models/t5_demo/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)