Skip to content

Commit c944612

Browse files
committed
use flatMap
1 parent a81ae6d commit c944612

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

src/main/scala/com/yahoo/tensorflowonspark/TFModel.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,7 @@ class TFModel(override val uid: String) extends Model[TFModel] with TFParams {
257257
TFModel.sess = TFModel.model.session
258258
}
259259

260-
var results = List.empty[Row]
261-
val groupedIter = iter.grouped(this.getBatchSize)
262-
while (groupedIter.hasNext) {
263-
val batch = groupedIter.next()
264-
260+
iter.grouped(this.getBatchSize).flatMap { batch =>
265261
// get input batch of Rows and convert to list of input Tensors
266262
val inputTensors = batch2tensors(batch, inputSchema)
267263

@@ -283,12 +279,8 @@ class TFModel(override val uid: String) extends Model[TFModel] with TFParams {
283279
"Cardinality of output tensors must match")
284280

285281
// convert the list of output Tensors to a batch of output Rows
286-
val batchResults = tensors2batch(outputTensors)
287-
288-
// and add the batch of output Rows to the partition output
289-
results ++= batchResults
282+
tensors2batch(outputTensors)
290283
}
291-
results.iterator
292284
}
293285

294286
spark.createDataFrame(outputRDD, outputSchema)

0 commit comments

Comments
 (0)