@@ -63,8 +63,8 @@ def test_print_summary(self):
63
63
fpath = os .path .join (temp_dir , file_name )
64
64
writer = open (fpath , 'w' )
65
65
66
- def print_to_file (text , end = ' \n ' ):
67
- print (text , end = end , file = writer )
66
+ def print_to_file (text ):
67
+ print (text , file = writer )
68
68
69
69
try :
70
70
layer_utils .print_summary (model , print_fn = print_to_file )
@@ -88,7 +88,6 @@ def make_model():
88
88
89
89
x = inner_inputs = keras .Input (shape )
90
90
x = make_model ()(x )
91
- x = make_model ()(x )
92
91
inner_model = keras .Model (inner_inputs , x )
93
92
94
93
inputs = keras .Input (shape )
@@ -100,8 +99,8 @@ def make_model():
100
99
fpath = os .path .join (temp_dir , file_name )
101
100
writer = open (fpath , 'w' )
102
101
103
- def print_to_file (text , end = ' \n ' ):
104
- print (text , end = end , file = writer )
102
+ def print_to_file (text ):
103
+ print (text , file = writer )
105
104
106
105
try :
107
106
layer_utils .print_summary (
@@ -111,7 +110,39 @@ def print_to_file(text, end='\n'):
111
110
reader = open (fpath , 'r' )
112
111
lines = reader .readlines ()
113
112
reader .close ()
114
- self .assertEqual (len (lines ), 34 )
113
+ check_str = (
114
+ 'Model: "model_2"\n '
115
+ '_________________________________________________________________\n '
116
+ ' Layer (type) Output Shape Param # \n '
117
+ '=================================================================\n '
118
+ ' input_3 (InputLayer) [(None, None, None, 3)] 0 \n '
119
+ ' \n '
120
+ ' model_1 (Functional) (None, None, None, 3) 24 \n '
121
+ '|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|\n '
122
+ '| input_1 (InputLayer) [(None, None, None, 3)] 0 |\n '
123
+ '| |\n '
124
+ '| model (Functional) (None, None, None, 3) 24 |\n '
125
+ '||¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯||\n '
126
+ '|| input_2 (InputLayer) [(None, None, None, 3)] 0 ||\n '
127
+ '|| ||\n '
128
+ '|| conv2d (Conv2D) (None, None, None, 3) 12 ||\n '
129
+ '|| ||\n '
130
+ '|| batch_normalization (BatchN (None, None, None, 3) 12 ||\n '
131
+ '|| ormalization) ||\n '
132
+ '|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|\n '
133
+ '¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯\n '
134
+ '=================================================================\n '
135
+ 'Total params: 24\n '
136
+ 'Trainable params: 18\n '
137
+ 'Non-trainable params: 6\n '
138
+ '_________________________________________________________________\n ' )
139
+
140
+ fin_str = ''
141
+ for line in lines :
142
+ fin_str += line
143
+
144
+ self .assertIn (fin_str , check_str )
145
+ self .assertEqual (len (lines ), 25 )
115
146
except ImportError :
116
147
pass
117
148
@@ -172,8 +203,8 @@ def call(self, inputs):
172
203
fpath = os .path .join (temp_dir , file_name )
173
204
writer = open (fpath , 'w' )
174
205
175
- def print_to_file (text , end = ' \n ' ):
176
- print (text , end = end , file = writer )
206
+ def print_to_file (text ):
207
+ print (text , file = writer )
177
208
178
209
try :
179
210
layer_utils .print_summary (
0 commit comments