9
9
from keras .src .layers import EinsumDense
10
10
from keras .src .layers import Embedding
11
11
from keras .src .quantizers .gptq import GPTQ
12
- from keras .src .quantizers .gptq_quant import GPTQQuantization
12
+ from keras .src .quantizers .gptq_quant import GPTQQuantizer
13
13
14
14
15
15
def get_dataloader (tokenizer , sequence_length , dataset , num_samples = 128 ):
@@ -93,16 +93,7 @@ def find_layers_in_block(block):
93
93
return found_layers
94
94
95
95
96
- def apply_gptq_layerwise (
97
- model ,
98
- dataloader ,
99
- num_samples ,
100
- hessian_damping ,
101
- group_size ,
102
- symmetric ,
103
- activation_order ,
104
- weight_bits ,
105
- ):
96
+ def apply_gptq_layerwise (model , dataloader , config ):
106
97
"""Applies GPTQ quantization layer-by-layer to a Keras model.
107
98
108
99
This function is designed to work with common transformer architectures,
@@ -134,26 +125,21 @@ def apply_gptq_layerwise(
134
125
attempt to automatically discover its structure.
135
126
dataloader: An iterable providing calibration data. Each item should
136
127
be a batch of token IDs suitable for the model's embedding layer.
137
- num_samples: (int) The number of samples from the dataloader to use for
138
- calibration.
139
- hessian_damping: (float) The percentage of dampening to add to the
140
- Hessian diagonal for stabilization during inverse calculation.
141
- A value of 0.01 is common.
142
- group_size: (int) The size of the groups to use for quantization. A
143
- value of 128 means that 128 weights will share the same scaling
144
- factor. Use -1 for per-channel quantization.
145
- symmetric: (bool) If True, symmetric quantization is used. Otherwise,
146
- asymmetric quantization is used.
147
- activation_order: (bool) If True, reorders the weight columns based on
148
- activation magnitude, which can improve quantization accuracy.
149
- weight_bits: (int) The number of bits to use for the quantized weights,
150
- e.g., 4 for 4-bit quantization.
128
+ config: A `GPTQConfiguration` object.
151
129
152
130
Raises:
153
131
ValueError: If the function cannot automatically find an embedding
154
132
layer or any transformer-like blocks to quantize within the model.
155
133
"""
156
134
logging .info ("Starting model quantization..." )
135
+
136
+ num_samples = config .num_samples
137
+ hessian_damping = config .hessian_damping
138
+ group_size = config .group_size
139
+ symmetric = config .symmetric
140
+ activation_order = config .activation_order
141
+ weight_bits = config .weight_bits
142
+
157
143
embedding_layer = None
158
144
transformer_blocks = []
159
145
if hasattr (model , "backbone" ):
@@ -221,7 +207,8 @@ def apply_gptq_layerwise(
221
207
else :
222
208
logging .info (f"Found layers: { list (sub_layers_map .keys ())} " )
223
209
gptq_objects = {
224
- name : GPTQ (layer ) for name , layer in sub_layers_map .items ()
210
+ name : GPTQ (layer , config )
211
+ for name , layer in sub_layers_map .items ()
225
212
}
226
213
227
214
captured_inputs = {name : [] for name in sub_layers_map .keys ()}
@@ -271,7 +258,7 @@ def hook(*args, **kwargs):
271
258
input_reshaped = ops .reshape (layer_inputs , (- 1 , num_features ))
272
259
gptq_object .update_hessian_with_batch (input_reshaped )
273
260
274
- quantizer = GPTQQuantization (
261
+ quantizer = GPTQQuantizer (
275
262
weight_bits ,
276
263
per_channel = True ,
277
264
symmetric = symmetric ,
@@ -304,7 +291,7 @@ def hook(*args, **kwargs):
304
291
logging .info ("Quantization process complete." )
305
292
306
293
307
- def quantize_model (model , config ):
294
+ def apply_gptq (model , config ):
308
295
"""
309
296
Top-level function to quantize a Keras model using GPTQ.
310
297
"""
@@ -323,13 +310,4 @@ def quantize_model(model, config):
323
310
# is now a NumPy array, which can be sliced and reused.
324
311
calibration_dataloader = full_dataloader [: config .num_samples ]
325
312
326
- apply_gptq_layerwise (
327
- model ,
328
- calibration_dataloader , # Use the calibration slice
329
- config .num_samples , # Use the configured number of samples
330
- config .hessian_damping ,
331
- config .group_size ,
332
- config .symmetric ,
333
- config .activation_order ,
334
- config .weight_bits ,
335
- )
313
+ apply_gptq_layerwise (model , calibration_dataloader , config )
0 commit comments