From 0e9d9f289571e8d2f39f1c722ab23c700da388d1 Mon Sep 17 00:00:00 2001 From: maxtext authors Date: Tue, 27 Jan 2026 14:26:11 -0800 Subject: [PATCH] Internal. PiperOrigin-RevId: 861886966 --- src/maxtext/trainers/pre_train/train.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 0e94e0c8ba..1011563a7b 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -58,7 +58,8 @@ ) from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag, is_decoupled from maxtext.common.gcloud_stub import vertex_tensorboard_modules -from maxtext.common.metric_logger import MetricLogger, record_activation_metrics +from maxtext.common import metric_logger +from maxtext.common.metric_logger import record_activation_metrics from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn from maxtext.utils import exceptions from maxtext.utils import gcs_utils @@ -570,10 +571,10 @@ def train_loop(config, recorder, state=None): start_step = get_first_step(model, state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) - metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) + metric_logger_instance = metric_logger.MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard - metric_logger.write_setup_info_to_tensorboard(state.params) + metric_logger_instance.write_setup_info_to_tensorboard(state.params) _job_completed_gracefully = False try: @@ -611,7 +612,7 @@ def train_loop(config, recorder, state=None): assert eval_data_iterator # Explicitly reset the eval iterator and counters before starting the eval loop eval_data_iterator.reset() - metric_logger.reset_eval_metrics() + metric_logger_instance.reset_eval_metrics() eval_step_count = 0 # pylint: disable=not-callable @@ -622,11 +623,11 @@ def train_loop(config, recorder, state=None): break with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): eval_metrics = p_eval_step(state, eval_batch, nextrng) - metric_logger.record_eval_metrics(step, metrics=eval_metrics) + metric_logger_instance.record_eval_metrics(step, metrics=eval_metrics) max_logging.log(f"Completed eval step {eval_step_count}") eval_step_count += 1 - metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) - if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: + metric_logger_instance.record_eval_metrics(step, eval_step_count=eval_step_count) + if metric_logger_instance.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: prof.deactivate() raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") @@ -635,7 +636,7 @@ def train_loop(config, recorder, state=None): if step == start_step: max_utils.print_mem_stats("After params initialized") - metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) + metric_logger_instance.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] @@ -650,7 +651,7 @@ def train_loop(config, recorder, state=None): finally: if _job_completed_gracefully: record_goodput(recorder, RECORD_JOB_END_TIME) - metric_logger.flush_metrics_and_cleanup() + metric_logger_instance.flush_metrics_and_cleanup() return state