Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/maxtext/common/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,12 @@ def write_metrics_to_managed_mldiagnostics(self, metrics, step):

def write_setup_info_to_tensorboard(self, params):
"""Writes setup information like train config params, num model params, and XLA flags to TensorBoard."""
if not self.config.enable_tensorboard:
return
num_model_parameters = max_utils.calculate_num_params_from_pytree(params)
self.metadata[MetadataKey.PER_DEVICE_TFLOPS], _, _ = maxtext_utils.calculate_tflops_training_per_device(self.config)
self.metadata[MetadataKey.PER_DEVICE_TOKENS] = maxtext_utils.calculate_tokens_training_per_device(self.config)
max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion")
if not self.config.enable_tensorboard:
return
max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), self.writer)
max_utils.add_text_to_summary_writer("libtpu_init_args", os.getenv("LIBTPU_INIT_ARGS", ""), self.writer)
maxtext_utils.add_config_to_summary_writer(self.config, self.writer)
Expand Down
61 changes: 45 additions & 16 deletions tests/unit/metric_logger_abort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@

import numpy as np

from maxtext.common.metric_logger import MetricLogger
from maxtext.common.metric_logger import MetricLogger, MetadataKey


class MetricLoggerAbortTest(unittest.TestCase):

def _make_logger(self, abort_on_nan_loss, abort_on_inf_loss):
logger = MetricLogger.__new__(MetricLogger) # skip __init__
logger.config = SimpleNamespace(
Expand Down Expand Up @@ -62,28 +63,56 @@ def test_abort_on_nan_exits_after_writes(self, _):
@mock.patch("jax.process_index", return_value=0)
def test_abort_on_inf_exits_after_writes(self, _):
logger = self._make_logger(False, True)
with mock.patch.object(logger, "log_metrics"), \
mock.patch.object(logger, "write_metrics_to_tensorboard"), \
mock.patch.object(logger, "write_metrics_locally"), \
mock.patch.object(logger, "write_metrics_for_gcs"), \
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"):
with (
mock.patch.object(logger, "log_metrics"),
mock.patch.object(logger, "write_metrics_to_tensorboard"),
mock.patch.object(logger, "write_metrics_locally"),
mock.patch.object(logger, "write_metrics_for_gcs"),
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"),
):
with self.assertRaises(SystemExit):
logger.write_metrics(self._metrics(np.inf), step=1, is_training=True)

def test_finite_loss_does_not_exit(self):
logger = self._make_logger(True, True)
with mock.patch.object(logger, "log_metrics"), \
mock.patch.object(logger, "write_metrics_to_tensorboard"), \
mock.patch.object(logger, "write_metrics_locally"), \
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), \
mock.patch("jax.process_index", return_value=1): # skip gcs branch
with (
mock.patch.object(logger, "log_metrics"),
mock.patch.object(logger, "write_metrics_to_tensorboard"),
mock.patch.object(logger, "write_metrics_locally"),
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"),
mock.patch("jax.process_index", return_value=1),
): # skip gcs branch
logger.write_metrics(self._metrics(1.23), step=1, is_training=True)

def test_abort_flags_disabled_does_not_exit(self):
logger = self._make_logger(False, False)
with mock.patch.object(logger, "log_metrics"), \
mock.patch.object(logger, "write_metrics_to_tensorboard"), \
mock.patch.object(logger, "write_metrics_locally"), \
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"), \
mock.patch("jax.process_index", return_value=1):
with (
mock.patch.object(logger, "log_metrics"),
mock.patch.object(logger, "write_metrics_to_tensorboard"),
mock.patch.object(logger, "write_metrics_locally"),
mock.patch.object(logger, "write_metrics_to_managed_mldiagnostics"),
mock.patch("jax.process_index", return_value=1),
):
logger.write_metrics(self._metrics(np.nan), step=1, is_training=True)


class MetricLoggerMetadataTest(unittest.TestCase):
"""Tests for MetricLogger metadata and setup initialization."""

def test_metadata_init_without_tensorboard(self):
logger = MetricLogger.__new__(MetricLogger)
logger.config = SimpleNamespace(enable_tensorboard=False)
logger.metadata = {}

with (
mock.patch("maxtext.src.maxtext.utils.max_utils.calculate_num_params_from_pytree", return_value=1e9),
mock.patch(
"maxtext.src.maxtext.utils.maxtext_utils.calculate_tflops_training_per_device", return_value=(100.0, 0, 0)
),
mock.patch("maxtext.src.maxtext.utils.maxtext_utils.calculate_tokens_training_per_device", return_value=1000.0),
mock.patch("maxtext.src.maxtext.utils.max_logging.log"),
):
logger.write_setup_info_to_tensorboard({})

self.assertEqual(logger.metadata[MetadataKey.PER_DEVICE_TFLOPS], 100.0)
self.assertEqual(logger.metadata[MetadataKey.PER_DEVICE_TOKENS], 1000.0)
Loading