Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
47b139c
Changes required to make distconv a dependency (#1)
michaelmckinsey1 Jan 22, 2026
f6d120c
Ensure `hpc-launcher@1.0.4` is used (#5)
michaelmckinsey1 Jan 24, 2026
42b7e7e
Fix model size mismatch on restart (#9)
PatrickRMiles Jan 29, 2026
55f1c36
Fix Minor Bugs Discovered During Testing (#7)
michaelmckinsey1 Jan 29, 2026
20d7575
Restore `channels_last_3d`
michaelmckinsey1 Jan 29, 2026
f8b657d
Leverage a Checkpoint Interval to Speed Up Training (#12)
michaelmckinsey1 Feb 6, 2026
94ecfa9
Fix nested directories being created when restarting (#10)
PatrickRMiles Feb 17, 2026
fe31fab
Set Default Behavior to Stop Training Upon Convergence (#16)
michaelmckinsey1 Feb 19, 2026
6a06474
Update to torch=2.10 and rocm=7.1 and Pin Versions (#17)
michaelmckinsey1 Mar 5, 2026
7b46b1b
Remove redundant variable already set by hpc-launcher (#21)
michaelmckinsey1 Mar 5, 2026
8992d5a
Add `num-shards` and `epochs` to cli (#22)
michaelmckinsey1 Mar 10, 2026
6319962
remove unet bottleneck dim from dataset params used to generate uniqu…
PatrickRMiles Mar 10, 2026
af04705
remove open3d dependency (#25)
PatrickRMiles Mar 10, 2026
61c2c62
fix .ply -> .npy (#30)
PatrickRMiles Mar 19, 2026
20b2285
7.1.1 replacing 7.1.0 (#34)
michaelmckinsey1 Mar 26, 2026
7226c8a
Accurately report finish criteria
michaelmckinsey1 Mar 26, 2026
1331765
Implement multi-dimensional DistConv sharding (#27)
PatrickRMiles Mar 26, 2026
87bd3d7
fix unet bottleneck dim off by 1 error (#29)
PatrickRMiles Apr 2, 2026
39c0b93
Speedup Warmup on ROCm (#24)
michaelmckinsey1 Apr 2, 2026
e6856f1
Apply optimizer every batch, not every epoch; unscale gradients befor…
PatrickRMiles Apr 2, 2026
5f1e2a1
Warmup changes: only warm a few batches; extract to separate method i…
PatrickRMiles Apr 2, 2026
f8fca7b
Data loading optimizations (#46)
PatrickRMiles Apr 3, 2026
adbb812
Fix `ruff` check (#47)
michaelmckinsey1 Apr 10, 2026
8db5f39
Fix `inf` val loss in early epochs (#50)
PatrickRMiles Apr 16, 2026
6bd9e14
Warmup evaluation step (#49)
michaelmckinsey1 Apr 16, 2026
04c1b4a
Add configuration option to disable checkpointing (#48)
michaelmckinsey1 Apr 16, 2026
ac98bbc
Enforce samples are not repeated (#55)
michaelmckinsey1 Apr 22, 2026
a36fb3a
Exit if dice score is NaN (#54)
michaelmckinsey1 Apr 22, 2026
d6b7b64
Move to sharded data loading (#52)
PatrickRMiles Apr 22, 2026
f66fff1
Missing Import (#57)
michaelmckinsey1 Apr 23, 2026
3194f6c
Consistently ignore background class in loss/dice (#59)
michaelmckinsey1 Apr 23, 2026
699cf5f
Improve AMP stability (#60)
michaelmckinsey1 Apr 30, 2026
7ba9de8
Implement CosineAnnealingWarmRestarts LR Scheduler (#61)
michaelmckinsey1 May 6, 2026
9cb7b71
Fix usage of `fract_base_dir` and `val_split` (#63)
michaelmckinsey1 May 6, 2026
23dc95f
Fix class imbalance in CE loss (#56)
PatrickRMiles May 6, 2026
1db5916
fix dtypes for torch (#65)
michaelmckinsey1 May 11, 2026
d3f386a
Enable rocm/7.2.1 (#67)
michaelmckinsey1 May 27, 2026
2e8a160
Add metadata (#70)
michaelmckinsey1 May 28, 2026
b362f33
Enable timing minibatch (#66)
michaelmckinsey1 May 28, 2026
362282f
fix restart epoch bug in trainer (#72)
PatrickRMiles May 28, 2026
5231a8f
Update scaffold-tuolumne.job (#74)
michaelmckinsey1 Jun 4, 2026
55292fd
Use snapshot to prevent warmup from affecting training and refactor w…
michaelmckinsey1 Jun 4, 2026
bf7a135
shrink gitobjects and update gitignore
Jun 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/style.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ jobs:
# Check format. To fix, run "ruff format ."
ruff format --diff .
# Check PEP8 violations, logic/correctness errors, and sort imports. To fix, run "ruff check --fix ."
ruff check --diff .
ruff check .
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ venv/
ENV/
env.bak/
venv.bak/
.venvs/
scaffoldvenv*/

# Data files
*.npy

# Spyder project settings
.spyderproject
Expand Down
26 changes: 14 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,24 @@ The model is trained from a random initialization until convergence, which is de
1. Clone the repository:
`git clone https://github.com/LBANN/ScaFFold.git && cd ScaFFold`

1. Build the ccl plugin (if not using WCI wheel)
`. scripts/install-rccl.sh`

1. Create and activate a python venv for running the benchmark:
`ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv && source .venvs/scaffoldvenv/bin/activate && pip install --upgrade pip`

1. Necessary LLNL settings:
- CUDA (matrix):
1. `ml cuda/12.6.0 gcc/12.1.1 mvapich2/2.3.7`
1. `ml cuda/12.9.1 gcc/13.3.1 mvapich2/2.3.7`
1. `export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH`
- ROCm (elcap):
1. `ml load rocm/6.4.2 rccl/fast-env-slows-mpi`
- If using generic wheel:
1. `export LD_LIBRARY_PATH=/opt/cray/pe/cce/20.0.0/cce/x86_64/lib:$LD_LIBRARY_PATH`
1. `export LD_LIBRARY_PATH=/collab/usr/global/tools/rccl/toss_4_x86_64_ib_cray/rocm-6.4.1/install/lib/:$LD_LIBRARY_PATH` # Necessary to use libfabric plugin (Only necessary if using generic install, wci already links correctly)
1. `ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi`
- If using WCI wheel:
1. `export LD_LIBRARY_PATH=/opt/cray/pe/cce/20.0.0/cce-clang/x86_64/lib/:$LD_LIBRARY_PATH` # for libomp.so
1. `export SPINDLE_FLUXOPT=off` # Avoid spindle error
1. `export LD_PRELOAD=/opt/rocm-7.1.1/llvm/lib/libomp.so` # for libomp.so

1. Install the benchmark in the python venv:
- CUDA: `pip install --no-binary=mpi4py .[cuda] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/cu126 2>&1 | tee install.log`
- ROCm (generic): `pip install --no-binary=mpi4py .[rocm] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/rocm6.4 2>&1 | tee install.log`
- CUDA: `pip install --no-binary=mpi4py .[cuda] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/cu129 2>&1 | tee install.log`
- ROCm (generic): `pip install --no-binary=mpi4py .[rocm] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/rocm7.1 2>&1 | tee install.log`
- ROCm (LLNL): `pip install .[rocmwci] --prefix=.venvs/scaffoldvenv 2>&1 | tee install.log`


Expand Down Expand Up @@ -84,7 +83,10 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa
n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3.
val_split: 25 # In percent.
epochs: 100 # Number of training epochs.
learning_rate: .0001 # Learning rate for training.
starting_learning_rate: .01 # Initial learning rate for training.
min_learning_rate: .001 # Minimum learning rate for CosineAnnealingWarmRestarts.
T_0: 100 # Epochs in the first cosine restart cycle.
T_mult: 2 # Restart cycle growth factor.
disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR.
more_determinism: 0 # If 1, improve model training determinism.
datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch.
Expand Down Expand Up @@ -227,8 +229,8 @@ make && make install
git clone https://github.com/LLNL/Caliper.git
cd Caliper
mkdir pybuild && cd pybuild
ml rocm/6.4.0
ml cuda/12.6.0
ml rocm/7.1.1
ml cuda/12.9.1
cmake -DWITH_PYTHON_BINDINGS=ON \
-DWITH_ROCPROFILER=ON \
-DWITH_CUPTI=ON \
Expand Down
79 changes: 76 additions & 3 deletions ScaFFold/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def main():
type=int,
help="Determines dataset resolution and number of UNet layers.",
)
generate_fractals_parser.add_argument(
"--n-categories",
type=int,
help="Number of fractal categories present in the dataset.",
)
generate_fractals_parser.add_argument(
"--fract-base-dir",
type=str,
Expand Down Expand Up @@ -111,10 +116,14 @@ def main():
benchmark_parser.add_argument(
"--base-run-dir", type=str, help="Subfolder of $(pwd) in which to run jobs."
)
benchmark_parser.add_argument(
"--fract-base-dir",
type=str,
help="Base directory for fractal IFS and instances.",
)
benchmark_parser.add_argument(
"--n-categories",
type=int,
nargs="+",
help="Number of fractal categories present in the dataset.",
)
benchmark_parser.add_argument(
Expand All @@ -134,7 +143,17 @@ def main():
)
benchmark_parser.add_argument("--seed", type=int, help="Random seed.")
benchmark_parser.add_argument(
"--batch-size", type=int, nargs="+", help="Batch sizes for each volume size."
"--batch-size", type=int, help="Batch sizes for each volume size."
)
benchmark_parser.add_argument(
"--warmup-batches",
type=int,
help="Number of warmup batches to run per rank before training.",
)
benchmark_parser.add_argument(
"--dataloader-num-workers",
type=int,
help="Number of DataLoader worker processes per rank.",
)
benchmark_parser.add_argument(
"--optimizer",
Expand All @@ -152,6 +171,39 @@ def main():
type=str,
help="Resume execution in this specific directory. Overrides --base-run-dir.",
)
benchmark_parser.add_argument(
"--dc-num-shards",
type=int,
nargs=3,
help="DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum",
)
benchmark_parser.add_argument(
"--epochs",
type=int,
help="Number of training epochs.",
)
benchmark_parser.add_argument(
"--starting-learning-rate",
type=float,
help="Initial learning rate for training.",
)
benchmark_parser.add_argument(
"--min-learning-rate",
type=float,
help="Minimum learning rate for CosineAnnealingWarmRestarts.",
)
benchmark_parser.add_argument(
"--T-0",
dest="T_0",
type=int,
help="Epochs in the first cosine restart cycle.",
)
benchmark_parser.add_argument(
"--T-mult",
dest="T_mult",
type=int,
help="Restart cycle growth factor.",
)

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
Expand All @@ -177,12 +229,33 @@ def main():
print(f"Overriding '{key}={combined_config[key]}' with '{key}={value}'")
combined_config[key] = value

# Recalculate unet_layers to capture any CLI overrides
combined_config["unet_layers"] = (
combined_config["problem_scale"] - combined_config["unet_bottleneck_dim"]
)

# Resolve paths to absolute, matching Config() behavior
if "base_run_dir" in combined_config and combined_config["base_run_dir"]:
combined_config["base_run_dir"] = str(
Path(combined_config["base_run_dir"]).resolve()
)

if "dataset_dir" in combined_config and combined_config["dataset_dir"]:
combined_config["dataset_dir"] = str(
Path(combined_config["dataset_dir"]).resolve()
)

if "fract_base_dir" in combined_config and combined_config["fract_base_dir"]:
combined_config["fract_base_dir"] = str(
Path(combined_config["fract_base_dir"]).resolve()
)

# Calculate these variables after override
combined_config["vol_size"] = pow(2, combined_config["problem_scale"])
combined_config["point_num"] = int(combined_config["vol_size"] ** 3 / 256)

# Handle Restart / Resume logic
if hasattr(args, "restart") and args.restart == True:
if hasattr(args, "restart") and args.restart:
print("Restart flag detected: Forcing train_from_scratch = False")
combined_config["train_from_scratch"] = False
combined_config["restart"] = True
Expand Down
27 changes: 17 additions & 10 deletions ScaFFold/configs/benchmark_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,26 @@ dataset_dir: "datasets" # Directory in which to store and query for d
fract_base_dir: "fractals" # Base directory for fractal IFS and instances.
n_categories: 5 # Number of fractal categories present in the dataset.
n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances.
problem_scale: 6 # Determines dataset resolution and number of unet layers. Default is 6.
problem_scale: 7 # Determines dataset resolution and number of unet layers. Default is 6.
unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8.
seed: 42 # Random seed.
batch_size: 1 # Batch sizes for each vol size.
batch_size: 1 # Batch sizes for each vol size per rank.
dataloader_num_workers: 1 # Number of DataLoader worker processes per rank. More workers will use more memory
optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp.
num_shards: 2 # DistConv param: number of shards to divide the tensor into
shard_dim: 2 # DistConv param: dimension on which to shard
dc_num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum
dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard
checkpoint_interval: -1 # Checkpoint every C epochs; set to -1 to disable checkpointing entirely.

# Internal/dev use only
variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15.
n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3.
val_split: 25 # In percent.
epochs: 2000 # Number of training epochs.
learning_rate: .0001 # Learning rate for training.
disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR.
val_split: 30 # In percent.
epochs: -1 # Number of training epochs.
starting_learning_rate: 0.001 # Initial learning rate for training.
min_learning_rate: 0.0001 # Minimum learning rate for CosineAnnealingWarmRestarts.
T_0: 100 # Epochs in the first cosine restart cycle.
T_mult: 2 # Restart cycle growth factor.
disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR.
more_determinism: 0 # If 1, improve model training determinism.
datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch.
train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off.
Expand All @@ -28,5 +33,7 @@ framework: "torch" # The DL framework to train with. Only valid
checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints.
loss_freq: 1 # Number of epochs between logging the overall loss.
normalize: 1 # Cateogry search normalization parameter
warmup_epochs: 1 # How many warmup epochs before training
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
warmup_batches: 64 # How many warmup batches per rank to run before training.
ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights.
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
target_dice: 0.95
39 changes: 39 additions & 0 deletions ScaFFold/configs/benchmark_testing.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# External/user-facing
base_run_dir: "benchmark_runs" # Subfolder of $(pwd) in which to run jobs.
dataset_dir: "datasets" # Directory in which to store and query for datasets.
fract_base_dir: "fractals" # Base directory for fractal IFS and instances.
n_categories: 5 # Number of fractal categories present in the dataset.
n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances.
problem_scale: 6 # Determines dataset resolution and number of unet layers. Default is 6.
unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8.
seed: 42 # Random seed.
batch_size: 1 # Batch sizes for each vol size.
dataloader_num_workers: 4 # Number of DataLoader worker processes per rank.
optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp.
num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum
shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard
checkpoint_interval: -1 # Checkpoint every C epochs; set to -1 to disable checkpointing entirely.

# Internal/dev use only
variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15.
n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3.
val_split: 25 # In percent.
epochs: 10 # Number of training epochs.
starting_learning_rate: .0001 # Initial learning rate for training.
min_learning_rate: .0001 # Minimum learning rate for CosineAnnealingWarmRestarts.
T_0: 10 # Epochs in the first cosine restart cycle.
T_mult: 1 # Restart cycle growth factor.
disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR.
more_determinism: 0 # If 1, improve model training determinism.
datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch.
train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off.
dist: 1 # If 1, use torch DDP.
torch_amp: 1 # If 1, use mixed precision in training.
framework: "torch" # The DL framework to train with. Only valid option for now is "torch".
checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints.
loss_freq: 1 # Number of epochs between logging the overall loss.
normalize: 1 # Cateogry search normalization parameter
warmup_batches: 5 # How many warmup batches per rank to run before training.
ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights.
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
target_dice: 0.95
7 changes: 4 additions & 3 deletions ScaFFold/datagen/category_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,10 @@ def main(config: Config) -> None:
print(f"MPI size = {size}")

# Setup directories
repo_src_path = config.library_root
fracts_sub_dir = f"/var{config.variance_threshold}"
fracts_write_dir = f"{repo_src_path}/fractals{fracts_sub_dir}/3DIFS_param"
fracts_sub_dir = f"var{config.variance_threshold}"
fracts_write_dir = os.path.join(
config.fract_base_dir, fracts_sub_dir, "3DIFS_param"
)
if rank == 0:
print(f"Writing fractals to {fracts_write_dir}")
if os.path.exists(fracts_write_dir) and config.datagen_from_scratch:
Expand Down
15 changes: 11 additions & 4 deletions ScaFFold/datagen/get_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@
from ScaFFold.datagen import volumegen

META_FILENAME = "meta.yaml"
DATASET_FORMAT_VERSION = 2
INCLUDE_KEYS = [
"dataset_format_version",
"n_categories",
"n_instances_used_per_fractal",
"problem_scale",
"unet_bottleneck_dim",
"seed",
"variance_threshold",
"n_fracts_per_vol",
"val_split",
]


Expand Down Expand Up @@ -86,12 +88,12 @@ def _git_commit_short() -> str:
)
except subprocess.CalledProcessError:
print(
f"Tried to get git commit id in non-git repo. No commit id will be enforced for dataset reuse."
"Tried to get git commit id in non-git repo. No commit id will be enforced for dataset reuse."
)
return "no-commit-id"
except Exception:
print(
f"Exception when trying to get git commit for dataset. No commit id will be enforced for dataset reuse."
"Exception when trying to get git commit for dataset. No commit id will be enforced for dataset reuse."
)
return "no-commit-id"

Expand All @@ -117,8 +119,10 @@ def get_dataset(
root.mkdir(exist_ok=True)

# Get dict of required keys and compute config_id
config_dict = vars(config).copy()
config_dict["dataset_format_version"] = DATASET_FORMAT_VERSION
volume_config = _get_required_keys_dict(
config=vars(config), include_keys=INCLUDE_KEYS
config=config_dict, include_keys=INCLUDE_KEYS
)
config_id = _hash_volume_config(volume_config)
commit = _git_commit_short()
Expand All @@ -137,6 +141,8 @@ def get_dataset(
meta = yaml.safe_load(meta_path.read_text())
if meta.get("config_id") != config_id:
continue
if meta.get("dataset_format_version", 1) != DATASET_FORMAT_VERSION:
continue
if require_commit and meta.get("code_commit") != commit:
continue
# If we pass the above checks, this dataset can be reused
Expand Down Expand Up @@ -187,6 +193,7 @@ def get_dataset(
# Write to tmp, then move, so readers never see half-written dataset
meta = {
"config_id": config_id,
"dataset_format_version": DATASET_FORMAT_VERSION,
"config_subset": volume_config,
"include_keys": INCLUDE_KEYS,
"code_commit": commit,
Expand Down
Loading
Loading