Skip to content
This repository was archived by the owner on Apr 10, 2024. It is now read-only.

Commit 964d838

Browse files
Add initial versions of SerializedModel and FrozenGraphModel to vision_base
1 parent 3c74744 commit 964d838

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

lucid/modelzoo/vision_base.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
# ==============================================================================
1515

1616
from __future__ import absolute_import, division, print_function
17+
from os import path
1718

1819
import tensorflow as tf
1920
from lucid.modelzoo.util import load_text_labels, load_graphdef, forget_xy
21+
from lucid.misc.io import load
2022

2123
class Model(object):
2224
"""Base pretrained model importer."""
@@ -61,3 +63,39 @@ def import_graph(self, t_input=None, scope='import', forget_xy_shape=True):
6163
tf.import_graph_def(
6264
self.graph_def, {self.input_name: t_prep_input}, name=scope)
6365
self.post_import(scope)
66+
67+
68+
class SerializedModel(Model):
69+
"""Allows importing various types of serialized models from a directory.
70+
71+
(Currently only supports frozen graph models and relies on manifest.json file.
72+
In the future we may want to support automatically detecting the type and
73+
support loading more ways of saving models: tf.SavedModel, metagraphs, etc.)
74+
"""
75+
76+
@classmethod
77+
def from_directory(cls, model_path):
78+
manifest_path = path.join(model_path, 'manifest.json')
79+
try:
80+
manifest = load(manifest_path)
81+
except Exception as e:
82+
raise ValueError("Could not find manifest.json file in dir {}. Error: {}".format(model_path, e))
83+
if manifest.get('type', 'frozen') == 'frozen':
84+
return FrozenGraphModel(model_path, manifest)
85+
else: # TODO: add tf.SavedModel support, etc
86+
raise NotImplementedError("SerializedModel Manifest type '{}' has not been implemented!".format(manifest.get('type')))
87+
88+
89+
class FrozenGraphModel(SerializedModel):
90+
91+
def __init__(self, model_directory, manifest):
92+
model_path = manifest.get('model_path', 'graph.pb')
93+
if model_path.startswith("./"): # TODO: can we be less specific here?
94+
self.model_path = path.join(model_directory, model_path)
95+
else:
96+
self.model_path = model_path
97+
self.labels_path = manifest.get('labels_path', None)
98+
self.image_value_range = manifest.get('image_value_range', None)
99+
self.image_shape = manifest.get('image_shape', None)
100+
self.input_name = manifest.get('input_name', 'input:0')
101+
super().__init__()

0 commit comments

Comments
 (0)