-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecoder.py
More file actions
54 lines (38 loc) · 1.69 KB
/
decoder.py
File metadata and controls
54 lines (38 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import tensorflow as tf
import numpy as np
from common import *
from SubPixCell import *
from ConvResCell import ConvResCell
from SubPixCell import SubPixCell
def build_decoder(code, hps=None):
with tf.variable_scope('decoder_net') as scopes:
batch_size = code.get_shape().as_list()[0]
subpix_1 = SubPixCell(kernel_size=3, stride=1, channel=512, r=2, scope='subpix_1', hps=HParams)
output_flow = subpix_1(code)
res_1 = ConvResCell(kernel_sizes=[3, 3], strides=[1, 1], channels=[128, 128],
scope='res_1', hps=hps)
output_flow = res_1(output_flow)
res_2 = ConvResCell(kernel_sizes=[3, 3], strides=[1, 1], channels=[128, 128],
scope='res_2', hps=hps)
output_flow = res_2(output_flow)
subpix_2 = SubPixCell(kernel_size=3, stride=1, channel=256, r=2, scope='subpix_2', hps=HParams)
output_flow = subpix_2(output_flow)
subpix_3 = SubPixCell(kernel_size=3, stride=1, channel=12, r=2, scope='subpix_3', hps=HParams)
output_flow = subpix_3(output_flow)
tf.summary.histogram('output', output_flow)
return output_flow
def build_model():
with tf.Graph().as_default():
input_image = tf.placeholder(name="input_image", shape=[None, 16, 16, 96], dtype=tf.float32)
code = build_decoder(input_image, hps=HParams)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
summary_writer = tf.summary.FileWriter('log/build_model/decoder/', sess.graph)
merged = tf.summary.merge_all()
summary = sess.run(merged, feed_dict={input_image: np.zeros([2, 16, 16, 96])})
summary_writer.add_summary(summary, 1)
def main(_):
build_model()
if __name__ == '__main__':
tf.app.run()