Skip to content

Commit 7e49dde

Browse files
committed
Add per-channel support for adaround
Signed-off-by: yathindra kota <[email protected]>
1 parent 5e88d21 commit 7e49dde

File tree

9 files changed

+383
-67
lines changed

9 files changed

+383
-67
lines changed

TrainingExtensions/tensorflow/src/python/aimet_tensorflow/adaround/adaround_weight.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,14 @@ def _apply_adaround_helper( # pylint: disable=too-many-locals
191191
# Create copies which will have model's weights quantized with hard and soft rounding.
192192
session_hard_rounded_weight = graph_saver.save_and_load_graph(WORKING_DIR, session)
193193
session_soft_rounded_weight = graph_saver.save_and_load_graph(WORKING_DIR, session)
194-
195194
configs = JsonConfigImporter.import_json_config_file(config_file)
196195
# Strict_symmetric and unsigned_symmetric flags have default value False and True respectively.
197196
strict_symmetric = configs[ConfigDictKeys.DEFAULTS].get(ConfigDictKeys.STRICT_SYMMETRIC, False)
198197
unsigned_symmetric = configs[ConfigDictKeys.DEFAULTS].get(ConfigDictKeys.UNSIGNED_SYMMETRIC, True)
199198

199+
# read per-channel quantization field. Default = False
200+
enable_per_channel = configs[ConfigDictKeys.DEFAULTS].get(ConfigDictKeys.PER_CHANNEL_QUANTIZATION, False)
201+
200202
# Optimization Hyper parameters
201203
opt_params = AdaroundHyperParameters(params.num_iterations, params.reg_param, params.beta_range,
202204
params.warm_start)
@@ -206,6 +208,7 @@ def _apply_adaround_helper( # pylint: disable=too-many-locals
206208
# Get Adaround supported ops based on occurrence in the model
207209
ordered_ops = cls._get_ordered_list_of_ops(session.graph, starting_op_names, output_op_names)
208210

211+
209212
param_encodings = {}
210213
for op in tqdm(ordered_ops):
211214
logger.info("Started Optimizing weight rounding of op: %s", op.name)
@@ -220,17 +223,19 @@ def _apply_adaround_helper( # pylint: disable=too-many-locals
220223
params.num_batches)
221224
is_symmetric = cls._get_is_symmetric_flag_for_op_param(configs, op.type, param_name="weight")
222225

226+
223227
# Find next following activation function
224228
act_func = cls._get_act_func(op)
225229

226230
# Perform Adaround optimization in separate graph
227231
graph = tf.Graph()
228232
with graph.as_default():
229233
wrapper = AdaroundWrapper(session, op, param_bw, quant_scheme, is_symmetric,
230-
strict_symmetric, unsigned_symmetric)
231-
hard_rounded_weight, \
232-
soft_rounded_weight = AdaroundOptimizer().adaround_wrapper(wrapper, act_func, all_inp_data,
233-
all_out_data, opt_params)
234+
strict_symmetric, unsigned_symmetric, enable_per_channel)
235+
hard_rounded_weight, soft_rounded_weight = AdaroundOptimizer().adaround_wrapper(wrapper, act_func,
236+
all_inp_data,
237+
all_out_data,
238+
opt_params)
234239

235240
# Update param encodings dictionary
236241
cls._update_param_encodings_dict(param_encodings, op, wrapper.encoding, is_symmetric)
@@ -304,7 +309,8 @@ def export_encoding_to_json(cls, path: str, filename_prefix: str, param_encoding
304309
json.dump(param_encodings, encoding_fp, sort_keys=True, indent=4)
305310

306311
@staticmethod
307-
def _update_param_encodings_dict(encoding_dict: Dict, op: tf.Operation, encoding: libpymo.TfEncoding,
312+
def _update_param_encodings_dict(encoding_dict: Dict, op: tf.Operation,
313+
encoding: Union[libpymo.TfEncoding, List[libpymo.TfEncoding]],
308314
is_symmetric: bool):
309315
"""
310316
Add op's parameter encoding to dictionary to be used for exporting
@@ -314,12 +320,20 @@ def _update_param_encodings_dict(encoding_dict: Dict, op: tf.Operation, encoding
314320
:param is_symmetric: Symmetric vs Asymmetric boolean
315321
"""
316322
tensor_name = op.inputs[1].name
317-
encoding_dict[tensor_name] = [{'min': encoding.min,
318-
'max': encoding.max,
319-
'scale': encoding.delta,
320-
'offset': encoding.offset,
321-
'bitwidth': encoding.bw,
322-
'is_symmetric': is_symmetric}]
323+
if isinstance(encoding, list):
324+
encoding_dict[tensor_name] = [{'min': [enc.min for enc in encoding],
325+
'max': [enc.max for enc in encoding],
326+
'scale': [enc.delta for enc in encoding],
327+
'offset': [enc.offset for enc in encoding],
328+
'bitwidth': encoding[0].bw,
329+
'is_symmetric': is_symmetric}]
330+
else:
331+
encoding_dict[tensor_name] = [{'min': encoding.min,
332+
'max': encoding.max,
333+
'scale': encoding.delta,
334+
'offset': encoding.offset,
335+
'bitwidth': encoding.bw,
336+
'is_symmetric': is_symmetric}]
323337

324338
@staticmethod
325339
def _get_is_symmetric_flag_for_op_param(configs: ConfigDictType, tf_op_type: str, param_name: str):
@@ -351,16 +365,16 @@ def _get_is_symmetric_flag_for_op_param(configs: ConfigDictType, tf_op_type: str
351365

352366
# Second level of specificity which applies to all parameters only.
353367
try:
354-
return configs[ConfigDictKeys.PARAMS]\
355-
[param_name]\
368+
return configs[ConfigDictKeys.PARAMS] \
369+
[param_name] \
356370
[ConfigDictKeys.IS_SYMMETRIC]
357371
except KeyError:
358372
pass
359373

360374
# First level of specificity which applies to all the ops and parameters.
361375
try:
362-
return configs[ConfigDictKeys.DEFAULTS]\
363-
[ConfigDictKeys.PARAMS]\
376+
return configs[ConfigDictKeys.DEFAULTS] \
377+
[ConfigDictKeys.PARAMS] \
364378
[ConfigDictKeys.IS_SYMMETRIC]
365379
except KeyError:
366380
pass

TrainingExtensions/tensorflow/src/python/aimet_tensorflow/adaround/adaround_wrapper.py

Lines changed: 145 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
""" Adaround wrapper """
4040

41-
from typing import Union, Dict
41+
from typing import Union, Dict, List
4242
import numpy as np
4343
import tensorflow as tf
4444
from tensorflow import keras
@@ -58,8 +58,9 @@ class AdaroundWrapper(keras.layers.Layer):
5858
"""
5959
Adaround Wrapper base class
6060
"""
61+
# pylint: disable=too-many-arguments
6162
def __init__(self, session: tf.compat.v1.Session, op: tf.Operation, param_bw: int, quant_scheme: QuantScheme,
62-
is_symmetric: bool, strict_symmetric: bool, unsigned_symmetric: bool):
63+
is_symmetric: bool, strict_symmetric: bool, unsigned_symmetric: bool, enable_per_channel: bool):
6364
"""
6465
:param session: Tf session.
6566
:param op: Tf op.
@@ -68,6 +69,7 @@ def __init__(self, session: tf.compat.v1.Session, op: tf.Operation, param_bw: in
6869
:param is_symmetric: Symmetric vs Asymmetric encodings.
6970
:param strict_symmetric: Strict symmetric flag.
7071
:param unsigned_symmetric: Unsigned symmetric flag.
72+
:param enable_per_channel: if set to True, use per channel quantization
7173
"""
7274
super(AdaroundWrapper, self).__init__()
7375

@@ -80,26 +82,33 @@ def __init__(self, session: tf.compat.v1.Session, op: tf.Operation, param_bw: in
8082
self._bias_tensor = tf.convert_to_tensor(bias, dtype='float32')
8183

8284
self.use_soft_rounding = tf.compat.v1.placeholder_with_default(True, shape=[])
85+
self.enable_per_channel = enable_per_channel
8386

87+
# use the last dimension as the default channel index
88+
self.ch_axis = len(list(self._weight_tensor.shape)) - 1
8489
self.encoding = self.compute_encodings(weight, param_bw, quant_scheme, is_symmetric,
85-
strict_symmetric, unsigned_symmetric)
86-
self.alpha = self._initialize_alpha(self._weight_tensor, self.encoding)
90+
strict_symmetric, unsigned_symmetric, enable_per_channel, self.ch_axis)
91+
self.alpha = self._initialize_alpha(self._weight_tensor, self.encoding, enable_per_channel, self.ch_axis)
8792

8893
def adaround_weights(self) -> tf.Tensor:
8994
"""
9095
Adaround the weight tensor
9196
:return: AdaRounded weight tensor
9297
"""
93-
return self.get_adarounded_weight(self.alpha, self._weight_tensor, self.encoding, self.use_soft_rounding)
98+
return self.get_adarounded_weight(self.alpha, self._weight_tensor, self.encoding, self.use_soft_rounding,
99+
self.enable_per_channel, self.ch_axis)
94100

95101
@staticmethod
96-
def get_adarounded_weight(alpha, weight_tensor, encoding, use_soft_rounding) -> tf.Tensor:
102+
def get_adarounded_weight(alpha, weight_tensor, encoding, use_soft_rounding, enable_per_channel: bool,
103+
ch_axis: int) -> tf.Tensor:
97104
"""
98105
Get the adarounded weight
99106
:param alpha: Alpha parameter
100107
:param weight_tensor: Weight to adaround
101108
:param encoding: Encodings corresponding to weights
102109
:param use_soft_rounding: True if soft rounding is to be used, False if hard rounding is to be used
110+
:param enable_per_channel: True if per-channel mode, else False
111+
:param ch_axis: channel axis to be used in the per-channel mode
103112
:return: Adarounded weight tensor
104113
"""
105114
# Soft rounding maps alpha parameter between zero and one using rectified sigmoid function
@@ -111,8 +120,21 @@ def compute_soft_rounding():
111120
def compute_hard_rounding():
112121
return tf.cast(alpha > 0, dtype=alpha.dtype)
113122

123+
if enable_per_channel:
124+
assert isinstance(encoding, list), "Per-channel expects encoding to be a list"
125+
126+
delta = AdaroundWrapper._broadcast_to_tensor(weight_tensor, [enc.delta for enc in encoding],
127+
ch_axis)
128+
offset = AdaroundWrapper._broadcast_to_tensor(weight_tensor, [enc.offset for enc in encoding],
129+
ch_axis)
130+
bw = encoding[0].bw
131+
else:
132+
delta = encoding.delta
133+
offset = encoding.offset
134+
bw = encoding.bw
135+
114136
# Scale the tensor
115-
tensor = tf.floor(weight_tensor / encoding.delta)
137+
tensor = tf.floor(weight_tensor / delta)
116138

117139
# Compute h_alpha depending on soft or hard rounding
118140
h_alpha = tf.cond(use_soft_rounding, compute_soft_rounding, compute_hard_rounding)
@@ -121,8 +143,8 @@ def compute_hard_rounding():
121143
tensor = tf.add(tensor, h_alpha)
122144

123145
# Quantize and de-quantize the tensor
124-
tensor_quant = tf.clip_by_value(tensor - encoding.offset, 0, 2 ** encoding.bw - 1)
125-
tensor_dequant = (tensor_quant + encoding.offset) * encoding.delta
146+
tensor_quant = tf.clip_by_value(tensor - offset, 0, 2 ** bw - 1)
147+
tensor_dequant = (tensor_quant + offset) * delta
126148

127149
return tensor_dequant
128150

@@ -165,16 +187,14 @@ def call(self, inputs, **kwargs): # pylint: disable=unused-argument
165187
return adaround_out_tensor
166188

167189
@staticmethod
168-
def _initialize_alpha(tensor: tf.Tensor, encoding: libpymo.TfEncoding) -> tf.Variable:
190+
def _create_alpha_var(tensor: tf.Tensor) -> tf.Variable:
169191
"""
170-
Initializes alpha parameter, same shape as the weight tensor
171-
:param tensor: The weight tensor to be ada rounded
192+
Helper method to create the alpha variable
193+
:param tensor: tensor to be used to generate alpha
172194
"""
173-
tensor_floor = tf.floor(tensor / encoding.delta)
174-
tensor = (tensor / encoding.delta) - tensor_floor
175-
176195
# pylint: disable=invalid-unary-operand-type
177-
alpha = -tf.math.log((AdaroundConstants.ZETA - AdaroundConstants.GAMMA) / (tensor - AdaroundConstants.GAMMA) - 1)
196+
alpha = -tf.math.log(
197+
(AdaroundConstants.ZETA - AdaroundConstants.GAMMA) / (tensor - AdaroundConstants.GAMMA) - 1)
178198

179199
# pylint: disable=unexpected-keyword-arg
180200
# Resource variable is default in TF2.x
@@ -185,6 +205,67 @@ def _initialize_alpha(tensor: tf.Tensor, encoding: libpymo.TfEncoding) -> tf.Var
185205

186206
return alpha_var
187207

208+
@staticmethod
209+
def _broadcast_to_tensor(tensor: tf.Tensor, encoding: list, ch_axis: int) -> tf.constant:
210+
"""
211+
Broadcast per-channel delta/offset using the encodings array
212+
:param tensor: The weight tensor to be ada-rounded
213+
:param encoding: list of per-channel encoding delta/offset to generate broadcasted encoding
214+
:param ch_axis: dimension to be used for per channel quantization
215+
"""
216+
217+
def _get_broadcast_shape() -> List:
218+
"""
219+
compute the broadcast shape based on the channel index
220+
"""
221+
shape = list(tensor.shape)
222+
channels = shape.pop(ch_axis)
223+
broadcast_shape = shape + [channels]
224+
return broadcast_shape
225+
226+
def _get_encoding_rotate_perm() -> List:
227+
"""
228+
Generate the permutation list to apply on delta/offset(which is broadcasted) to match the original shape
229+
"""
230+
length = len(list(tensor.shape))
231+
ret_perm = list(range(length))
232+
channel_swap = ret_perm.pop()
233+
ret_perm.insert(ch_axis, channel_swap)
234+
return ret_perm
235+
236+
tensor_encoding = tf.constant(encoding, dtype=tensor.dtype)
237+
# broadcast delta/offset of shape (num_channels,) to broadcast_shape
238+
tensor_encoding = tf.broadcast_to(tensor_encoding, _get_broadcast_shape())
239+
tensor_encoding = tf.transpose(tensor_encoding, perm=_get_encoding_rotate_perm())
240+
return tensor_encoding
241+
242+
@staticmethod
243+
def _initialize_alpha(tensor: tf.Tensor, encoding: libpymo.TfEncoding, enable_per_channel: bool,
244+
ch_axis: int) -> tf.Variable:
245+
"""
246+
Initializes alpha parameter, same shape as the weight tensor
247+
:param tensor: The weight tensor to be ada rounded
248+
:param enable_per_channel: if set to True, use per channel quantization
249+
:param ch_axis: dimension to be used for per channel quantization. This field is unused for per-tensor flow
250+
weight: (A, B, C, D)
251+
ch_axis: 2 (this holds good for other values in [0, len(shape)))
252+
encodings: (C,)
253+
delta/offset: (C,)
254+
after broadcast of encoding: (A, B, D, C) -> _get_broadcast_shape
255+
after transpose of encoding: (A, B, C, D)
256+
"""
257+
258+
if enable_per_channel:
259+
assert isinstance(encoding, list), "Per-channel expects encoding to be a list"
260+
delta = AdaroundWrapper._broadcast_to_tensor(tensor, [enc.delta for enc in encoding], ch_axis)
261+
else:
262+
delta = encoding.delta
263+
264+
tensor_floor = tf.floor(tensor / delta)
265+
tensor = (tensor / delta) - tensor_floor
266+
alpha_var = AdaroundWrapper._create_alpha_var(tensor)
267+
return alpha_var
268+
188269
@staticmethod
189270
def _get_weight_bias(session: tf.compat.v1.Session, op: tf.Operation) -> (np.ndarray, Union[None, np.ndarray]):
190271
"""
@@ -202,9 +283,28 @@ def _get_weight_bias(session: tf.compat.v1.Session, op: tf.Operation) -> (np.nda
202283
return weight, bias
203284

204285
@staticmethod
205-
def compute_encodings(weight_data: np.ndarray, param_bw: int, quant_scheme: QuantScheme,
206-
is_symmetric: bool, strict_symmetric: bool, unsigned_symmetric: bool) \
207-
-> libpymo.TfEncoding:
286+
def _generate_weight_transpose_perm(shape: tuple, ch_axis: int) -> List:
287+
"""
288+
Given shape of tensor/np.ndarray and channel axis, this function generates the permutation list to be used for
289+
the transpose operation of the tensor/np.ndarray
290+
shape = (A, B, C, D)
291+
ch_axis = 2
292+
return = (C, A, B, D)
293+
294+
:param shape: tuple representing the shape of the tensor/np.ndarray
295+
:ch_axis: dimension to be used for per channel quantization
296+
:return permutation list
297+
"""
298+
perm = list(range(len(shape)))
299+
ch_dim = perm.pop(ch_axis)
300+
# make ch_idx dimension the first one
301+
perm.insert(0, ch_dim)
302+
return perm
303+
304+
@staticmethod
305+
def compute_encodings(weight_data: np.ndarray, param_bw: int, quant_scheme: QuantScheme, is_symmetric: bool,
306+
strict_symmetric: bool, unsigned_symmetric: bool, enable_per_channel: bool, ch_axis: int) \
307+
-> Union[libpymo.TfEncoding, List[libpymo.TfEncoding]]:
208308
"""
209309
:param weight_data: Weight data of Adaround supported ops
210310
:param param_bw: bitwidth (4-31) to use for quantizing weight data
@@ -217,17 +317,37 @@ def compute_encodings(weight_data: np.ndarray, param_bw: int, quant_scheme: Quan
217317
have collected are for +ve numbers. If yes, use quantized int values (0:255). This is a special case,
218318
where we have double the resolution for the computed encodings while still preserving the zero-point to
219319
be absolute 0.
220-
:return: Encodings object. (max, min, delta and offset)
320+
:param enable_per_channel: if set to True, use per channel quantization
321+
:param ch_axis: dimension to be used for per channel quantization. This field is unused for per-tensor flow
322+
:return: Encodings object for per-tensor flow or list of Encoding objects for per-channel flow
323+
Encoding object to contain (bw, max, min, delta and offset)
221324
"""
325+
# pylint: disable=too-many-locals
222326
quant_scheme = QUANT_SCHEME_TO_PYMO[quant_scheme]
223-
224327
# Create Encodings Analyzer and collect statistical data to compute encodings
225328
# Since the weight data is numpy and on CPU memory, useCuda is False
226-
analyzer = libpymo.EncodingAnalyzerForPython(quant_scheme)
227-
analyzer.updateStats(weight_data, False)
329+
if enable_per_channel:
330+
encoding = []
331+
shape = list(weight_data.shape)
332+
assert ch_axis < len(shape), 'ch_axis is pointing to an incorrect dimension'
333+
num_channels = shape.pop(ch_axis)
334+
335+
# reshape weights based on the ch_axis - ch_axis has to be the first index to slice and be used for encoding
336+
weight_data = weight_data.transpose(
337+
AdaroundWrapper._generate_weight_transpose_perm(weight_data.shape, ch_axis))
338+
339+
for ch_idx in range(num_channels):
340+
analyzer = libpymo.EncodingAnalyzerForPython(quant_scheme)
341+
analyzer.updateStats(weight_data[ch_idx], False)
342+
channel_encoding, _ = analyzer.computeEncoding(param_bw, is_symmetric, strict_symmetric,
343+
unsigned_symmetric)
344+
encoding.append(channel_encoding)
228345

229-
# Compute the encodings for the weight data using collected stats
230-
encoding, _ = analyzer.computeEncoding(param_bw, is_symmetric, strict_symmetric, unsigned_symmetric)
346+
else:
347+
# Compute the encodings for the weight data using collected stats
348+
analyzer = libpymo.EncodingAnalyzerForPython(quant_scheme)
349+
analyzer.updateStats(weight_data, False)
350+
encoding, _ = analyzer.computeEncoding(param_bw, is_symmetric, strict_symmetric, unsigned_symmetric)
231351

232352
return encoding
233353

0 commit comments

Comments
 (0)