Skip to content

Commit 5f762ba

Browse files
initial restructure
1 parent b9ff57a commit 5f762ba

File tree

6 files changed

+223
-107
lines changed

6 files changed

+223
-107
lines changed

keras/src/dtype_policies/dtype_policy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,11 @@ def _get_quantized_dtype_policy_by_str(policy):
350350
f"Received: policy={policy}"
351351
)
352352
mode, source_name = split_name
353-
if policy.startswith("int8") or policy.startswith("int4"):
353+
if (
354+
policy.startswith("int8")
355+
or policy.startswith("int4")
356+
or policy.startswith("gptq")
357+
):
354358
return QuantizedDTypePolicy(mode, source_name)
355359
elif policy.startswith("float8"):
356360
return QuantizedFloat8DTypePolicy(mode, source_name)

keras/src/layers/core/dense.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def build(self, input_shape):
109109
kernel_shape = (input_shape[-1], self.units)
110110
if self.quantization_mode:
111111
self.quantized_build(kernel_shape, mode=self.quantization_mode)
112-
if self.quantization_mode not in ("int8", "int4"):
112+
if self.quantization_mode not in ("int8", "int4", "gptq"):
113113
# If the layer is quantized to int8 or int4, `self._kernel` will be
114114
# added in `self._int8_build` or `_int4_build`. Therefore, we skip
115115
# it here.
@@ -332,10 +332,36 @@ def quantized_build(self, kernel_shape, mode):
332332
self._int4_build(kernel_shape)
333333
elif mode == "float8":
334334
self._float8_build()
335+
elif mode == "gptq":
336+
self._gptq_build(kernel_shape)
335337
else:
336338
raise self._quantization_mode_error(mode)
337339
self._is_quantized = True
338340

341+
def _gptq_build(self, kernel_shape):
342+
self._kernel = self.add_weight(
343+
name="kernel",
344+
shape=kernel_shape,
345+
# TODO: choose this based on weight bits
346+
dtype="int8",
347+
initializer="zeros",
348+
trainable=False,
349+
)
350+
self.kernel_scale = self.add_weight(
351+
name="scale",
352+
shape=(kernel_shape),
353+
dtype="float32",
354+
initializer="zeros",
355+
trainable=False,
356+
)
357+
self.zero_point = self.add_weight(
358+
name="zero_point",
359+
shape=(kernel_shape),
360+
dtype="float32",
361+
initializer="zeros",
362+
trainable=False,
363+
)
364+
339365
def _int8_build(self, kernel_shape):
340366
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
341367
self._kernel = self.add_weight(
@@ -526,6 +552,12 @@ def grad_fn(*args, upstream=None):
526552
x = self.activation(x)
527553
return x
528554

555+
def _gptq_call(self, inputs, training=None):
556+
del training
557+
x = ops.matmul(inputs, ops.subtract(self._kernel, self.zero_point))
558+
x = ops.matmul(x, self.kernel_scale)
559+
return x
560+
529561
def _float8_call(self, inputs, training=None):
530562
if self.lora_enabled:
531563
raise NotImplementedError(
@@ -654,6 +686,9 @@ def quantize(self, mode, type_check=True):
654686
self.kernel_scale.assign(kernel_scale)
655687
elif mode == "float8":
656688
self.quantized_build(kernel_shape, mode)
689+
elif mode == "gptq":
690+
del self._kernel
691+
self.quantized_build(kernel_shape, mode)
657692
else:
658693
raise self._quantization_mode_error(mode)
659694

keras/src/layers/layer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,8 @@ def quantized_call(self, *args, **kwargs):
13181318
return self._float8_call(*args, **kwargs)
13191319
elif self.quantization_mode == "int4":
13201320
return self._int4_call(*args, **kwargs)
1321+
elif self.quantization_mode == "gptq":
1322+
return self._gptq_call(*args, **kwargs)
13211323
else:
13221324
raise self._quantization_mode_error(self.quantization_mode)
13231325

@@ -1330,6 +1332,9 @@ def _int8_call(self, *args, **kwargs):
13301332
def _float8_call(self, *args, **kwargs):
13311333
raise self._not_implemented_error(self._float8_call)
13321334

1335+
def _gptq_call(self, *args, **kwargs):
1336+
raise self._not_implemented_error(self._gptq_call)
1337+
13331338
def _not_implemented_error(self, attr, msg=None):
13341339
if callable(attr):
13351340
attr_name = attr.__name__

0 commit comments

Comments
 (0)