From 3973a8265117ed0937b5ab7a70b7cff87bad70db Mon Sep 17 00:00:00 2001 From: Zakir Jiwani <108548454+JiwaniZakir@users.noreply.github.com> Date: Sat, 18 Apr 2026 06:24:07 +0000 Subject: [PATCH] Fallback to tf.keras when tf_keras not installed with Keras 3 Co-Authored-By: Claude Sonnet 4.6 --- .../python/core/keras/compat.py | 6 ++++- .../python/core/keras/metrics_test.py | 23 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/tensorflow_model_optimization/python/core/keras/compat.py b/tensorflow_model_optimization/python/core/keras/compat.py index 034eca897..59a84f31c 100644 --- a/tensorflow_model_optimization/python/core/keras/compat.py +++ b/tensorflow_model_optimization/python/core/keras/compat.py @@ -32,7 +32,11 @@ def _get_keras_instance(): # Use Keras 2. version_fn = getattr(tf.keras, 'version', None) if version_fn and version_fn().startswith('3.'): - import tf_keras as keras_internal # pylint: disable=g-import-not-at-top,unused-import + try: + import tf_keras as keras_internal # pylint: disable=g-import-not-at-top + except ImportError: + # tf_keras is not installed; fall back to tf.keras (Keras 2 bundled with TF). + keras_internal = tf.keras else: keras_internal = tf.keras return keras_internal diff --git a/tensorflow_model_optimization/python/core/keras/metrics_test.py b/tensorflow_model_optimization/python/core/keras/metrics_test.py index b42c518f2..07f29a2ad 100644 --- a/tensorflow_model_optimization/python/core/keras/metrics_test.py +++ b/tensorflow_model_optimization/python/core/keras/metrics_test.py @@ -15,10 +15,12 @@ """Tests for Metrics.""" import mock +import sys import tensorflow as tf from tensorflow.python.eager import monitoring from tensorflow_model_optimization.python.core.keras import metrics +from tensorflow_model_optimization.python.core.keras import compat from tensorflow_model_optimization.python.core.keras.compat import keras @@ -70,5 +72,26 @@ def test_SetTest(self): self.assertTrue(MetricsTest.gauge.get_cell(self.test_label).value()) +class CompatGetKerasInstanceTest(tf.test.TestCase): + """Tests for _get_keras_instance fallback when tf_keras is unavailable.""" + + def test_fallback_to_tf_keras_when_tf_keras_not_installed(self): + """When tf.keras reports v3 but tf_keras is missing, fall back to tf.keras.""" + original_modules = dict(sys.modules) + # Remove tf_keras from sys.modules so the import inside the function fails. + sys.modules.pop('tf_keras', None) + try: + with mock.patch.object(tf.keras, 'version', return_value='3.0.0', + create=True): + with mock.patch.dict(sys.modules, {'tf_keras': None}): + result = compat._get_keras_instance() + self.assertIs(result, tf.keras) + finally: + # Restore original sys.modules state. + for key in list(sys.modules): + if key not in original_modules: + del sys.modules[key] + + if __name__ == '__main__': tf.test.main()