Skip to content

Commit b0a38f2

Browse files
code restructure trying to fix dense shaping
1 parent 92858f0 commit b0a38f2

File tree

10 files changed

+91
-104
lines changed

10 files changed

+91
-104
lines changed

keras/src/layers/core/dense.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,6 @@ def _check_load_own_variables(self, store):
323323
f"Expected: {[v.name for v in all_vars]}"
324324
)
325325

326-
# Quantization-related (int8 and float8) methods
327-
328326
def quantized_build(self, kernel_shape, mode):
329327
if mode == "int8":
330328
self._int8_build(kernel_shape)
@@ -553,10 +551,24 @@ def grad_fn(*args, upstream=None):
553551
return x
554552

555553
def _gptq_call(self, inputs, training=None):
556-
del training
557-
x = ops.matmul(inputs, ops.subtract(self._kernel, self.zero_point))
558-
x = ops.cast(x, self.compute_dtype)
559-
x = ops.matmul(x, self.kernel_scale)
554+
zero_point = self.zero_point
555+
if self.gptq_config.symmetric:
556+
zero_point = ops.zeros_like(self.zero_point, dtype="int8")
557+
558+
# Elementwise dequantization (works for per-weight or
559+
# broadcastable S/ZP)
560+
dequant_kernel = ops.multiply(
561+
ops.subtract(self._kernel, zero_point), self.kernel_scale
562+
)
563+
564+
# Standard Dense matmul
565+
x = ops.matmul(inputs, dequant_kernel)
566+
567+
# Add bias/activation to mirror Dense.call behavior
568+
if self.bias is not None:
569+
x = ops.add(x, self.bias)
570+
if self.activation is not None:
571+
x = self.activation(x)
560572
return x
561573

562574
def _float8_call(self, inputs, training=None):
@@ -650,7 +662,7 @@ def grad(*args, upstream=None, variables=None):
650662
x = self.activation(x)
651663
return x
652664

653-
def quantize(self, mode, type_check=True):
665+
def quantize(self, mode, type_check=True, config=None):
654666
# Prevent quantization of the subclasses
655667
if type_check and (type(self) is not Dense):
656668
raise self._not_implemented_error(self.quantize)
@@ -689,6 +701,7 @@ def quantize(self, mode, type_check=True):
689701
self.quantized_build(kernel_shape, mode)
690702
elif mode == "gptq":
691703
del self._kernel
704+
self.gptq_config = config
692705
self.quantized_build(kernel_shape, mode)
693706
else:
694707
raise self._quantization_mode_error(mode)

keras/src/layers/core/einsum_dense.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from keras.src import quantizers
1313
from keras.src import regularizers
1414
from keras.src.api_export import keras_export
15+
from keras.src.backend.config import backend
1516
from keras.src.layers.input_spec import InputSpec
1617
from keras.src.layers.layer import Layer
1718

@@ -608,11 +609,20 @@ def grad_fn(*args, upstream=None):
608609
return x
609610

610611
def _gptq_call(self, inputs, training=None):
611-
zero_point = self._adjust_scale_for_dequant(self.zero_point)
612+
zero_point = self.zero_point
613+
if self.gptq_config.symmetric:
614+
# Constant zero-point (symmetric): integer 0
615+
zero_point = ops.zeros_like(zero_point, dtype="int8")
612616

613-
dequantized_kernel = ops.subtract(self._kernel, zero_point)
617+
zero_point = self._adjust_scale_for_dequant(zero_point)
614618

615-
x = ops.einsum(self.equation, inputs, dequantized_kernel)
619+
# handle zero point with kernel
620+
kernel = ops.subtract(self._kernel, zero_point)
621+
622+
# if backend is torch, do a cast
623+
if backend() == "torch":
624+
kernel = ops.cast(kernel, self.compute_dtype)
625+
x = ops.einsum(self.equation, inputs, kernel)
616626
x = ops.cast(x, self.compute_dtype)
617627
x = ops.multiply(x, self.kernel_scale)
618628

@@ -798,7 +808,7 @@ def grad(*args, upstream=None, variables=None):
798808
x = self.activation(x)
799809
return x
800810

801-
def quantize(self, mode, type_check=True):
811+
def quantize(self, mode, type_check=True, config=None):
802812
# Prevent quantization of the subclasses
803813
if type_check and (type(self) is not EinsumDense):
804814
raise self._not_implemented_error(self.quantize)
@@ -834,6 +844,7 @@ def quantize(self, mode, type_check=True):
834844
del self._kernel
835845
elif mode == "gptq":
836846
del self._kernel
847+
self.gptq_config = config
837848
self.quantized_build(kernel_shape, mode)
838849

839850
# Assign values to the newly created variables.

keras/src/layers/core/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def _int8_call(self, inputs, training=None):
363363
)
364364
return outputs
365365

366-
def quantize(self, mode, type_check=True):
366+
def quantize(self, mode, type_check=True, config=None):
367367
# Prevent quantization of the subclasses
368368
if type_check and (type(self) is not Embedding):
369369
raise self._not_implemented_error(self.quantize)

keras/src/layers/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1268,7 +1268,7 @@ def _clear_losses(self):
12681268
def quantized_build(self, input_shape, mode):
12691269
raise self._not_implemented_error(self.quantized_build)
12701270

1271-
def quantize(self, mode, type_check=True):
1271+
def quantize(self, mode, type_check=True, config=None):
12721272
raise self._not_implemented_error(self.quantize)
12731273

12741274
def _check_quantize_args(self, mode, compute_dtype):

keras/src/models/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from keras.src.layers.layer import Layer
1010
from keras.src.models.variable_mapping import map_saveable_variables
1111
from keras.src.quantizers.gptq_config import GPTQConfig
12+
from keras.src.quantizers.gptq_core import apply_gptq
1213
from keras.src.saving import saving_api
1314
from keras.src.trainers import trainer as base_trainer
1415
from keras.src.utils import summary_utils
@@ -421,7 +422,7 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs):
421422
**kwargs,
422423
)
423424

424-
def quantize(self, mode, config=None, **kwargs):
425+
def quantize(self, mode, type_check=True, config=None, **kwargs):
425426
"""Quantize the weights of the model.
426427
427428
Note that the model must be built first before calling this method.
@@ -440,8 +441,7 @@ def quantize(self, mode, config=None, **kwargs):
440441
"The `config` argument must be of type "
441442
"`keras.quantizers.GPTQConfig`."
442443
)
443-
# The config object's own quantize method drives the process
444-
config.quantize(self)
444+
apply_gptq(self, config=config)
445445
return
446446

447447
# For all other modes, verify that a config object was not passed.

keras/src/quantizers/gptq.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from keras.src import ops
44
from keras.src.layers import Dense
55
from keras.src.layers import EinsumDense
6+
from keras.src.quantizers.gptq_quant import GPTQQuantizer
67

78

89
class GPTQ:
9-
def __init__(self, layer):
10+
def __init__(self, layer, config):
1011
self.original_layer = layer
12+
self.config = config
1113
self.num_samples = 0
12-
self.quantizer = None
14+
self.quantizer = GPTQQuantizer()
1315

1416
# Explicitly handle each supported layer type
1517
if isinstance(layer, Dense) or (
@@ -63,9 +65,7 @@ def __init__(self, layer):
6365
raise ValueError(
6466
"The EinsumDense layer must be built before applying GPTQ. "
6567
)
66-
# This populates self.original_layer with attributes like
67-
# `_kernel_reduced_axes`, `_kernel_transpose_axes`, etc.
68-
layer._set_quantization_info()
68+
6969
self.hessian = ops.zeros((self.rows, self.rows), dtype="float32")
7070

7171
def update_hessian_with_batch(self, input_batch):
@@ -186,7 +186,8 @@ def quantize_and_correct_block(
186186
based
187187
on their activation's second-order information.
188188
"""
189-
self.original_layer.quantize("gptq")
189+
190+
self.original_layer.quantize("gptq", config=self.config)
190191

191192
weights_matrix = ops.transpose(ops.cast(self.layer.kernel, "float32"))
192193
hessian_matrix = ops.cast(self.hessian, "float32")
@@ -271,17 +272,13 @@ def quantize_and_correct_block(
271272
group_slice = weights_matrix[:, group_start:group_end]
272273
self.quantizer.find_params(group_slice, weight=True)
273274
else:
274-
# Per-column params
275275
self.quantizer.find_params(
276276
ops.expand_dims(weight_column, 1), weight=True
277277
)
278278

279279
# Quantize the current column and store the results
280280
quantized_column = self.quantizer.quantize(
281-
ops.expand_dims(weight_column, 1),
282-
self.quantizer.scale,
283-
self.quantizer.zero,
284-
self.quantizer.maxq,
281+
ops.expand_dims(weight_column, 1)
285282
)[:, 0]
286283

287284
# Write integer weights
@@ -302,16 +299,12 @@ def quantize_and_correct_block(
302299
zero_col = ops.expand_dims(
303300
ops.cast(self.quantizer.zero, "float32")[0, :], 1
304301
)
305-
306302
scales = ops.slice_update(scales, (0, abs_col), scale_col)
307303
zeros = ops.slice_update(zeros, (0, abs_col), zero_col)
308304

309305
# Dequantize back to float32 for error correction.
310306
dequantized_column = self.quantizer.dequantize(
311307
ops.expand_dims(weight_column, 1),
312-
self.quantizer.scale,
313-
self.quantizer.zero,
314-
self.quantizer.maxq,
315308
)[:, 0]
316309

317310
quantization_error = ops.divide(
@@ -408,9 +401,7 @@ def quantize_and_correct_block(
408401
)
409402

410403
self.original_layer.kernel_scale.assign(scale)
411-
412404
self.original_layer.zero_point.assign(zero_point)
413-
414405
self.original_layer._kernel.assign(quantized_kernel)
415406

416407
def free(self):

keras/src/quantizers/gptq_config.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from absl import logging
2-
31
from keras.src.api_export import keras_export
4-
from keras.src.quantizers.gptq_core import quantize_model
52

63

74
@keras_export("keras.quantizers.GPTQConfig")
@@ -157,13 +154,3 @@ def __init__(
157154
self.group_size = group_size
158155
self.symmetric = symmetric
159156
self.activation_order = activation_order
160-
161-
def quantize(self, model):
162-
"""
163-
Applies GPTQ quantization to the provided model using this
164-
configuration.
165-
"""
166-
logging.info("Initiating quantization from GPTQConfig...")
167-
# The core logic is now delegated to gptqutils, which will handle
168-
# the dynamic imports and data loading.
169-
quantize_model(model=model, config=self)

keras/src/quantizers/gptq_core.py

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from keras.src.layers import EinsumDense
1010
from keras.src.layers import Embedding
1111
from keras.src.quantizers.gptq import GPTQ
12-
from keras.src.quantizers.gptq_quant import GPTQQuantization
12+
from keras.src.quantizers.gptq_quant import GPTQQuantizer
1313

1414

1515
def get_dataloader(tokenizer, sequence_length, dataset, num_samples=128):
@@ -93,16 +93,7 @@ def find_layers_in_block(block):
9393
return found_layers
9494

9595

96-
def apply_gptq_layerwise(
97-
model,
98-
dataloader,
99-
num_samples,
100-
hessian_damping,
101-
group_size,
102-
symmetric,
103-
activation_order,
104-
weight_bits,
105-
):
96+
def apply_gptq_layerwise(model, dataloader, config):
10697
"""Applies GPTQ quantization layer-by-layer to a Keras model.
10798
10899
This function is designed to work with common transformer architectures,
@@ -134,26 +125,21 @@ def apply_gptq_layerwise(
134125
attempt to automatically discover its structure.
135126
dataloader: An iterable providing calibration data. Each item should
136127
be a batch of token IDs suitable for the model's embedding layer.
137-
num_samples: (int) The number of samples from the dataloader to use for
138-
calibration.
139-
hessian_damping: (float) The percentage of dampening to add to the
140-
Hessian diagonal for stabilization during inverse calculation.
141-
A value of 0.01 is common.
142-
group_size: (int) The size of the groups to use for quantization. A
143-
value of 128 means that 128 weights will share the same scaling
144-
factor. Use -1 for per-channel quantization.
145-
symmetric: (bool) If True, symmetric quantization is used. Otherwise,
146-
asymmetric quantization is used.
147-
activation_order: (bool) If True, reorders the weight columns based on
148-
activation magnitude, which can improve quantization accuracy.
149-
weight_bits: (int) The number of bits to use for the quantized weights,
150-
e.g., 4 for 4-bit quantization.
128+
config: A `GPTQConfiguration` object.
151129
152130
Raises:
153131
ValueError: If the function cannot automatically find an embedding
154132
layer or any transformer-like blocks to quantize within the model.
155133
"""
156134
logging.info("Starting model quantization...")
135+
136+
num_samples = config.num_samples
137+
hessian_damping = config.hessian_damping
138+
group_size = config.group_size
139+
symmetric = config.symmetric
140+
activation_order = config.activation_order
141+
weight_bits = config.weight_bits
142+
157143
embedding_layer = None
158144
transformer_blocks = []
159145
if hasattr(model, "backbone"):
@@ -221,7 +207,8 @@ def apply_gptq_layerwise(
221207
else:
222208
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
223209
gptq_objects = {
224-
name: GPTQ(layer) for name, layer in sub_layers_map.items()
210+
name: GPTQ(layer, config)
211+
for name, layer in sub_layers_map.items()
225212
}
226213

227214
captured_inputs = {name: [] for name in sub_layers_map.keys()}
@@ -271,7 +258,7 @@ def hook(*args, **kwargs):
271258
input_reshaped = ops.reshape(layer_inputs, (-1, num_features))
272259
gptq_object.update_hessian_with_batch(input_reshaped)
273260

274-
quantizer = GPTQQuantization(
261+
quantizer = GPTQQuantizer(
275262
weight_bits,
276263
per_channel=True,
277264
symmetric=symmetric,
@@ -304,7 +291,7 @@ def hook(*args, **kwargs):
304291
logging.info("Quantization process complete.")
305292

306293

307-
def quantize_model(model, config):
294+
def apply_gptq(model, config):
308295
"""
309296
Top-level function to quantize a Keras model using GPTQ.
310297
"""
@@ -323,13 +310,4 @@ def quantize_model(model, config):
323310
# is now a NumPy array, which can be sliced and reused.
324311
calibration_dataloader = full_dataloader[: config.num_samples]
325312

326-
apply_gptq_layerwise(
327-
model,
328-
calibration_dataloader, # Use the calibration slice
329-
config.num_samples, # Use the configured number of samples
330-
config.hessian_damping,
331-
config.group_size,
332-
config.symmetric,
333-
config.activation_order,
334-
config.weight_bits,
335-
)
313+
apply_gptq_layerwise(model, calibration_dataloader, config)

0 commit comments

Comments
 (0)