Skip to content

Commit 910f065

Browse files
committed
Working with cleaned code
Signed-off-by: Amit Raj <[email protected]>
1 parent ab0e6e8 commit 910f065

File tree

6 files changed

+48
-59
lines changed

6 files changed

+48
-59
lines changed

QEfficient/diffusers/models/attention.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
1-
from diffusers.models.attention import JointTransformerBlock, _chunked_feed_forward
21
import torch
3-
import torch as nn
4-
from QEfficient.diffusers.models.attention_processor import QEffJointAttnProcessor2_0
5-
from QEfficient.diffusers.models.attention_processor import QEffAttention
6-
from typing import Optional
2+
from diffusers.models.attention import JointTransformerBlock, _chunked_feed_forward
73

84

95
class QEffJointTransformerBlock(JointTransformerBlock):
10-
116
def forward(
127
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
138
):
@@ -45,7 +40,7 @@ def forward(
4540
# "feed_forward_chunk_size" can be used to save memory
4641
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
4742
else:
48-
#ff_output = self.ff(norm_hidden_states)
43+
# ff_output = self.ff(norm_hidden_states)
4944
ff_output = self.ff(norm_hidden_states, block_size=4096)
5045
ff_output = gate_mlp.unsqueeze(1) * ff_output
5146

@@ -66,7 +61,7 @@ def forward(
6661
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
6762
)
6863
else:
69-
#context_ff_output = self.ff_context(norm_encoder_hidden_states)
64+
# context_ff_output = self.ff_context(norm_encoder_hidden_states)
7065
context_ff_output = self.ff_context(norm_encoder_hidden_states, block_size=333)
7166
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
7267

QEfficient/diffusers/models/attention_processor.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
from diffusers.models.attention_processor import Attention
2-
import torch
31
from typing import Optional
4-
import torch as nn
5-
from diffusers.models.attention_processor import JointAttnProcessor2_0
2+
3+
import torch
4+
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
5+
66

77
class QEffAttention(Attention):
8-
98
def __qeff_init__(self):
10-
processor=QEffJointAttnProcessor2_0()
11-
self.processor=processor
9+
processor = QEffJointAttnProcessor2_0()
10+
self.processor = processor
1211
processor.query_block_size = 64
1312

1413
def get_attention_scores(
@@ -47,8 +46,8 @@ def get_attention_scores(
4746

4847
return attention_probs
4948

49+
5050
class QEffJointAttnProcessor2_0(JointAttnProcessor2_0):
51-
5251
def __call__(
5352
self,
5453
attn: QEffAttention,
@@ -110,18 +109,18 @@ def __call__(
110109

111110
# pre-transpose the key
112111
key = key.transpose(-1, -2)
113-
if query.size(-2) != value.size(-2): # cross-attention, use regular attention
112+
if query.size(-2) != value.size(-2): # cross-attention, use regular attention
114113
# QKV done in single block
115114
attention_probs = attn.get_attention_scores(query, key, attention_mask)
116115
hidden_states = torch.bmm(attention_probs, value)
117-
else: # self-attention, use blocked attention
116+
else: # self-attention, use blocked attention
118117
# QKV done with block-attention (a la FlashAttentionV2)
119118
print(f"{query.shape = }, {key.shape = }, {value.shape = }")
120119
query_block_size = self.query_block_size
121120
query_seq_len = query.size(-2)
122121
num_blocks = (query_seq_len + query_block_size - 1) // query_block_size
123122
for qidx in range(num_blocks):
124-
query_block = query[:,qidx*query_block_size:(qidx+1)*query_block_size,:]
123+
query_block = query[:, qidx * query_block_size : (qidx + 1) * query_block_size, :]
125124
attention_probs = attn.get_attention_scores(query_block, key, attention_mask)
126125
hidden_states_block = torch.bmm(attention_probs, value)
127126
if qidx == 0:

QEfficient/diffusers/models/pytorch_transforms.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,30 @@
66
# -----------------------------------------------------------------------------
77
from typing import Tuple
88

9-
from torch import nn
10-
from QEfficient.customop import CustomRMSNormAIC
11-
12-
13-
from diffusers import AutoencoderKL
14-
from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ExternalModuleMapperTransform
159
from diffusers.models.attention import JointTransformerBlock
1610
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
11+
from torch import nn
1712

18-
19-
from QEfficient.diffusers.models.attention_processor import QEffAttention, QEffJointAttnProcessor2_0, JointAttnProcessor2_0
13+
from QEfficient.base.pytorch_transforms import ModuleMappingTransform
2014
from QEfficient.diffusers.models.attention import QEffJointTransformerBlock
15+
from QEfficient.diffusers.models.attention_processor import (
16+
QEffAttention,
17+
QEffJointAttnProcessor2_0,
18+
)
19+
2120

2221
class CustomOpsTransform(ModuleMappingTransform):
23-
_module_mapping = {
24-
}
22+
_module_mapping = {}
2523

2624

2725
class AttentionTransform(ModuleMappingTransform):
2826
_module_mapping = {
29-
Attention: QEffAttention,
30-
JointAttnProcessor2_0: QEffJointAttnProcessor2_0,
31-
JointTransformerBlock: QEffJointTransformerBlock
27+
Attention: QEffAttention,
28+
JointAttnProcessor2_0: QEffJointAttnProcessor2_0,
29+
JointTransformerBlock: QEffJointTransformerBlock,
3230
}
33-
31+
3432
@classmethod
3533
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
3634
model, transformed = super().apply(model)
37-
return model, transformed
35+
return model, transformed

QEfficient/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66
from QEfficient.base.modeling_qeff import QEFFBaseModel
77
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
8+
from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform
89
from QEfficient.transformers.models.pytorch_transforms import (
910
CustomOpsTransform,
1011
KVCacheExternalModuleMapperTransform,
1112
KVCacheTransform,
1213
)
1314
from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform
1415
from QEfficient.utils.cache import to_hashable
15-
from QEfficient. diffusers.models.pytorch_transforms import CustomOpsTransform, AttentionTransform
1616

1717

1818
class QEffTextEncoder(QEFFBaseModel):
@@ -266,9 +266,7 @@ def get_model_config(self) -> dict:
266266

267267

268268
class QEffSD3Transformer2DModel(QEFFBaseModel):
269-
_pytorch_transforms = [
270-
AttentionTransform, CustomOpsTransform
271-
]
269+
_pytorch_transforms = [AttentionTransform, CustomOpsTransform]
272270
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
273271

274272
def __init__(self, model: nn.modules):

QEfficient/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
import numpy as np
55
import torch
6-
76
from diffusers import StableDiffusionPipeline
87
from diffusers.image_processor import VaeImageProcessor
8+
99
from QEfficient.diffusers.pipelines.pipeline_utils import QEffSafetyChecker, QEffTextEncoder, QEffUNet, QEffVAE
1010
from QEfficient.generation.cloud_infer import QAICInferenceSession
1111
from QEfficient.utils import constants
@@ -252,14 +252,14 @@ def compile(
252252

253253
# Compile vae_encoder
254254

255-
encoder_specializations = [
256-
{
257-
"batch_size": batch_size,
258-
"channels": self.vae_encoder.model.config.in_channels,
259-
"height": self.vae_encoder.model.config.sample_size,
260-
"width": self.vae_encoder.model.config.sample_size,
261-
}
262-
]
255+
# encoder_specializations = [
256+
# {
257+
# "batch_size": batch_size,
258+
# "channels": self.vae_encoder.model.config.in_channels,
259+
# "height": self.vae_encoder.model.config.sample_size,
260+
# "width": self.vae_encoder.model.config.sample_size,
261+
# }
262+
# ]
263263

264264
# self.vae_encoder_compile_path=self.vae_encoder._compile(
265265
# onnx_path,
@@ -273,14 +273,14 @@ def compile(
273273

274274
# compile vae decoder
275275

276-
decoder_sepcializations = [
277-
{
278-
"batch_size": batch_size,
279-
"channels": 4,
280-
"height": self.vae_decoder.model.config.sample_size,
281-
"width": self.vae_decoder.model.config.sample_size,
282-
}
283-
]
276+
# decoder_sepcializations = [
277+
# {
278+
# "batch_size": batch_size,
279+
# "channels": 4,
280+
# "height": self.vae_decoder.model.config.sample_size,
281+
# "width": self.vae_decoder.model.config.sample_size,
282+
# }
283+
# ]
284284

285285
# self.vae_decoder_compile_path=self.vae_decoder._compile(
286286
# onnx_path,

QEfficient/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion3.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import os
2-
import time
32
from typing import Any, Callable, Dict, List, Optional, Union
43
from venv import logger
54

65
import numpy as np
76
import torch
8-
97
from diffusers import StableDiffusion3Pipeline
108
from diffusers.image_processor import VaeImageProcessor
119
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
1210
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
11+
1312
from QEfficient.diffusers.pipelines.pipeline_utils import QEffSD3Transformer2DModel, QEffTextEncoder, QEffVAE
1413
from QEfficient.generation.cloud_infer import QAICInferenceSession
1514
from QEfficient.utils import constants
@@ -310,10 +309,10 @@ def _get_clip_prompt_embeds(
310309
device_ids: List[int] = [0],
311310
):
312311
if clip_model_index == 0:
313-
text_encoder = self.text_encoder
312+
# text_encoder = self.text_encoder
314313
tokenizer = self.tokenizer
315314
else:
316-
text_encoder = self.text_encoder_2
315+
# text_encoder = self.text_encoder_2
317316
tokenizer = self.tokenizer_2
318317

319318
prompt = [prompt] if isinstance(prompt, str) else prompt

0 commit comments

Comments
 (0)