diff --git a/tensorflow_model_optimization/python/core/keras/compat.py b/tensorflow_model_optimization/python/core/keras/compat.py index 034eca89..59a84f31 100644 --- a/tensorflow_model_optimization/python/core/keras/compat.py +++ b/tensorflow_model_optimization/python/core/keras/compat.py @@ -32,7 +32,11 @@ def _get_keras_instance(): # Use Keras 2. version_fn = getattr(tf.keras, 'version', None) if version_fn and version_fn().startswith('3.'): - import tf_keras as keras_internal # pylint: disable=g-import-not-at-top,unused-import + try: + import tf_keras as keras_internal # pylint: disable=g-import-not-at-top + except ImportError: + # tf_keras is not installed; fall back to tf.keras (Keras 2 bundled with TF). + keras_internal = tf.keras else: keras_internal = tf.keras return keras_internal diff --git a/tensorflow_model_optimization/python/core/keras/metrics_test.py b/tensorflow_model_optimization/python/core/keras/metrics_test.py index b42c518f..07f29a2a 100644 --- a/tensorflow_model_optimization/python/core/keras/metrics_test.py +++ b/tensorflow_model_optimization/python/core/keras/metrics_test.py @@ -15,10 +15,12 @@ """Tests for Metrics.""" import mock +import sys import tensorflow as tf from tensorflow.python.eager import monitoring from tensorflow_model_optimization.python.core.keras import metrics +from tensorflow_model_optimization.python.core.keras import compat from tensorflow_model_optimization.python.core.keras.compat import keras @@ -70,5 +72,26 @@ def test_SetTest(self): self.assertTrue(MetricsTest.gauge.get_cell(self.test_label).value()) +class CompatGetKerasInstanceTest(tf.test.TestCase): + """Tests for _get_keras_instance fallback when tf_keras is unavailable.""" + + def test_fallback_to_tf_keras_when_tf_keras_not_installed(self): + """When tf.keras reports v3 but tf_keras is missing, fall back to tf.keras.""" + original_modules = dict(sys.modules) + # Remove tf_keras from sys.modules so the import inside the function fails. + sys.modules.pop('tf_keras', None) + try: + with mock.patch.object(tf.keras, 'version', return_value='3.0.0', + create=True): + with mock.patch.dict(sys.modules, {'tf_keras': None}): + result = compat._get_keras_instance() + self.assertIs(result, tf.keras) + finally: + # Restore original sys.modules state. + for key in list(sys.modules): + if key not in original_modules: + del sys.modules[key] + + if __name__ == '__main__': tf.test.main()