diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py index 014abe2b78..9f7bc4c196 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py @@ -24,8 +24,13 @@ import tempfile from zipfile import ZipFile -# Required to load saved models that use TFDF. -import tensorflow_decision_forests +# Optional: only required to load saved models that use TFDF custom ops. +# Not all environments have this package installed. +try: + import tensorflow_decision_forests +except ImportError: + tensorflow_decision_forests = None + import tensorflow as tf from tensorflow.core.framework import function_pb2 from tensorflow.core.framework import graph_pb2 @@ -48,7 +53,10 @@ from tensorflow.python.training.saver import export_meta_graph from tensorflow.saved_model.experimental import TrackableResource from google.protobuf.json_format import MessageToDict -import tensorflow_hub as hub +try: + import tensorflow_hub as hub +except ImportError: + hub = None from packaging import version from tensorflowjs import write_weights @@ -1009,6 +1017,10 @@ def load_and_initialize_hub_module(module_path, signature='default'): Raises: ValueError: If signature contains a SparseTensor on input or output. """ + if hub is None: + raise ImportError( + 'tensorflow_hub is required to convert TF-Hub modules. ' + 'Install it with: pip install tensorflow_hub') graph = tf.Graph() with graph.as_default(): tf.compat.v1.logging.info('Importing %s', module_path) @@ -1153,6 +1165,10 @@ def convert_tf_hub_module(module_handle, output_dir, experiments: Bool enable experimental features. metadata: User defined metadata map. """ + if hub is None: + raise ImportError( + 'tensorflow_hub is required to convert TF-Hub modules. ' + 'Install it with: pip install tensorflow_hub') module_path = hub.resolve(module_handle) # TODO(vbardiovskyg): We can remove this v1 code path once loading of all v1 # modules is fixed on the TF side, or once the modules we cannot load become diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py index d5d97703c3..661ae9cf29 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py @@ -25,7 +25,13 @@ import tensorflow.compat.v2 as tf import tf_keras -from tensorflow_decision_forests.keras import GradientBoostedTreesModel +try: + from tensorflow_decision_forests.keras import GradientBoostedTreesModel + _TFDF_AVAILABLE = True +except ImportError: + GradientBoostedTreesModel = None + _TFDF_AVAILABLE = False + from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -34,7 +40,13 @@ from tensorflow.python.trackable import autotrackable from tensorflow.python.tools import freeze_graph from tensorflow.python.saved_model.save import save -import tensorflow_hub as hub +try: + import tensorflow_hub as hub + _TFHUB_AVAILABLE = True +except ImportError: + hub = None + _TFHUB_AVAILABLE = False + from tensorflowjs import version from tensorflowjs.converters import graph_rewrite_util from tensorflowjs.converters import tf_saved_model_conversion_v2 @@ -1072,6 +1084,7 @@ def test_convert_saved_model_with_control_flow_v2(self): glob.glob( os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*'))) + @unittest.skipUnless(_TFDF_AVAILABLE, 'tensorflow_decision_forests not installed') def test_convert_saved_model_with_tfdf(self): self._create_saved_model_with_tfdf() @@ -1226,6 +1239,7 @@ def test_convert_saved_model_structured_outputs_false(self): model_json = json.load(f) self.assertIs(model_json.get('userDefinedMetadata'), None) + @unittest.skipUnless(_TFHUB_AVAILABLE, 'tensorflow_hub not installed') def test_convert_hub_module_v2(self): self._create_saved_model() module_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) @@ -1253,6 +1267,7 @@ def test_convert_hub_module_v2(self): glob.glob( os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*'))) + @unittest.skipUnless(_TFHUB_AVAILABLE, 'tensorflow_hub not installed') def test_convert_hub_module_v2_with_metadata(self): self._create_saved_model() module_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) @@ -1341,5 +1356,46 @@ def test_convert_keras_model_to_saved_model(self): glob.glob( os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*'))) +class OptionalDependencyImportTest(tf.test.TestCase): + """Tests that the converter imports cleanly when optional deps are absent.""" + + def test_import_succeeds_without_tensorflow_decision_forests(self): + import sys + import importlib + from unittest.mock import patch + with patch.dict(sys.modules, {'tensorflow_decision_forests': None}): + try: + importlib.reload( + sys.modules[ + 'tensorflowjs.converters.tf_saved_model_conversion_v2']) + except ImportError as e: + self.fail( + 'Importing tf_saved_model_conversion_v2 raised ImportError when ' + 'tensorflow_decision_forests was absent: %s' % e) + + def test_import_succeeds_without_tensorflow_hub(self): + import sys + import importlib + from unittest.mock import patch + with patch.dict(sys.modules, {'tensorflow_hub': None}): + try: + importlib.reload( + sys.modules[ + 'tensorflowjs.converters.tf_saved_model_conversion_v2']) + except ImportError as e: + self.fail( + 'Importing tf_saved_model_conversion_v2 raised ImportError when ' + 'tensorflow_hub was absent: %s' % e) + + def test_hub_conversion_raises_clear_error_without_tensorflow_hub(self): + from tensorflowjs.converters import tf_saved_model_conversion_v2 + import unittest.mock as mock + with mock.patch.object(tf_saved_model_conversion_v2, 'hub', None): + with self.assertRaises(ImportError) as ctx: + tf_saved_model_conversion_v2.convert_tf_hub_module( + 'fake_handle', '/tmp/out') + self.assertIn('tensorflow_hub', str(ctx.exception)) + + if __name__ == '__main__': tf.test.main()