Skip to content

Commit 3a0c22d

Browse files
Xharktensorflower-gardener
authored andcommitted
Add public API classes for quantization scheme.
tfmot.quantization.keras.QuantizeScheme tfmot.quantization.keras.QuantizeRegistry tfmot.quantization.keras.QuantizeLayoutTransform tfmot.quantization.keras.default_8bit.Default8BitQuantizeScheme tfmot.quantization.keras.default_8bit.Default8BitQuantizeRegistry tfmot.quantization.keras.default_8bit.Default8BitQuantizeLayoutTransform PiperOrigin-RevId: 343009136
1 parent dae21f6 commit 3a0c22d

File tree

9 files changed

+41
-9
lines changed

9 files changed

+41
-9
lines changed

tensorflow_model_optimization/python/core/api/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ py_library(
1010
"clustering/keras/__init__.py",
1111
"quantization/__init__.py",
1212
"quantization/keras/__init__.py",
13+
"quantization/keras/default_8bit/__init__.py",
1314
"quantization/keras/quantizers/__init__.py",
1415
"sparsity/__init__.py",
1516
"sparsity/keras/__init__.py",

tensorflow_model_optimization/python/core/api/quantization/keras/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
# submodules
1919
from tensorflow_model_optimization.python.core.api.quantization.keras import quantizers
20+
from tensorflow_model_optimization.python.core.api.quantization.keras import default_8bit
2021

2122
# quantize all layers with default quantization implementation.
2223
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_model
@@ -33,4 +34,9 @@
3334
# Deserialize quantized model for Keras h5 format.
3435
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_scope
3536

37+
# Quantization Scheme classes.
38+
from tensorflow_model_optimization.python.core.quantization.keras.quantize_scheme import QuantizeScheme
39+
from tensorflow_model_optimization.python.core.quantization.keras.quantize_layout_transform import QuantizeLayoutTransform
40+
from tensorflow_model_optimization.python.core.quantization.keras.quantize_registry import QuantizeRegistry
41+
3642
# pylint: enable=g-bad-import-order
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Module containing 8bit default quantization scheme."""
16+
# pylint: disable=g-bad-import-order
17+
18+
# The 8bit default quantization scheme classes.
19+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_scheme import Default8BitQuantizeScheme
20+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_layout_transform import Default8BitQuantizeLayoutTransform
21+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit.default_8bit_quantize_registry import Default8BitQuantizeRegistry
22+
23+
# pylint: enable=g-bad-import-order

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_layout_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
keras = tf.keras
2828

2929

30-
class QuantizeLayoutTransform(
30+
class Default8BitQuantizeLayoutTransform(
3131
quantize_layout_transform.QuantizeLayoutTransform):
3232
"""Default model transformations."""
3333

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def _get_rnn_cells(self, rnn_layer):
6969
return [rnn_layer.cell]
7070

7171

72-
class QuantizeRegistry(quantize_registry.QuantizeRegistry, _RNNHelper):
72+
class Default8BitQuantizeRegistry(
73+
quantize_registry.QuantizeRegistry, _RNNHelper):
7374
"""QuantizationRegistry for built-in Keras classes for default 8-bit scheme."""
7475

7576
# TODO(tfmot): expand layers test in quantize_functional_test.py

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ class QuantizeRegistryTest(
7979

8080
def setUp(self):
8181
super(QuantizeRegistryTest, self).setUp()
82-
self.quantize_registry = default_8bit_quantize_registry.QuantizeRegistry(
83-
)
82+
self.quantize_registry = default_8bit_quantize_registry.\
83+
Default8BitQuantizeRegistry()
8484

8585
class CustomLayer(l.Layer):
8686
pass

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_scheme.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
class Default8BitQuantizeScheme(quantize_scheme.QuantizeScheme):
2323

2424
def get_layout_transformer(self):
25-
return default_8bit_quantize_layout_transform.QuantizeLayoutTransform()
25+
return default_8bit_quantize_layout_transform.\
26+
Default8BitQuantizeLayoutTransform()
2627

2728
def get_quantize_registry(self):
28-
return default_8bit_quantize_registry.QuantizeRegistry()
29+
return default_8bit_quantize_registry.Default8BitQuantizeRegistry()
2930

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,8 @@ def replacement(self, match_layer):
584584
concat_layer_node = match_layer
585585
feeding_layer_nodes = match_layer.input_layers
586586

587-
default_registry = default_8bit_quantize_registry.QuantizeRegistry(
588-
)
587+
default_registry = default_8bit_quantize_registry.\
588+
Default8BitQuantizeRegistry()
589589

590590
feed_quantize_configs = []
591591
for feed_layer_node in feeding_layer_nodes:

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation
3131
QuantizeWrapper = quantize_wrapper.QuantizeWrapper
32-
QuantizeRegistry = default_8bit_quantize_registry.QuantizeRegistry
32+
QuantizeRegistry = default_8bit_quantize_registry.Default8BitQuantizeRegistry
3333

3434
keras = tf.keras
3535
layers = tf.keras.layers

0 commit comments

Comments
 (0)