Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@
import tempfile
from zipfile import ZipFile

# Required to load saved models that use TFDF.
import tensorflow_decision_forests
# Optional: only required to load saved models that use TFDF custom ops.
# Not all environments have this package installed.
try:
import tensorflow_decision_forests
except ImportError:
tensorflow_decision_forests = None

import tensorflow as tf
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import graph_pb2
Expand All @@ -48,7 +53,10 @@
from tensorflow.python.training.saver import export_meta_graph
from tensorflow.saved_model.experimental import TrackableResource
from google.protobuf.json_format import MessageToDict
import tensorflow_hub as hub
try:
import tensorflow_hub as hub
except ImportError:
hub = None
from packaging import version

from tensorflowjs import write_weights
Expand Down Expand Up @@ -1009,6 +1017,10 @@ def load_and_initialize_hub_module(module_path, signature='default'):
Raises:
ValueError: If signature contains a SparseTensor on input or output.
"""
if hub is None:
raise ImportError(
'tensorflow_hub is required to convert TF-Hub modules. '
'Install it with: pip install tensorflow_hub')
graph = tf.Graph()
with graph.as_default():
tf.compat.v1.logging.info('Importing %s', module_path)
Expand Down Expand Up @@ -1153,6 +1165,10 @@ def convert_tf_hub_module(module_handle, output_dir,
experiments: Bool enable experimental features.
metadata: User defined metadata map.
"""
if hub is None:
raise ImportError(
'tensorflow_hub is required to convert TF-Hub modules. '
'Install it with: pip install tensorflow_hub')
module_path = hub.resolve(module_handle)
# TODO(vbardiovskyg): We can remove this v1 code path once loading of all v1
# modules is fixed on the TF side, or once the modules we cannot load become
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@

import tensorflow.compat.v2 as tf
import tf_keras
from tensorflow_decision_forests.keras import GradientBoostedTreesModel
try:
from tensorflow_decision_forests.keras import GradientBoostedTreesModel
_TFDF_AVAILABLE = True
except ImportError:
GradientBoostedTreesModel = None
_TFDF_AVAILABLE = False

from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
Expand All @@ -34,7 +40,13 @@
from tensorflow.python.trackable import autotrackable
from tensorflow.python.tools import freeze_graph
from tensorflow.python.saved_model.save import save
import tensorflow_hub as hub
try:
import tensorflow_hub as hub
_TFHUB_AVAILABLE = True
except ImportError:
hub = None
_TFHUB_AVAILABLE = False

from tensorflowjs import version
from tensorflowjs.converters import graph_rewrite_util
from tensorflowjs.converters import tf_saved_model_conversion_v2
Expand Down Expand Up @@ -1072,6 +1084,7 @@ def test_convert_saved_model_with_control_flow_v2(self):
glob.glob(
os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*')))

@unittest.skipUnless(_TFDF_AVAILABLE, 'tensorflow_decision_forests not installed')
def test_convert_saved_model_with_tfdf(self):
self._create_saved_model_with_tfdf()

Expand Down Expand Up @@ -1226,6 +1239,7 @@ def test_convert_saved_model_structured_outputs_false(self):
model_json = json.load(f)
self.assertIs(model_json.get('userDefinedMetadata'), None)

@unittest.skipUnless(_TFHUB_AVAILABLE, 'tensorflow_hub not installed')
def test_convert_hub_module_v2(self):
self._create_saved_model()
module_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
Expand Down Expand Up @@ -1253,6 +1267,7 @@ def test_convert_hub_module_v2(self):
glob.glob(
os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*')))

@unittest.skipUnless(_TFHUB_AVAILABLE, 'tensorflow_hub not installed')
def test_convert_hub_module_v2_with_metadata(self):
self._create_saved_model()
module_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
Expand Down Expand Up @@ -1341,5 +1356,46 @@ def test_convert_keras_model_to_saved_model(self):
glob.glob(
os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*')))

class OptionalDependencyImportTest(tf.test.TestCase):
"""Tests that the converter imports cleanly when optional deps are absent."""

def test_import_succeeds_without_tensorflow_decision_forests(self):
import sys
import importlib
from unittest.mock import patch
with patch.dict(sys.modules, {'tensorflow_decision_forests': None}):
try:
importlib.reload(
sys.modules[
'tensorflowjs.converters.tf_saved_model_conversion_v2'])
except ImportError as e:
self.fail(
'Importing tf_saved_model_conversion_v2 raised ImportError when '
'tensorflow_decision_forests was absent: %s' % e)

def test_import_succeeds_without_tensorflow_hub(self):
import sys
import importlib
from unittest.mock import patch
with patch.dict(sys.modules, {'tensorflow_hub': None}):
try:
importlib.reload(
sys.modules[
'tensorflowjs.converters.tf_saved_model_conversion_v2'])
except ImportError as e:
self.fail(
'Importing tf_saved_model_conversion_v2 raised ImportError when '
'tensorflow_hub was absent: %s' % e)

def test_hub_conversion_raises_clear_error_without_tensorflow_hub(self):
from tensorflowjs.converters import tf_saved_model_conversion_v2
import unittest.mock as mock
with mock.patch.object(tf_saved_model_conversion_v2, 'hub', None):
with self.assertRaises(ImportError) as ctx:
tf_saved_model_conversion_v2.convert_tf_hub_module(
'fake_handle', '/tmp/out')
self.assertIn('tensorflow_hub', str(ctx.exception))


if __name__ == '__main__':
tf.test.main()