Skip to content

Commit dae21f6

Browse files
Xharktensorflower-gardener
authored andcommitted
Expose QuantizationScheme in the QAT API.
PiperOrigin-RevId: 341923968
1 parent 20b2b38 commit dae21f6

File tree

5 files changed

+104
-7
lines changed

5 files changed

+104
-7
lines changed

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ py_library(
243243
":quantize_layer",
244244
":quantize_wrapper",
245245
# tensorflow dep1,
246-
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_layout_transform",
247246
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
247+
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_scheme",
248248
"//tensorflow_model_optimization/python/core/quantization/keras/layers:conv_batchnorm",
249249
],
250250
)
@@ -329,3 +329,12 @@ py_library(
329329
# tensorflow dep1,
330330
],
331331
)
332+
333+
py_library(
334+
name = "quantize_scheme",
335+
srcs = [
336+
"quantize_scheme.py",
337+
],
338+
srcs_version = "PY3",
339+
visibility = ["//visibility:public"],
340+
)

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,16 @@ py_test(
147147
"//tensorflow_model_optimization/python/core/quantization/keras:utils",
148148
],
149149
)
150+
151+
py_library(
152+
name = "default_8bit_quantize_scheme",
153+
srcs = [
154+
"default_8bit_quantize_scheme.py",
155+
],
156+
srcs_version = "PY3",
157+
deps = [
158+
":default_8bit_quantize_layout_transform",
159+
":default_8bit_quantize_registry",
160+
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_scheme",
161+
],
162+
)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Quantization scheme which specifies how quantization should be applied."""
16+
17+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_scheme
18+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_layout_transform
19+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
20+
21+
22+
class Default8BitQuantizeScheme(quantize_scheme.QuantizeScheme):
23+
24+
def get_layout_transformer(self):
25+
return default_8bit_quantize_layout_transform.QuantizeLayoutTransform()
26+
27+
def get_quantize_registry(self):
28+
return default_8bit_quantize_registry.QuantizeRegistry()
29+

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
2323
from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper
2424
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
25-
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_layout_transform
2625
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
26+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_scheme
2727
from tensorflow_model_optimization.python.core.quantization.keras.layers import conv_batchnorm
2828

2929
keras = tf.keras
@@ -263,7 +263,9 @@ def quantize_annotate_layer(to_annotate, quantize_config=None):
263263
layer=to_annotate, quantize_config=quantize_config)
264264

265265

266-
def quantize_apply(model):
266+
def quantize_apply(
267+
model,
268+
scheme=default_8bit_quantize_scheme.Default8BitQuantizeScheme()):
267269
"""Quantize a `tf.keras` model that has been annotated for quantization.
268270
269271
Quantization constructs a model which emulates quantization during training.
@@ -298,6 +300,8 @@ def quantize_apply(model):
298300
Args:
299301
model: A `tf.keras` Sequential or Functional model which has been annotated
300302
with `quantize_annotate`. It can have pre-trained weights.
303+
scheme: A `QuantizeScheme` which specifies transformer and quantization
304+
registry. The default is `Default8BitQuantizeScheme()`.
301305
302306
Returns:
303307
Returns a new `tf.keras` model in which the annotated layers have been
@@ -403,15 +407,13 @@ def _quantize(layer): # pylint: disable=missing-docstring
403407

404408
# 3. Apply the graph transformations required to match model passes on
405409
# target device/dialect.
406-
quantize_transform = \
407-
default_8bit_quantize_layout_transform.QuantizeLayoutTransform()
410+
quantize_transform = scheme.get_layout_transformer()
408411
# layer_quantize_map gets modified by the transformations.
409412
transformed_model, layer_quantize_map = quantize_transform.apply(
410413
unwrapped_model, layer_quantize_map)
411414

412415
# TODO(pulkitb): Think more about how to introduce Default specific code.
413-
quantize_registry = default_8bit_quantize_registry.QuantizeRegistry(
414-
)
416+
quantize_registry = scheme.get_quantize_registry()
415417

416418
# 4. Actually quantize all the relevant layers in the model. This is done by
417419
# wrapping the layers with QuantizeWrapper, and passing the associated
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Quantization scheme which specifies how quantization should be applied.
16+
17+
Module: tfmot.quantization.keras
18+
"""
19+
20+
import abc
21+
import six
22+
23+
24+
@six.add_metaclass(abc.ABCMeta)
25+
class QuantizeScheme(object):
26+
"""ABC interface which specifies transformer and quantization registry."""
27+
28+
@abc.abstractmethod
29+
def get_layout_transformer(self):
30+
"""Returns the layout transforms for this scheme.
31+
32+
Returns:
33+
Returns the QuantizeLayoutTransform for this quantization scheme.
34+
"""
35+
raise NotImplementedError('Must be implemented in subclasses.')
36+
37+
@abc.abstractmethod
38+
def get_quantize_registry(self):
39+
"""Returns the quantization registry for this scheme.
40+
41+
Returns:
42+
Returns the QuantizeRegistry for this quantization scheme.
43+
"""
44+
raise NotImplementedError('Must be implemented in subclasses.')

0 commit comments

Comments
 (0)