Skip to content

Commit e528fad

Browse files
Merge pull request #15355 from krishrustagi:changes
PiperOrigin-RevId: 398506469
2 parents 451055a + f406f56 commit e528fad

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

keras/utils/layer_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,13 +207,15 @@ def print_row(fields, positions, nested_level=0):
207207
cutoff = min(candidate_cutoffs)
208208
fit_into_line = fit_into_line[:cutoff]
209209

210+
if col == 0:
211+
line += '|' * nested_level + ' '
210212
line += fit_into_line
211213
line += ' ' * space if space else ''
212214
left_to_print[col] = left_to_print[col][cutoff:]
213215

214216
# Pad out to the next position
215217
if nested_level:
216-
line += ' ' * (positions[col] - len(line) - (2 * nested_level) - 1)
218+
line += ' ' * (positions[col] - len(line) - nested_level)
217219
else:
218220
line += ' ' * (positions[col] - len(line))
219221
line += '|' * nested_level
@@ -294,14 +296,13 @@ def print_layer(layer, nested_level=0, is_nested_last=False):
294296
for i in range(len(nested_layer)):
295297
if i == len(nested_layer) - 1:
296298
is_nested_last = True
297-
print_fn('|' * (nested_level + 1), end=' ')
298299
print_layer(nested_layer[i], nested_level + 1, is_nested_last)
299300

300301
print_fn('|' * nested_level + '¯' * (line_length - 2 * nested_level) +
301302
'|' * nested_level)
302303

303304
if not is_nested_last:
304-
print_fn('|' * nested_level + '_' * (line_length - 2 * nested_level) +
305+
print_fn('|' * nested_level + ' ' * (line_length - 2 * nested_level) +
305306
'|' * nested_level)
306307

307308
layers = model.layers

keras/utils/layer_utils_test.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def test_print_summary(self):
6363
fpath = os.path.join(temp_dir, file_name)
6464
writer = open(fpath, 'w')
6565

66-
def print_to_file(text, end='\n'):
67-
print(text, end=end, file=writer)
66+
def print_to_file(text):
67+
print(text, file=writer)
6868

6969
try:
7070
layer_utils.print_summary(model, print_fn=print_to_file)
@@ -88,7 +88,6 @@ def make_model():
8888

8989
x = inner_inputs = keras.Input(shape)
9090
x = make_model()(x)
91-
x = make_model()(x)
9291
inner_model = keras.Model(inner_inputs, x)
9392

9493
inputs = keras.Input(shape)
@@ -100,8 +99,8 @@ def make_model():
10099
fpath = os.path.join(temp_dir, file_name)
101100
writer = open(fpath, 'w')
102101

103-
def print_to_file(text, end='\n'):
104-
print(text, end=end, file=writer)
102+
def print_to_file(text):
103+
print(text, file=writer)
105104

106105
try:
107106
layer_utils.print_summary(
@@ -111,7 +110,39 @@ def print_to_file(text, end='\n'):
111110
reader = open(fpath, 'r')
112111
lines = reader.readlines()
113112
reader.close()
114-
self.assertEqual(len(lines), 34)
113+
check_str = (
114+
'Model: "model_2"\n'
115+
'_________________________________________________________________\n'
116+
' Layer (type) Output Shape Param # \n'
117+
'=================================================================\n'
118+
' input_3 (InputLayer) [(None, None, None, 3)] 0 \n'
119+
' \n'
120+
' model_1 (Functional) (None, None, None, 3) 24 \n'
121+
'|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|\n'
122+
'| input_1 (InputLayer) [(None, None, None, 3)] 0 |\n'
123+
'| |\n'
124+
'| model (Functional) (None, None, None, 3) 24 |\n'
125+
'||¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯||\n'
126+
'|| input_2 (InputLayer) [(None, None, None, 3)] 0 ||\n'
127+
'|| ||\n'
128+
'|| conv2d (Conv2D) (None, None, None, 3) 12 ||\n'
129+
'|| ||\n'
130+
'|| batch_normalization (BatchN (None, None, None, 3) 12 ||\n'
131+
'|| ormalization) ||\n'
132+
'|¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯|\n'
133+
'¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯\n'
134+
'=================================================================\n'
135+
'Total params: 24\n'
136+
'Trainable params: 18\n'
137+
'Non-trainable params: 6\n'
138+
'_________________________________________________________________\n')
139+
140+
fin_str = ''
141+
for line in lines:
142+
fin_str += line
143+
144+
self.assertIn(fin_str, check_str)
145+
self.assertEqual(len(lines), 25)
115146
except ImportError:
116147
pass
117148

@@ -172,8 +203,8 @@ def call(self, inputs):
172203
fpath = os.path.join(temp_dir, file_name)
173204
writer = open(fpath, 'w')
174205

175-
def print_to_file(text, end='\n'):
176-
print(text, end=end, file=writer)
206+
def print_to_file(text):
207+
print(text, file=writer)
177208

178209
try:
179210
layer_utils.print_summary(

0 commit comments

Comments
 (0)