diff --git a/tensorflow_hub/keras_layer.py b/tensorflow_hub/keras_layer.py index b78a7137..93dd553b 100644 --- a/tensorflow_hub/keras_layer.py +++ b/tensorflow_hub/keras_layer.py @@ -25,10 +25,8 @@ # pylint: disable=g-import-not-at-top # Use Keras 2. version_fn = getattr(tf.keras, "version", None) -if version_fn and version_fn().startswith("3."): - import tf_keras as keras -else: - keras = tf.keras +# Always align with tf.keras to avoid mismatched Layer types +keras = tf.keras # pylint: disable=g-direct-tensorflow-import from tensorflow.python.framework import smart_cond @@ -210,11 +208,34 @@ def _setup_layer(self, trainable=False, **kwargs): self.add_loss(self._call_loss_if_trainable(l)) # Supports callables. def _add_existing_weight(self, weight, trainable=None): - """Calls add_weight() to register but not create an existing weight.""" - if trainable is None: trainable = weight.trainable - self.add_weight(name=weight.name, shape=weight.shape, dtype=weight.dtype, - trainable=trainable, experimental_autocast=False, - getter=lambda *_, **__: weight) + """Registers an existing tf.Variable with this layer.""" + if trainable is None: + trainable = getattr(weight, "trainable", False) + + # Create custom weight lists if they don't exist + if not hasattr(self, '_hub_trainable_weights'): + self._hub_trainable_weights = [] + if not hasattr(self, '_hub_non_trainable_weights'): + self._hub_non_trainable_weights = [] + + # Add to appropriate list + if trainable: + self._hub_trainable_weights.append(weight) + else: + self._hub_non_trainable_weights.append(weight) + @property + def trainable_weights(self): + """Override to include hub weights.""" + base_weights = super().trainable_weights + hub_weights = getattr(self, '_hub_trainable_weights', []) + return base_weights + hub_weights + + @property + def non_trainable_weights(self): + """Override to include hub weights.""" + base_weights = super().non_trainable_weights + hub_weights = getattr(self, '_hub_non_trainable_weights', []) + return base_weights + hub_weights def _call_loss_if_trainable(self, loss): """Returns `loss` conditioned on whether this layer is trainable.""" @@ -338,6 +359,7 @@ def get_config(self): if not isinstance(self._handle, str): # Need to raise this type in order for tf.saved_model.save() to fall back # to not using config, instead of crashing. + # TODO(b/134528831): Reconsider the usability implications. raise NotImplementedError( "Can only generate a valid config for `hub.KerasLayer(handle, ...)`" "that uses a string `handle`.\n\n"