|
14 | 14 | # ============================================================================== |
15 | 15 |
|
16 | 16 | from __future__ import absolute_import, division, print_function |
| 17 | +from os import path |
17 | 18 |
|
18 | 19 | import tensorflow as tf |
19 | 20 | from lucid.modelzoo.util import load_text_labels, load_graphdef, forget_xy |
| 21 | +from lucid.misc.io import load |
20 | 22 |
|
21 | 23 | class Model(object): |
22 | 24 | """Base pretrained model importer.""" |
@@ -61,3 +63,39 @@ def import_graph(self, t_input=None, scope='import', forget_xy_shape=True): |
61 | 63 | tf.import_graph_def( |
62 | 64 | self.graph_def, {self.input_name: t_prep_input}, name=scope) |
63 | 65 | self.post_import(scope) |
| 66 | + |
| 67 | + |
| 68 | +class SerializedModel(Model): |
| 69 | + """Allows importing various types of serialized models from a directory. |
| 70 | +
|
| 71 | + (Currently only supports frozen graph models and relies on manifest.json file. |
| 72 | + In the future we may want to support automatically detecting the type and |
| 73 | + support loading more ways of saving models: tf.SavedModel, metagraphs, etc.) |
| 74 | + """ |
| 75 | + |
| 76 | + @classmethod |
| 77 | + def from_directory(cls, model_path): |
| 78 | + manifest_path = path.join(model_path, 'manifest.json') |
| 79 | + try: |
| 80 | + manifest = load(manifest_path) |
| 81 | + except Exception as e: |
| 82 | + raise ValueError("Could not find manifest.json file in dir {}. Error: {}".format(model_path, e)) |
| 83 | + if manifest.get('type', 'frozen') == 'frozen': |
| 84 | + return FrozenGraphModel(model_path, manifest) |
| 85 | + else: # TODO: add tf.SavedModel support, etc |
| 86 | + raise NotImplementedError("SerializedModel Manifest type '{}' has not been implemented!".format(manifest.get('type'))) |
| 87 | + |
| 88 | + |
| 89 | +class FrozenGraphModel(SerializedModel): |
| 90 | + |
| 91 | + def __init__(self, model_directory, manifest): |
| 92 | + model_path = manifest.get('model_path', 'graph.pb') |
| 93 | + if model_path.startswith("./"): # TODO: can we be less specific here? |
| 94 | + self.model_path = path.join(model_directory, model_path) |
| 95 | + else: |
| 96 | + self.model_path = model_path |
| 97 | + self.labels_path = manifest.get('labels_path', None) |
| 98 | + self.image_value_range = manifest.get('image_value_range', None) |
| 99 | + self.image_shape = manifest.get('image_shape', None) |
| 100 | + self.input_name = manifest.get('input_name', 'input:0') |
| 101 | + super().__init__() |
0 commit comments