diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index dc90becbb8..114c0e1519 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -295,12 +295,12 @@ def write_metrics_to_managed_mldiagnostics(self, metrics, step): def write_setup_info_to_tensorboard(self, params): """Writes setup information like train config params, num model params, and XLA flags to TensorBoard.""" - if not self.config.enable_tensorboard: - return num_model_parameters = max_utils.calculate_num_params_from_pytree(params) self.metadata[MetadataKey.PER_DEVICE_TFLOPS], _, _ = maxtext_utils.calculate_tflops_training_per_device(self.config) self.metadata[MetadataKey.PER_DEVICE_TOKENS] = maxtext_utils.calculate_tokens_training_per_device(self.config) max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion") + if not self.config.enable_tensorboard: + return max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), self.writer) max_utils.add_text_to_summary_writer("libtpu_init_args", os.getenv("LIBTPU_INIT_ARGS", ""), self.writer) maxtext_utils.add_config_to_summary_writer(self.config, self.writer) diff --git a/tests/unit/metric_logger_abort_test.py b/tests/unit/metric_logger_abort_test.py index 95eefca88b..9ca996bad5 100644 --- a/tests/unit/metric_logger_abort_test.py +++ b/tests/unit/metric_logger_abort_test.py @@ -19,10 +19,11 @@ import numpy as np -from maxtext.common.metric_logger import MetricLogger +from maxtext.common.metric_logger import MetricLogger, MetadataKey class MetricLoggerAbortTest(unittest.TestCase): + def _make_logger(self, abort_on_nan_loss, abort_on_inf_loss): logger = MetricLogger.__new__(MetricLogger) # skip __init__ logger.config = SimpleNamespace( @@ -62,28 +63,56 @@ def test_abort_on_nan_exits_after_writes(self, _): @mock.patch("jax.process_index", return_value=0) def test_abort_on_inf_exits_after_writes(self, _): logger = self._make_logger(False, True) - with mock.patch.object(logger, "log_metrics"), \ - mock.patch.object(logger, "write_metrics_to_tensorboard"), \ - mock.patch.object(logger, "write_metrics_locally"), \ - mock.patch.object(logger, "write_metrics_for_gcs"), \ - mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"): + with ( + mock.patch.object(logger, "log_metrics"), + mock.patch.object(logger, "write_metrics_to_tensorboard"), + mock.patch.object(logger, "write_metrics_locally"), + mock.patch.object(logger, "write_metrics_for_gcs"), + mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), + ): with self.assertRaises(SystemExit): logger.write_metrics(self._metrics(np.inf), step=1, is_training=True) def test_finite_loss_does_not_exit(self): logger = self._make_logger(True, True) - with mock.patch.object(logger, "log_metrics"), \ - mock.patch.object(logger, "write_metrics_to_tensorboard"), \ - mock.patch.object(logger, "write_metrics_locally"), \ - mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), \ - mock.patch("jax.process_index", return_value=1): # skip gcs branch + with ( + mock.patch.object(logger, "log_metrics"), + mock.patch.object(logger, "write_metrics_to_tensorboard"), + mock.patch.object(logger, "write_metrics_locally"), + mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), + mock.patch("jax.process_index", return_value=1), + ): # skip gcs branch logger.write_metrics(self._metrics(1.23), step=1, is_training=True) def test_abort_flags_disabled_does_not_exit(self): logger = self._make_logger(False, False) - with mock.patch.object(logger, "log_metrics"), \ - mock.patch.object(logger, "write_metrics_to_tensorboard"), \ - mock.patch.object(logger, "write_metrics_locally"), \ - mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), \ - mock.patch("jax.process_index", return_value=1): + with ( + mock.patch.object(logger, "log_metrics"), + mock.patch.object(logger, "write_metrics_to_tensorboard"), + mock.patch.object(logger, "write_metrics_locally"), + mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), + mock.patch("jax.process_index", return_value=1), + ): logger.write_metrics(self._metrics(np.nan), step=1, is_training=True) + + +class MetricLoggerMetadataTest(unittest.TestCase): + """Tests for MetricLogger metadata and setup initialization.""" + + def test_metadata_init_without_tensorboard(self): + logger = MetricLogger.__new__(MetricLogger) + logger.config = SimpleNamespace(enable_tensorboard=False) + logger.metadata = {} + + with ( + mock.patch("maxtext.src.maxtext.utils.max_utils.calculate_num_params_from_pytree", return_value=1e9), + mock.patch( + "maxtext.src.maxtext.utils.maxtext_utils.calculate_tflops_training_per_device", return_value=(100.0, 0, 0) + ), + mock.patch("maxtext.src.maxtext.utils.maxtext_utils.calculate_tokens_training_per_device", return_value=1000.0), + mock.patch("maxtext.src.maxtext.utils.max_logging.log"), + ): + logger.write_setup_info_to_tensorboard({}) + + self.assertEqual(logger.metadata[MetadataKey.PER_DEVICE_TFLOPS], 100.0) + self.assertEqual(logger.metadata[MetadataKey.PER_DEVICE_TOKENS], 1000.0)