-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnormalization_op.py
More file actions
84 lines (71 loc) · 3.45 KB
/
normalization_op.py
File metadata and controls
84 lines (71 loc) · 3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import check_ops
# example:
# assume x is image batch tensor of 4-D shape: [batch, height, width, channel]
# n, m, s = images_standardization(x)
# y = images_destandardization(n,m,s)
def _Check3DImage(image, require_static=True):
try:
image_shape = image.get_shape().with_rank(3)
except ValueError:
raise ValueError("'image' (shape %s) must be three-dimensional." %
image.shape)
if require_static and not image_shape.is_fully_defined():
raise ValueError("'image' (shape %s) must be fully defined." %
image_shape)
if any(x == 0 for x in image_shape):
raise ValueError("all dims of 'image.shape' must be > 0: %s" %
image_shape)
if not image_shape.is_fully_defined():
return [check_ops.assert_positive(array_ops.shape(image),
["all dims of 'image.shape' "
"must be > 0."])]
else:
return []
def per_image_normalization(image):
image = ops.convert_to_tensor(image, name='image')
image = control_flow_ops.with_dependencies(
_Check3DImage(image, require_static=False), image)
num_pixels = math_ops.reduce_prod(array_ops.shape(image))
image = math_ops.cast(image, dtype=dtypes.float32)
image_mean = math_ops.reduce_mean(image)
variance = (math_ops.reduce_mean(math_ops.square(image)) -
math_ops.square(image_mean))
variance = gen_nn_ops.relu(variance)
stddev = math_ops.sqrt(variance)
# Apply a minimum normalization that protects us against uniform images.
min_stddev = math_ops.rsqrt(math_ops.cast(num_pixels, dtypes.float32))
pixel_value_scale = math_ops.maximum(stddev, min_stddev)
pixel_value_offset = image_mean
image = math_ops.subtract(image, pixel_value_offset)
image = math_ops.div(image, pixel_value_scale)
return image, pixel_value_offset, pixel_value_scale
def per_image_denormalization(image, mean, stddev):
# image = ops.convert_to_tensor(image, name='image')
# image = control_flow_ops.with_dependencies(
# _Check3DImage(image, require_static=False), image)
# num_pixels = math_ops.reduce_prod(array_ops.shape(image))
image = math_ops.cast(image, dtype=dtypes.float32)
image = math_ops.multiply(image, stddev)
image = math_ops.add(image, mean)
return image
def batch_normalization(images):
with tf.variable_scope("input_normalization") as scope:
norm_data = tf.map_fn(lambda image: per_image_normalization(image), images,
dtype=(tf.float32, tf.float32, tf.float32))
norm_images = tf.map_fn(lambda data: data[0], norm_data, dtype=tf.float32, name='norm_images')
means = tf.map_fn(lambda data: data[1], norm_data, dtype=tf.float32, name='means')
stddevs = tf.map_fn(lambda data: data[2], norm_data, dtype=tf.float32, name='stddevs')
return norm_images, means, stddevs
def batch_denormalization(norms, means, stddevs):
with tf.variable_scope("output_denormalization") as scope:
images = tf.map_fn(lambda data: per_image_denormalization(data[0], data[1], data[2]),
(norms, means, stddevs), dtype=tf.float32, name='denorm_images')
return images