38
38
39
39
""" Adaround wrapper """
40
40
41
- from typing import Union , Dict
41
+ from typing import Union , Dict , List
42
42
import numpy as np
43
43
import tensorflow as tf
44
44
from tensorflow import keras
@@ -58,8 +58,9 @@ class AdaroundWrapper(keras.layers.Layer):
58
58
"""
59
59
Adaround Wrapper base class
60
60
"""
61
+ # pylint: disable=too-many-arguments
61
62
def __init__ (self , session : tf .compat .v1 .Session , op : tf .Operation , param_bw : int , quant_scheme : QuantScheme ,
62
- is_symmetric : bool , strict_symmetric : bool , unsigned_symmetric : bool ):
63
+ is_symmetric : bool , strict_symmetric : bool , unsigned_symmetric : bool , enable_per_channel : bool ):
63
64
"""
64
65
:param session: Tf session.
65
66
:param op: Tf op.
@@ -68,6 +69,7 @@ def __init__(self, session: tf.compat.v1.Session, op: tf.Operation, param_bw: in
68
69
:param is_symmetric: Symmetric vs Asymmetric encodings.
69
70
:param strict_symmetric: Strict symmetric flag.
70
71
:param unsigned_symmetric: Unsigned symmetric flag.
72
+ :param enable_per_channel: if set to True, use per channel quantization
71
73
"""
72
74
super (AdaroundWrapper , self ).__init__ ()
73
75
@@ -80,26 +82,33 @@ def __init__(self, session: tf.compat.v1.Session, op: tf.Operation, param_bw: in
80
82
self ._bias_tensor = tf .convert_to_tensor (bias , dtype = 'float32' )
81
83
82
84
self .use_soft_rounding = tf .compat .v1 .placeholder_with_default (True , shape = [])
85
+ self .enable_per_channel = enable_per_channel
83
86
87
+ # use the last dimension as the default channel index
88
+ self .ch_axis = len (list (self ._weight_tensor .shape )) - 1
84
89
self .encoding = self .compute_encodings (weight , param_bw , quant_scheme , is_symmetric ,
85
- strict_symmetric , unsigned_symmetric )
86
- self .alpha = self ._initialize_alpha (self ._weight_tensor , self .encoding )
90
+ strict_symmetric , unsigned_symmetric , enable_per_channel , self . ch_axis )
91
+ self .alpha = self ._initialize_alpha (self ._weight_tensor , self .encoding , enable_per_channel , self . ch_axis )
87
92
88
93
def adaround_weights (self ) -> tf .Tensor :
89
94
"""
90
95
Adaround the weight tensor
91
96
:return: AdaRounded weight tensor
92
97
"""
93
- return self .get_adarounded_weight (self .alpha , self ._weight_tensor , self .encoding , self .use_soft_rounding )
98
+ return self .get_adarounded_weight (self .alpha , self ._weight_tensor , self .encoding , self .use_soft_rounding ,
99
+ self .enable_per_channel , self .ch_axis )
94
100
95
101
@staticmethod
96
- def get_adarounded_weight (alpha , weight_tensor , encoding , use_soft_rounding ) -> tf .Tensor :
102
+ def get_adarounded_weight (alpha , weight_tensor , encoding , use_soft_rounding , enable_per_channel : bool ,
103
+ ch_axis : int ) -> tf .Tensor :
97
104
"""
98
105
Get the adarounded weight
99
106
:param alpha: Alpha parameter
100
107
:param weight_tensor: Weight to adaround
101
108
:param encoding: Encodings corresponding to weights
102
109
:param use_soft_rounding: True if soft rounding is to be used, False if hard rounding is to be used
110
+ :param enable_per_channel: True if per-channel mode, else False
111
+ :param ch_axis: channel axis to be used in the per-channel mode
103
112
:return: Adarounded weight tensor
104
113
"""
105
114
# Soft rounding maps alpha parameter between zero and one using rectified sigmoid function
@@ -111,8 +120,21 @@ def compute_soft_rounding():
111
120
def compute_hard_rounding ():
112
121
return tf .cast (alpha > 0 , dtype = alpha .dtype )
113
122
123
+ if enable_per_channel :
124
+ assert isinstance (encoding , list ), "Per-channel expects encoding to be a list"
125
+
126
+ delta = AdaroundWrapper ._broadcast_to_tensor (weight_tensor , [enc .delta for enc in encoding ],
127
+ ch_axis )
128
+ offset = AdaroundWrapper ._broadcast_to_tensor (weight_tensor , [enc .offset for enc in encoding ],
129
+ ch_axis )
130
+ bw = encoding [0 ].bw
131
+ else :
132
+ delta = encoding .delta
133
+ offset = encoding .offset
134
+ bw = encoding .bw
135
+
114
136
# Scale the tensor
115
- tensor = tf .floor (weight_tensor / encoding . delta )
137
+ tensor = tf .floor (weight_tensor / delta )
116
138
117
139
# Compute h_alpha depending on soft or hard rounding
118
140
h_alpha = tf .cond (use_soft_rounding , compute_soft_rounding , compute_hard_rounding )
@@ -121,8 +143,8 @@ def compute_hard_rounding():
121
143
tensor = tf .add (tensor , h_alpha )
122
144
123
145
# Quantize and de-quantize the tensor
124
- tensor_quant = tf .clip_by_value (tensor - encoding . offset , 0 , 2 ** encoding . bw - 1 )
125
- tensor_dequant = (tensor_quant + encoding . offset ) * encoding . delta
146
+ tensor_quant = tf .clip_by_value (tensor - offset , 0 , 2 ** bw - 1 )
147
+ tensor_dequant = (tensor_quant + offset ) * delta
126
148
127
149
return tensor_dequant
128
150
@@ -165,16 +187,14 @@ def call(self, inputs, **kwargs): # pylint: disable=unused-argument
165
187
return adaround_out_tensor
166
188
167
189
@staticmethod
168
- def _initialize_alpha (tensor : tf .Tensor , encoding : libpymo . TfEncoding ) -> tf .Variable :
190
+ def _create_alpha_var (tensor : tf .Tensor ) -> tf .Variable :
169
191
"""
170
- Initializes alpha parameter, same shape as the weight tensor
171
- :param tensor: The weight tensor to be ada rounded
192
+ Helper method to create the alpha variable
193
+ :param tensor: tensor to be used to generate alpha
172
194
"""
173
- tensor_floor = tf .floor (tensor / encoding .delta )
174
- tensor = (tensor / encoding .delta ) - tensor_floor
175
-
176
195
# pylint: disable=invalid-unary-operand-type
177
- alpha = - tf .math .log ((AdaroundConstants .ZETA - AdaroundConstants .GAMMA ) / (tensor - AdaroundConstants .GAMMA ) - 1 )
196
+ alpha = - tf .math .log (
197
+ (AdaroundConstants .ZETA - AdaroundConstants .GAMMA ) / (tensor - AdaroundConstants .GAMMA ) - 1 )
178
198
179
199
# pylint: disable=unexpected-keyword-arg
180
200
# Resource variable is default in TF2.x
@@ -185,6 +205,67 @@ def _initialize_alpha(tensor: tf.Tensor, encoding: libpymo.TfEncoding) -> tf.Var
185
205
186
206
return alpha_var
187
207
208
+ @staticmethod
209
+ def _broadcast_to_tensor (tensor : tf .Tensor , encoding : list , ch_axis : int ) -> tf .constant :
210
+ """
211
+ Broadcast per-channel delta/offset using the encodings array
212
+ :param tensor: The weight tensor to be ada-rounded
213
+ :param encoding: list of per-channel encoding delta/offset to generate broadcasted encoding
214
+ :param ch_axis: dimension to be used for per channel quantization
215
+ """
216
+
217
+ def _get_broadcast_shape () -> List :
218
+ """
219
+ compute the broadcast shape based on the channel index
220
+ """
221
+ shape = list (tensor .shape )
222
+ channels = shape .pop (ch_axis )
223
+ broadcast_shape = shape + [channels ]
224
+ return broadcast_shape
225
+
226
+ def _get_encoding_rotate_perm () -> List :
227
+ """
228
+ Generate the permutation list to apply on delta/offset(which is broadcasted) to match the original shape
229
+ """
230
+ length = len (list (tensor .shape ))
231
+ ret_perm = list (range (length ))
232
+ channel_swap = ret_perm .pop ()
233
+ ret_perm .insert (ch_axis , channel_swap )
234
+ return ret_perm
235
+
236
+ tensor_encoding = tf .constant (encoding , dtype = tensor .dtype )
237
+ # broadcast delta/offset of shape (num_channels,) to broadcast_shape
238
+ tensor_encoding = tf .broadcast_to (tensor_encoding , _get_broadcast_shape ())
239
+ tensor_encoding = tf .transpose (tensor_encoding , perm = _get_encoding_rotate_perm ())
240
+ return tensor_encoding
241
+
242
+ @staticmethod
243
+ def _initialize_alpha (tensor : tf .Tensor , encoding : libpymo .TfEncoding , enable_per_channel : bool ,
244
+ ch_axis : int ) -> tf .Variable :
245
+ """
246
+ Initializes alpha parameter, same shape as the weight tensor
247
+ :param tensor: The weight tensor to be ada rounded
248
+ :param enable_per_channel: if set to True, use per channel quantization
249
+ :param ch_axis: dimension to be used for per channel quantization. This field is unused for per-tensor flow
250
+ weight: (A, B, C, D)
251
+ ch_axis: 2 (this holds good for other values in [0, len(shape)))
252
+ encodings: (C,)
253
+ delta/offset: (C,)
254
+ after broadcast of encoding: (A, B, D, C) -> _get_broadcast_shape
255
+ after transpose of encoding: (A, B, C, D)
256
+ """
257
+
258
+ if enable_per_channel :
259
+ assert isinstance (encoding , list ), "Per-channel expects encoding to be a list"
260
+ delta = AdaroundWrapper ._broadcast_to_tensor (tensor , [enc .delta for enc in encoding ], ch_axis )
261
+ else :
262
+ delta = encoding .delta
263
+
264
+ tensor_floor = tf .floor (tensor / delta )
265
+ tensor = (tensor / delta ) - tensor_floor
266
+ alpha_var = AdaroundWrapper ._create_alpha_var (tensor )
267
+ return alpha_var
268
+
188
269
@staticmethod
189
270
def _get_weight_bias (session : tf .compat .v1 .Session , op : tf .Operation ) -> (np .ndarray , Union [None , np .ndarray ]):
190
271
"""
@@ -202,9 +283,28 @@ def _get_weight_bias(session: tf.compat.v1.Session, op: tf.Operation) -> (np.nda
202
283
return weight , bias
203
284
204
285
@staticmethod
205
- def compute_encodings (weight_data : np .ndarray , param_bw : int , quant_scheme : QuantScheme ,
206
- is_symmetric : bool , strict_symmetric : bool , unsigned_symmetric : bool ) \
207
- -> libpymo .TfEncoding :
286
+ def _generate_weight_transpose_perm (shape : tuple , ch_axis : int ) -> List :
287
+ """
288
+ Given shape of tensor/np.ndarray and channel axis, this function generates the permutation list to be used for
289
+ the transpose operation of the tensor/np.ndarray
290
+ shape = (A, B, C, D)
291
+ ch_axis = 2
292
+ return = (C, A, B, D)
293
+
294
+ :param shape: tuple representing the shape of the tensor/np.ndarray
295
+ :ch_axis: dimension to be used for per channel quantization
296
+ :return permutation list
297
+ """
298
+ perm = list (range (len (shape )))
299
+ ch_dim = perm .pop (ch_axis )
300
+ # make ch_idx dimension the first one
301
+ perm .insert (0 , ch_dim )
302
+ return perm
303
+
304
+ @staticmethod
305
+ def compute_encodings (weight_data : np .ndarray , param_bw : int , quant_scheme : QuantScheme , is_symmetric : bool ,
306
+ strict_symmetric : bool , unsigned_symmetric : bool , enable_per_channel : bool , ch_axis : int ) \
307
+ -> Union [libpymo .TfEncoding , List [libpymo .TfEncoding ]]:
208
308
"""
209
309
:param weight_data: Weight data of Adaround supported ops
210
310
:param param_bw: bitwidth (4-31) to use for quantizing weight data
@@ -217,17 +317,37 @@ def compute_encodings(weight_data: np.ndarray, param_bw: int, quant_scheme: Quan
217
317
have collected are for +ve numbers. If yes, use quantized int values (0:255). This is a special case,
218
318
where we have double the resolution for the computed encodings while still preserving the zero-point to
219
319
be absolute 0.
220
- :return: Encodings object. (max, min, delta and offset)
320
+ :param enable_per_channel: if set to True, use per channel quantization
321
+ :param ch_axis: dimension to be used for per channel quantization. This field is unused for per-tensor flow
322
+ :return: Encodings object for per-tensor flow or list of Encoding objects for per-channel flow
323
+ Encoding object to contain (bw, max, min, delta and offset)
221
324
"""
325
+ # pylint: disable=too-many-locals
222
326
quant_scheme = QUANT_SCHEME_TO_PYMO [quant_scheme ]
223
-
224
327
# Create Encodings Analyzer and collect statistical data to compute encodings
225
328
# Since the weight data is numpy and on CPU memory, useCuda is False
226
- analyzer = libpymo .EncodingAnalyzerForPython (quant_scheme )
227
- analyzer .updateStats (weight_data , False )
329
+ if enable_per_channel :
330
+ encoding = []
331
+ shape = list (weight_data .shape )
332
+ assert ch_axis < len (shape ), 'ch_axis is pointing to an incorrect dimension'
333
+ num_channels = shape .pop (ch_axis )
334
+
335
+ # reshape weights based on the ch_axis - ch_axis has to be the first index to slice and be used for encoding
336
+ weight_data = weight_data .transpose (
337
+ AdaroundWrapper ._generate_weight_transpose_perm (weight_data .shape , ch_axis ))
338
+
339
+ for ch_idx in range (num_channels ):
340
+ analyzer = libpymo .EncodingAnalyzerForPython (quant_scheme )
341
+ analyzer .updateStats (weight_data [ch_idx ], False )
342
+ channel_encoding , _ = analyzer .computeEncoding (param_bw , is_symmetric , strict_symmetric ,
343
+ unsigned_symmetric )
344
+ encoding .append (channel_encoding )
228
345
229
- # Compute the encodings for the weight data using collected stats
230
- encoding , _ = analyzer .computeEncoding (param_bw , is_symmetric , strict_symmetric , unsigned_symmetric )
346
+ else :
347
+ # Compute the encodings for the weight data using collected stats
348
+ analyzer = libpymo .EncodingAnalyzerForPython (quant_scheme )
349
+ analyzer .updateStats (weight_data , False )
350
+ encoding , _ = analyzer .computeEncoding (param_bw , is_symmetric , strict_symmetric , unsigned_symmetric )
231
351
232
352
return encoding
233
353
0 commit comments