Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tensorflow_model_optimization/python/core/keras/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def _get_keras_instance():
# Use Keras 2.
version_fn = getattr(tf.keras, 'version', None)
if version_fn and version_fn().startswith('3.'):
import tf_keras as keras_internal # pylint: disable=g-import-not-at-top,unused-import
try:
import tf_keras as keras_internal # pylint: disable=g-import-not-at-top
except ImportError:
# tf_keras is not installed; fall back to tf.keras (Keras 2 bundled with TF).
keras_internal = tf.keras
else:
keras_internal = tf.keras
return keras_internal
Expand Down
23 changes: 23 additions & 0 deletions tensorflow_model_optimization/python/core/keras/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"""Tests for Metrics."""

import mock
import sys
import tensorflow as tf

from tensorflow.python.eager import monitoring
from tensorflow_model_optimization.python.core.keras import metrics
from tensorflow_model_optimization.python.core.keras import compat
from tensorflow_model_optimization.python.core.keras.compat import keras


Expand Down Expand Up @@ -70,5 +72,26 @@ def test_SetTest(self):
self.assertTrue(MetricsTest.gauge.get_cell(self.test_label).value())


class CompatGetKerasInstanceTest(tf.test.TestCase):
"""Tests for _get_keras_instance fallback when tf_keras is unavailable."""

def test_fallback_to_tf_keras_when_tf_keras_not_installed(self):
"""When tf.keras reports v3 but tf_keras is missing, fall back to tf.keras."""
original_modules = dict(sys.modules)
# Remove tf_keras from sys.modules so the import inside the function fails.
sys.modules.pop('tf_keras', None)
try:
with mock.patch.object(tf.keras, 'version', return_value='3.0.0',
create=True):
with mock.patch.dict(sys.modules, {'tf_keras': None}):
result = compat._get_keras_instance()
self.assertIs(result, tf.keras)
finally:
# Restore original sys.modules state.
for key in list(sys.modules):
if key not in original_modules:
del sys.modules[key]


if __name__ == '__main__':
tf.test.main()
Loading