Skip to content

Commit dd07a4c

Browse files
gabrieldemarmiesseFrédéric Branchaud-Charron
authored andcommitted
Refactoring: Simplified some code by using the to_list function. (#10678)
### Summary We have the method `to_list` in keras. Let's use it to make the codebase simpler! ### Related Issues ### PR Overview - [ ] This PR requires new unit tests [y/n] (make sure tests are included) - [ ] This PR requires to update the documentation [y/n] (make sure the docs are up-to-date) - [x] This PR is backwards compatible [y/n] - [ ] This PR changes the current API [y/n] (all API changes need to be approved by fchollet)
1 parent b2979c2 commit dd07a4c

File tree

7 files changed

+22
-32
lines changed

7 files changed

+22
-32
lines changed

keras/engine/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .. import losses
2626
from .. import metrics as metrics_module
2727
from ..utils.generic_utils import slice_arrays
28+
from ..utils.generic_utils import to_list
2829
from ..utils.generic_utils import unpack_singleton
2930
from ..legacy import interfaces
3031

@@ -155,8 +156,7 @@ def compile(self, optimizer,
155156
masks = self.compute_mask(self.inputs, mask=None)
156157
if masks is None:
157158
masks = [None for _ in self.outputs]
158-
if not isinstance(masks, list):
159-
masks = [masks]
159+
masks = to_list(masks)
160160

161161
# Prepare loss weights.
162162
if loss_weights is None:

keras/engine/training_arrays.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .. import callbacks as cbks
1515
from ..utils.generic_utils import Progbar
1616
from ..utils.generic_utils import slice_arrays
17+
from ..utils.generic_utils import to_list
1718
from ..utils.generic_utils import unpack_singleton
1819

1920

@@ -152,8 +153,7 @@ def fit_loop(model, f, ins,
152153
callbacks.on_batch_begin(step_index, batch_logs)
153154
outs = f(ins)
154155

155-
if not isinstance(outs, list):
156-
outs = [outs]
156+
outs = to_list(outs)
157157
for l, o in zip(out_labels, outs):
158158
batch_logs[l] = o
159159

@@ -165,8 +165,7 @@ def fit_loop(model, f, ins,
165165
val_outs = test_loop(model, val_f, val_ins,
166166
steps=validation_steps,
167167
verbose=0)
168-
if not isinstance(val_outs, list):
169-
val_outs = [val_outs]
168+
val_outs = to_list(val_outs)
170169
# Same labels assumed.
171170
for l, o in zip(out_labels, val_outs):
172171
epoch_logs['val_' + l] = o
@@ -198,8 +197,7 @@ def fit_loop(model, f, ins,
198197
ins_batch[i] = ins_batch[i].toarray()
199198

200199
outs = f(ins_batch)
201-
if not isinstance(outs, list):
202-
outs = [outs]
200+
outs = to_list(outs)
203201
for l, o in zip(out_labels, outs):
204202
batch_logs[l] = o
205203

@@ -212,8 +210,7 @@ def fit_loop(model, f, ins,
212210
val_outs = test_loop(model, val_f, val_ins,
213211
batch_size=batch_size,
214212
verbose=0)
215-
if not isinstance(val_outs, list):
216-
val_outs = [val_outs]
213+
val_outs = to_list(val_outs)
217214
# Same labels assumed.
218215
for l, o in zip(out_labels, val_outs):
219216
epoch_logs['val_' + l] = o
@@ -267,8 +264,7 @@ def predict_loop(model, f, ins, batch_size=32, verbose=0, steps=None):
267264
unconcatenated_outs = []
268265
for step in range(steps):
269266
batch_outs = f(ins)
270-
if not isinstance(batch_outs, list):
271-
batch_outs = [batch_outs]
267+
batch_outs = to_list(batch_outs)
272268
if step == 0:
273269
for batch_out in batch_outs:
274270
unconcatenated_outs.append([])
@@ -296,8 +292,7 @@ def predict_loop(model, f, ins, batch_size=32, verbose=0, steps=None):
296292
ins_batch[i] = ins_batch[i].toarray()
297293

298294
batch_outs = f(ins_batch)
299-
if not isinstance(batch_outs, list):
300-
batch_outs = [batch_outs]
295+
batch_outs = to_list(batch_outs)
301296
if batch_index == 0:
302297
# Pre-allocate the results arrays.
303298
for batch_out in batch_outs:

keras/engine/training_generator.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ..utils.data_utils import GeneratorEnqueuer
1313
from ..utils.data_utils import OrderedEnqueuer
1414
from ..utils.generic_utils import Progbar
15+
from ..utils.generic_utils import to_list
1516
from ..utils.generic_utils import unpack_singleton
1617
from .. import callbacks as cbks
1718

@@ -211,8 +212,7 @@ def fit_generator(model,
211212
sample_weight=sample_weight,
212213
class_weight=class_weight)
213214

214-
if not isinstance(outs, list):
215-
outs = [outs]
215+
outs = to_list(outs)
216216
for l, o in zip(out_labels, outs):
217217
batch_logs[l] = o
218218

@@ -236,8 +236,7 @@ def fit_generator(model,
236236
batch_size=batch_size,
237237
sample_weight=val_sample_weights,
238238
verbose=0)
239-
if not isinstance(val_outs, list):
240-
val_outs = [val_outs]
239+
val_outs = to_list(val_outs)
241240
# Same labels assumed.
242241
for l, o in zip(out_labels, val_outs):
243242
epoch_logs['val_' + l] = o
@@ -342,8 +341,7 @@ def evaluate_generator(model, generator,
342341
'or (x, y). Found: ' +
343342
str(generator_output))
344343
outs = model.test_on_batch(x, y, sample_weight=sample_weight)
345-
if not isinstance(outs, list):
346-
outs = [outs]
344+
outs = to_list(outs)
347345
outs_per_batch.append(outs)
348346

349347
if x is None or len(x) == 0:
@@ -450,8 +448,7 @@ def predict_generator(model, generator,
450448
x = generator_output
451449

452450
outs = model.predict_on_batch(x)
453-
if not isinstance(outs, list):
454-
outs = [outs]
451+
outs = to_list(outs)
455452

456453
if not all_outs:
457454
for out in outs:

keras/engine/training_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .. import backend as K
1111
from .. import losses
12+
from ..utils.generic_utils import to_list
1213

1314

1415
def standardize_single_array(x):
@@ -321,8 +322,7 @@ def collect_metrics(metrics, output_names):
321322
nested_metrics = []
322323
for name in output_names:
323324
output_metrics = metrics.get(name, [])
324-
if not isinstance(output_metrics, list):
325-
output_metrics = [output_metrics]
325+
output_metrics = to_list(output_metrics)
326326
nested_metrics.append(output_metrics)
327327
return nested_metrics
328328
else:

keras/legacy/layers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ..engine import Layer, InputSpec
99
from .. import backend as K
1010
from ..utils import conv_utils
11+
from ..utils.generic_utils import to_list
1112
from .. import regularizers
1213
from .. import constraints
1314
from .. import activations
@@ -521,10 +522,8 @@ def __call__(self, inputs, initial_state=None, **kwargs):
521522
# Compute the full input spec, including state
522523
input_spec = self.input_spec
523524
state_spec = self.state_spec
524-
if not isinstance(input_spec, list):
525-
input_spec = [input_spec]
526-
if not isinstance(state_spec, list):
527-
state_spec = [state_spec]
525+
input_spec = to_list(input_spec)
526+
state_spec = to_list(state_spec)
528527
self.input_spec = input_spec + state_spec
529528

530529
# Compute the full inputs, including state

keras/utils/multi_gpu_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,7 @@ def get_slice(data, i, parts):
224224
# Apply model on slice
225225
# (creating a model replica on the target device).
226226
outputs = model(inputs)
227-
if not isinstance(outputs, list):
228-
outputs = [outputs]
227+
outputs = to_list(outputs)
229228

230229
# Save the outputs for merging back together later.
231230
for o in range(len(outputs)):

keras/wrappers/scikit_learn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ..utils.np_utils import to_categorical
1313
from ..utils.generic_utils import has_arg
14+
from ..utils.generic_utils import to_list
1415
from ..models import Sequential
1516

1617

@@ -291,8 +292,7 @@ def score(self, x, y, **kwargs):
291292
y = to_categorical(y)
292293

293294
outputs = self.model.evaluate(x, y, **kwargs)
294-
if not isinstance(outputs, list):
295-
outputs = [outputs]
295+
outputs = to_list(outputs)
296296
for name, output in zip(self.model.metrics_names, outputs):
297297
if name == 'acc':
298298
return output

0 commit comments

Comments
 (0)