Skip to content

Commit 24c2bff

Browse files
authored
Gaudi gptq gidx support (#3297)
Signed-off-by: Wang, Yi A <[email protected]>
1 parent fc2405c commit 24c2bff

File tree

2 files changed

+32
-11
lines changed
  • backends/gaudi/server/text_generation_server/layers/gptq
  • server/text_generation_server/layers

2 files changed

+32
-11
lines changed

backends/gaudi/server/text_generation_server/layers/gptq/hpu.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,31 @@ def _preprocessing(self):
8989
g_idx_trivial = torch.tensor(
9090
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
9191
)
92-
assert torch.equal(
93-
self.g_idx, g_idx_trivial
94-
), "Non-trivial tensor g_idx is not supported"
92+
sort_zeros = not (torch.equal(self.g_idx, g_idx_trivial))
9593
self.qzeros = self.qzeros.cpu()
9694
zeros = self.unpack_zeros_from_cuda_old_format()
97-
new_qzeros = pack_tensor(zeros)
98-
self.qzeros = new_qzeros.to(orig_device)
95+
if sort_zeros:
96+
zeros_group_1 = torch.zeros(
97+
(self.infeatures, self.outfeatures),
98+
dtype=zeros.dtype,
99+
device=zeros.device,
100+
)
101+
scales = self.scales.cpu()
102+
scale_group_1 = torch.zeros(
103+
(self.infeatures, self.outfeatures),
104+
dtype=scales.dtype,
105+
device=scales.device,
106+
)
107+
for i in range(self.infeatures):
108+
zeros_group_1[i] = zeros[self.g_idx[i]]
109+
scale_group_1[i] = self.scales[self.g_idx[i]]
110+
self.qzeros = pack_tensor(zeros_group_1).to(orig_device)
111+
self.scales = scale_group_1.to(orig_device)
112+
self.groupsize = 1
113+
self.g_idx = None
114+
else:
115+
new_qzeros = pack_tensor(zeros)
116+
self.qzeros = new_qzeros.to(orig_device)
99117

100118
@classmethod
101119
def new(cls, bits, groupsize, infeatures, outfeatures, bias):

server/text_generation_server/layers/lora.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
punica_sgmv = None
1717

1818
if SYSTEM == "ipex":
19-
from intel_extension_for_pytorch.llm.functional import (
20-
bgmv_expand,
21-
bgmv_shrink,
22-
sgmv_expand,
23-
sgmv_shrink,
24-
)
19+
try:
20+
from intel_extension_for_pytorch.llm.functional import (
21+
bgmv_expand,
22+
bgmv_shrink,
23+
sgmv_expand,
24+
sgmv_shrink,
25+
)
26+
except ImportError:
27+
pass
2528

2629

2730
if TYPE_CHECKING:

0 commit comments

Comments
 (0)