Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
)
from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag, is_decoupled
from maxtext.common.gcloud_stub import vertex_tensorboard_modules
from maxtext.common.metric_logger import MetricLogger, record_activation_metrics
from maxtext.common import metric_logger
from maxtext.common.metric_logger import record_activation_metrics
from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn
from maxtext.utils import exceptions
from maxtext.utils import gcs_utils
Expand Down Expand Up @@ -570,10 +571,10 @@ def train_loop(config, recorder, state=None):

start_step = get_first_step(model, state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
metric_logger_instance = metric_logger.MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)

# Write train config params, num model params, and XLA flags to tensorboard
metric_logger.write_setup_info_to_tensorboard(state.params)
metric_logger_instance.write_setup_info_to_tensorboard(state.params)

_job_completed_gracefully = False
try:
Expand Down Expand Up @@ -611,7 +612,7 @@ def train_loop(config, recorder, state=None):
assert eval_data_iterator
# Explicitly reset the eval iterator and counters before starting the eval loop
eval_data_iterator.reset()
metric_logger.reset_eval_metrics()
metric_logger_instance.reset_eval_metrics()

eval_step_count = 0
# pylint: disable=not-callable
Expand All @@ -622,11 +623,11 @@ def train_loop(config, recorder, state=None):
break
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, nextrng)
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
metric_logger_instance.record_eval_metrics(step, metrics=eval_metrics)
max_logging.log(f"Completed eval step {eval_step_count}")
eval_step_count += 1
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count)
if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
metric_logger_instance.record_eval_metrics(step, eval_step_count=eval_step_count)
if metric_logger_instance.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss:
prof.deactivate()
raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.")

Expand All @@ -635,7 +636,7 @@ def train_loop(config, recorder, state=None):
if step == start_step:
max_utils.print_mem_stats("After params initialized")

metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta)
metric_logger_instance.buffer_and_write_train_metrics(metrics, step, step_time_delta)

if config.save_checkpoint_on_completion:
state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0]
Expand All @@ -650,7 +651,7 @@ def train_loop(config, recorder, state=None):
finally:
if _job_completed_gracefully:
record_goodput(recorder, RECORD_JOB_END_TIME)
metric_logger.flush_metrics_and_cleanup()
metric_logger_instance.flush_metrics_and_cleanup()

return state

Expand Down
Loading