@@ -89,13 +89,31 @@ def _preprocessing(self):
89
89
g_idx_trivial = torch .tensor (
90
90
g_idx_trivial , dtype = torch .int32 , device = self .g_idx .device
91
91
)
92
- assert torch .equal (
93
- self .g_idx , g_idx_trivial
94
- ), "Non-trivial tensor g_idx is not supported"
92
+ sort_zeros = not (torch .equal (self .g_idx , g_idx_trivial ))
95
93
self .qzeros = self .qzeros .cpu ()
96
94
zeros = self .unpack_zeros_from_cuda_old_format ()
97
- new_qzeros = pack_tensor (zeros )
98
- self .qzeros = new_qzeros .to (orig_device )
95
+ if sort_zeros :
96
+ zeros_group_1 = torch .zeros (
97
+ (self .infeatures , self .outfeatures ),
98
+ dtype = zeros .dtype ,
99
+ device = zeros .device ,
100
+ )
101
+ scales = self .scales .cpu ()
102
+ scale_group_1 = torch .zeros (
103
+ (self .infeatures , self .outfeatures ),
104
+ dtype = scales .dtype ,
105
+ device = scales .device ,
106
+ )
107
+ for i in range (self .infeatures ):
108
+ zeros_group_1 [i ] = zeros [self .g_idx [i ]]
109
+ scale_group_1 [i ] = self .scales [self .g_idx [i ]]
110
+ self .qzeros = pack_tensor (zeros_group_1 ).to (orig_device )
111
+ self .scales = scale_group_1 .to (orig_device )
112
+ self .groupsize = 1
113
+ self .g_idx = None
114
+ else :
115
+ new_qzeros = pack_tensor (zeros )
116
+ self .qzeros = new_qzeros .to (orig_device )
99
117
100
118
@classmethod
101
119
def new (cls , bits , groupsize , infeatures , outfeatures , bias ):
0 commit comments