Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e583d85
fix dtypes for torch
michaelmckinsey1 May 6, 2026
3dfbd13
Add per minibatch timer
michaelmckinsey1 May 7, 2026
47a4812
Merge remote-tracking branch 'origin/fix-dtypes' into per-minibatch
michaelmckinsey1 May 7, 2026
c9ef075
cleanup
michaelmckinsey1 May 7, 2026
ed246f5
Merge remote-tracking branch 'origin/main' into per-minibatch
michaelmckinsey1 May 14, 2026
c0ba9c9
Add adiak metadata
michaelmckinsey1 May 22, 2026
1f7af32
Update worker.py
michaelmckinsey1 May 22, 2026
1fee7d3
Merge remote-tracking branch 'origin/per-minibatch' into adiak
michaelmckinsey1 May 26, 2026
fa113f1
Define FOM
michaelmckinsey1 May 27, 2026
1e9df19
rm gbs, divide by epochs
michaelmckinsey1 May 27, 2026
ef23c72
Merge remote-tracking branch 'origin/main' into FOM
michaelmckinsey1 May 28, 2026
cc7961c
lint
michaelmckinsey1 May 28, 2026
d0a61f4
Redefine FOM
michaelmckinsey1 May 28, 2026
e1a7def
Cleanup
michaelmckinsey1 May 28, 2026
8b34295
Update worker.py
michaelmckinsey1 Jun 4, 2026
7c54441
Merge remote-tracking branch 'origin/main' into FOM
michaelmckinsey1 Jun 4, 2026
3603678
Merge remote-tracking branch 'origin/main' into FOM
michaelmckinsey1 Jun 4, 2026
b3a4848
Fix merge artifact
michaelmckinsey1 Jun 4, 2026
4aa9238
lint
michaelmckinsey1 Jun 4, 2026
6014356
Cleanup
michaelmckinsey1 Jun 4, 2026
54c40de
Cleanup
michaelmckinsey1 Jun 4, 2026
8819965
Cleanup
michaelmckinsey1 Jun 4, 2026
2de4c55
fix restart epoch bug in trainer (#72)
PatrickRMiles May 28, 2026
fb69a72
Update scaffold-tuolumne.job
michaelmckinsey1 Jun 3, 2026
d4f523d
Merge branch 'FOM' of github.com:michaelmckinsey1/ScaFFold into FOM
michaelmckinsey1 Jun 11, 2026
3528cbc
fix region. Make file reading scope to 1 proc
michaelmckinsey1 Jun 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 63 additions & 11 deletions ScaFFold/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
import os
import shutil
import statistics
import time
from pathlib import Path

Expand Down Expand Up @@ -543,6 +544,29 @@ def _run_training_batch(

return batch_size, detached_loss, batch_dice_score

def _sync_gather_minibatch_timer(self, minibatch_events):
minibatch_events[-1][1].synchronize()
local_minibatch_times = torch.tensor(
[
start_event.elapsed_time(end_event) / 1000.0
for start_event, end_event in minibatch_events
],
device=self.device,
)
if self.config.dist:
gathered_minibatch_times = [
torch.empty_like(local_minibatch_times) for _ in range(self.world_size)
]
torch.distributed.all_gather(
gathered_minibatch_times, local_minibatch_times
)
minibatch_times = torch.stack(gathered_minibatch_times)
minibatch_times = torch.max(minibatch_times, dim=0).values
else:
minibatch_times = local_minibatch_times
minibatch_time_s = statistics.median(minibatch_times.cpu().tolist())
return minibatch_time_s

def warmup(self):
"""Run warmup iterations before the main training loop."""
warmup_batches = self.config.warmup_batches
Expand Down Expand Up @@ -610,6 +634,7 @@ def train(self):

epoch = self.start_epoch
dice_score_train = 0
epoch_minibatch_times_s = []
with open(self.outfile_path, "a", newline="") as outfile:
start = time.time()
while dice_score_train < self.config.target_dice:
Expand All @@ -623,6 +648,8 @@ def train(self):
epoch_start_time = time.time()
train_dice_total = 0
epoch_loss = 0 # Accumulator for per-batch losses
minibatch_time_s = None
minibatch_events = []

# Set necessary modes/states
if self.config.dist:
Expand All @@ -645,31 +672,40 @@ def train(self):
) as pbar:
begin_code_region("batch_loop")
for batch_idx, batch in enumerate(self.train_loader):
time_minibatch = batch_idx == 0 and self.world_rank == 0
# We don't want to time partial batches, i.e. last batch (time will be lower than expected).
time_minibatch = (
batch_idx
< len(self.train_sampler) // self.config.batch_size
)
if time_minibatch:
minibatch_start_time = time.perf_counter()
minibatch_start_event = torch.cuda.Event(enable_timing=True)
minibatch_end_event = torch.cuda.Event(enable_timing=True)
minibatch_start_event.record()
minibatch_events.append(
(minibatch_start_event, minibatch_end_event)
)
begin_code_region("minibatch_time")
begin_code_region("run_training_batch")
batch_size, batch_loss, batch_dice_score = (
self._run_training_batch(
batch,
gather_mem_stats=True,
)
)
train_dice_total += batch_dice_score
end_code_region("run_training_batch")

# Update the loss
begin_code_region("update_loss")
pbar.update(batch_size)
self.global_step += 1
# Stay on GPU
epoch_loss += batch_loss
if time_minibatch:
# This sync has some potential performance impact
# TODO: Would be better to measure this with Caliper, which uses CUDA events.
torch.cuda.synchronize(self.device)
minibatch_time_s = (
time.perf_counter() - minibatch_start_time
)
end_code_region("update_loss")
end_code_region("minibatch_time")

if time_minibatch:
minibatch_end_event.record()
end_code_region("batch_loop")

# Calculate overall loss as average of per-batch loss
Expand Down Expand Up @@ -710,6 +746,14 @@ def train(self):

epoch_end_time = time.time()
epoch_duration = epoch_end_time - epoch_start_time

# Sync for batch time happens once after epoch is already done (low overhead)
if len(minibatch_events) > 0:
minibatch_time_s = self._sync_gather_minibatch_timer(
minibatch_events
)
epoch_minibatch_times_s.append(minibatch_time_s)

#
# Write out data for this epoch to train stats csv
#
Expand All @@ -736,7 +780,7 @@ def train(self):
)
outfile.flush()
print(
f"Epoch {epoch} completed in {epoch_duration} seconds. Total train time so far: {time.time() - start}. Rank 0 first batch minibatch_time_s={minibatch_time_s:.6f}."
f"Epoch {epoch} completed in {epoch_duration:.6f} seconds. Total train time so far: {time.time() - start:.6f} seconds. Median of minibatch times: {minibatch_time_s:.6f} seconds."
)

#
Expand All @@ -763,4 +807,12 @@ def train(self):
"Invalid value (NaN) encountered in dice score computation"
)

adiak_value("final_epochs", epoch)
completed_epochs = epoch - 1
if epoch_minibatch_times_s:
minibatch_time_s = statistics.median(epoch_minibatch_times_s)
adiak_value("minibatch_time_s", minibatch_time_s)
if self.world_rank == 0:
self.log.info(
f"Median of epoch minibatch time medians: {minibatch_time_s:.6f} seconds."
)
adiak_value("final_epochs", completed_epochs)
38 changes: 23 additions & 15 deletions ScaFFold/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def main(kwargs_dict: dict = {}):
trainer.ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims
total_shards = math.prod(config.dc_num_shards)
global_batch_size = config.batch_size * (world_size // total_shards)
config.global_batch_size = global_batch_size
ddp_ranks = world_size // total_shards
adiak_value("global_batch_size", global_batch_size)
adiak_value("ddp_ranks", ddp_ranks)
Expand Down Expand Up @@ -278,24 +279,31 @@ def main(kwargs_dict: dict = {}):
#
# Calculate benchmark score
#
outfile_path = trainer.outfile_path
train_data = np.genfromtxt(outfile_path, dtype=float, delimiter=",", names=True)
total_train_time = train_data["epoch_duration"].sum()
epochs = np.atleast_1d(train_data["epoch"])
total_epochs = int(epochs[-1])
if config.epochs == -1:
extra_msg = f"Trained to >= {config.target_dice} validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs."
else:
extra_msg = (
f"Completed in {total_train_time:.2f} seconds, {total_epochs} epochs."
if rank == 0:
outfile_path = trainer.outfile_path
train_data = np.genfromtxt(outfile_path, dtype=float, delimiter=",", names=True)
total_train_time = train_data["epoch_duration"].sum()
fom = 1.0 / total_train_time
adiak_value("FOM", fom)
log.info(
f"FOM = {fom} (1 / total_train_time={total_train_time:.6f} seconds). "
f"This FOM is specific to problem_scale={config.problem_scale}, "
f"target_dice={config.target_dice}, seed={config.seed}."
)
epochs = np.atleast_1d(train_data["epoch"])
total_epochs = int(epochs[-1])
if config.epochs == -1:
extra_msg = f"Trained to >= {config.target_dice} validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs."
else:
extra_msg = (
f"Completed in {total_train_time:.2f} seconds, {total_epochs} epochs."
)

log.info(f"Benchmark run at scale {config.problem_scale} complete. \n{extra_msg}")
log.info(
f"Benchmark run at scale {config.problem_scale} complete. \n{extra_msg}"
)

#
# Generate plots
#
if rank == 0:
# Generate plots
log.info("Generating figures on rank 0...")
begin_code_region("generate_figures")
standard_viz.main(config)
Expand Down
Loading