Skip to content

Commit a81ae6d

Browse files
committed
remove driver-side model cache
1 parent 2091bb7 commit a81ae6d

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

src/main/scala/com/yahoo/tensorflowonspark/TFModel.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -295,18 +295,10 @@ class TFModel(override val uid: String) extends Model[TFModel] with TFParams {
295295
}
296296

297297
override def transformSchema(schema: StructType): StructType = {
298-
if (TFModel.model == null || TFModel.modelDir != this.getModel) {
299-
// load model into a driver singleton reference, if needed.
300-
// Note: this implies that the driver memory must be sized similarly to the executors.
301-
TFModel.modelDir = this.getModel
302-
TFModel.model = SavedModelBundle.load(this.getModel, this.getTag)
303-
TFModel.graph = TFModel.model.graph()
304-
TFModel.sess = TFModel.model.session()
305-
}
298+
val model = SavedModelBundle.load(this.getModel, this.getTag)
299+
val g = model.graph
306300

307301
val fields = this.getOutputMapping.map { case (tensorName, columnName) =>
308-
val g = TFModel.graph
309-
require(g != null, "graph is null")
310302
val op = g.operation(tensorName)
311303
// if a requested tensorName is not found, dump all operations in the graph to aid user and throw exception.
312304
if (op == null) {

0 commit comments

Comments
 (0)