diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index a1f77f5..65e55e7 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -16,6 +16,7 @@ import math import os import shutil +import statistics import time from pathlib import Path @@ -543,6 +544,29 @@ def _run_training_batch( return batch_size, detached_loss, batch_dice_score + def _sync_gather_minibatch_timer(self, minibatch_events): + minibatch_events[-1][1].synchronize() + local_minibatch_times = torch.tensor( + [ + start_event.elapsed_time(end_event) / 1000.0 + for start_event, end_event in minibatch_events + ], + device=self.device, + ) + if self.config.dist: + gathered_minibatch_times = [ + torch.empty_like(local_minibatch_times) for _ in range(self.world_size) + ] + torch.distributed.all_gather( + gathered_minibatch_times, local_minibatch_times + ) + minibatch_times = torch.stack(gathered_minibatch_times) + minibatch_times = torch.max(minibatch_times, dim=0).values + else: + minibatch_times = local_minibatch_times + minibatch_time_s = statistics.median(minibatch_times.cpu().tolist()) + return minibatch_time_s + def warmup(self): """Run warmup iterations before the main training loop.""" warmup_batches = self.config.warmup_batches @@ -610,6 +634,7 @@ def train(self): epoch = self.start_epoch dice_score_train = 0 + epoch_minibatch_times_s = [] with open(self.outfile_path, "a", newline="") as outfile: start = time.time() while dice_score_train < self.config.target_dice: @@ -623,6 +648,8 @@ def train(self): epoch_start_time = time.time() train_dice_total = 0 epoch_loss = 0 # Accumulator for per-batch losses + minibatch_time_s = None + minibatch_events = [] # Set necessary modes/states if self.config.dist: @@ -645,9 +672,20 @@ def train(self): ) as pbar: begin_code_region("batch_loop") for batch_idx, batch in enumerate(self.train_loader): - time_minibatch = batch_idx == 0 and self.world_rank == 0 + # We don't want to time partial batches, i.e. last batch (time will be lower than expected). + time_minibatch = ( + batch_idx + < len(self.train_sampler) // self.config.batch_size + ) if time_minibatch: - minibatch_start_time = time.perf_counter() + minibatch_start_event = torch.cuda.Event(enable_timing=True) + minibatch_end_event = torch.cuda.Event(enable_timing=True) + minibatch_start_event.record() + minibatch_events.append( + (minibatch_start_event, minibatch_end_event) + ) + begin_code_region("minibatch_time") + begin_code_region("run_training_batch") batch_size, batch_loss, batch_dice_score = ( self._run_training_batch( batch, @@ -655,6 +693,7 @@ def train(self): ) ) train_dice_total += batch_dice_score + end_code_region("run_training_batch") # Update the loss begin_code_region("update_loss") @@ -662,14 +701,11 @@ def train(self): self.global_step += 1 # Stay on GPU epoch_loss += batch_loss - if time_minibatch: - # This sync has some potential performance impact - # TODO: Would be better to measure this with Caliper, which uses CUDA events. - torch.cuda.synchronize(self.device) - minibatch_time_s = ( - time.perf_counter() - minibatch_start_time - ) end_code_region("update_loss") + end_code_region("minibatch_time") + + if time_minibatch: + minibatch_end_event.record() end_code_region("batch_loop") # Calculate overall loss as average of per-batch loss @@ -710,6 +746,14 @@ def train(self): epoch_end_time = time.time() epoch_duration = epoch_end_time - epoch_start_time + + # Sync for batch time happens once after epoch is already done (low overhead) + if len(minibatch_events) > 0: + minibatch_time_s = self._sync_gather_minibatch_timer( + minibatch_events + ) + epoch_minibatch_times_s.append(minibatch_time_s) + # # Write out data for this epoch to train stats csv # @@ -736,7 +780,7 @@ def train(self): ) outfile.flush() print( - f"Epoch {epoch} completed in {epoch_duration} seconds. Total train time so far: {time.time() - start}. Rank 0 first batch minibatch_time_s={minibatch_time_s:.6f}." + f"Epoch {epoch} completed in {epoch_duration:.6f} seconds. Total train time so far: {time.time() - start:.6f} seconds. Median of minibatch times: {minibatch_time_s:.6f} seconds." ) # @@ -763,4 +807,12 @@ def train(self): "Invalid value (NaN) encountered in dice score computation" ) - adiak_value("final_epochs", epoch) + completed_epochs = epoch - 1 + if epoch_minibatch_times_s: + minibatch_time_s = statistics.median(epoch_minibatch_times_s) + adiak_value("minibatch_time_s", minibatch_time_s) + if self.world_rank == 0: + self.log.info( + f"Median of epoch minibatch time medians: {minibatch_time_s:.6f} seconds." + ) + adiak_value("final_epochs", completed_epochs) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index fde4582..02c9118 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -220,6 +220,7 @@ def main(kwargs_dict: dict = {}): trainer.ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims total_shards = math.prod(config.dc_num_shards) global_batch_size = config.batch_size * (world_size // total_shards) + config.global_batch_size = global_batch_size ddp_ranks = world_size // total_shards adiak_value("global_batch_size", global_batch_size) adiak_value("ddp_ranks", ddp_ranks) @@ -278,24 +279,31 @@ def main(kwargs_dict: dict = {}): # # Calculate benchmark score # - outfile_path = trainer.outfile_path - train_data = np.genfromtxt(outfile_path, dtype=float, delimiter=",", names=True) - total_train_time = train_data["epoch_duration"].sum() - epochs = np.atleast_1d(train_data["epoch"]) - total_epochs = int(epochs[-1]) - if config.epochs == -1: - extra_msg = f"Trained to >= {config.target_dice} validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs." - else: - extra_msg = ( - f"Completed in {total_train_time:.2f} seconds, {total_epochs} epochs." + if rank == 0: + outfile_path = trainer.outfile_path + train_data = np.genfromtxt(outfile_path, dtype=float, delimiter=",", names=True) + total_train_time = train_data["epoch_duration"].sum() + fom = 1.0 / total_train_time + adiak_value("FOM", fom) + log.info( + f"FOM = {fom} (1 / total_train_time={total_train_time:.6f} seconds). " + f"This FOM is specific to problem_scale={config.problem_scale}, " + f"target_dice={config.target_dice}, seed={config.seed}." ) + epochs = np.atleast_1d(train_data["epoch"]) + total_epochs = int(epochs[-1]) + if config.epochs == -1: + extra_msg = f"Trained to >= {config.target_dice} validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs." + else: + extra_msg = ( + f"Completed in {total_train_time:.2f} seconds, {total_epochs} epochs." + ) - log.info(f"Benchmark run at scale {config.problem_scale} complete. \n{extra_msg}") + log.info( + f"Benchmark run at scale {config.problem_scale} complete. \n{extra_msg}" + ) - # - # Generate plots - # - if rank == 0: + # Generate plots log.info("Generating figures on rank 0...") begin_code_region("generate_figures") standard_viz.main(config)