Skip to content

Commit bd3dede

Browse files
wenba0lijiaojiao
andauthored
support qwen25 vl w8a8 quantization (#2778)
### What this PR does / why we need it? support qwen25 vl w8a8 quantization ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@62f66be --------- Signed-off-by: lijiaojiao <[email protected]> Co-authored-by: lijiaojiao <[email protected]>
1 parent 2b9269b commit bd3dede

File tree

3 files changed

+103
-3
lines changed

3 files changed

+103
-3
lines changed

tests/ut/models/test_qwen2_5_vl.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,46 @@ def test_cal_cos_sin(self, interleaved, expected, mocker: MockerFixture):
353353
cos_new, _ = vision_transformer.cal_cos_sin(self.input_data)
354354
assert cos_new.shape == (1, 32, 1, 2)
355355

356+
def test_pad_qkv_bias(self, mocker: MockerFixture):
357+
attention = self.init_vision_transformer(mocker)
358+
mocker.patch("torch.nn.Module.__setattr__")
359+
mocker.patch("torch.nn.Module.__getattr__")
360+
mocker.patch("torch.nn.Module.__delattr__")
361+
res = attention.pad_qkv_bias(torch.rand((300)))
362+
assert res.shape[0] == 384
363+
364+
def test_pad_qkv_weight(self, mocker: MockerFixture):
365+
attention = self.init_vision_transformer(mocker)
366+
mocker.patch("torch.nn.Module.__setattr__")
367+
mocker.patch("torch.nn.Module.__getattr__")
368+
mocker.patch("torch.nn.Module.__delattr__")
369+
res = attention.pad_qkv_weight(torch.rand((300, 300)))
370+
assert res.shape == (384, 300)
371+
372+
def test_pad_proj_weight(self, mocker: MockerFixture):
373+
attention = self.init_vision_transformer(mocker)
374+
mocker.patch("torch.nn.Module.__setattr__")
375+
mocker.patch("torch.nn.Module.__getattr__")
376+
mocker.patch("torch.nn.Module.__delattr__")
377+
res = attention.pad_proj_weight(torch.rand((300, 300)))
378+
assert res.shape == (300, 384)
379+
380+
def test_pad_qkv_weight_scale_offset(self, mocker: MockerFixture):
381+
attention = self.init_vision_transformer(mocker)
382+
mocker.patch("torch.nn.Module.__setattr__")
383+
mocker.patch("torch.nn.Module.__getattr__")
384+
mocker.patch("torch.nn.Module.__delattr__")
385+
res = attention.pad_qkv_weight_scale_offset(torch.rand((300, 1)))
386+
assert res.shape == (384, 1)
387+
388+
def test_pad_qkv_deq_scale_quant_bias(self, mocker: MockerFixture):
389+
attention = self.init_vision_transformer(mocker)
390+
mocker.patch("torch.nn.Module.__setattr__")
391+
mocker.patch("torch.nn.Module.__getattr__")
392+
mocker.patch("torch.nn.Module.__delattr__")
393+
res = attention.pad_qkv_deq_scale_quant_bias(torch.rand((300)))
394+
assert res.shape[0] == 384
395+
356396
def test_forward(self, mocker: MockerFixture):
357397
vision_transformer = self.init_vision_transformer(mocker)
358398
mocker.patch("torch.nn.Module.__setattr__")

vllm_ascend/models/qwen2_5_vl.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,40 @@ def pad_proj_weight(self, data):
291291
self.hidden_size, -1)
292292
return out_weight
293293

294+
def pad_qkv_weight_scale_offset(self, data):
295+
reshaped_data = data.reshape(
296+
-1, 3, self.origin_hidden_size_per_attention_head, 1)
297+
data1 = reshaped_data[:, :, :self.
298+
half_origin_hidden_size_per_attention_head, :]
299+
data2 = reshaped_data[:, :, self.
300+
half_origin_hidden_size_per_attention_head:, :]
301+
data1_paded = torch.nn.functional.pad(
302+
data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
303+
0, 0, 0))
304+
data2_paded = torch.nn.functional.pad(
305+
data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
306+
0, 0, 0))
307+
res = torch.cat([data1_paded, data2_paded], dim=2)
308+
res = res.reshape(-1, 1)
309+
return res
310+
311+
def pad_qkv_deq_scale_quant_bias(self, data):
312+
reshaped_data = data.reshape(
313+
-1, 3, self.origin_hidden_size_per_attention_head)
314+
data1 = reshaped_data[:, :, :self.
315+
half_origin_hidden_size_per_attention_head]
316+
data2 = reshaped_data[:, :,
317+
self.half_origin_hidden_size_per_attention_head:]
318+
319+
data1_paded = torch.nn.functional.pad(
320+
data1, (0, self.half_pad_hidden_size_per_attention_head))
321+
data2_paded = torch.nn.functional.pad(
322+
data2, (0, self.half_pad_hidden_size_per_attention_head))
323+
324+
res = torch.cat([data1_paded, data2_paded], dim=2)
325+
res = res.reshape(-1)
326+
return res
327+
294328
def load_weights(self, weights: Iterable[Tuple[str,
295329
torch.Tensor]]) -> Set[str]:
296330
stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [
@@ -318,11 +352,23 @@ def load_weights(self, weights: Iterable[Tuple[str,
318352
weight_loader = getattr(param, "weight_loader",
319353
default_weight_loader)
320354
weight_loader(param, loaded_weight)
321-
if ("attn.proj.weight" in name) and self.enable_pad:
355+
if ("attn.proj.weight_scale" in name or
356+
"attn.proj.weight_offset" in name) and self.enable_pad:
357+
continue
358+
elif ("attn.proj.deq_scale" in name
359+
or "attn.proj.quant_bias" in name) and self.enable_pad:
360+
continue
361+
elif ("attn.qkv.weight_scale" in name
362+
or "attn.qkv.weight_offset" in name) and self.enable_pad:
363+
param.data = self.pad_qkv_weight_scale_offset(param.data)
364+
elif ("attn.qkv.deq_scale" in name
365+
or "attn.qkv.quant_bias" in name) and self.enable_pad:
366+
param.data = self.pad_qkv_deq_scale_quant_bias(param.data)
367+
elif ("attn.proj.weight" in name) and self.enable_pad:
322368
param.data = self.pad_proj_weight(param.data)
323-
if ("attn.qkv.weight" in name) and self.enable_pad:
369+
elif ("attn.qkv.weight" in name) and self.enable_pad:
324370
param.data = self.pad_qkv_weight(param.data)
325-
if ("attn.qkv.bias" in name) and self.enable_pad:
371+
elif ("attn.qkv.bias" in name) and self.enable_pad:
326372
param.data = self.pad_qkv_bias(param.data)
327373
loaded_params.add(name)
328374
return loaded_params
@@ -445,6 +491,17 @@ def forward(
445491
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
446492
class AscendQwen2_5_VLForConditionalGeneration(
447493
Qwen2_5_VLForConditionalGeneration):
494+
packed_modules_mapping = {
495+
"qkv_proj": [
496+
"q_proj",
497+
"k_proj",
498+
"v_proj",
499+
],
500+
"gate_up_proj": [
501+
"gate_proj",
502+
"up_proj",
503+
],
504+
}
448505

449506
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
450507
super().__init__(vllm_config=vllm_config, prefix=prefix)

vllm_ascend/quantization/quant_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class AscendQuantConfig(QuantizationConfig):
5353
"""
5454

5555
def __init__(self, quant_config: Dict[str, Any]):
56+
super().__init__()
5657
self.quant_description = quant_config
5758

5859
def __repr__(self) -> str:
@@ -89,6 +90,8 @@ def override_quantization_method(cls, hf_quant_cfg,
8990
def get_quant_method(self, layer: torch.nn.Module,
9091
prefix: str) -> Optional["QuantizeMethodBase"]:
9192
from vllm.attention.layer import Attention
93+
if prefix.startswith("language_model"):
94+
prefix = prefix.split('.', 1)[-1]
9295
if isinstance(layer, LinearBase):
9396
if self.is_layer_skipped_ascend(prefix,
9497
self.packed_modules_mapping):

0 commit comments

Comments
 (0)