File tree Expand file tree Collapse file tree 1 file changed +2
-10
lines changed
src/main/scala/com/yahoo/tensorflowonspark Expand file tree Collapse file tree 1 file changed +2
-10
lines changed Original file line number Diff line number Diff line change @@ -257,11 +257,7 @@ class TFModel(override val uid: String) extends Model[TFModel] with TFParams {
257257 TFModel .sess = TFModel .model.session
258258 }
259259
260- var results = List .empty[Row ]
261- val groupedIter = iter.grouped(this .getBatchSize)
262- while (groupedIter.hasNext) {
263- val batch = groupedIter.next()
264-
260+ iter.grouped(this .getBatchSize).flatMap { batch =>
265261 // get input batch of Rows and convert to list of input Tensors
266262 val inputTensors = batch2tensors(batch, inputSchema)
267263
@@ -283,12 +279,8 @@ class TFModel(override val uid: String) extends Model[TFModel] with TFParams {
283279 " Cardinality of output tensors must match" )
284280
285281 // convert the list of output Tensors to a batch of output Rows
286- val batchResults = tensors2batch(outputTensors)
287-
288- // and add the batch of output Rows to the partition output
289- results ++= batchResults
282+ tensors2batch(outputTensors)
290283 }
291- results.iterator
292284 }
293285
294286 spark.createDataFrame(outputRDD, outputSchema)
You can’t perform that action at this time.
0 commit comments