-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquantization_op.py
More file actions
55 lines (45 loc) · 2.03 KB
/
quantization_op.py
File metadata and controls
55 lines (45 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.client import timeline
from common import *
# define common quantization function
def quantization(x):
x_shape = x.get_shape().as_list()
noise = tf.random_uniform(tf.stack([tf.shape(x)[0], x_shape[1], x_shape[2], x_shape[3]]), dtype=tf.float32)
fx = tf.floor(x)
quant_noise = tf.less(noise, (x - fx))
return tf.cast(fx, tf.float32) + tf.cast(quant_noise, tf.float32)
# quantization in forward propagation and identity in backward propagation
def my_quant(x):
grad_func = x
return grad_func + tf.stop_gradient(quantization(x) - grad_func)
if __name__ == '__main__':
with tf.Session() as sess:
x = tf.placeholder(name="x", shape=[None, 128, 128, 3], dtype=tf.float32)
y = my_quant(x)
tf.summary.histogram('y', y)
merged = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter('log/build_model/quantization/', sess.graph)
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
summary, oy = sess.run([merged, y], feed_dict={x: np.random.random((32, 128, 128, 3))}, #[[-0.8, -0.3, 0.3, 0.9], [0.1, 0.9, -0.8, -0.1]]},
options=run_options, run_metadata=run_metadata)
summary_writer.add_run_metadata(run_metadata, "this")
tl = timeline.Timeline(step_stats=run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('quantization_timeline.json', 'w') as f:
f.write(ctf)
summary_writer.add_summary(summary)
print(oy)
'''x = tf.constant([[-0.8, -0.3, 0.3, 0.9],[0.1,0.9,-0.8,-0.1]])
y = my_quant_tf(x)
print(x)
print(y)
tf.global_variables_initializer().run()
print(x.eval())
print(y.eval())
print(tf.gradients(y, [x])[0].eval()) '''
# [-0.30000001 0.005 0.08 0.12 ]
# [ 0. 0. 0.08 0.12]
# [ 0. 0. 1. 1.] s