@@ -109,7 +109,7 @@ def build(self, input_shape):
109
109
kernel_shape = (input_shape [- 1 ], self .units )
110
110
if self .quantization_mode :
111
111
self .quantized_build (kernel_shape , mode = self .quantization_mode )
112
- if self .quantization_mode not in ("int8" , "int4" ):
112
+ if self .quantization_mode not in ("int8" , "int4" , "gptq" ):
113
113
# If the layer is quantized to int8 or int4, `self._kernel` will be
114
114
# added in `self._int8_build` or `_int4_build`. Therefore, we skip
115
115
# it here.
@@ -332,10 +332,36 @@ def quantized_build(self, kernel_shape, mode):
332
332
self ._int4_build (kernel_shape )
333
333
elif mode == "float8" :
334
334
self ._float8_build ()
335
+ elif mode == "gptq" :
336
+ self ._gptq_build (kernel_shape )
335
337
else :
336
338
raise self ._quantization_mode_error (mode )
337
339
self ._is_quantized = True
338
340
341
+ def _gptq_build (self , kernel_shape ):
342
+ self ._kernel = self .add_weight (
343
+ name = "kernel" ,
344
+ shape = kernel_shape ,
345
+ # TODO: choose this based on weight bits
346
+ dtype = "int8" ,
347
+ initializer = "zeros" ,
348
+ trainable = False ,
349
+ )
350
+ self .kernel_scale = self .add_weight (
351
+ name = "scale" ,
352
+ shape = (kernel_shape ),
353
+ dtype = "float32" ,
354
+ initializer = "zeros" ,
355
+ trainable = False ,
356
+ )
357
+ self .zero_point = self .add_weight (
358
+ name = "zero_point" ,
359
+ shape = (kernel_shape ),
360
+ dtype = "float32" ,
361
+ initializer = "zeros" ,
362
+ trainable = False ,
363
+ )
364
+
339
365
def _int8_build (self , kernel_shape ):
340
366
self .inputs_quantizer = quantizers .AbsMaxQuantizer (axis = - 1 )
341
367
self ._kernel = self .add_weight (
@@ -526,6 +552,12 @@ def grad_fn(*args, upstream=None):
526
552
x = self .activation (x )
527
553
return x
528
554
555
+ def _gptq_call (self , inputs , training = None ):
556
+ del training
557
+ x = ops .matmul (inputs , ops .subtract (self ._kernel , self .zero_point ))
558
+ x = ops .matmul (x , self .kernel_scale )
559
+ return x
560
+
529
561
def _float8_call (self , inputs , training = None ):
530
562
if self .lora_enabled :
531
563
raise NotImplementedError (
@@ -654,6 +686,9 @@ def quantize(self, mode, type_check=True):
654
686
self .kernel_scale .assign (kernel_scale )
655
687
elif mode == "float8" :
656
688
self .quantized_build (kernel_shape , mode )
689
+ elif mode == "gptq" :
690
+ del self ._kernel
691
+ self .quantized_build (kernel_shape , mode )
657
692
else :
658
693
raise self ._quantization_mode_error (mode )
659
694
0 commit comments