Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 72 additions & 4 deletions official/projects/yolo/modeling/layers/nn_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import functools
from typing import Callable, List, Tuple

import numpy as np
import tensorflow as tf, tf_keras

from official.modeling import tf_utils
Expand Down Expand Up @@ -1949,7 +1950,74 @@ def call(self, inputs, training=None):
return self._activation_fn(x + y + id_out)

def fuse(self):
if self._fuse:
return
# TODO(b/264495198): Implement fuse for RepConv.
raise NotImplementedError()
if not self._use_separable_conv:
if self._fuse:
return

self._fuse = True
# fuse rbr dense
kernel_dense, bias_dense = self._fuse_conv_bn(
self._rbr_dense.get_weights()[0], self._rbr_dense_bn.get_weights())
# fuse rbr 1x1
kernel_1x1, bias_1x1 = self._fuse_conv_bn(
self._rbr_1x1.get_weights()[0], self._rbr_1x1_bn.get_weights())

# pad rbr 1x1 kernel into rbr dense kernel size
kernel_size = self._kernel_size
if not isinstance(kernel_size, (tuple, list)):
kernel_size = [kernel_size, kernel_size]
strides = self._strides
if not isinstance(strides, (tuple, list)):
strides = [strides, strides]
if 1 == strides[0] or kernel_size[0] % 2 == 0:
lh = (kernel_size[0] - 1) // 2
else:
lh = kernel_size[0] // 2 - 1
rh = kernel_size[0] - lh - 1
if 1 == strides[1] or kernel_size[1] % 2 == 0:
lw = (kernel_size[1] - 1) // 2
else:
lw = kernel_size[1] // 2 - 1
rw = kernel_size[1] - lw - 1
kernel_1x1 = tf.pad(kernel_1x1, [[lh, rh], [lw, rw], [0, 0], [0, 0]])

kernel = kernel_dense + kernel_1x1
bias = bias_dense + bias_1x1
# convert bn into conv
if hasattr(self, '_rbr_identity'):
kernel_id, bias_id = self._fuse_conv_bn(
self._identity_conv_kernel(), self._rbr_identity.get_weights()
)
kernel += kernel_id
bias += bias_id
self._rbr_reparam.build([None, None, None, kernel.shape[-2]])
self._rbr_reparam.set_weights([kernel, bias])
self._rbr_reparam.trainable = False
self.trainable = False
return

def _fuse_conv_bn(self, conv_weights, bn_weights):
"""fuse Conv layer and batch norm layer
like ConvBN.fuse"""
gamma, beta, moving_mean, moving_variance = bn_weights
base = tf.math.rsqrt(self._norm_epsilon + moving_variance)
w_conv_base = tf.transpose(conv_weights, perm=(3, 2, 0, 1))
w_conv = tf.reshape(w_conv_base, [conv_weights.shape[-1], -1])
w_bn = tf.linalg.diag(gamma * base)
w_conv = tf.reshape(tf.matmul(w_bn, w_conv), w_conv_base.shape)
w_conv = tf.transpose(w_conv, perm=(2, 3, 1, 0))
b_bn = beta - gamma * moving_mean * base
return w_conv, b_bn

def _identity_conv_kernel(self):
"""create a identity conv layer kernel
conv2d(inputs, kernel, padding='same', strides=1) == inputs"""
filters = self._filters
kernel_size = self._kernel_size
if not isinstance(kernel_size, (tuple, list)):
kernel_size = (kernel_size, kernel_size)
kernel = np.zeros(kernel_size + (filters, filters), dtype=np.float32)
half_h = [(kernel_size[0] - 1) // 2] * filters
half_w = [(kernel_size[1] - 1) // 2] * filters
kernel[half_h, half_w, range(filters), range(filters)] = 1
return kernel
18 changes: 18 additions & 0 deletions official/projects/yolo/modeling/layers/nn_blocks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,5 +376,23 @@ def test_gradient_pass_though(self, width, height, filters, strides):
self.assertNotIn(None, grad)
return

@parameterized.named_parameters(
('test', 3, 32, 32),
('test1 infilters!=outfilters', 3, 16, 32),
('test2 kernelsize=4', 4, 32, 32),
('test3 strides=2', 3, 32, 32, 2),
('test4 strides=2 infilters!=outfilters', 3, 16, 32, 2),
('test5 kernelsize=4 strides=2', 4, 32, 32, 2),
('test6 strides=4', 3, 32, 32, 4))
def test_fuse_and_unfuse_result(self, kernel_size, infilters, outfilters, strides=1):
batch_size = 1
height = width = 224
inputs = tf.random.uniform([batch_size, height, width, infilters])
test_layer = nn_blocks.RepConv(outfilters, kernel_size=kernel_size, strides=strides)
unfuse = test_layer(inputs)
test_layer.fuse()
fuse = test_layer(inputs)
self.assertAllClose(unfuse, fuse, atol=1.0e-5)

if __name__ == '__main__':
tf.test.main()