File tree Expand file tree Collapse file tree 1 file changed +2
-10
lines changed
src/main/scala/com/yahoo/tensorflowonspark Expand file tree Collapse file tree 1 file changed +2
-10
lines changed Original file line number Diff line number Diff line change @@ -295,18 +295,10 @@ class TFModel(override val uid: String) extends Model[TFModel] with TFParams {
295295 }
296296
297297 override def transformSchema (schema : StructType ): StructType = {
298- if (TFModel .model == null || TFModel .modelDir != this .getModel) {
299- // load model into a driver singleton reference, if needed.
300- // Note: this implies that the driver memory must be sized similarly to the executors.
301- TFModel .modelDir = this .getModel
302- TFModel .model = SavedModelBundle .load(this .getModel, this .getTag)
303- TFModel .graph = TFModel .model.graph()
304- TFModel .sess = TFModel .model.session()
305- }
298+ val model = SavedModelBundle .load(this .getModel, this .getTag)
299+ val g = model.graph
306300
307301 val fields = this .getOutputMapping.map { case (tensorName, columnName) =>
308- val g = TFModel .graph
309- require(g != null , " graph is null" )
310302 val op = g.operation(tensorName)
311303 // if a requested tensorName is not found, dump all operations in the graph to aid user and throw exception.
312304 if (op == null ) {
You can’t perform that action at this time.
0 commit comments