From e33e8f8a261eaf7a8d5ed342f22fcc8de6c9cd5d Mon Sep 17 00:00:00 2001 From: maxtext authors Date: Mon, 11 May 2026 09:49:24 -0700 Subject: [PATCH] Fix KeyError in MetricLogger when tensorboard is disabled. # Description Fix KeyError in MetricLogger when enable_tensorboard=False. # Tests Unit tests and Guitar integration. # Checklist Before submitting this PR, please make sure (put X in square brackets): - [X] I have performed a self-review of my code. - [X] I have necessary comments in my code, particularly in hard-to-understand areas. - [X] I have run end-to-end tests and provided workload links above if applicable. - [X] I have made or will make corresponding changes to the doc if needed. PiperOrigin-RevId: 913748302 --- src/maxtext/common/metric_logger.py | 4 +- tests/unit/metric_logger_abort_test.py | 61 +++++++++++++++++++------- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index dc90becbb8..114c0e1519 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -295,12 +295,12 @@ def write_metrics_to_managed_mldiagnostics(self, metrics, step): def write_setup_info_to_tensorboard(self, params): """Writes setup information like train config params, num model params, and XLA flags to TensorBoard.""" - if not self.config.enable_tensorboard: - return num_model_parameters = max_utils.calculate_num_params_from_pytree(params) self.metadata[MetadataKey.PER_DEVICE_TFLOPS], _, _ = maxtext_utils.calculate_tflops_training_per_device(self.config) self.metadata[MetadataKey.PER_DEVICE_TOKENS] = maxtext_utils.calculate_tokens_training_per_device(self.config) max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion") + if not self.config.enable_tensorboard: + return max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), self.writer) max_utils.add_text_to_summary_writer("libtpu_init_args", os.getenv("LIBTPU_INIT_ARGS", ""), self.writer) maxtext_utils.add_config_to_summary_writer(self.config, self.writer) diff --git a/tests/unit/metric_logger_abort_test.py b/tests/unit/metric_logger_abort_test.py index 95eefca88b..9ca996bad5 100644 --- a/tests/unit/metric_logger_abort_test.py +++ b/tests/unit/metric_logger_abort_test.py @@ -19,10 +19,11 @@ import numpy as np -from maxtext.common.metric_logger import MetricLogger +from maxtext.common.metric_logger import MetricLogger, MetadataKey class MetricLoggerAbortTest(unittest.TestCase): + def _make_logger(self, abort_on_nan_loss, abort_on_inf_loss): logger = MetricLogger.__new__(MetricLogger) # skip __init__ logger.config = SimpleNamespace( @@ -62,28 +63,56 @@ def test_abort_on_nan_exits_after_writes(self, _): @mock.patch("jax.process_index", return_value=0) def test_abort_on_inf_exits_after_writes(self, _): logger = self._make_logger(False, True) - with mock.patch.object(logger, "log_metrics"), \ - mock.patch.object(logger, "write_metrics_to_tensorboard"), \ - mock.patch.object(logger, "write_metrics_locally"), \ - mock.patch.object(logger, "write_metrics_for_gcs"), \ - mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"): + with ( + mock.patch.object(logger, "log_metrics"), + mock.patch.object(logger, "write_metrics_to_tensorboard"), + mock.patch.object(logger, "write_metrics_locally"), + mock.patch.object(logger, "write_metrics_for_gcs"), + mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), + ): with self.assertRaises(SystemExit): logger.write_metrics(self._metrics(np.inf), step=1, is_training=True) def test_finite_loss_does_not_exit(self): logger = self._make_logger(True, True) - with mock.patch.object(logger, "log_metrics"), \ - mock.patch.object(logger, "write_metrics_to_tensorboard"), \ - mock.patch.object(logger, "write_metrics_locally"), \ - mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), \ - mock.patch("jax.process_index", return_value=1): # skip gcs branch + with ( + mock.patch.object(logger, "log_metrics"), + mock.patch.object(logger, "write_metrics_to_tensorboard"), + mock.patch.object(logger, "write_metrics_locally"), + mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), + mock.patch("jax.process_index", return_value=1), + ): # skip gcs branch logger.write_metrics(self._metrics(1.23), step=1, is_training=True) def test_abort_flags_disabled_does_not_exit(self): logger = self._make_logger(False, False) - with mock.patch.object(logger, "log_metrics"), \ - mock.patch.object(logger, "write_metrics_to_tensorboard"), \ - mock.patch.object(logger, "write_metrics_locally"), \ - mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), \ - mock.patch("jax.process_index", return_value=1): + with ( + mock.patch.object(logger, "log_metrics"), + mock.patch.object(logger, "write_metrics_to_tensorboard"), + mock.patch.object(logger, "write_metrics_locally"), + mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), + mock.patch("jax.process_index", return_value=1), + ): logger.write_metrics(self._metrics(np.nan), step=1, is_training=True) + + +class MetricLoggerMetadataTest(unittest.TestCase): + """Tests for MetricLogger metadata and setup initialization.""" + + def test_metadata_init_without_tensorboard(self): + logger = MetricLogger.__new__(MetricLogger) + logger.config = SimpleNamespace(enable_tensorboard=False) + logger.metadata = {} + + with ( + mock.patch("maxtext.src.maxtext.utils.max_utils.calculate_num_params_from_pytree", return_value=1e9), + mock.patch( + "maxtext.src.maxtext.utils.maxtext_utils.calculate_tflops_training_per_device", return_value=(100.0, 0, 0) + ), + mock.patch("maxtext.src.maxtext.utils.maxtext_utils.calculate_tokens_training_per_device", return_value=1000.0), + mock.patch("maxtext.src.maxtext.utils.max_logging.log"), + ): + logger.write_setup_info_to_tensorboard({}) + + self.assertEqual(logger.metadata[MetadataKey.PER_DEVICE_TFLOPS], 100.0) + self.assertEqual(logger.metadata[MetadataKey.PER_DEVICE_TOKENS], 1000.0)