From 47b139ca740938f00fdf4cd10ec5c9e171db9a7e Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 22 Jan 2026 09:04:03 -0800 Subject: [PATCH 01/43] Changes required to make distconv a dependency (#1) --- pyproject.toml | 1 + requirements.txt | 1 + scripts/install-matrix.sh | 3 +-- scripts/install-tuolumne.sh | 3 +-- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6943402..44ebe27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "wandb>=0.19.6", "open3d>=0.18.0", "PyYAML>=6.0.2", + "distconv @ git+https://github.com/LBANN/DistConv.git", ] requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 0f8fb78..1c6766e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ wandb>=0.19.6 open3d>=0.18.0 PyYAML>=6.0.2 mpi4py==4.0.2 --no-binary mpi4py +distconv @ git+https://github.com/LBANN/DistConv.git # cuda # torch==2.7.1+cu126 diff --git a/scripts/install-matrix.sh b/scripts/install-matrix.sh index f72d2cb..f8c7e61 100644 --- a/scripts/install-matrix.sh +++ b/scripts/install-matrix.sh @@ -1,5 +1,4 @@ ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv-matrix && source .venvs/scaffoldvenv-matrix/bin/activate && pip install --upgrade pip ml cuda/12.6.0 gcc/12.1.1 mvapich2/2.3.7 export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH -git clone git@github.com:LBANN/DistConv.git -pip install --no-binary=mpi4py -e .[cuda] DistConv/ --prefix=.venvs/scaffoldvenv-matrix --extra-index-url https://download.pytorch.org/whl/cu126 2>&1 | tee install.log +pip install --no-binary=mpi4py -e .[cuda] --prefix=.venvs/scaffoldvenv-matrix --extra-index-url https://download.pytorch.org/whl/cu126 2>&1 | tee install.log diff --git a/scripts/install-tuolumne.sh b/scripts/install-tuolumne.sh index 8e74416..3e03de3 100644 --- a/scripts/install-tuolumne.sh +++ b/scripts/install-tuolumne.sh @@ -1,4 +1,3 @@ ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv-tuo && source .venvs/scaffoldvenv-tuo/bin/activate && pip install --upgrade pip ml load rocm/6.4.2 rccl/fast-env-slows-mpi libfabric -git clone git@github.com:LBANN/DistConv.git -pip install -e .[rocmwci] DistConv/ --prefix=.venvs/scaffoldvenv-tuo 2>&1 | tee install.log +pip install -e .[rocmwci] --prefix=.venvs/scaffoldvenv-tuo 2>&1 | tee install.log From f6d120c4673a32c9689a315562f21eac5b498863 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Sat, 24 Jan 2026 08:55:33 -0800 Subject: [PATCH 02/43] Ensure `hpc-launcher@1.0.4` is used (#5) * Update pyproject.toml * Update requirements.txt --- pyproject.toml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 44ebe27..b715d47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ authors = [ ] license = { file = "LICENSE" } dependencies = [ - "hpc-launcher>=1.0.3", + "hpc-launcher>=1.0.4", "matplotlib>=3.9.4", "numpy>=1.26.4", "numba>=0.60.0", diff --git a/requirements.txt b/requirements.txt index 1c6766e..8af677c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ --index-url https://pypi.org/simple -hpc-launcher>=1.0.3 +hpc-launcher>=1.0.4 matplotlib>=3.9.4 numpy>=1.26.4 numba>=0.60.0 From 42b7e7e1df29ebb398d4fd786d160fff1120f521 Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Thu, 29 Jan 2026 13:03:32 -0800 Subject: [PATCH 03/43] Fix model size mismatch on restart (#9) * fix cli bug: must recalculate unet_layers in CLI since problem_scale can be overwritten iwhtout changing config object unet_layers * whitespace * ruff --------- Co-authored-by: Patrick Miles --- ScaFFold/cli.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 2caf981..82532fe 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -177,6 +177,24 @@ def main(): print(f"Overriding '{key}={combined_config[key]}' with '{key}={value}'") combined_config[key] = value + # Recalculate unet_layers to capture any CLI overrides + combined_config["unet_layers"] = ( + combined_config["problem_scale"] + - combined_config["unet_bottleneck_dim"] + + 1 + ) + + # Resolve paths to absolute, matching Config() behavior + if "base_run_dir" in combined_config and combined_config["base_run_dir"]: + combined_config["base_run_dir"] = str( + Path(combined_config["base_run_dir"]).resolve() + ) + + if "dataset_dir" in combined_config and combined_config["dataset_dir"]: + combined_config["dataset_dir"] = str( + Path(combined_config["dataset_dir"]).resolve() + ) + # Calculate these variables after override combined_config["vol_size"] = pow(2, combined_config["problem_scale"]) combined_config["point_num"] = int(combined_config["vol_size"] ** 3 / 256) From 55f1c362b4f0d89f00677876b84506a49846021e Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 29 Jan 2026 13:41:40 -0800 Subject: [PATCH 04/43] Fix Minor Bugs Discovered During Testing (#7) * Continue if checkpointing fails * Fixes for new distconv * enable running 1 epoch * Update trainer.py * Update trainer.py --- ScaFFold/utils/checkpointing.py | 12 ++++++++++-- ScaFFold/utils/trainer.py | 18 ++++++++++++++---- ScaFFold/worker.py | 3 ++- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/ScaFFold/utils/checkpointing.py b/ScaFFold/utils/checkpointing.py index 0bab949..92c5203 100644 --- a/ScaFFold/utils/checkpointing.py +++ b/ScaFFold/utils/checkpointing.py @@ -257,14 +257,22 @@ def save_checkpoint( def _write_to_disk(state_dict, last_path, best_path, is_best): """Worker function to perform actual disk I/O.""" # Save 'last' - torch.save(state_dict, last_path) + try: + torch.save(state_dict, last_path) + except Exception as e: + print("Saving checkpoint failed. Continuing...") + print(e) # Save 'best' (copy logic) if is_best: # Copy is often faster than re-serializing if last_path.exists(): shutil.copyfile(last_path, best_path) else: - torch.save(state_dict, best_path) + try: + torch.save(state_dict, best_path) + except Exception as e: + print("Saving checkpoint failed. Continuing...") + print(e) def _transfer_dict_to_cpu(self, obj): """Recursively move tensors to CPU.""" diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index c76646a..4196c3a 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -295,7 +295,11 @@ def train(self): masks_pred_dc = self.model(images_dc) # Convert predictions for loss - if images.size(0) < ps.num_shards: + if isinstance(ps.num_shards, tuple) and len(ps.num_shards) == 1: + n_shards = ps.num_shards[0] + else: + n_shards = ps.num_shards + if images.size(0) < n_shards: # For small batches (e.g., N=1 with dc_num_shards=2), replicate outputs masks_pred = masks_pred_dc.to_replicate() labels_for_loss = true_masks @@ -304,7 +308,9 @@ def train(self): masks_pred = masks_pred_dc.to_ddp() dt_labels = distribute_tensor( true_masks, - device_mesh=ps.device_mesh["dc"], + device_mesh=ps.device_mesh[ + f"dc{self.config.shard_dim + 2}" + ], placements=[Shard(0)], ) labels_for_loss = dt_labels.to_local() @@ -419,11 +425,15 @@ def train(self): true_masks_ddp = ( DTensor.from_local( true_masks_dp, - device_mesh=ps.device_mesh["dc"], + device_mesh=ps.device_mesh[ + f"dc{self.config.shard_dim + 2}" + ], placements=[Replicate()], ) .redistribute( - device_mesh=ps.device_mesh["dc"], + device_mesh=ps.device_mesh[ + f"dc{self.config.shard_dim + 2}" + ], placements=[Shard(0)], ) .to_local() diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 33f8949..07bedf0 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -239,7 +239,8 @@ def main(kwargs_dict: dict = {}): outfile_path = trainer.outfile_path train_data = np.genfromtxt(outfile_path, dtype=float, delimiter=",", names=True) total_train_time = train_data["epoch_duration"].sum() - total_epochs = train_data["epoch"][-1] + epochs = np.atleast_1d(train_data["epoch"]) + total_epochs = int(epochs[-1]) log.info( f"Benchmark run at scale {config.problem_scale} complete. \n\ Trained to >= 0.95 validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs." From 20d7575f7c881d28c122c467872f599c1630638c Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 29 Jan 2026 14:51:35 -0800 Subject: [PATCH 05/43] Restore `channels_last_3d` --- ScaFFold/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 07bedf0..431c2eb 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -187,7 +187,7 @@ def main(kwargs_dict: dict = {}): shard_dim=shard_dim, device_type=device.type, ) - model = model.to(device).to(memory_format=torch.contiguous_format) + model = model.to(device, memory_format=torch.channels_last_3d) # Wrap with DistConvDDP that corrects gradient scaling for dc submesh model = DistConvDDP( model, From f8b657dada9dfd82fc56573218da851471c1b599 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 5 Feb 2026 18:41:07 -0800 Subject: [PATCH 06/43] Leverage a Checkpoint Interval to Speed Up Training (#12) * set checkpoint interval * truncate stats csv when loading from checkpoint if checkpoint is behind latest CSV entries * lint --------- Co-authored-by: Patrick Miles --- ScaFFold/configs/benchmark_default.yml | 1 + ScaFFold/utils/config_utils.py | 1 + ScaFFold/utils/trainer.py | 68 +++++++++++++++++++++++++- 3 files changed, 68 insertions(+), 2 deletions(-) diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 9ba4bc3..f6d6a17 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -11,6 +11,7 @@ batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. num_shards: 2 # DistConv param: number of shards to divide the tensor into shard_dim: 2 # DistConv param: dimension on which to shard +checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 06ee76d..1f3e3a6 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -72,6 +72,7 @@ def __init__(self, config_dict): self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] + self.checkpoint_interval = config_dict["checkpoint_interval"] class RunConfig(Config): diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 4196c3a..ac92908 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -239,6 +239,15 @@ def cleanup_or_resume(self): "train_mask_values" ] + # If we loaded a checkpoint (start_epoch > 1), we must ensure the CSV + # matches the state of that checkpoint. + if ( + self.world_rank == 0 + and self.start_epoch > 1 + and os.path.exists(self.outfile_path) + ): + self._truncate_stats_file(self.start_epoch) + # Set up the output file headers headers = [ "epoch", @@ -254,6 +263,60 @@ def cleanup_or_resume(self): with open(self.outfile_path, "a", newline="") as outfile: outfile.write(",".join(headers) + "\n") + def _truncate_stats_file(self, start_epoch): + """ + Scans the stats file and truncates it at the first occurrence of + an epoch >= start_epoch. This is O(1) memory and safe for large logs. + """ + self.log.info( + f"Truncating {self.outfile_path} to remove epochs >= {start_epoch}" + ) + + try: + # Open in read+update mode ('r+') to allow seeking and truncating + with open(self.outfile_path, "r+") as f: + header = f.readline() + if not header: + return + + # Identify the index of the 'epoch' column + headers = header.strip().split(",") + try: + epoch_idx = headers.index("epoch") + except ValueError: + epoch_idx = 0 + + while True: + # Save the current file position (start of the line) + current_pos = f.tell() + line = f.readline() + + # End of file reached + if not line: + break + + parts = line.strip().split(",") + try: + row_epoch = int(float(parts[epoch_idx])) + + # If we find a row that is "from the future" (or the current restarting epoch) + if row_epoch >= start_epoch: + # Move pointer back to the start of this line + f.seek(current_pos) + # Cut the file off right here + f.truncate() + self.log.info( + f"Truncated stats file at byte {current_pos} (found epoch {row_epoch})" + ) + break + except (ValueError, IndexError): + # Skip malformed lines, or decide to stop. + # Usually safe to continue scanning. + pass + + except Exception as e: + self.log.warning(f"Failed to truncate stats file: {e}") + def train(self): """ Execute model training @@ -581,8 +644,9 @@ def train(self): # begin_code_region("checkpoint") - extras = {"train_mask_values": self.train_set.mask_values} - self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras) + if epoch % self.config.checkpoint_interval == 0: + extras = {"train_mask_values": self.train_set.mask_values} + self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras) end_code_region("checkpoint") From 94ecfa9a1f902f2c9c6beb41490d1c11d89fc2eb Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Tue, 17 Feb 2026 11:26:57 -0800 Subject: [PATCH 07/43] Fix nested directories being created when restarting (#10) * add -l to torchrun-hpc restart command, preventing nested dir creation; also simplify restart script to more closely match default run method like ScaFFold/scripts/scaffold-tuolumne.job * give restarting torchrun the full path to the existing dir * ruff --------- Co-authored-by: Patrick Miles --- ScaFFold/utils/create_restart_script.py | 63 +++++++------------------ 1 file changed, 18 insertions(+), 45 deletions(-) diff --git a/ScaFFold/utils/create_restart_script.py b/ScaFFold/utils/create_restart_script.py index 4994205..a4bd618 100644 --- a/ScaFFold/utils/create_restart_script.py +++ b/ScaFFold/utils/create_restart_script.py @@ -39,7 +39,7 @@ def _rewrite_config_and_add_restart(cli_args: List[str]) -> List[str]: new_args = [] skip_next = False - # Args to strip because they trigger new directory creation + # Args to strip because they trigger new directory creation or shouldn't change args_to_remove = {"--base-run-dir", "--job-name"} for i, tok in enumerate(cli_args): @@ -90,61 +90,29 @@ def _bash_array(var_name: str, argv: List[str], var_subs: dict[str, str]) -> str def _get_env_setup() -> str: - """Return the bash block that sets up the environment (modules, LD_PRELOAD, etc).""" - # Dynamically determine the current virtualenv path + """Return the bash block that sets up the environment based on your stable configuration.""" + # Dynamically determine the current virtualenv path to reuse the active one venv_path = sys.prefix return f""" # --- Begin Environment Setup --- # Load Modules if command -v module &> /dev/null; then - module load rocm/6.4.2 rccl/fast-env-slows-mpi libfabric + module load rocm/6.4.2 rccl/fast-env-slows-mpi fi # Activate Virtual Environment +# (Using the one active when this script was generated) if [ -f "{venv_path}/bin/activate" ]; then source "{venv_path}/bin/activate" else echo "WARNING: Could not find venv activate script at {venv_path}/bin/activate" fi -# 1. Define the path to the ROCm LLVM OpenMP library -ROCM_OMP_LIB="/opt/rocm-6.4.2/llvm/lib/libomp.so" +# Environment variables +export SPINDLE_FLUXOPT=off +export LD_PRELOAD=/opt/rocm-6.4.2/llvm/lib/libomp.so -# 2. Check if it exists before proceeding -if [ ! -f "$ROCM_OMP_LIB" ]; then - echo "ERROR: Could not find OpenMP at $ROCM_OMP_LIB" - # Fallback search if the standard path is wrong - ROCM_OMP_LIB=$(find /opt/rocm-6.4.2 -name libomp.so | head -n 1) - echo "Found alternative at: $ROCM_OMP_LIB" -fi -if [ -z "$ROCM_OMP_LIB" ]; then - echo "CRITICAL: Unable to find libomp.so in /opt/rocm-6.4.2. Aborting." - exit 1 -fi - -# 3. Force the dynamic linker to load this specific library first -echo "Forcing Preload of: $ROCM_OMP_LIB" -export LD_PRELOAD=$ROCM_OMP_LIB - -# Setup Torch Library Path -SITE_PACKAGES=$(python3 -c "import sysconfig; print(sysconfig.get_path('purelib'))") -TORCH_LIB_PATH="$SITE_PACKAGES/torch/lib" -export LD_LIBRARY_PATH=$TORCH_LIB_PATH:$LD_LIBRARY_PATH - -# Setup System Libfabric -SYSTEM_LIBFABRIC=$(ls /opt/cray/libfabric/2.1/lib64/libfabric.so.1 | head -n 1) - -if [ -z "$SYSTEM_LIBFABRIC" ]; then - echo "Error: Could not find system libfabric!" - exit 1 -fi - -echo "Forcing preload of system Libfabric: $SYSTEM_LIBFABRIC" -export LD_PRELOAD=$SYSTEM_LIBFABRIC:$LD_PRELOAD - -export NCCL_NET=Socket -export NCCL_SOCKET_IFNAME=hsi0 export PROFILE_TORCH=ON # --- End Environment Setup --- """ @@ -180,20 +148,25 @@ def _render_torchrun_hpc_restart( # Additional torchrun-hpc arguments (e.g. --launcher-args for specific scheduler flags) LAUNCHER_ADDITIONAL_ARGS='' -LAUNCHER_ARGS="-N $NODES -n $TASKS_PER_NODE --gpus-per-proc $GPUS_PER_PROC $LAUNCHER_ADDITIONAL_ARGS" - -IFS=' ' read -r -a LAUNCHER_ARR <<< "$LAUNCHER_ARGS" +# Use a proper Bash array for arguments to handle paths with spaces safely +LAUNCHER_ARGS=( + -l "$RUN_DIR" + -N "$NODES" + -n "$TASKS_PER_NODE" + --gpus-per-proc "$GPUS_PER_PROC" + $LAUNCHER_ADDITIONAL_ARGS +) # Exact Python command to rerun the CLI {py_array_decl} echo "Restarting in $RUN_DIR via torchrun-hpc:" -echo " torchrun-hpc $LAUNCHER_ARGS ..." +echo " torchrun-hpc ${{LAUNCHER_ARGS[*]}} ..." printf ' python cmd: '; printf '%q ' "${{PY[@]}}"; echo cd "$RUN_DIR" # Invoking torchrun-hpc to handle scheduler interaction (Flux/Slurm) -exec torchrun-hpc "${{LAUNCHER_ARR[@]}}" "${{PY[@]}}" +exec torchrun-hpc "${{LAUNCHER_ARGS[@]}}" "${{PY[@]}}" """ From fe31fabf6285458eb439da751c82e99e8e815915 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 19 Feb 2026 11:23:58 -0800 Subject: [PATCH 08/43] Set Default Behavior to Stop Training Upon Convergence (#16) * init * debug * testing * Enable configuring n_categories * set checkpoint interval * cleanup * lint * Update trainer.py * Create benchmark_testing.yml --- ScaFFold/cli.py | 6 ++++- ScaFFold/configs/benchmark_default.yml | 7 +++--- ScaFFold/configs/benchmark_testing.yml | 34 ++++++++++++++++++++++++++ ScaFFold/utils/config_utils.py | 1 + ScaFFold/utils/perf_measure.py | 3 ++- ScaFFold/utils/trainer.py | 29 ++++++++++++++++------ 6 files changed, 67 insertions(+), 13 deletions(-) create mode 100644 ScaFFold/configs/benchmark_testing.yml diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 82532fe..a360c38 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -78,6 +78,11 @@ def main(): type=int, help="Determines dataset resolution and number of UNet layers.", ) + generate_fractals_parser.add_argument( + "--n-categories", + type=int, + help="Number of fractal categories present in the dataset.", + ) generate_fractals_parser.add_argument( "--fract-base-dir", type=str, @@ -114,7 +119,6 @@ def main(): benchmark_parser.add_argument( "--n-categories", type=int, - nargs="+", help="Number of fractal categories present in the dataset.", ) benchmark_parser.add_argument( diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index f6d6a17..fce1042 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -9,7 +9,7 @@ unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dim seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: 2 # DistConv param: number of shards to divide the tensor into +num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum shard_dim: 2 # DistConv param: dimension on which to shard checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. @@ -17,7 +17,7 @@ checkpoint_interval: 10 # Checkpoint every C epochs. More frequent ch variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 25 # In percent. -epochs: 2000 # Number of training epochs. +epochs: -1 # Number of training epochs. learning_rate: .0001 # Learning rate for training. disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. @@ -30,4 +30,5 @@ checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpo loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter warmup_epochs: 1 # How many warmup epochs before training -dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. \ No newline at end of file +dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. +target_dice: 0.95 \ No newline at end of file diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml new file mode 100644 index 0000000..aa97106 --- /dev/null +++ b/ScaFFold/configs/benchmark_testing.yml @@ -0,0 +1,34 @@ +# External/user-facing +base_run_dir: "benchmark_runs" # Subfolder of $(pwd) in which to run jobs. +dataset_dir: "datasets" # Directory in which to store and query for datasets. +fract_base_dir: "fractals" # Base directory for fractal IFS and instances. +n_categories: 5 # Number of fractal categories present in the dataset. +n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances. +problem_scale: 6 # Determines dataset resolution and number of unet layers. Default is 6. +unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. +seed: 42 # Random seed. +batch_size: 1 # Batch sizes for each vol size. +optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. +num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +shard_dim: 2 # DistConv param: dimension on which to shard +checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. + +# Internal/dev use only +variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. +n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. +val_split: 25 # In percent. +epochs: 10 # Number of training epochs. +learning_rate: .0001 # Learning rate for training. +disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. +more_determinism: 0 # If 1, improve model training determinism. +datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. +train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off. +dist: 1 # If 1, use torch DDP. +torch_amp: 1 # If 1, use mixed precision in training. +framework: "torch" # The DL framework to train with. Only valid option for now is "torch". +checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. +loss_freq: 1 # Number of epochs between logging the overall loss. +normalize: 1 # Cateogry search normalization parameter +warmup_epochs: 1 # How many warmup epochs before training +dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. +target_dice: 0.95 diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 1f3e3a6..08cb481 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -72,6 +72,7 @@ def __init__(self, config_dict): self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] + self.target_dice = config_dict["target_dice"] self.checkpoint_interval = config_dict["checkpoint_interval"] diff --git a/ScaFFold/utils/perf_measure.py b/ScaFFold/utils/perf_measure.py index e0ac1a2..5af8d5b 100644 --- a/ScaFFold/utils/perf_measure.py +++ b/ScaFFold/utils/perf_measure.py @@ -27,8 +27,9 @@ from pycaliper.instrumentation import begin_region, end_region _CALI_PERF_ENABLED = True - except Exception: + except Exception as e: print("User requested Caliper annotations, but could not import Caliper") + print(f"Exception: {e}") elif ( TORCH_PERF_ENV_VAR in os.environ and os.environ.get(TORCH_PERF_ENV_VAR).lower() != "off" diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index ac92908..29d6807 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -40,7 +40,7 @@ # Local from ScaFFold.utils.evaluate import evaluate -from ScaFFold.utils.perf_measure import begin_code_region, end_code_region +from ScaFFold.utils.perf_measure import adiak_value, begin_code_region, end_code_region from ScaFFold.utils.utils import gather_and_print_mem @@ -404,9 +404,17 @@ def train(self): end_code_region("warmup") self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") + epoch = 1 + dice_score_train = 0 with open(self.outfile_path, "a", newline="") as outfile: start = time.time() - for epoch in range(self.start_epoch, self.config.epochs + 1): + while dice_score_train < self.config.target_dice: + if self.config.epochs != -1 and epoch > self.config.epochs: + print( + f"Maxmimum epochs reached '{self.config.epochs}'. Concluding training early (may have not converged)." + ) + break + # DistConv ParallelStrategy ps = getattr(self.config, "_parallel_strategy", None) if ps is None: @@ -427,10 +435,15 @@ def train(self): self.val_loader.sampler.set_epoch(epoch) self.model.train() + estr = ( + f"{epoch}" + if self.config.epochs == -1 + else f"{epoch}/{self.config.epochs}" + ) with tqdm( total=self.n_train // self.world_size, desc=f"({os.path.basename(self.config.run_dir)}) \ - Epoch {epoch}/{self.config.epochs}", + Epoch {estr}", unit="img", disable=True if self.world_rank != 0 else False, ) as pbar: @@ -644,14 +657,14 @@ def train(self): # begin_code_region("checkpoint") + # Checkpoint only if at a checkpoint_interval epoch if epoch % self.config.checkpoint_interval == 0: extras = {"train_mask_values": self.train_set.mask_values} self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras) end_code_region("checkpoint") - if val_score >= 0.95: - self.log.info( - f"val_score of {val_score} is > threshold of 0.95. Benchmark run complete. Wrapping up..." - ) - return 0 + dice_score_train = val_score + epoch += 1 + + adiak_value("final_epochs", epoch) From 6a0647469e01cb152b918c85bf1cb4004166f523 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 5 Mar 2026 10:15:48 -0800 Subject: [PATCH 09/43] Update to torch=2.10 and rocm=7.1 and Pin Versions (#17) * Update versions and pin distconv & ccl. Add separate install for pypi * Ensure libfabric * Don't need spindle off anymore * Enforce cray-mpich 9.1.0 * patch all so files --- README.md | 21 ++++++++--------- ScaFFold/utils/create_restart_script.py | 4 ++-- pyproject.toml | 14 +++++------ requirements.txt | 16 ++----------- scripts/install-matrix.sh | 4 ++-- scripts/install-rccl.sh | 31 +++++++++++++++++++++++++ scripts/install-tuolumne-torchpypi.sh | 4 ++++ scripts/install-tuolumne.sh | 24 ++++++++++++++++++- scripts/scaffold-matrix.job | 2 -- scripts/scaffold-tuolumne-torchpypi.job | 24 +++++++++++++++++++ scripts/scaffold-tuolumne.job | 11 +++++---- 11 files changed, 111 insertions(+), 44 deletions(-) create mode 100644 scripts/install-rccl.sh create mode 100644 scripts/install-tuolumne-torchpypi.sh create mode 100644 scripts/scaffold-tuolumne-torchpypi.job diff --git a/README.md b/README.md index cdfead8..63a4649 100644 --- a/README.md +++ b/README.md @@ -29,25 +29,24 @@ The model is trained from a random initialization until convergence, which is de 1. Clone the repository: `git clone https://github.com/LBANN/ScaFFold.git && cd ScaFFold` +1. Build the ccl plugin (if not using WCI wheel) + `. scripts/install-rccl.sh` + 1. Create and activate a python venv for running the benchmark: `ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv && source .venvs/scaffoldvenv/bin/activate && pip install --upgrade pip` 1. Necessary LLNL settings: - CUDA (matrix): - 1. `ml cuda/12.6.0 gcc/12.1.1 mvapich2/2.3.7` + 1. `ml cuda/12.9.1 gcc/13.3.1 mvapich2/2.3.7` 1. `export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH` - ROCm (elcap): - 1. `ml load rocm/6.4.2 rccl/fast-env-slows-mpi` - - If using generic wheel: - 1. `export LD_LIBRARY_PATH=/opt/cray/pe/cce/20.0.0/cce/x86_64/lib:$LD_LIBRARY_PATH` - 1. `export LD_LIBRARY_PATH=/collab/usr/global/tools/rccl/toss_4_x86_64_ib_cray/rocm-6.4.1/install/lib/:$LD_LIBRARY_PATH` # Necessary to use libfabric plugin (Only necessary if using generic install, wci already links correctly) + 1. `ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi` - If using WCI wheel: - 1. `export LD_LIBRARY_PATH=/opt/cray/pe/cce/20.0.0/cce-clang/x86_64/lib/:$LD_LIBRARY_PATH` # for libomp.so - 1. `export SPINDLE_FLUXOPT=off` # Avoid spindle error + 1. `export LD_PRELOAD=/opt/rocm-7.1.0/llvm/lib/libomp.so` # for libomp.so 1. Install the benchmark in the python venv: - - CUDA: `pip install --no-binary=mpi4py .[cuda] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/cu126 2>&1 | tee install.log` - - ROCm (generic): `pip install --no-binary=mpi4py .[rocm] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/rocm6.4 2>&1 | tee install.log` + - CUDA: `pip install --no-binary=mpi4py .[cuda] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/cu129 2>&1 | tee install.log` + - ROCm (generic): `pip install --no-binary=mpi4py .[rocm] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/rocm7.1 2>&1 | tee install.log` - ROCm (LLNL): `pip install .[rocmwci] --prefix=.venvs/scaffoldvenv 2>&1 | tee install.log` @@ -227,8 +226,8 @@ make && make install git clone https://github.com/LLNL/Caliper.git cd Caliper mkdir pybuild && cd pybuild -ml rocm/6.4.0 -ml cuda/12.6.0 +ml rocm/7.1.0 +ml cuda/12.9.1 cmake -DWITH_PYTHON_BINDINGS=ON \ -DWITH_ROCPROFILER=ON \ -DWITH_CUPTI=ON \ diff --git a/ScaFFold/utils/create_restart_script.py b/ScaFFold/utils/create_restart_script.py index a4bd618..27a892a 100644 --- a/ScaFFold/utils/create_restart_script.py +++ b/ScaFFold/utils/create_restart_script.py @@ -98,7 +98,7 @@ def _get_env_setup() -> str: # --- Begin Environment Setup --- # Load Modules if command -v module &> /dev/null; then - module load rocm/6.4.2 rccl/fast-env-slows-mpi + ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi fi # Activate Virtual Environment @@ -111,7 +111,7 @@ def _get_env_setup() -> str: # Environment variables export SPINDLE_FLUXOPT=off -export LD_PRELOAD=/opt/rocm-6.4.2/llvm/lib/libomp.so +export LD_PRELOAD=/opt/rocm-7.1.0/llvm/lib/libomp.so export PROFILE_TORCH=ON # --- End Environment Setup --- diff --git a/pyproject.toml b/pyproject.toml index b715d47..1caa3a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ dependencies = [ "wandb>=0.19.6", "open3d>=0.18.0", "PyYAML>=6.0.2", - "distconv @ git+https://github.com/LBANN/DistConv.git", + "distconv @ git+https://github.com/LBANN/DistConv.git@232cba6", ] requires-python = ">=3.9" @@ -60,16 +60,16 @@ profiling = [ "pybind11>=3.0.0" ] cuda = [ - "torch==2.8.0+cu126", - "mpi4py==4.0.2", + "torch==2.10.0+cu129", + "mpi4py==4.1.1", ] rocm = [ - "torch==2.8.0+rocm6.4", - "mpi4py==4.0.2", + "torch==2.10.0+rocm7.1", + "mpi4py==4.1.1", ] rocmwci = [ - "torch==2.8.0+rocm642", - "mpi4py==4.1.1.dev0+mpich.8.1.32", + "torch==2.10.0+rocm710", + "mpi4py==4.1.1+mpich.9.1.0", ] [project.entry-points.console_scripts] diff --git a/requirements.txt b/requirements.txt index 8af677c..8361868 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,17 +7,5 @@ tqdm>=4.67.1 wandb>=0.19.6 open3d>=0.18.0 PyYAML>=6.0.2 -mpi4py==4.0.2 --no-binary mpi4py -distconv @ git+https://github.com/LBANN/DistConv.git - -# cuda -# torch==2.7.1+cu126 -# torchvision==0.22.1+cu126 -# torchaudio==2.7.1+cu126 -# --extra-index-url https://download.pytorch.org/whl/cu126 - -# rocm -# torch==2.8.0+rocm6. -# torchvision==0.23.0+rocm6.4 -# torchaudio==2.8.0+rocm6.4 -# --extra-index-url https://download.pytorch.org/whl/test/rocm6.4 +mpi4py==4.1.1 --no-binary mpi4py +distconv @ git+https://github.com/LBANN/DistConv.git@232cba6 diff --git a/scripts/install-matrix.sh b/scripts/install-matrix.sh index f8c7e61..15c4e6d 100644 --- a/scripts/install-matrix.sh +++ b/scripts/install-matrix.sh @@ -1,4 +1,4 @@ ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv-matrix && source .venvs/scaffoldvenv-matrix/bin/activate && pip install --upgrade pip -ml cuda/12.6.0 gcc/12.1.1 mvapich2/2.3.7 +ml cuda/12.9.1 gcc/13.3.1 mvapich2/2.3.7 export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH -pip install --no-binary=mpi4py -e .[cuda] --prefix=.venvs/scaffoldvenv-matrix --extra-index-url https://download.pytorch.org/whl/cu126 2>&1 | tee install.log +pip install --no-binary=mpi4py -e .[cuda] --prefix=.venvs/scaffoldvenv-matrix --extra-index-url https://download.pytorch.org/whl/cu129 2>&1 | tee install.log diff --git a/scripts/install-rccl.sh b/scripts/install-rccl.sh new file mode 100644 index 0000000..306486a --- /dev/null +++ b/scripts/install-rccl.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Exit if target directory already exists +if [ -d "aws-ofi-nccl.git" ]; then + echo "Directory 'aws-ofi-nccl.git' already exists. Exiting to avoid overwrite." + return 1 2>/dev/null || exit 1 +fi + +rocm_version=7.1.0 + +module swap PrgEnv-cray PrgEnv-gnu +module load rocm/$rocm_version + +git clone --recursive --branch v1.18.0 https://github.com/aws/aws-ofi-nccl.git aws-ofi-nccl.git + +cd aws-ofi-nccl.git + +installdir=$(pwd)/install + +./autogen.sh + +export LD_LIBRARY_PATH=$PWD/../rccl/install/lib:/opt/rocm-$rocm_version/lib:$LD_LIBRARY_PATH + +#CC=hipcc CXX=hipcc CFLAGS=-I$PWD/../rccl/install/include/rccl ./configure \ +./configure \ + --with-libfabric=/opt/cray/libfabric/2.1 \ + --with-rocm=$ROCM_PATH \ + --prefix=$installdir + +make +make install \ No newline at end of file diff --git a/scripts/install-tuolumne-torchpypi.sh b/scripts/install-tuolumne-torchpypi.sh new file mode 100644 index 0000000..87c8473 --- /dev/null +++ b/scripts/install-tuolumne-torchpypi.sh @@ -0,0 +1,4 @@ +. install-rccl.sh +ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv-tuo-pypi && source .venvs/scaffoldvenv-tuo-pypi/bin/activate && pip install --upgrade pip +ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi +pip install -e .[rocm] --prefix=.venvs/scaffoldvenv-tuo-pypi --extra-index-url https://download.pytorch.org/whl/rocm7.1 2>&1 | tee install.log diff --git a/scripts/install-tuolumne.sh b/scripts/install-tuolumne.sh index 3e03de3..d8f5da1 100644 --- a/scripts/install-tuolumne.sh +++ b/scripts/install-tuolumne.sh @@ -1,3 +1,25 @@ ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv-tuo && source .venvs/scaffoldvenv-tuo/bin/activate && pip install --upgrade pip -ml load rocm/6.4.2 rccl/fast-env-slows-mpi libfabric +ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi pip install -e .[rocmwci] --prefix=.venvs/scaffoldvenv-tuo 2>&1 | tee install.log +# Needed until new wheel exists for torch using mpich 9.1.0 +TORCH_LIB_DIR=".venvs/scaffoldvenv-tuo/lib/python3.11/site-packages/torch/lib" +OLD="libmpi_gnu_112.so.12" +NEW="libmpi_gnu.so.12" +cd "$TORCH_LIB_DIR" || exit 1 +# Patch every file that has OLD in its DT_NEEDED +for f in *.so*; do + [ -f "$f" ] || continue + + if patchelf --print-needed "$f" 2>/dev/null | grep -Fxq "$OLD"; then + echo "Patching $f" + patchelf --replace-needed "$OLD" "$NEW" "$f" + fi +done +echo +echo "Verification (should show no $OLD):" +for f in *.so*; do + [ -f "$f" ] || continue + if patchelf --print-needed "$f" 2>/dev/null | grep -Fxq "$OLD"; then + echo "STILL NEEDS $OLD -> $f" + fi +done \ No newline at end of file diff --git a/scripts/scaffold-matrix.job b/scripts/scaffold-matrix.job index 65f2e85..d194aa8 100644 --- a/scripts/scaffold-matrix.job +++ b/scripts/scaffold-matrix.job @@ -8,8 +8,6 @@ #SBATCH -A fractale #SBATCH -perl -ml cuda/12.6.0 gcc/12.1.1 mvapich2/2.3.7 - . .venvs/scaffoldvenv-matrix/bin/activate export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH diff --git a/scripts/scaffold-tuolumne-torchpypi.job b/scripts/scaffold-tuolumne-torchpypi.job new file mode 100644 index 0000000..3b25274 --- /dev/null +++ b/scripts/scaffold-tuolumne-torchpypi.job @@ -0,0 +1,24 @@ +#!/bin/bash + +# flux: --exclusive +# flux: -N 1 +# flux: -g=1 +# flux: -t 60m +# flux: -qpdebug +# flux: -B fractale + +ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi + +. .venvs/scaffoldvenv-tuo-pypi/bin/activate + +# Use ccl plugin that we manually built with install-rccl.sh +export NCCL_NET_PLUGIN=../aws-ofi-nccl.git/install/lib/librccl-net.so +export NCCL_NET="AWS Libfabric" + +torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/configs/benchmark_default.yml + +# Uncomment if you want torch profiling +#export PROFILE_TORCH=ON + +torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml +#torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml diff --git a/scripts/scaffold-tuolumne.job b/scripts/scaffold-tuolumne.job index 4eb3715..79604f7 100644 --- a/scripts/scaffold-tuolumne.job +++ b/scripts/scaffold-tuolumne.job @@ -7,15 +7,16 @@ # flux: -qpdebug # flux: -B fractale -ml rocm/6.4.2 rccl/fast-env-slows-mpi +ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi . .venvs/scaffoldvenv-tuo/bin/activate -# Avoid spindle error -export SPINDLE_FLUXOPT=off +# (1) Avoid libmagma error +# (2) Removing libmpi may cause segfault on mpi4py import +export LD_PRELOAD="/opt/rocm-7.1.0/llvm/lib/libomp.so /opt/cray/pe/mpich/9.1.0/ofi/gnu/11.2/lib/libmpi_gnu.so.12" -# Avoid libmagma error -export LD_PRELOAD=/opt/rocm-6.4.2/llvm/lib/libomp.so +# Ensure using libfabric. NCCL_NET_PLUGIN should be unecessary to set for WCI wheel. +export NCCL_NET="AWS Libfabric" torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/configs/benchmark_default.yml From 7b46b1b55b2f41e84b18e8d544d1a0611b3f6ab6 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 5 Mar 2026 11:34:03 -0800 Subject: [PATCH 10/43] Remove redundant variable already set by hpc-launcher (#21) * cd back after done * Do not set var. This will be set by hpclauncher --- scripts/install-tuolumne.sh | 3 ++- scripts/scaffold-tuolumne-torchpypi.job | 1 - scripts/scaffold-tuolumne.job | 3 --- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/scripts/install-tuolumne.sh b/scripts/install-tuolumne.sh index d8f5da1..62760ca 100644 --- a/scripts/install-tuolumne.sh +++ b/scripts/install-tuolumne.sh @@ -22,4 +22,5 @@ for f in *.so*; do if patchelf --print-needed "$f" 2>/dev/null | grep -Fxq "$OLD"; then echo "STILL NEEDS $OLD -> $f" fi -done \ No newline at end of file +done +cd - diff --git a/scripts/scaffold-tuolumne-torchpypi.job b/scripts/scaffold-tuolumne-torchpypi.job index 3b25274..cc9b10e 100644 --- a/scripts/scaffold-tuolumne-torchpypi.job +++ b/scripts/scaffold-tuolumne-torchpypi.job @@ -13,7 +13,6 @@ ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi # Use ccl plugin that we manually built with install-rccl.sh export NCCL_NET_PLUGIN=../aws-ofi-nccl.git/install/lib/librccl-net.so -export NCCL_NET="AWS Libfabric" torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/configs/benchmark_default.yml diff --git a/scripts/scaffold-tuolumne.job b/scripts/scaffold-tuolumne.job index 79604f7..ce50b46 100644 --- a/scripts/scaffold-tuolumne.job +++ b/scripts/scaffold-tuolumne.job @@ -15,9 +15,6 @@ ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi # (2) Removing libmpi may cause segfault on mpi4py import export LD_PRELOAD="/opt/rocm-7.1.0/llvm/lib/libomp.so /opt/cray/pe/mpich/9.1.0/ofi/gnu/11.2/lib/libmpi_gnu.so.12" -# Ensure using libfabric. NCCL_NET_PLUGIN should be unecessary to set for WCI wheel. -export NCCL_NET="AWS Libfabric" - torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/configs/benchmark_default.yml # Uncomment if you want torch profiling From 8992d5ad31f6f20d8649a1edfcc298a4aca10005 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Tue, 10 Mar 2026 09:45:15 -0700 Subject: [PATCH 11/43] Add `num-shards` and `epochs` to cli (#22) * Add num-shards to cli * lint * Update cli.py --- ScaFFold/cli.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index a360c38..840a7e6 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -156,6 +156,16 @@ def main(): type=str, help="Resume execution in this specific directory. Overrides --base-run-dir.", ) + benchmark_parser.add_argument( + "--num-shards", + type=int, + help="DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum", + ) + benchmark_parser.add_argument( + "--epochs", + type=int, + help="Number of training epochs.", + ) comm = MPI.COMM_WORLD rank = comm.Get_rank() From 6319962657fed58823d73ae42ffff0354bdd9075 Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Tue, 10 Mar 2026 13:01:27 -0700 Subject: [PATCH 12/43] remove unet bottleneck dim from dataset params used to generate unique hash ID (#26) --- ScaFFold/datagen/get_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ScaFFold/datagen/get_dataset.py b/ScaFFold/datagen/get_dataset.py index 65bedd8..fc19f8c 100644 --- a/ScaFFold/datagen/get_dataset.py +++ b/ScaFFold/datagen/get_dataset.py @@ -33,7 +33,6 @@ "n_categories", "n_instances_used_per_fractal", "problem_scale", - "unet_bottleneck_dim", "seed", "variance_threshold", "n_fracts_per_vol", From af04705ebd7c0baf13f6409e14ff34444f109022 Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Tue, 10 Mar 2026 13:01:55 -0700 Subject: [PATCH 13/43] remove open3d dependency (#25) * remove open3d dependency * ruff * comments --------- Co-authored-by: Patrick Miles --- ScaFFold/datagen/instance.py | 26 +++++--------------------- ScaFFold/datagen/volumegen.py | 12 +++++------- pyproject.toml | 1 - requirements.txt | 1 - 4 files changed, 10 insertions(+), 30 deletions(-) diff --git a/ScaFFold/datagen/instance.py b/ScaFFold/datagen/instance.py index 91067ee..801cb0b 100644 --- a/ScaFFold/datagen/instance.py +++ b/ScaFFold/datagen/instance.py @@ -25,7 +25,6 @@ from pathlib import Path import numpy as np -import open3d from mpi4py import MPI from ScaFFold.datagen.generate_fractal_points import generate_fractal_points @@ -170,33 +169,18 @@ def main(config: Config): # Generate points points = generate_single_instance(config.point_num, params) - # Force point_data to be contiguous -- prevents possible segfaults in later Vector3dVector call from non-contiguous arrays + # Force point_data to be contiguous points_contiguous = np.ascontiguousarray(points, dtype=DEFAULT_NP_DTYPE) - # Create o3d PointCloud object - pointcloud = open3d.geometry.PointCloud() - - # Populate PointCloud points attribute with point_data - pointcloud.points = open3d.utility.Vector3dVector(points_contiguous) - - # Construct the long path and short filename + # Construct the output path out_dir = Path(instance_write_dir) / f"{category:06d}" - filename = f"{category:06d}_{instance:04d}.ply" + filename = f"{category:06d}_{instance:04d}.npy" # Ensure parent directory exists out_dir.mkdir(parents=True, exist_ok=True) - # Save current working directory so we can restore it later - cwd = os.getcwd() - try: - # Change to the output directory - os.chdir(out_dir) - - # Write with short relative path - open3d.io.write_point_cloud(filename, pointcloud) - finally: - # Always restore the working directory - os.chdir(cwd) + # Save array to out_dir + np.save(out_dir / filename, points_contiguous) end_time = time.time() total_time = end_time - start_time diff --git a/ScaFFold/datagen/volumegen.py b/ScaFFold/datagen/volumegen.py index d1120ab..479e67e 100644 --- a/ScaFFold/datagen/volumegen.py +++ b/ScaFFold/datagen/volumegen.py @@ -23,7 +23,6 @@ from typing import Dict import numpy as np -import open3d as o3d from mpi4py import MPI from ScaFFold.utils.config_utils import Config @@ -31,12 +30,11 @@ DEFAULT_NP_DTYPE = np.float64 -def load_ply_with_open3d(path: str) -> np.ndarray: +def load_np_ptcloud(path: str) -> np.ndarray: """ - Read a .ply via Open3D and return an (N,3) array of dtype float64. + Read a .npy file and return an (N,3) array of dtype float64. """ - pcd = o3d.io.read_point_cloud(path) - pts = np.asarray(pcd.points) + pts = np.load(path) return pts.astype(DEFAULT_NP_DTYPE, copy=False) @@ -204,7 +202,7 @@ def main(config: Dict): "fractals", instances_dir, f"{curr_category:06d}", - f"{curr_category:06d}_{curr_instance:04d}.ply", + f"{curr_category:06d}_{curr_instance:04d}.npy", ) if not os.path.exists(point_cloud_path): @@ -213,7 +211,7 @@ def main(config: Dict): ) sys.exit(1) - points = load_ply_with_open3d(point_cloud_path) + points = load_np_ptcloud(point_cloud_path) mask3d = points_to_voxelgrid(points, grid_size) assert mask3d.shape == volume.shape[:3], ( diff --git a/pyproject.toml b/pyproject.toml index 1caa3a3..4f8ed54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ dependencies = [ "numba>=0.60.0", "tqdm>=4.67.1", "wandb>=0.19.6", - "open3d>=0.18.0", "PyYAML>=6.0.2", "distconv @ git+https://github.com/LBANN/DistConv.git@232cba6", ] diff --git a/requirements.txt b/requirements.txt index 8361868..e0165fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ numpy>=1.26.4 numba>=0.60.0 tqdm>=4.67.1 wandb>=0.19.6 -open3d>=0.18.0 PyYAML>=6.0.2 mpi4py==4.1.1 --no-binary mpi4py distconv @ git+https://github.com/LBANN/DistConv.git@232cba6 From 61c2c629f8456e05010254d942e892b7195a48ff Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Thu, 19 Mar 2026 13:42:26 -0700 Subject: [PATCH 14/43] fix .ply -> .npy (#30) Co-authored-by: Patrick Miles --- ScaFFold/datagen/instance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/datagen/instance.py b/ScaFFold/datagen/instance.py index 801cb0b..f8cf651 100644 --- a/ScaFFold/datagen/instance.py +++ b/ScaFFold/datagen/instance.py @@ -112,7 +112,7 @@ def main(config: Config): existing_instances = [ int(path_str.split("_")[-1].split(".")[0]) for path_str in glob.glob( - f"{instance_write_dir}/{category:06d}/[0-9][0-9][0-9][0-9][0-9][0-9]_[0-9][0-9][0-9][0-9].ply" + f"{instance_write_dir}/{category:06d}/[0-9][0-9][0-9][0-9][0-9][0-9]_[0-9][0-9][0-9][0-9].npy" ) ] category_instance_pairs = [ From 20b22857a758e8364e467b1d2a2f462c105ad15f Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 26 Mar 2026 09:50:16 -0700 Subject: [PATCH 15/43] 7.1.1 replacing 7.1.0 (#34) --- README.md | 6 +++--- ScaFFold/utils/create_restart_script.py | 4 ++-- scripts/install-rccl.sh | 2 +- scripts/install-tuolumne-torchpypi.sh | 2 +- scripts/install-tuolumne.sh | 2 +- scripts/scaffold-tuolumne-torchpypi.job | 2 +- scripts/scaffold-tuolumne.job | 4 ++-- 7 files changed, 11 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 63a4649..04bef50 100644 --- a/README.md +++ b/README.md @@ -40,9 +40,9 @@ The model is trained from a random initialization until convergence, which is de 1. `ml cuda/12.9.1 gcc/13.3.1 mvapich2/2.3.7` 1. `export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH` - ROCm (elcap): - 1. `ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi` + 1. `ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi` - If using WCI wheel: - 1. `export LD_PRELOAD=/opt/rocm-7.1.0/llvm/lib/libomp.so` # for libomp.so + 1. `export LD_PRELOAD=/opt/rocm-7.1.1/llvm/lib/libomp.so` # for libomp.so 1. Install the benchmark in the python venv: - CUDA: `pip install --no-binary=mpi4py .[cuda] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/cu129 2>&1 | tee install.log` @@ -226,7 +226,7 @@ make && make install git clone https://github.com/LLNL/Caliper.git cd Caliper mkdir pybuild && cd pybuild -ml rocm/7.1.0 +ml rocm/7.1.1 ml cuda/12.9.1 cmake -DWITH_PYTHON_BINDINGS=ON \ -DWITH_ROCPROFILER=ON \ diff --git a/ScaFFold/utils/create_restart_script.py b/ScaFFold/utils/create_restart_script.py index 27a892a..cc8bbbc 100644 --- a/ScaFFold/utils/create_restart_script.py +++ b/ScaFFold/utils/create_restart_script.py @@ -98,7 +98,7 @@ def _get_env_setup() -> str: # --- Begin Environment Setup --- # Load Modules if command -v module &> /dev/null; then - ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi + ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi fi # Activate Virtual Environment @@ -111,7 +111,7 @@ def _get_env_setup() -> str: # Environment variables export SPINDLE_FLUXOPT=off -export LD_PRELOAD=/opt/rocm-7.1.0/llvm/lib/libomp.so +export LD_PRELOAD=/opt/rocm-7.1.1/llvm/lib/libomp.so export PROFILE_TORCH=ON # --- End Environment Setup --- diff --git a/scripts/install-rccl.sh b/scripts/install-rccl.sh index 306486a..a84add3 100644 --- a/scripts/install-rccl.sh +++ b/scripts/install-rccl.sh @@ -6,7 +6,7 @@ if [ -d "aws-ofi-nccl.git" ]; then return 1 2>/dev/null || exit 1 fi -rocm_version=7.1.0 +rocm_version=7.1.1 module swap PrgEnv-cray PrgEnv-gnu module load rocm/$rocm_version diff --git a/scripts/install-tuolumne-torchpypi.sh b/scripts/install-tuolumne-torchpypi.sh index 87c8473..26e7a22 100644 --- a/scripts/install-tuolumne-torchpypi.sh +++ b/scripts/install-tuolumne-torchpypi.sh @@ -1,4 +1,4 @@ . install-rccl.sh ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv-tuo-pypi && source .venvs/scaffoldvenv-tuo-pypi/bin/activate && pip install --upgrade pip -ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi +ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi pip install -e .[rocm] --prefix=.venvs/scaffoldvenv-tuo-pypi --extra-index-url https://download.pytorch.org/whl/rocm7.1 2>&1 | tee install.log diff --git a/scripts/install-tuolumne.sh b/scripts/install-tuolumne.sh index 62760ca..339fd8f 100644 --- a/scripts/install-tuolumne.sh +++ b/scripts/install-tuolumne.sh @@ -1,5 +1,5 @@ ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv-tuo && source .venvs/scaffoldvenv-tuo/bin/activate && pip install --upgrade pip -ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi +ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi pip install -e .[rocmwci] --prefix=.venvs/scaffoldvenv-tuo 2>&1 | tee install.log # Needed until new wheel exists for torch using mpich 9.1.0 TORCH_LIB_DIR=".venvs/scaffoldvenv-tuo/lib/python3.11/site-packages/torch/lib" diff --git a/scripts/scaffold-tuolumne-torchpypi.job b/scripts/scaffold-tuolumne-torchpypi.job index cc9b10e..2629c60 100644 --- a/scripts/scaffold-tuolumne-torchpypi.job +++ b/scripts/scaffold-tuolumne-torchpypi.job @@ -7,7 +7,7 @@ # flux: -qpdebug # flux: -B fractale -ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi +ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi . .venvs/scaffoldvenv-tuo-pypi/bin/activate diff --git a/scripts/scaffold-tuolumne.job b/scripts/scaffold-tuolumne.job index ce50b46..bd0d33a 100644 --- a/scripts/scaffold-tuolumne.job +++ b/scripts/scaffold-tuolumne.job @@ -7,13 +7,13 @@ # flux: -qpdebug # flux: -B fractale -ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.0 rccl/fast-env-slows-mpi +ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi . .venvs/scaffoldvenv-tuo/bin/activate # (1) Avoid libmagma error # (2) Removing libmpi may cause segfault on mpi4py import -export LD_PRELOAD="/opt/rocm-7.1.0/llvm/lib/libomp.so /opt/cray/pe/mpich/9.1.0/ofi/gnu/11.2/lib/libmpi_gnu.so.12" +export LD_PRELOAD="/opt/rocm-7.1.1/llvm/lib/libomp.so /opt/cray/pe/mpich/9.1.0/ofi/gnu/11.2/lib/libmpi_gnu.so.12" torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/configs/benchmark_default.yml From 7226c8abe5dea609d286ee2c0139e56065a52f6b Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 26 Mar 2026 14:04:33 -0700 Subject: [PATCH 16/43] Accurately report finish criteria --- ScaFFold/worker.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 431c2eb..fcd1394 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -241,10 +241,14 @@ def main(kwargs_dict: dict = {}): total_train_time = train_data["epoch_duration"].sum() epochs = np.atleast_1d(train_data["epoch"]) total_epochs = int(epochs[-1]) - log.info( - f"Benchmark run at scale {config.problem_scale} complete. \n\ - Trained to >= 0.95 validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs." - ) + if config.epochs == -1: + extra_msg = f"Trained to >= {config.target_dice} validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs." + else: + extra_msg = ( + f"Completed in {total_train_time:.2f} seconds, {total_epochs} epochs." + ) + + log.info(f"Benchmark run at scale {config.problem_scale} complete. \n{extra_msg}") # # Generate plots From 133176569c497e75e3988c3bc662d9ea2a3d956d Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Thu, 26 Mar 2026 14:45:21 -0700 Subject: [PATCH 17/43] Implement multi-dimensional DistConv sharding (#27) * update config with 3D num_shards and shard_dim * update config util to expect 3D num_shards and shard_dim, add helper to deal with 1D inputs * worker no longer needs to modify distconv params set in config -- just pass them as-is to the ParallelStrategy call * implement multi-dimensional sharding for distconv * update distconv param name scheme * fix loss calc * add sharded dice loss calculation to dice score util * update evaluate to use sharded dice loss calc * update trainer to use new evaluate; other small fixes/tweaks * fix assert * fix naming * fix naming * better default values * missing import * update distconv param names, default vals * use np.prod instead of math.prod * ruff * import math * warmup logging and timing * import math for prod * Add missing import * Remove extra func make cli arg tuple * lint * lint * Update configs --------- Co-authored-by: Patrick Miles Co-authored-by: Michael McKinsey --- ScaFFold/cli.py | 3 +- ScaFFold/configs/benchmark_default.yml | 8 +- ScaFFold/configs/benchmark_testing.yml | 6 +- ScaFFold/utils/config_utils.py | 12 +- ScaFFold/utils/dice_score.py | 54 +++++ ScaFFold/utils/evaluate.py | 124 +++++----- ScaFFold/utils/trainer.py | 321 ++++++++++++++++--------- ScaFFold/worker.py | 21 +- 8 files changed, 354 insertions(+), 195 deletions(-) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 840a7e6..3c73f40 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -157,8 +157,9 @@ def main(): help="Resume execution in this specific directory. Overrides --base-run-dir.", ) benchmark_parser.add_argument( - "--num-shards", + "--dc-num-shards", type=int, + nargs=3, help="DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum", ) benchmark_parser.add_argument( diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index fce1042..e96b103 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -4,14 +4,14 @@ dataset_dir: "datasets" # Directory in which to store and query for d fract_base_dir: "fractals" # Base directory for fractal IFS and instances. n_categories: 5 # Number of fractal categories present in the dataset. n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances. -problem_scale: 6 # Determines dataset resolution and number of unet layers. Default is 6. +problem_scale: 8 # Determines dataset resolution and number of unet layers. Default is 6. unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum -shard_dim: 2 # DistConv param: dimension on which to shard -checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. +dc_num_shards: [1, 1, 2] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard +checkpoint_interval: 100 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index aa97106..5fea4d0 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -9,9 +9,9 @@ unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dim seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: 2 # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum -shard_dim: 2 # DistConv param: dimension on which to shard -checkpoint_interval: 10 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. +num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard +checkpoint_interval: 100 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 08cb481..378dc51 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -67,14 +67,22 @@ def __init__(self, config_dict): self.checkpoint_dir = config_dict["checkpoint_dir"] self.normalize = config_dict["normalize"] self.warmup_epochs = config_dict["warmup_epochs"] - self.num_shards = config_dict["num_shards"] - self.shard_dim = config_dict["shard_dim"] self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] self.target_dice = config_dict["target_dice"] self.checkpoint_interval = config_dict["checkpoint_interval"] + self.dc_num_shards = config_dict["dc_num_shards"] + self.dc_shard_dims = config_dict["dc_shard_dims"] + self.dc_total_shards = math.prod(self.dc_num_shards) + # Safety Check: Length mismatch + if len(self.dc_num_shards) != len(self.dc_shard_dims): + raise ValueError( + f"Configuration Mismatch: num_shards {self.dc_num_shards} " + f"must have same length as shard_dim {self.dc_shard_dims}" + ) + class RunConfig(Config): def __init__(self, config_dict): diff --git a/ScaFFold/utils/dice_score.py b/ScaFFold/utils/dice_score.py index 6536345..ed60fd4 100644 --- a/ScaFFold/utils/dice_score.py +++ b/ScaFFold/utils/dice_score.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: (Apache-2.0) import torch +import torch.distributed as dist from torch import Tensor from ScaFFold.utils.perf_measure import annotate @@ -59,3 +60,56 @@ def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): # Dice loss (objective to minimize) between 0 and 1 fn = multiclass_dice_coeff if multiclass else dice_coeff return 1 - fn(input, target, reduce_batch_first=True) + + +class SpatialAllReduce(torch.autograd.Function): + @staticmethod + def forward(ctx, input, spatial_mesh): + output = input.clone() + for mesh_dim in range(spatial_mesh.ndim): + pg = spatial_mesh.get_group(mesh_dim) + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pg) + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +@annotate() +def compute_sharded_dice( + preds: torch.Tensor, + targets: torch.Tensor, + spatial_mesh, + epsilon: float = 1e-6, +): + """ + Computes the globally sharded Dice score. + Returns the raw score tensor of shape [Batch, Channels]. + """ + assert preds.size() == targets.size(), ( + f"Shape mismatch: {preds.size()} vs {targets.size()}" + ) + assert preds.dim() == 5, f"Expected 5D tensor, got {preds.dim()}D" + + sum_dim = (-1, -2, -3) # D, H, W + + local_inter = 2.0 * (preds * targets).sum(dim=sum_dim) + local_sets_sum_raw = preds.sum(dim=sum_dim) + targets.sum(dim=sum_dim) + + packed = torch.stack([local_inter, local_sets_sum_raw]) + + # Global reduce across spatial mesh + packed_global = SpatialAllReduce.apply(packed, spatial_mesh) + + global_inter = packed_global[0] + global_sets_sum_raw = packed_global[1] + + global_sets_sum = torch.where( + global_sets_sum_raw == 0, global_inter, global_sets_sum_raw + ) + + # Calculate score + dice_score = (global_inter + epsilon) / (global_sets_sum + epsilon) + + return dice_score diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index c2d0672..fbecf90 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -12,13 +12,22 @@ # # SPDX-License-Identifier: (Apache-2.0) +import math + +import numpy as np import torch import torch.nn.functional as F from distconv import DCTensor from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor from tqdm import tqdm -from ScaFFold.utils.dice_score import dice_coeff, dice_loss, multiclass_dice_coeff +from ScaFFold.utils.dice_score import ( + SpatialAllReduce, + compute_sharded_dice, + dice_coeff, + dice_loss, + multiclass_dice_coeff, +) from ScaFFold.utils.perf_measure import annotate @@ -29,10 +38,11 @@ def evaluate( ): net.eval() num_val_batches = len(dataloader) - dice_score = 0.0 + total_dice_score = 0.0 processed_batches = 0 - # For reference, dc sharding happens on this spatial dim: 2=D, 3=H, 4=W + spatial_mesh = parallel_strategy.device_mesh[parallel_strategy.distconv_dim_names] + if primary: print( f"[eval] ps.shard_dim={parallel_strategy.shard_dim} num_shards={parallel_strategy.num_shards}" @@ -50,92 +60,80 @@ def evaluate( ): image, mask_true = batch["image"], batch["mask"] - # move images and labels to correct device and type image = image.to( device=device, dtype=torch.float32, - memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first) + memory_format=torch.channels_last_3d, ) - mask_true = mask_true.to( - device=device, dtype=torch.long - ).contiguous() # masks no channels NDHW, but ensure cotinuity. + mask_true = mask_true.to(device=device, dtype=torch.long).contiguous() + + # Dummy channel dimension [B, 1, D, H, W] + mask_true = mask_true.unsqueeze(1) - # Shard batch across ddp mesh, replicate across dc mesh - image_dp = distribute_tensor( - image, parallel_strategy.device_mesh, placements=[Shard(0), Replicate()] + # DDP Sharding + ddp_placements = [Shard(0)] + [Replicate()] * len( + parallel_strategy.shard_dim + ) + image_dp = DTensor.from_local( + image, parallel_strategy.device_mesh, placements=ddp_placements ).to_local() - mask_true_dp = distribute_tensor( - mask_true, - parallel_strategy.device_mesh, - placements=[Shard(0), Replicate()], + mask_true_dp = DTensor.from_local( + mask_true, parallel_strategy.device_mesh, placements=ddp_placements ).to_local() - # Spatially shard images along the dc mesh and run the model + # DistConv Spatial Sharding dcx = DCTensor.distribute(image_dp, parallel_strategy) - dcy = net(dcx) + mask_true_dc = DCTensor.distribute(mask_true_dp, parallel_strategy) - # Replicate predictions across dc to get full spatial result on each dc rank - mask_pred = dcy.to_replicate() + # Forward pass on sharded data + dcy = net(dcx) - # Use labels that are replicated across dc and sharded across ddp, like predictions - mask_true_ddp = mask_true_dp + # Extract underlying local tensors (STAY SHARDED) + local_preds = dcy + local_labels_5d = mask_true_dc + local_labels = local_labels_5d.squeeze(1) - # Skip if this ddp rank has an empty local batch - if mask_pred.size(0) == 0 or mask_true_ddp.size(0) == 0: + # Skip empty batches + if local_preds.size(0) == 0 or local_labels.size(0) == 0: continue - # Loss - CE_loss = criterion(mask_pred, mask_true_ddp) + # --- 1. Sharded CE Loss --- + local_ce_sum = F.cross_entropy(local_preds, local_labels, reduction="sum") + global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) - # Dice loss - mask_pred_softmax = F.softmax(mask_pred, dim=1).float() + # Divide by total global voxels to get the mean CE Loss + global_total_voxels = local_labels.numel() * math.prod( + parallel_strategy.num_shards + ) + CE_loss = global_ce_sum / global_total_voxels + + # --- 2. Format Predictions & Labels (Strictly Multiclass) --- + mask_pred_probs = F.softmax(local_preds, dim=1).float() mask_true_onehot = ( - F.one_hot(mask_true_ddp, n_categories + 1) - .permute(0, 4, 1, 2, 3) - .float() + F.one_hot(local_labels, n_categories + 1).permute(0, 4, 1, 2, 3).float() ) - dice_loss_curr = dice_loss( - mask_pred_softmax, - mask_true_onehot, - multiclass=True, + + # Dice loss uses probabilities + dice_score_probs = compute_sharded_dice( + mask_pred_probs, mask_true_onehot, spatial_mesh ) + dice_loss_curr = 1.0 - dice_score_probs.mean() - # Combined validation loss + # Eval metric (excluding background class 0) + # dice_score_probs shape is [Batch, Channels]. We slice [:, 1:] to drop background + batch_dice_score = dice_score_probs[:, 1:].mean() + + # --- Combine and Accumulate --- loss = CE_loss + dice_loss_curr val_loss_epoch += loss.item() + total_dice_score += batch_dice_score.item() processed_batches += 1 - # Dice score - if net.module.n_classes == 1: - assert mask_true_ddp.min() >= 0 and mask_true_ddp.max() <= 1, ( - "True mask indices should be in [0, 1]" - ) - mask_pred_bin = (F.sigmoid(mask_pred) > 0.5).float() - dice_score += dice_coeff( - mask_pred_bin, mask_true_ddp, reduce_batch_first=False - ) - else: - assert ( - mask_true_ddp.min() >= 0 - and mask_true_ddp.max() < net.module.n_classes - ), "True mask indices should be in [0, n_classes]" - mask_pred_processed = F.softmax(mask_pred, dim=1).float() - mask_true_onehot_mc = ( - F.one_hot(mask_true_ddp, net.module.n_classes) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_score += multiclass_dice_coeff( - mask_pred_processed[:, 1:], - mask_true_onehot_mc[:, 1:], - reduce_batch_first=True, - ) - net.train() val_loss_avg = val_loss_epoch / max(processed_batches, 1) if primary: print( - f"evaluate.py: dice_score={dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}" + f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}" ) - return dice_score, val_loss_epoch, val_loss_avg, processed_batches + return total_dice_score, val_loss_epoch, val_loss_avg, processed_batches diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 29d6807..20f9251 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -35,7 +35,7 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset -from ScaFFold.utils.dice_score import dice_loss +from ScaFFold.utils.dice_score import SpatialAllReduce, compute_sharded_dice, dice_loss from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size # Local @@ -317,6 +317,15 @@ def _truncate_stats_file(self, start_epoch): except Exception as e: self.log.warning(f"Failed to truncate stats file: {e}") + def _get_memsize(self, tensor, tensor_label: str, verbosity: int = 0): + """Log size of tensor in memory""" + + if verbosity < 2: + return + tensor_memory_bytes = tensor[0].element_size() * tensor[0].nelement() + tensor_memory_gb = tensor_memory_bytes / (1024**3) + self.log.info(f"{tensor_label} size on GPU: {tensor_memory_gb:.2f} GB") + def train(self): """ Execute model training @@ -324,82 +333,168 @@ def train(self): self.cleanup_or_resume() + # DistConv ParallelStrategy + ps = getattr(self.config, "_parallel_strategy", None) + if ps is None: + raise RuntimeError( + "ParallelStrategy not found in config. Set config._parallel_strategy when wrapping model with DistConvDDP." + ) + # Get the process group for spatial sharding mesh + spatial_mesh = ps.device_mesh[ps.distconv_dim_names] + + # Get placements for DDP sharding + num_spatial_dims = len(ps.shard_dim) + ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims + warmup_epochs = self.config.warmup_epochs if warmup_epochs > 0: begin_code_region("warmup") # Keep BN/Dropout from changing behavior/statistics - self.model.eval() + self.model.train() start_warmup = time.time() self.log.info(f"Running {warmup_epochs} warmup epoch(s)") - ps = getattr(self.config, "_parallel_strategy", None) - for _ in range(warmup_epochs): - for batch in self.train_loader: + for i, batch in enumerate(self.train_loader): + self.log.debug(f" warmup: batch {i} / {len(self.train_loader)}") + batch_t_start = time.time() + # Load initial samples and labels images, true_masks = batch["image"], batch["mask"] + # Move samples and labels to GPU images = images.to( device=self.device, dtype=torch.float32, memory_format=torch.channels_last_3d, - non_blocking=False, + non_blocking=True, ) - images_dc = DCTensor.distribute(images, ps) - + self._get_memsize(images, "Original image", self.config.verbose) true_masks = true_masks.to( device=self.device, dtype=torch.long, non_blocking=True ) + self._get_memsize(images, "Original label", self.config.verbose) + + # Add a dummy channel dimension to get 5D [B, 1, D, H, W] + true_masks = true_masks.unsqueeze(1) + + # Data parallel sharding + images_dp = DTensor.from_local( + images, ps.device_mesh, placements=ddp_placements + ).to_local() + + true_masks_dp = DTensor.from_local( + true_masks, ps.device_mesh, placements=ddp_placements + ).to_local() + + # Delete source tensors immediately after use to keep memory down + del images, true_masks + + # Spatial sharding via DistConv + images_dc = DCTensor.distribute(images_dp, ps) + true_masks_dc = DCTensor.distribute(true_masks_dp, ps) + self._get_memsize(images_dc, "Sharded image", self.config.verbose) with torch.autocast( self.device.type if self.device.type != "mps" else "cpu", enabled=self.config.torch_amp, ): # Forward on DCTensor + self.log.debug(f" warmup: running forward pass") masks_pred_dc = self.model(images_dc) + self.log.debug(f" warmup: forward pass complete") - # Convert predictions for loss - if isinstance(ps.num_shards, tuple) and len(ps.num_shards) == 1: - n_shards = ps.num_shards[0] - else: - n_shards = ps.num_shards - if images.size(0) < n_shards: - # For small batches (e.g., N=1 with dc_num_shards=2), replicate outputs - masks_pred = masks_pred_dc.to_replicate() - labels_for_loss = true_masks - else: - # Otherwise, shard labels across batch dim to match to_ddp layout - masks_pred = masks_pred_dc.to_ddp() - dt_labels = distribute_tensor( - true_masks, - device_mesh=ps.device_mesh[ - f"dc{self.config.shard_dim + 2}" - ], - placements=[Shard(0)], + # Extract the underlying PyTorch local tensors + local_preds = masks_pred_dc + local_labels_5d = true_masks_dc + + # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] + local_labels = local_labels_5d.squeeze(1) + if self.world_rank == 0: + self.log.debug( + f" warmup: Local Preds Shape: {local_preds.shape}" + ) + # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 + self.log.debug( + f" warmup: Local Labels Shape: {local_labels.shape}" ) - labels_for_loss = dt_labels.to_local() + # Should be something like [1, 128, 128, 64] + + # --- SHARDED LOSS CALCULATION --- + current_mem = torch.cuda.memory_allocated() / (1024**3) + self.log.debug( + f" warmup: Calculating sharded loss. Mem: {current_mem:.2f} GB." + ) + + # 1. Sharded Cross Entropy + local_ce_sum = F.cross_entropy( + local_preds, local_labels, reduction="sum" + ) - CE_loss = self.criterion(masks_pred, labels_for_loss) + # Pass the spatial_mesh directly + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, spatial_mesh + ) + + global_total_voxels = local_labels.numel() * math.prod( + self.config.dc_num_shards + ) + loss_ce = global_ce_sum / global_total_voxels - # Calculate the train dice loss - masks_pred_softmax = F.softmax(masks_pred, dim=1).float() - true_masks_onehot = ( - F.one_hot(labels_for_loss, self.config.n_categories + 1) + # 2. Sharded Dice Loss + local_preds_softmax = F.softmax(local_preds, dim=1).float() + local_labels_one_hot = ( + F.one_hot( + local_labels, num_classes=self.config.n_categories + 1 + ) .permute(0, 4, 1, 2, 3) .float() ) - train_dice_curr = dice_loss( - masks_pred_softmax, - true_masks_onehot, - multiclass=True, + dice_scores = compute_sharded_dice( + local_preds_softmax, local_labels_one_hot, spatial_mesh ) - loss = CE_loss + train_dice_curr + loss_dice = 1.0 - dice_scores.mean() - # Fine as long as we don't step/update - self.grad_scaler.scale(loss).backward() + # 3. Combine Loss + loss = loss_ce + loss_dice + + self.log.debug( + f" warmup: loss calculation complete. Proceeding to backward pass" + ) + + # Backward pass + self.grad_scaler.scale(loss).backward() + self.log.debug( + f" warmup: backward pass complete. Stepping optimizer" + ) + + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + + # Free memory aggressively + del images_dc, true_masks_dc, masks_pred_dc + del ( + local_preds, + local_labels, + local_preds_softmax, + local_labels_one_hot, + ) + del loss_ce, loss_dice, loss, images_dp, true_masks_dp + + if self.world_rank == 0: + peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) + peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) + self.log.debug( + f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", + ) + batch_t_end = time.time() + self.log.debug( + f" warmup: batch {i} completed in {batch_t_end - batch_t_start} seconds" + ) # Nuke any accumulated grads so the first real step starts clean for p in self.model.parameters(): p.grad = None + self.optimizer.zero_grad(set_to_none=True) torch.distributed.barrier() end_code_region("warmup") self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") @@ -415,13 +510,6 @@ def train(self): ) break - # DistConv ParallelStrategy - ps = getattr(self.config, "_parallel_strategy", None) - if ps is None: - raise RuntimeError( - "ParallelStrategy not found in config. Set config._parallel_strategy when wrapping model with DistConvDDP." - ) - # Timer and tracking variables epoch_start_time = time.time() train_dice_curr = 0 @@ -451,6 +539,7 @@ def train(self): begin_code_region("batch_loop") for batch in self.train_loader: + # Load initial samples and labels images, true_masks = batch["image"], batch["mask"] begin_code_region("image_to_device") @@ -466,18 +555,28 @@ def train(self): end_code_region("image_to_device") gather_and_print_mem(self.log, "after_batch_to_device") - # Replicate batch across dc mesh, shard batch across ddp mesh. - # This ensures all dc ranks in the same ddp group see the same samples, - # and ddp ranks see disjoint samples. - images_dp = distribute_tensor( - images, ps.device_mesh, placements=[Shard(0), Replicate()] + # Add a dummy channel dimension to get 5D [B, 1, D, H, W] + true_masks = true_masks.unsqueeze(1) + + # Data parallel sharding + images_dp = DTensor.from_local( + images, ps.device_mesh, placements=ddp_placements ).to_local() - true_masks_dp = distribute_tensor( - true_masks, - ps.device_mesh, - placements=[Shard(0), Replicate()], + + true_masks_dp = DTensor.from_local( + true_masks, ps.device_mesh, placements=ddp_placements ).to_local() + # Delete source tensors immediately after use to keep memory down + del images, true_masks + + # Spatial sharding via DistConv + images_dc = DCTensor.distribute(images_dp, ps) + true_masks_dc = DCTensor.distribute(true_masks_dp, ps) + self._get_memsize( + images_dc, "Sharded image", self.config.verbose + ) + with torch.autocast( self.device.type if self.device.type != "mps" else "cpu", enabled=self.config.torch_amp, @@ -486,72 +585,69 @@ def train(self): torch.cuda.reset_peak_memory_stats() gather_and_print_mem(self.log, "pre_forward") begin_code_region("predict") - - # Spatially shard the chosen dimension across dc mesh - dcx = DCTensor.distribute(images_dp, ps) - dcy = self.model(dcx) - # Convert back to batch-sharded layout across the dc mesh - masks_pred = dcy.to_ddp() - + masks_pred_dc = self.model(images_dc) end_code_region("predict") gather_and_print_mem(self.log, "post_forward") - # Reshard labels across dc mesh to match masks_pred's batch partition - # Start from dc-replicated labels, then shard batch across dc - true_masks_ddp = ( - DTensor.from_local( - true_masks_dp, - device_mesh=ps.device_mesh[ - f"dc{self.config.shard_dim + 2}" - ], - placements=[Replicate()], + # Extract the underlying PyTorch local tensors + local_preds = masks_pred_dc + local_labels_5d = true_masks_dc + + # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] + local_labels = local_labels_5d.squeeze(1) + if self.world_rank == 0: + self.log.debug( + f"Local Preds Shape: {local_preds.shape}" ) - .redistribute( - device_mesh=ps.device_mesh[ - f"dc{self.config.shard_dim + 2}" - ], - placements=[Shard(0)], + # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 + self.log.debug( + f"Local Labels Shape: {local_labels.shape}" ) - .to_local() - ) + # Should be something like [1, 128, 128, 64] begin_code_region("calculate_loss") - # Calculate the loss - if self.config.n_categories + 1 == 1: - loss = self.criterion( - masks_pred.squeeze(1), true_masks_ddp.float() - ) - loss += dice_loss( - F.sigmoid(masks_pred.squeeze(1)), - true_masks_ddp.float(), - multiclass=False, - ) - else: - # Calculate the CrossEntropy loss - CE_loss = self.criterion(masks_pred, true_masks_ddp) - - # Calculate the train dice loss - masks_pred_softmax = F.softmax( - masks_pred, dim=1 - ).float() - true_masks_onehot = ( - F.one_hot( - true_masks_ddp, self.config.n_categories + 1 - ) - .permute(0, 4, 1, 2, 3) - .float() - ) - train_dice_curr = dice_loss( - masks_pred_softmax, - true_masks_onehot, - multiclass=True, + # --- SHARDED LOSS CALCULATION --- + current_mem = torch.cuda.memory_allocated() / (1024**3) + self.log.debug( + f"Calculating sharded loss. Mem: {current_mem:.2f} GB." + ) + + # 1. Sharded Cross Entropy + local_ce_sum = F.cross_entropy( + local_preds, local_labels, reduction="sum" + ) + + # Pass the spatial_mesh directly + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, spatial_mesh + ) + + global_total_voxels = local_labels.numel() * math.prod( + self.config.dc_num_shards + ) + loss_ce = global_ce_sum / global_total_voxels + + # 2. Sharded Dice Loss + local_preds_softmax = F.softmax(local_preds, dim=1).float() + local_labels_one_hot = ( + F.one_hot( + local_labels, + num_classes=self.config.n_categories + 1, ) + .permute(0, 4, 1, 2, 3) + .float() + ) + + # Compute sharded dice using new function + dice_scores = compute_sharded_dice( + local_preds_softmax, local_labels_one_hot, spatial_mesh + ) + loss_dice = 1.0 - dice_scores.mean() - # Our loss function is CE loss + dice loss - loss = CE_loss + train_dice_curr + # 3. Combine Loss + loss = loss_ce + loss_dice + train_dice_total += dice_scores[:, 1:].mean().item() - # Track the train dice loss separately for debugging - train_dice_total += train_dice_curr end_code_region("calculate_loss") gather_and_print_mem(self.log, "pre_backward") @@ -562,6 +658,7 @@ def train(self): begin_code_region("step_and_update") if batch_step + 1 == len(self.train_loader): + self.grad_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=1.0 ) @@ -574,7 +671,7 @@ def train(self): # Update the loss begin_code_region("update_loss") - pbar.update(images_dp.shape[0]) + pbar.update(images_dc.shape[0]) self.global_step += 1 batch_step += 1 # Stay on GPU diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index fcd1394..a11a8c3 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -12,6 +12,7 @@ # # SPDX-License-Identifier: (Apache-2.0) +import math import os import socket import sys @@ -161,10 +162,10 @@ def main(kwargs_dict: dict = {}): # Initialize model begin_code_region("init_model") - config.dc_num_shards = getattr(config, "dc_num_shards", config.num_shards) - config.dc_shard_dim = getattr(config, "dc_shard_dim", config.shard_dim) + config.dc_num_shards = getattr(config, "dc_num_shards", config.dc_num_shards) + config.dc_shard_dims = getattr(config, "dc_shard_dims", config.dc_shard_dims) log.info( - f"DistConv num_shards={config.dc_num_shards}, shard_dim={config.dc_shard_dim}" + f"DistConv num_shards={config.dc_num_shards}, shard_dim={config.dc_shard_dims}" ) device = get_device() log.info(f"Using device: {device}") @@ -176,17 +177,17 @@ def main(kwargs_dict: dict = {}): ) if config.dist: # DDP + DistConv setup - # Ensure world_size is divisible by dc_num_shards - assert dist.get_world_size() % config.dc_num_shards == 0, ( - f"world_size={dist.get_world_size()} must be divisible by dc_num_shards={config.dc_num_shards}" + # Ensure world_size is divisible by total distconv shards + assert dist.get_world_size() % math.prod(config.dc_num_shards) == 0, ( + f"world_size={dist.get_world_size()} must be divisible by total number of distconv shards = {math.prod(config.dc_num_shards)}" ) - # Select which full-tensor dim to shard: 2 + dc_shard_dim - shard_dim = 2 + int(config.dc_shard_dim) + ps = ParallelStrategy( - num_shards=int(config.dc_num_shards), - shard_dim=shard_dim, + num_shards=config.dc_num_shards, + shard_dim=config.dc_shard_dims, device_type=device.type, ) + model = model.to(device, memory_format=torch.channels_last_3d) # Wrap with DistConvDDP that corrects gradient scaling for dc submesh model = DistConvDDP( From 87bd3d75920fe19dd37ac3df9aaf23619e2bc7bd Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Thu, 2 Apr 2026 10:09:44 -0700 Subject: [PATCH 18/43] fix unet bottleneck dim off by 1 error (#29) * fix unet bottleneck dim off by 1 error * ruff --------- Co-authored-by: Patrick Miles --- ScaFFold/cli.py | 4 +--- ScaFFold/utils/config_utils.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 3c73f40..9c552ee 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -194,9 +194,7 @@ def main(): # Recalculate unet_layers to capture any CLI overrides combined_config["unet_layers"] = ( - combined_config["problem_scale"] - - combined_config["unet_bottleneck_dim"] - + 1 + combined_config["problem_scale"] - combined_config["unet_bottleneck_dim"] ) # Resolve paths to absolute, matching Config() behavior diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 378dc51..50a198d 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -45,7 +45,7 @@ def __init__(self, config_dict): ) self.problem_scale = math.floor(self.problem_scale) self.unet_bottleneck_dim = config_dict["unet_bottleneck_dim"] - self.unet_layers = self.problem_scale - self.unet_bottleneck_dim + 1 + self.unet_layers = self.problem_scale - self.unet_bottleneck_dim self.n_fracts_per_vol = config_dict["n_fracts_per_vol"] self.n_instances_used_per_fractal = config_dict["n_instances_used_per_fractal"] self.scale = 1 From 39c0b939d4d7968f94098f757d120f01c19e4815 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 2 Apr 2026 10:55:53 -0700 Subject: [PATCH 19/43] Speedup Warmup on ROCm (#24) * Update scaffold-tuolumne.job * Update scaffold-tuolumne-torchpypi.job * Update scaffold-tuolumne.job * Update scaffold-tuolumne-torchpypi.job * Update scaffold-tuolumne-torchpypi.job * Update scaffold-tuolumne.job * Update scaffold-tuolumne-torchpypi.job * Update scaffold-tuolumne.job --- scripts/scaffold-tuolumne-torchpypi.job | 9 +++++++++ scripts/scaffold-tuolumne.job | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/scripts/scaffold-tuolumne-torchpypi.job b/scripts/scaffold-tuolumne-torchpypi.job index 2629c60..c0b0780 100644 --- a/scripts/scaffold-tuolumne-torchpypi.job +++ b/scripts/scaffold-tuolumne-torchpypi.job @@ -14,6 +14,15 @@ ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi # Use ccl plugin that we manually built with install-rccl.sh export NCCL_NET_PLUGIN=../aws-ofi-nccl.git/install/lib/librccl-net.so +# Disable direct convolution benchmarking (should speedup warmup by a significant amount, does the below three options together) +# export MIOPEN_DEBUG_CONV_DIRECT=0 +# Disable direct naive convolution benchmarking (naive_conv_ab_nonpacked_fwd_ndhwc_half_double_half.kd) +export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD=0 +# Disable naive_conv_ab_nonpacked_bwd_ndhwc_half_double_half.kd +# export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD=0 +# Disable naive_conv_ab_nonpacked_wrw_ndhwc_half_double_half.kd +export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_WRW=0 + torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/configs/benchmark_default.yml # Uncomment if you want torch profiling diff --git a/scripts/scaffold-tuolumne.job b/scripts/scaffold-tuolumne.job index bd0d33a..d3b9e05 100644 --- a/scripts/scaffold-tuolumne.job +++ b/scripts/scaffold-tuolumne.job @@ -15,6 +15,15 @@ ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi # (2) Removing libmpi may cause segfault on mpi4py import export LD_PRELOAD="/opt/rocm-7.1.1/llvm/lib/libomp.so /opt/cray/pe/mpich/9.1.0/ofi/gnu/11.2/lib/libmpi_gnu.so.12" +# Disable direct convolution benchmarking (should speedup warmup by a significant amount, does the below three options together) +# export MIOPEN_DEBUG_CONV_DIRECT=0 +# Disable direct naive convolution benchmarking (naive_conv_ab_nonpacked_fwd_ndhwc_half_double_half.kd) +export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD=0 +# Disable naive_conv_ab_nonpacked_bwd_ndhwc_half_double_half.kd +# export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD=0 +# Disable naive_conv_ab_nonpacked_wrw_ndhwc_half_double_half.kd +export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_WRW=0 + torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/configs/benchmark_default.yml # Uncomment if you want torch profiling From e6856f1283c7125bdfc382692cf8acdc07d9b8d2 Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Thu, 2 Apr 2026 13:37:01 -0700 Subject: [PATCH 20/43] Apply optimizer every batch, not every epoch; unscale gradients before clipping (#40) * apply optimizer every batch, not every epoch; unscale gradients before clipping * trainer tweaks --------- Co-authored-by: Patrick Miles --- ScaFFold/utils/trainer.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 20f9251..1c9ba9a 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -522,6 +522,7 @@ def train(self): self.train_loader.sampler.set_epoch(epoch) self.val_loader.sampler.set_epoch(epoch) self.model.train() + self.optimizer.zero_grad(set_to_none=False) estr = ( f"{epoch}" @@ -535,8 +536,6 @@ def train(self): unit="img", disable=True if self.world_rank != 0 else False, ) as pbar: - batch_step = 0 - begin_code_region("batch_loop") for batch in self.train_loader: # Load initial samples and labels @@ -657,23 +656,20 @@ def train(self): gather_and_print_mem(self.log, "post_backward") begin_code_region("step_and_update") - if batch_step + 1 == len(self.train_loader): - self.grad_scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), max_norm=1.0 - ) - self.grad_scaler.step(self.optimizer) - gather_and_print_mem(self.log, "after_optim_step") - - self.grad_scaler.update() - self.optimizer.zero_grad(set_to_none=False) + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=1.0 + ) + self.grad_scaler.step(self.optimizer) + gather_and_print_mem(self.log, "after_optim_step") + self.grad_scaler.update() + self.optimizer.zero_grad(set_to_none=False) end_code_region("step_and_update") # Update the loss begin_code_region("update_loss") pbar.update(images_dc.shape[0]) self.global_step += 1 - batch_step += 1 # Stay on GPU epoch_loss += loss.detach() end_code_region("update_loss") From 5f1e2a14ea88269debecd9e757c17eec3c134283 Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Thu, 2 Apr 2026 14:47:14 -0700 Subject: [PATCH 21/43] Warmup changes: only warm a few batches; extract to separate method in trainer class (#43) * apply optimizer every batch, not every epoch; unscale gradients before clipping * trainer tweaks * apply optimizer every batch, not every epoch; unscale gradients before clipping * extract warmup to separate method; switch to warming up set number of batches (user configurable) * whitespace; num_workers revert * ruff * make parallelstrategy, spatial_mesh, ddp_placements attrs of trainer; other small tweaks * remove deprecated config attrs * ruff * get device mesh from ps class attr * ruff * missing self. on some ps accesses * Fix imports and missing self.ps * rm legacy warmup_epochs * Move attributes to base class for clarity * remove warmup_epochs -- not useful to keep support for this * call cleanup_or_resume trainer method directly * rm unused vars --------- Co-authored-by: Patrick Miles Co-authored-by: Michael McKinsey --- ScaFFold/cli.py | 5 + ScaFFold/configs/benchmark_default.yml | 4 +- ScaFFold/configs/benchmark_testing.yml | 2 +- ScaFFold/utils/config_utils.py | 2 +- ScaFFold/utils/trainer.py | 318 ++++++++++++------------- ScaFFold/worker.py | 15 +- 6 files changed, 174 insertions(+), 172 deletions(-) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 9c552ee..5b76c87 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -140,6 +140,11 @@ def main(): benchmark_parser.add_argument( "--batch-size", type=int, nargs="+", help="Batch sizes for each volume size." ) + benchmark_parser.add_argument( + "--warmup-batches", + type=int, + help="Number of warmup batches to run per rank before training.", + ) benchmark_parser.add_argument( "--optimizer", type=str, diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index e96b103..180a0dc 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -29,6 +29,6 @@ framework: "torch" # The DL framework to train with. Only valid checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter -warmup_epochs: 1 # How many warmup epochs before training +warmup_batches: 5 # How many warmup batches per rank to run before training. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. -target_dice: 0.95 \ No newline at end of file +target_dice: 0.95 diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index 5fea4d0..6b8c30a 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -29,6 +29,6 @@ framework: "torch" # The DL framework to train with. Only valid checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter -warmup_epochs: 1 # How many warmup epochs before training +warmup_batches: 5 # How many warmup batches per rank to run before training. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. target_dice: 0.95 diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 50a198d..640ad19 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -66,7 +66,7 @@ def __init__(self, config_dict): self.loss_freq = config_dict["loss_freq"] self.checkpoint_dir = config_dict["checkpoint_dir"] self.normalize = config_dict["normalize"] - self.warmup_epochs = config_dict["warmup_epochs"] + self.warmup_batches = config_dict.get("warmup_batches") self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 1c9ba9a..ec5fc7b 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -29,7 +29,7 @@ import torch.nn.functional as F from distconv import DCTensor from torch import optim -from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor +from torch.distributed.tensor import DTensor from torch.utils.data import DataLoader from tqdm import tqdm @@ -73,6 +73,9 @@ def __init__(self, model, config, device, log): self.criterion = None self.global_step = 0 self.start_epoch = -1 + self.ps = None # DistConv ParallelStrategy + self.spatial_mesh = None # Spatial mesh for use w/ DistConv + self.ddp_placements = None # DDP placements for use w/ DistConv self.checkpoint_path_absolute = str( self.config.run_dir + "/" + self.config.checkpoint_dir @@ -326,178 +329,159 @@ def _get_memsize(self, tensor, tensor_label: str, verbosity: int = 0): tensor_memory_gb = tensor_memory_bytes / (1024**3) self.log.info(f"{tensor_label} size on GPU: {tensor_memory_gb:.2f} GB") - def train(self): - """ - Execute model training - """ - - self.cleanup_or_resume() + def warmup(self): + """Run warmup iterations before the main training loop.""" + warmup_batches = self.config.warmup_batches + if warmup_batches <= 0: + return - # DistConv ParallelStrategy - ps = getattr(self.config, "_parallel_strategy", None) - if ps is None: - raise RuntimeError( - "ParallelStrategy not found in config. Set config._parallel_strategy when wrapping model with DistConvDDP." + if self.config.dist: + self.train_loader.sampler.set_epoch(0) + + # Match the main training path as closely as possible. + self.model.train() + self.optimizer.zero_grad(set_to_none=False) + start_warmup = time.time() + max_batches = min(warmup_batches, len(self.train_loader)) + self.log.info(f"Running {max_batches} warmup batch(es) per rank") + + for batch_idx, batch in enumerate(self.train_loader): + if batch_idx >= max_batches: + break + + images, true_masks = batch["image"], batch["mask"] + + images = images.to( + device=self.device, + dtype=torch.float32, + memory_format=torch.channels_last_3d, + non_blocking=True, ) - # Get the process group for spatial sharding mesh - spatial_mesh = ps.device_mesh[ps.distconv_dim_names] - - # Get placements for DDP sharding - num_spatial_dims = len(ps.shard_dim) - ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims - - warmup_epochs = self.config.warmup_epochs - if warmup_epochs > 0: - begin_code_region("warmup") - # Keep BN/Dropout from changing behavior/statistics - self.model.train() - start_warmup = time.time() - self.log.info(f"Running {warmup_epochs} warmup epoch(s)") - - for _ in range(warmup_epochs): - for i, batch in enumerate(self.train_loader): - self.log.debug(f" warmup: batch {i} / {len(self.train_loader)}") - batch_t_start = time.time() - # Load initial samples and labels - images, true_masks = batch["image"], batch["mask"] - - # Move samples and labels to GPU - images = images.to( - device=self.device, - dtype=torch.float32, - memory_format=torch.channels_last_3d, - non_blocking=True, - ) - self._get_memsize(images, "Original image", self.config.verbose) - true_masks = true_masks.to( - device=self.device, dtype=torch.long, non_blocking=True - ) - self._get_memsize(images, "Original label", self.config.verbose) - - # Add a dummy channel dimension to get 5D [B, 1, D, H, W] - true_masks = true_masks.unsqueeze(1) - - # Data parallel sharding - images_dp = DTensor.from_local( - images, ps.device_mesh, placements=ddp_placements - ).to_local() - - true_masks_dp = DTensor.from_local( - true_masks, ps.device_mesh, placements=ddp_placements - ).to_local() - - # Delete source tensors immediately after use to keep memory down - del images, true_masks - - # Spatial sharding via DistConv - images_dc = DCTensor.distribute(images_dp, ps) - true_masks_dc = DCTensor.distribute(true_masks_dp, ps) - self._get_memsize(images_dc, "Sharded image", self.config.verbose) - - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=self.config.torch_amp, - ): - # Forward on DCTensor - self.log.debug(f" warmup: running forward pass") - masks_pred_dc = self.model(images_dc) - self.log.debug(f" warmup: forward pass complete") - - # Extract the underlying PyTorch local tensors - local_preds = masks_pred_dc - local_labels_5d = true_masks_dc - - # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] - local_labels = local_labels_5d.squeeze(1) - if self.world_rank == 0: - self.log.debug( - f" warmup: Local Preds Shape: {local_preds.shape}" - ) - # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 - self.log.debug( - f" warmup: Local Labels Shape: {local_labels.shape}" - ) - # Should be something like [1, 128, 128, 64] + true_masks = true_masks.to( + device=self.device, dtype=torch.long, non_blocking=True + ).contiguous() + + # Add a dummy channel dimension to get 5D [B, 1, D, H, W] + true_masks = true_masks.unsqueeze(1) + + # Data parallel sharding + images_dp = DTensor.from_local( + images, self.ps.device_mesh, placements=self.ddp_placements + ).to_local() + + true_masks_dp = DTensor.from_local( + true_masks, self.ps.device_mesh, placements=self.ddp_placements + ).to_local() + + # Spatial sharding via DistConv + images_dc = DCTensor.distribute(images_dp, self.ps) + true_masks_dc = DCTensor.distribute(true_masks_dp, self.ps) + self._get_memsize(images_dc, "Sharded image", self.config.verbose) + + with torch.autocast( + self.device.type if self.device.type != "mps" else "cpu", + enabled=self.config.torch_amp, + ): + # Forward on DCTensor + self.log.debug(f" warmup: running forward pass") + masks_pred_dc = self.model(images_dc) + self.log.debug(f" warmup: forward pass complete") - # --- SHARDED LOSS CALCULATION --- - current_mem = torch.cuda.memory_allocated() / (1024**3) - self.log.debug( - f" warmup: Calculating sharded loss. Mem: {current_mem:.2f} GB." - ) + # Extract the underlying PyTorch local tensors + local_preds = masks_pred_dc + local_labels_5d = true_masks_dc - # 1. Sharded Cross Entropy - local_ce_sum = F.cross_entropy( - local_preds, local_labels, reduction="sum" - ) + # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] + local_labels = local_labels_5d.squeeze(1) + if self.world_rank == 0: + self.log.debug(f" warmup: Local Preds Shape: {local_preds.shape}") + # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 + self.log.debug( + f" warmup: Local Labels Shape: {local_labels.shape}" + ) + # Should be something like [1, 128, 128, 64] - # Pass the spatial_mesh directly - global_ce_sum = SpatialAllReduce.apply( - local_ce_sum, spatial_mesh - ) + # --- SHARDED LOSS CALCULATION --- + current_mem = torch.cuda.memory_allocated() / (1024**3) + self.log.debug( + f" warmup: Calculating sharded loss. Mem: {current_mem:.2f} GB." + ) - global_total_voxels = local_labels.numel() * math.prod( - self.config.dc_num_shards - ) - loss_ce = global_ce_sum / global_total_voxels + # 1. Sharded Cross Entropy + local_ce_sum = F.cross_entropy( + local_preds, local_labels, reduction="sum" + ) - # 2. Sharded Dice Loss - local_preds_softmax = F.softmax(local_preds, dim=1).float() - local_labels_one_hot = ( - F.one_hot( - local_labels, num_classes=self.config.n_categories + 1 - ) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_scores = compute_sharded_dice( - local_preds_softmax, local_labels_one_hot, spatial_mesh - ) - loss_dice = 1.0 - dice_scores.mean() + # Pass the spatial_mesh directly + global_ce_sum = SpatialAllReduce.apply(local_ce_sum, self.spatial_mesh) - # 3. Combine Loss - loss = loss_ce + loss_dice + global_total_voxels = local_labels.numel() * math.prod( + self.config.dc_num_shards + ) + loss_ce = global_ce_sum / global_total_voxels + + # 2. Sharded Dice Loss + local_preds_softmax = F.softmax(local_preds, dim=1).float() + local_labels_one_hot = ( + F.one_hot(local_labels, num_classes=self.config.n_categories + 1) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_scores = compute_sharded_dice( + local_preds_softmax, local_labels_one_hot, self.spatial_mesh + ) + loss_dice = 1.0 - dice_scores.mean() - self.log.debug( - f" warmup: loss calculation complete. Proceeding to backward pass" - ) + # 3. Combine Loss + loss = loss_ce + loss_dice - # Backward pass - self.grad_scaler.scale(loss).backward() - self.log.debug( - f" warmup: backward pass complete. Stepping optimizer" - ) + self.log.debug( + f" warmup: loss calculation complete. Proceeding to backward pass" + ) - self.grad_scaler.step(self.optimizer) - self.grad_scaler.update() + # Backward pass + self.grad_scaler.scale(loss).backward() + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + self.log.debug(f" warmup: backward pass complete. Stepping optimizer") + + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + + # Free memory aggressively + del images_dc, true_masks_dc, masks_pred_dc + del ( + local_preds, + local_labels, + local_preds_softmax, + local_labels_one_hot, + ) + del loss_ce, loss_dice, loss, images_dp, true_masks_dp - # Free memory aggressively - del images_dc, true_masks_dc, masks_pred_dc - del ( - local_preds, - local_labels, - local_preds_softmax, - local_labels_one_hot, - ) - del loss_ce, loss_dice, loss, images_dp, true_masks_dp + if self.world_rank == 0: + peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) + peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) + self.log.debug( + f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", + ) + batch_t_end = time.time() + self.log.debug( + f" warmup: batch {batch_idx} completed in {batch_t_end - start_warmup} seconds" + ) - if self.world_rank == 0: - peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) - peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) - self.log.debug( - f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", - ) - batch_t_end = time.time() - self.log.debug( - f" warmup: batch {i} completed in {batch_t_end - batch_t_start} seconds" - ) + # Nuke any accumulated grads so the first real step starts clean + for p in self.model.parameters(): + p.grad = None + self.optimizer.zero_grad(set_to_none=True) - # Nuke any accumulated grads so the first real step starts clean - for p in self.model.parameters(): - p.grad = None - self.optimizer.zero_grad(set_to_none=True) + if self.config.dist: torch.distributed.barrier() - end_code_region("warmup") - self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") + self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") + + def train(self): + """ + Execute model training + """ epoch = 1 dice_score_train = 0 @@ -512,9 +496,7 @@ def train(self): # Timer and tracking variables epoch_start_time = time.time() - train_dice_curr = 0 train_dice_total = 0 - CE_loss = 0 epoch_loss = 0 # Accumulator for per-batch losses # Set necessary modes/states @@ -559,19 +541,21 @@ def train(self): # Data parallel sharding images_dp = DTensor.from_local( - images, ps.device_mesh, placements=ddp_placements + images, self.ps.device_mesh, placements=self.ddp_placements ).to_local() true_masks_dp = DTensor.from_local( - true_masks, ps.device_mesh, placements=ddp_placements + true_masks, + self.ps.device_mesh, + placements=self.ddp_placements, ).to_local() # Delete source tensors immediately after use to keep memory down del images, true_masks # Spatial sharding via DistConv - images_dc = DCTensor.distribute(images_dp, ps) - true_masks_dc = DCTensor.distribute(true_masks_dp, ps) + images_dc = DCTensor.distribute(images_dp, self.ps) + true_masks_dc = DCTensor.distribute(true_masks_dp, self.ps) self._get_memsize( images_dc, "Sharded image", self.config.verbose ) @@ -618,7 +602,7 @@ def train(self): # Pass the spatial_mesh directly global_ce_sum = SpatialAllReduce.apply( - local_ce_sum, spatial_mesh + local_ce_sum, self.spatial_mesh ) global_total_voxels = local_labels.numel() * math.prod( @@ -639,7 +623,9 @@ def train(self): # Compute sharded dice using new function dice_scores = compute_sharded_dice( - local_preds_softmax, local_labels_one_hot, spatial_mesh + local_preds_softmax, + local_labels_one_hot, + self.spatial_mesh, ) loss_dice = 1.0 - dice_scores.mean() diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index a11a8c3..ab20c4e 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -24,8 +24,8 @@ import torch import torch.distributed as dist import yaml -from distconv import DCTensor, DistConvDDP, ParallelStrategy -from torch.nn.parallel import DistributedDataParallel as DDP +from distconv import DistConvDDP, ParallelStrategy +from torch.distributed.tensor import Replicate, Shard from ScaFFold.datagen.get_dataset import get_dataset from ScaFFold.unet import UNet @@ -214,6 +214,11 @@ def main(kwargs_dict: dict = {}): torch.backends.cudnn.benchmark = False torch.use_deterministic_algorithms(True, warn_only=True) trainer = PyTorchTrainer(model, config, device, log) + trainer.ps = ps + trainer.spatial_mesh = ps.device_mesh[ps.distconv_dim_names] + num_spatial_dims = len(ps.shard_dim) + trainer.ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims + else: raise RuntimeError( "Invalid framework specified. Currently [torch] is the supported framework." @@ -225,6 +230,12 @@ def main(kwargs_dict: dict = {}): ranks_per_node = get_local_size() prof_ctx, TORCH_PERF_LOCAL = get_torch_context(ranks_per_node, rank) with prof_ctx as prof: + begin_code_region("cleanup_or_resume") + trainer.cleanup_or_resume() + end_code_region("cleanup_or_resume") + begin_code_region("warmup") + trainer.warmup() + end_code_region("warmup") begin_code_region("train") trainer.train() end_code_region("train") From f8fca7b214b825dec24117c2ac766c4f418d53cb Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Fri, 3 Apr 2026 09:05:31 -0700 Subject: [PATCH 22/43] Data loading optimizations (#46) - make dataloader num_workers user-configurable - shift dataloader preprocessing work into dataset generation for speedup, maintaining support for old datasets - data_loading.py: restore .contiguous() and dtype cast calls, but change order to avoid redundant copies --- ScaFFold/cli.py | 5 ++ ScaFFold/configs/benchmark_default.yml | 1 + ScaFFold/configs/benchmark_testing.yml | 1 + ScaFFold/datagen/get_dataset.py | 9 ++- ScaFFold/datagen/volumegen.py | 17 ++++-- ScaFFold/utils/config_utils.py | 1 + ScaFFold/utils/data_loading.py | 83 +++++++++++++++++++------- ScaFFold/utils/data_types.py | 21 +++++++ ScaFFold/utils/trainer.py | 10 +++- 9 files changed, 117 insertions(+), 31 deletions(-) create mode 100644 ScaFFold/utils/data_types.py diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 5b76c87..b387917 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -145,6 +145,11 @@ def main(): type=int, help="Number of warmup batches to run per rank before training.", ) + benchmark_parser.add_argument( + "--dataloader-num-workers", + type=int, + help="Number of DataLoader worker processes per rank.", + ) benchmark_parser.add_argument( "--optimizer", type=str, diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 180a0dc..2d1d414 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -8,6 +8,7 @@ problem_scale: 8 # Determines dataset resolution and number of unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. +dataloader_num_workers: 4 # Number of DataLoader worker processes per rank. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. dc_num_shards: [1, 1, 2] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index 6b8c30a..f37749c 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -8,6 +8,7 @@ problem_scale: 6 # Determines dataset resolution and number of unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. +dataloader_num_workers: 4 # Number of DataLoader worker processes per rank. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard diff --git a/ScaFFold/datagen/get_dataset.py b/ScaFFold/datagen/get_dataset.py index fc19f8c..2e74abe 100644 --- a/ScaFFold/datagen/get_dataset.py +++ b/ScaFFold/datagen/get_dataset.py @@ -29,7 +29,9 @@ from ScaFFold.datagen import volumegen META_FILENAME = "meta.yaml" +DATASET_FORMAT_VERSION = 2 INCLUDE_KEYS = [ + "dataset_format_version", "n_categories", "n_instances_used_per_fractal", "problem_scale", @@ -116,8 +118,10 @@ def get_dataset( root.mkdir(exist_ok=True) # Get dict of required keys and compute config_id + config_dict = vars(config).copy() + config_dict["dataset_format_version"] = DATASET_FORMAT_VERSION volume_config = _get_required_keys_dict( - config=vars(config), include_keys=INCLUDE_KEYS + config=config_dict, include_keys=INCLUDE_KEYS ) config_id = _hash_volume_config(volume_config) commit = _git_commit_short() @@ -136,6 +140,8 @@ def get_dataset( meta = yaml.safe_load(meta_path.read_text()) if meta.get("config_id") != config_id: continue + if meta.get("dataset_format_version", 1) != DATASET_FORMAT_VERSION: + continue if require_commit and meta.get("code_commit") != commit: continue # If we pass the above checks, this dataset can be reused @@ -186,6 +192,7 @@ def get_dataset( # Write to tmp, then move, so readers never see half-written dataset meta = { "config_id": config_id, + "dataset_format_version": DATASET_FORMAT_VERSION, "config_subset": volume_config, "include_keys": INCLUDE_KEYS, "code_commit": commit, diff --git a/ScaFFold/datagen/volumegen.py b/ScaFFold/datagen/volumegen.py index 479e67e..efee01e 100644 --- a/ScaFFold/datagen/volumegen.py +++ b/ScaFFold/datagen/volumegen.py @@ -26,8 +26,7 @@ from mpi4py import MPI from ScaFFold.utils.config_utils import Config - -DEFAULT_NP_DTYPE = np.float64 +from ScaFFold.utils.data_types import DEFAULT_NP_DTYPE, MASK_DTYPE, VOLUME_DTYPE def load_np_ptcloud(path: str) -> np.ndarray: @@ -177,10 +176,10 @@ def main(config: Dict): volume = np.full( (config.vol_size, config.vol_size, config.vol_size, 3), 0, - dtype=np.float32, + dtype=VOLUME_DTYPE, ) mask = np.full( - (config.vol_size, config.vol_size, config.vol_size), 0, dtype=np.short + (config.vol_size, config.vol_size, config.vol_size), 0, dtype=MASK_DTYPE ) global_vol_idx = curr_vol[0] @@ -223,14 +222,20 @@ def main(config: Dict): # Determine destination folder subdir = "validation" if global_vol_idx in val_indices else "training" + # Tensors must logically be channels-first, later we will change striding/storage to channels-last on GPU (metadata will always stay channels-first). + volume_channels_first = volume.transpose((3, 0, 1, 2)) + volume_to_save = np.ascontiguousarray( + volume_channels_first, dtype=VOLUME_DTYPE + ) + mask_to_save = np.ascontiguousarray(mask, dtype=MASK_DTYPE) vol_file = os.path.join(vol_path, subdir, f"{global_vol_idx}.npy") with open(vol_file, "wb") as f: - np.save(f, volume) + np.save(f, volume_to_save) mask_file = os.path.join(mask_path, subdir, f"{global_vol_idx}_mask.npy") with open(mask_file, "wb") as f: - np.save(f, mask) + np.save(f, mask_to_save) end_time = time.time() total_time = end_time - start_time diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 640ad19..9a67182 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -50,6 +50,7 @@ def __init__(self, config_dict): self.n_instances_used_per_fractal = config_dict["n_instances_used_per_fractal"] self.scale = 1 self.batch_size = config_dict["batch_size"] + self.dataloader_num_workers = config_dict["dataloader_num_workers"] self.epochs = config_dict["epochs"] self.optimizer = config_dict["optimizer"] self.disable_scheduler = bool(config_dict["disable_scheduler"]) diff --git a/ScaFFold/utils/data_loading.py b/ScaFFold/utils/data_loading.py index 725854c..74326e4 100644 --- a/ScaFFold/utils/data_loading.py +++ b/ScaFFold/utils/data_loading.py @@ -19,10 +19,16 @@ import numpy as np import torch +import yaml from torch.utils.data import Dataset +from ScaFFold.utils.data_types import MASK_DTYPE, VOLUME_DTYPE from ScaFFold.utils.utils import customlog +DATASET_FORMAT_VERSION = 2 +LEGACY_DATASET_FORMAT_VERSION = 1 +META_FILENAME = "meta.yaml" + class BasicDataset(Dataset): def __init__( @@ -31,6 +37,8 @@ def __init__( self.images_dir = Path(images_dir) self.mask_dir = Path(mask_dir) self.mask_suffix = mask_suffix + self.dataset_root = self.images_dir.parents[1] + self.dataset_format_version = self._load_dataset_format_version() self.ids = [ splitext(file)[0] @@ -49,25 +57,56 @@ def __init__( data = pickle.load(data_file) self.mask_values = data["mask_values"] customlog(f"Unique mask values: {self.mask_values}") + customlog(f"Dataset format version: {self.dataset_format_version}") def __len__(self): return len(self.ids) @staticmethod - def preprocess(mask_values, img, is_mask): - if is_mask: - mask = np.zeros((img.shape[0], img.shape[1], img.shape[2]), dtype=np.short) - for i, v in enumerate(mask_values): - if img.ndim == 3: - mask[img == v] = i - else: - mask[(img == v).all(-1)] = i + def _load_numpy_array(path): + with open(path, "rb") as handle: + return np.load(handle) + + def _load_dataset_format_version(self): + meta_path = self.dataset_root / META_FILENAME + if not meta_path.exists(): + return LEGACY_DATASET_FORMAT_VERSION + + try: + with open(meta_path, "r") as meta_file: + meta = yaml.safe_load(meta_file) or {} + except Exception as exc: + customlog( + f"Failed to read dataset metadata from {meta_path}: {exc}. Falling back to legacy loader." + ) + return LEGACY_DATASET_FORMAT_VERSION - return mask + return int(meta.get("dataset_format_version", LEGACY_DATASET_FORMAT_VERSION)) - else: - img = img.transpose((3, 0, 1, 2)) - return img + @staticmethod + def _prepare_legacy_image(img): + return np.ascontiguousarray(img.transpose((3, 0, 1, 2)), dtype=VOLUME_DTYPE) + + @staticmethod + def _prepare_legacy_mask(mask_values, mask): + remapped = np.zeros( + (mask.shape[0], mask.shape[1], mask.shape[2]), dtype=MASK_DTYPE + ) + for i, value in enumerate(mask_values): + if mask.ndim == 3: + remapped[mask == value] = i + else: + remapped[(mask == value).all(-1)] = i + + return remapped + + @staticmethod + def _prepare_optimized_image(img): + return np.ascontiguousarray(img, dtype=VOLUME_DTYPE) + + @staticmethod + def _prepare_optimized_mask(mask): + return np.ascontiguousarray(mask, dtype=MASK_DTYPE) def __getitem__(self, idx): name = self.ids[idx] @@ -80,19 +119,19 @@ def __getitem__(self, idx): assert len(mask_file) == 1, ( f"Either no mask or multiple masks found for the ID {name}: {mask_file}" ) - with open(mask_file[0], "rb") as f: - mask = np.load(f) - f.close() - with open(img_file[0], "rb") as f: - img = np.load(f) - f.close() + mask = self._load_numpy_array(mask_file[0]) + img = self._load_numpy_array(img_file[0]) - img = self.preprocess(self.mask_values, img, is_mask=False) - mask = self.preprocess(self.mask_values, mask, is_mask=True) + if self.dataset_format_version >= DATASET_FORMAT_VERSION: + img = self._prepare_optimized_image(img) + mask = self._prepare_optimized_mask(mask) + else: + img = self._prepare_legacy_image(img) + mask = self._prepare_legacy_mask(self.mask_values, mask) return { - "image": torch.as_tensor(img.copy()).float().contiguous(), - "mask": torch.as_tensor(mask.copy()).long().contiguous(), + "image": torch.from_numpy(img).contiguous().float(), + "mask": torch.from_numpy(mask).contiguous().long(), } diff --git a/ScaFFold/utils/data_types.py b/ScaFFold/utils/data_types.py new file mode 100644 index 0000000..90186db --- /dev/null +++ b/ScaFFold/utils/data_types.py @@ -0,0 +1,21 @@ +# Copyright (c) 2014-2026, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LBANN/ScaFFold. +# +# SPDX-License-Identifier: (Apache-2.0) + +import numpy as np + +DEFAULT_NP_DTYPE = np.float64 +# Masks are values 0 <= x <= n_categories +MASK_DTYPE = np.uint16 +# Volumes/img are 0 <= x <= 1 +VOLUME_DTYPE = np.float32 diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index ec5fc7b..2a69637 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -133,11 +133,17 @@ def create_dataloaders(self): self.create_dataset() self.create_sampler() + num_workers = self.config.dataloader_num_workers loader_args = dict( - batch_size=self.config.batch_size, num_workers=1, pin_memory=True + batch_size=self.config.batch_size, + num_workers=num_workers, + pin_memory=True, ) + if num_workers > 0: + loader_args["persistent_workers"] = True + loader_args["prefetch_factor"] = 2 self.log.debug( - f"dataloader num_workers={loader_args['num_workers']}, os.cpu_count()={os.cpu_count()}, self.world_size={self.world_size} " + f"dataloader num_workers={loader_args['num_workers']}, prefetch_factor={loader_args.get('prefetch_factor')}, persistent_workers={loader_args.get('persistent_workers', False)}, os.cpu_count()={os.cpu_count()}, self.world_size={self.world_size} " ) self.train_loader = DataLoader( self.train_set, sampler=self.train_sampler, **loader_args From adbb81236505517b8d283aaefd6720451f3cd0b6 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Fri, 10 Apr 2026 08:40:09 -0700 Subject: [PATCH 23/43] Fix `ruff` check (#47) * Fix check * Fix flake --- .github/workflows/style.yml | 2 +- ScaFFold/cli.py | 2 +- ScaFFold/datagen/get_dataset.py | 4 ++-- ScaFFold/utils/checkpointing.py | 6 ++---- ScaFFold/utils/create_restart_script.py | 2 +- ScaFFold/utils/distributed.py | 2 +- ScaFFold/utils/evaluate.py | 6 +----- ScaFFold/utils/trainer.py | 14 +++++--------- ScaFFold/worker.py | 3 +-- 9 files changed, 15 insertions(+), 26 deletions(-) diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 0b93bcd..ebe4149 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -20,4 +20,4 @@ jobs: # Check format. To fix, run "ruff format ." ruff format --diff . # Check PEP8 violations, logic/correctness errors, and sort imports. To fix, run "ruff check --fix ." - ruff check --diff . + ruff check . diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index b387917..0aa755b 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -223,7 +223,7 @@ def main(): combined_config["point_num"] = int(combined_config["vol_size"] ** 3 / 256) # Handle Restart / Resume logic - if hasattr(args, "restart") and args.restart == True: + if hasattr(args, "restart") and args.restart: print("Restart flag detected: Forcing train_from_scratch = False") combined_config["train_from_scratch"] = False combined_config["restart"] = True diff --git a/ScaFFold/datagen/get_dataset.py b/ScaFFold/datagen/get_dataset.py index 2e74abe..885377b 100644 --- a/ScaFFold/datagen/get_dataset.py +++ b/ScaFFold/datagen/get_dataset.py @@ -87,12 +87,12 @@ def _git_commit_short() -> str: ) except subprocess.CalledProcessError: print( - f"Tried to get git commit id in non-git repo. No commit id will be enforced for dataset reuse." + "Tried to get git commit id in non-git repo. No commit id will be enforced for dataset reuse." ) return "no-commit-id" except Exception: print( - f"Exception when trying to get git commit for dataset. No commit id will be enforced for dataset reuse." + "Exception when trying to get git commit for dataset. No commit id will be enforced for dataset reuse." ) return "no-commit-id" diff --git a/ScaFFold/utils/checkpointing.py b/ScaFFold/utils/checkpointing.py index 92c5203..5a65b09 100644 --- a/ScaFFold/utils/checkpointing.py +++ b/ScaFFold/utils/checkpointing.py @@ -12,11 +12,9 @@ # # SPDX-License-Identifier: (Apache-2.0) -import copy import math import random import shutil -import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Dict, Optional @@ -238,7 +236,7 @@ def save_checkpoint( self.best_ckpt_path, is_best, ) - self._log(f"Async checkpoint offloaded to background thread.") + self._log("Async checkpoint offloaded to background thread.") else: # Synchronous Save self._write_to_disk( @@ -314,7 +312,7 @@ def _get_rng_snapshot(self) -> Dict[str, Any]: pass try: snap["rng_state_python"] = random.getstate() - except: + except Exception: pass return snap diff --git a/ScaFFold/utils/create_restart_script.py b/ScaFFold/utils/create_restart_script.py index cc8bbbc..206eae7 100644 --- a/ScaFFold/utils/create_restart_script.py +++ b/ScaFFold/utils/create_restart_script.py @@ -20,7 +20,7 @@ import stat import sys from pathlib import Path -from typing import List, Literal, Optional, Union +from typing import List, Literal, Union import torch diff --git a/ScaFFold/utils/distributed.py b/ScaFFold/utils/distributed.py index 2ca1c71..5e1f92e 100644 --- a/ScaFFold/utils/distributed.py +++ b/ScaFFold/utils/distributed.py @@ -98,7 +98,7 @@ def force_cuda_visible_devices(force: bool = False) -> None: other GPUs. """ - print(f"force_cuda_visible_devices is deprecated. Skipping...") + print("force_cuda_visible_devices is deprecated. Skipping...") def get_device() -> torch.device: diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index fbecf90..440f18f 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -14,19 +14,15 @@ import math -import numpy as np import torch import torch.nn.functional as F from distconv import DCTensor -from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor +from torch.distributed.tensor import DTensor, Replicate, Shard from tqdm import tqdm from ScaFFold.utils.dice_score import ( SpatialAllReduce, compute_sharded_dice, - dice_coeff, - dice_loss, - multiclass_dice_coeff, ) from ScaFFold.utils.perf_measure import annotate diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 2a69637..fe923a1 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -13,18 +13,14 @@ # SPDX-License-Identifier: (Apache-2.0) # Standard library -import json import math import os -import random import shutil import time from pathlib import Path # Third party -import numpy as np import torch -import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from distconv import DCTensor @@ -35,7 +31,7 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset -from ScaFFold.utils.dice_score import SpatialAllReduce, compute_sharded_dice, dice_loss +from ScaFFold.utils.dice_score import SpatialAllReduce, compute_sharded_dice from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size # Local @@ -389,9 +385,9 @@ def warmup(self): enabled=self.config.torch_amp, ): # Forward on DCTensor - self.log.debug(f" warmup: running forward pass") + self.log.debug(" warmup: running forward pass") masks_pred_dc = self.model(images_dc) - self.log.debug(f" warmup: forward pass complete") + self.log.debug(" warmup: forward pass complete") # Extract the underlying PyTorch local tensors local_preds = masks_pred_dc @@ -442,14 +438,14 @@ def warmup(self): loss = loss_ce + loss_dice self.log.debug( - f" warmup: loss calculation complete. Proceeding to backward pass" + " warmup: loss calculation complete. Proceeding to backward pass" ) # Backward pass self.grad_scaler.scale(loss).backward() self.grad_scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - self.log.debug(f" warmup: backward pass complete. Stepping optimizer") + self.log.debug(" warmup: backward pass complete. Stepping optimizer") self.grad_scaler.step(self.optimizer) self.grad_scaler.update() diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index ab20c4e..f0279e1 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -31,7 +31,6 @@ from ScaFFold.unet import UNet from ScaFFold.utils.distributed import ( get_device, - get_job_id, get_local_rank, get_local_size, get_world_rank, @@ -266,7 +265,7 @@ def main(kwargs_dict: dict = {}): # Generate plots # if rank == 0: - log.info(f"Generating figures on rank 0...") + log.info("Generating figures on rank 0...") begin_code_region("generate_figures") standard_viz.main(config) end_code_region("generate_figures") From 8db5f39d46340f3a995c8838dd90049cd36f889c Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Thu, 16 Apr 2026 12:07:47 -0700 Subject: [PATCH 24/43] Fix `inf` val loss in early epochs (#50) * Calculate local CE loss w/o AMP to prevent inf from f16 overflow * ruff --------- Co-authored-by: Patrick Miles --- ScaFFold/utils/evaluate.py | 7 ++++++- ScaFFold/utils/trainer.py | 24 ++++++++++++++++++------ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 440f18f..c66c255 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -94,7 +94,12 @@ def evaluate( continue # --- 1. Sharded CE Loss --- - local_ce_sum = F.cross_entropy(local_preds, local_labels, reduction="sum") + with torch.autocast( + device.type if device.type != "mps" else "cpu", enabled=False + ): + local_ce_sum = F.cross_entropy( + local_preds.float(), local_labels, reduction="sum" + ) global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) # Divide by total global voxels to get the mean CE Loss diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index fe923a1..c98fa64 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -410,9 +410,13 @@ def warmup(self): ) # 1. Sharded Cross Entropy - local_ce_sum = F.cross_entropy( - local_preds, local_labels, reduction="sum" - ) + with torch.autocast( + self.device.type if self.device.type != "mps" else "cpu", + enabled=False, + ): + local_ce_sum = F.cross_entropy( + local_preds.float(), local_labels, reduction="sum" + ) # Pass the spatial_mesh directly global_ce_sum = SpatialAllReduce.apply(local_ce_sum, self.spatial_mesh) @@ -598,9 +602,17 @@ def train(self): ) # 1. Sharded Cross Entropy - local_ce_sum = F.cross_entropy( - local_preds, local_labels, reduction="sum" - ) + with torch.autocast( + self.device.type + if self.device.type != "mps" + else "cpu", + enabled=False, + ): + local_ce_sum = F.cross_entropy( + local_preds.float(), + local_labels, + reduction="sum", + ) # Pass the spatial_mesh directly global_ce_sum = SpatialAllReduce.apply( From 6bd9e14413423df9d577d9679a43cfc0e54cb2b0 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 16 Apr 2026 14:01:41 -0700 Subject: [PATCH 25/43] Warmup evaluation step (#49) * Warmup evaluation * cleanup --- ScaFFold/utils/trainer.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index c98fa64..6f3cf80 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -480,6 +480,21 @@ def warmup(self): p.grad = None self.optimizer.zero_grad(set_to_none=True) + if self.config.dist: + self.val_loader.sampler.set_epoch(0) + + evaluate( + self.model, + self.val_loader, + self.device, + self.config.torch_amp, + self.world_rank == 0, + self.criterion, + self.config.n_categories, + self.config._parallel_strategy, + ) + self.model.train() + if self.config.dist: torch.distributed.barrier() self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") @@ -686,7 +701,7 @@ def train(self): self.val_loader, self.device, self.config.torch_amp, - True if self.world_rank == 0 else False, + self.world_rank == 0, self.criterion, self.config.n_categories, self.config._parallel_strategy, From 04c1b4a92a5a182987483c165ba11a7a403c00f0 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 16 Apr 2026 14:10:22 -0700 Subject: [PATCH 26/43] Add configuration option to disable checkpointing (#48) * enable option to never checkpoint and make default * Update config_utils.py --- ScaFFold/configs/benchmark_default.yml | 6 +++--- ScaFFold/configs/benchmark_testing.yml | 2 +- ScaFFold/utils/config_utils.py | 1 - ScaFFold/utils/trainer.py | 7 +++++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 2d1d414..0bc715d 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -4,15 +4,15 @@ dataset_dir: "datasets" # Directory in which to store and query for d fract_base_dir: "fractals" # Base directory for fractal IFS and instances. n_categories: 5 # Number of fractal categories present in the dataset. n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances. -problem_scale: 8 # Determines dataset resolution and number of unet layers. Default is 6. +problem_scale: 7 # Determines dataset resolution and number of unet layers. Default is 6. unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. dataloader_num_workers: 4 # Number of DataLoader worker processes per rank. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -dc_num_shards: [1, 1, 2] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +dc_num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard -checkpoint_interval: 100 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. +checkpoint_interval: -1 # Checkpoint every C epochs; set to -1 to disable checkpointing entirely. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index f37749c..8fbd52c 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -12,7 +12,7 @@ dataloader_num_workers: 4 # Number of DataLoader worker processes per r optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard -checkpoint_interval: 100 # Checkpoint every C epochs. More frequent checkpointing can be very expensive on slow filesystems. +checkpoint_interval: -1 # Checkpoint every C epochs; set to -1 to disable checkpointing entirely. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 9a67182..4684779 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -73,7 +73,6 @@ def __init__(self, config_dict): ] self.target_dice = config_dict["target_dice"] self.checkpoint_interval = config_dict["checkpoint_interval"] - self.dc_num_shards = config_dict["dc_num_shards"] self.dc_shard_dims = config_dict["dc_shard_dims"] self.dc_total_shards = math.prod(self.dc_num_shards) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 6f3cf80..e60d173 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -765,8 +765,11 @@ def train(self): # begin_code_region("checkpoint") - # Checkpoint only if at a checkpoint_interval epoch - if epoch % self.config.checkpoint_interval == 0: + # A checkpoint interval of -1 disables checkpointing entirely. + if ( + self.config.checkpoint_interval > 0 + and epoch % self.config.checkpoint_interval == 0 + ): extras = {"train_mask_values": self.train_set.mask_values} self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras) From ac98bbc94bcee05e36520e8f71bbfc17e4d31c96 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Wed, 22 Apr 2026 15:02:37 -0700 Subject: [PATCH 27/43] Enforce samples are not repeated (#55) --- ScaFFold/cli.py | 2 +- ScaFFold/worker.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 0aa755b..30512a6 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -138,7 +138,7 @@ def main(): ) benchmark_parser.add_argument("--seed", type=int, help="Random seed.") benchmark_parser.add_argument( - "--batch-size", type=int, nargs="+", help="Batch sizes for each volume size." + "--batch-size", type=int, help="Batch sizes for each volume size." ) benchmark_parser.add_argument( "--warmup-batches", diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index f0279e1..d2818c1 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -217,6 +217,22 @@ def main(kwargs_dict: dict = {}): trainer.spatial_mesh = ps.device_mesh[ps.distconv_dim_names] num_spatial_dims = len(ps.shard_dim) trainer.ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims + global_batch_size = config.batch_size * ( + world_size // math.prod(config.dc_num_shards) + ) + if rank == 0: + log.info( + f"Effective global batch size = {global_batch_size} " + f"(batch_size={config.batch_size} * " + f"(world_size={world_size} / prod(dc_num_shards)={math.prod(config.dc_num_shards)}))" + ) + if global_batch_size > trainer.n_train: + raise ValueError( + "Effective global batch size exceeds available training samples: " + f"global_batch_size={global_batch_size}, n_train={trainer.n_train}, " + f"batch_size={config.batch_size}, world_size={world_size}, " + f"dc_num_shards={config.dc_num_shards}" + ) else: raise RuntimeError( From a36fb3adfaace120877770fe729e4c05af9b8b4c Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Wed, 22 Apr 2026 15:54:50 -0700 Subject: [PATCH 28/43] Exit if dice score is NaN (#54) * check for nan dice_score_train * lint --- ScaFFold/utils/trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index e60d173..d0f4dc7 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -778,4 +778,10 @@ def train(self): dice_score_train = val_score epoch += 1 + # This check must exist otherwise the condition dice_score_train < self.config.target_dice will evaluate to False and incorrectly exit the training + if math.isnan(dice_score_train): + raise ValueError( + "Invalid value (NaN) encountered in dice score computation" + ) + adiak_value("final_epochs", epoch) From d6b7b641b6a83ac05379cebf2319092bbcbff72c Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Wed, 22 Apr 2026 16:16:52 -0700 Subject: [PATCH 29/43] Move to sharded data loading (#52) * move to sharded data loading * bug fixes * ruff * restore missing import * ruff --------- Co-authored-by: Patrick Miles --- ScaFFold/utils/data_loading.py | 131 ++++++++++++++++++++++++++++++--- ScaFFold/utils/evaluate.py | 31 +++----- ScaFFold/utils/trainer.py | 96 +++++++++++++----------- 3 files changed, 184 insertions(+), 74 deletions(-) diff --git a/ScaFFold/utils/data_loading.py b/ScaFFold/utils/data_loading.py index 74326e4..688f329 100644 --- a/ScaFFold/utils/data_loading.py +++ b/ScaFFold/utils/data_loading.py @@ -13,9 +13,11 @@ # SPDX-License-Identifier: (Apache-2.0) import pickle +from dataclasses import dataclass from os import listdir from os.path import isfile, join, splitext from pathlib import Path +from typing import Dict, Optional, Tuple import numpy as np import torch @@ -30,13 +32,89 @@ META_FILENAME = "meta.yaml" +@dataclass(frozen=True) +class SpatialShardSpec: + """Describe the local spatial shard owned by the current rank.""" + + shard_dims: Tuple[int, ...] + num_shards: Tuple[int, ...] + shard_indices: Tuple[int, ...] + + def __post_init__(self): + if not ( + len(self.shard_dims) == len(self.num_shards) == len(self.shard_indices) + ): + raise ValueError( + "shard_dims, num_shards, and shard_indices must have matching lengths" + ) + if len(set(self.shard_dims)) != len(self.shard_dims): + raise ValueError(f"Shard dimensions must be unique: {self.shard_dims}") + for shard_dim, num_shards, shard_index in zip( + self.shard_dims, self.num_shards, self.shard_indices + ): + if shard_dim < 2: + raise ValueError( + f"Invalid shard_dim {shard_dim}: only spatial dimensions are supported" + ) + if num_shards < 1: + raise ValueError( + f"Invalid num_shards {num_shards} for shard_dim {shard_dim}" + ) + if shard_index < 0 or shard_index >= num_shards: + raise ValueError( + f"Invalid shard_index {shard_index} for shard_dim {shard_dim} with {num_shards} shards" + ) + + @staticmethod + def _chunk_slice(size: int, num_shards: int, shard_index: int) -> slice: + """Match torch.chunk-style uneven shard boundaries.""" + + chunk_size = (size + num_shards - 1) // num_shards + start = shard_index * chunk_size + if start >= size: + raise ValueError( + f"Empty local shard: dim size {size}, num_shards {num_shards}, shard_index {shard_index}" + ) + stop = min(size, start + chunk_size) + return slice(start, stop) + + def slice_array( + self, array: np.ndarray, axis_map: Dict[int, int], array_label: str + ) -> np.ndarray: + if not self.shard_dims: + return array + + slices = [slice(None)] * array.ndim + for shard_dim, num_shards, shard_index in zip( + self.shard_dims, self.num_shards, self.shard_indices + ): + if shard_dim not in axis_map: + raise ValueError( + f"No axis mapping defined for {array_label} shard_dim {shard_dim}" + ) + axis = axis_map[shard_dim] + if axis >= array.ndim: + raise ValueError( + f"Axis {axis} out of range for {array_label} with shape {array.shape}" + ) + slices[axis] = self._chunk_slice(array.shape[axis], num_shards, shard_index) + + return array[tuple(slices)] + + class BasicDataset(Dataset): def __init__( - self, images_dir: str, mask_dir: str, mask_suffix: str = "", data_dir: str = "" + self, + images_dir: str, + mask_dir: str, + mask_suffix: str = "", + data_dir: str = "", + spatial_shard_spec: Optional[SpatialShardSpec] = None, ): self.images_dir = Path(images_dir) self.mask_dir = Path(mask_dir) self.mask_suffix = mask_suffix + self.spatial_shard_spec = spatial_shard_spec self.dataset_root = self.images_dir.parents[1] self.dataset_format_version = self._load_dataset_format_version() @@ -63,9 +141,8 @@ def __len__(self): return len(self.ids) @staticmethod - def _load_numpy_array(path): - with open(path, "rb") as handle: - return np.load(handle) + def _load_numpy_array(path, mmap_mode=None): + return np.load(path, allow_pickle=False, mmap_mode=mmap_mode) def _load_dataset_format_version(self): meta_path = self.dataset_root / META_FILENAME @@ -102,11 +179,28 @@ def _prepare_legacy_mask(mask_values, mask): @staticmethod def _prepare_optimized_image(img): - return np.ascontiguousarray(img, dtype=VOLUME_DTYPE) + return np.array(img, dtype=VOLUME_DTYPE, copy=True, order="C") @staticmethod def _prepare_optimized_mask(mask): - return np.ascontiguousarray(mask, dtype=MASK_DTYPE) + return np.array(mask, dtype=MASK_DTYPE, copy=True, order="C") + + def _slice_image_array(self, img): + if self.spatial_shard_spec is None: + return img + + if self.dataset_format_version >= DATASET_FORMAT_VERSION: + axis_map = {2: 1, 3: 2, 4: 3} + else: + axis_map = {2: 0, 3: 1, 4: 2} + return self.spatial_shard_spec.slice_array(img, axis_map, "image") + + def _slice_mask_array(self, mask): + if self.spatial_shard_spec is None: + return mask + + axis_map = {2: 0, 3: 1, 4: 2} + return self.spatial_shard_spec.slice_array(mask, axis_map, "mask") def __getitem__(self, idx): name = self.ids[idx] @@ -119,8 +213,13 @@ def __getitem__(self, idx): assert len(mask_file) == 1, ( f"Either no mask or multiple masks found for the ID {name}: {mask_file}" ) - mask = self._load_numpy_array(mask_file[0]) - img = self._load_numpy_array(img_file[0]) + mmap_mode = "r" if self.spatial_shard_spec is not None else None + # Memmap lets each rank slice out just its local shard without eagerly + # reading the full sample into process memory first. + mask = self._load_numpy_array(mask_file[0], mmap_mode=mmap_mode) + img = self._load_numpy_array(img_file[0], mmap_mode=mmap_mode) + mask = self._slice_mask_array(mask) + img = self._slice_image_array(img) if self.dataset_format_version >= DATASET_FORMAT_VERSION: img = self._prepare_optimized_image(img) @@ -136,5 +235,17 @@ def __getitem__(self, idx): class FractalDataset(BasicDataset): - def __init__(self, images_dir, mask_dir, data_dir): - super().__init__(images_dir, mask_dir, mask_suffix="_mask", data_dir=data_dir) + def __init__( + self, + images_dir, + mask_dir, + data_dir, + spatial_shard_spec: Optional[SpatialShardSpec] = None, + ): + super().__init__( + images_dir, + mask_dir, + mask_suffix="_mask", + data_dir=data_dir, + spatial_shard_spec=spatial_shard_spec, + ) diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index c66c255..372aa70 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -12,12 +12,9 @@ # # SPDX-License-Identifier: (Apache-2.0) -import math - import torch import torch.nn.functional as F from distconv import DCTensor -from torch.distributed.tensor import DTensor, Replicate, Shard from tqdm import tqdm from ScaFFold.utils.dice_score import ( @@ -66,20 +63,9 @@ def evaluate( # Dummy channel dimension [B, 1, D, H, W] mask_true = mask_true.unsqueeze(1) - # DDP Sharding - ddp_placements = [Shard(0)] + [Replicate()] * len( - parallel_strategy.shard_dim - ) - image_dp = DTensor.from_local( - image, parallel_strategy.device_mesh, placements=ddp_placements - ).to_local() - mask_true_dp = DTensor.from_local( - mask_true, parallel_strategy.device_mesh, placements=ddp_placements - ).to_local() - - # DistConv Spatial Sharding - dcx = DCTensor.distribute(image_dp, parallel_strategy) - mask_true_dc = DCTensor.distribute(mask_true_dp, parallel_strategy) + # Inputs are already loaded as local shards by the dataset. + dcx = DCTensor.from_shard(image, parallel_strategy) + mask_true_dc = DCTensor.from_shard(mask_true, parallel_strategy) # Forward pass on sharded data dcy = net(dcx) @@ -102,9 +88,14 @@ def evaluate( ) global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) - # Divide by total global voxels to get the mean CE Loss - global_total_voxels = local_labels.numel() * math.prod( - parallel_strategy.num_shards + # Divide by the actual global voxel count to handle uneven shards. + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=torch.float32, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, spatial_mesh ) CE_loss = global_ce_sum / global_total_voxels diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index d0f4dc7..e8752fd 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -12,8 +12,6 @@ # # SPDX-License-Identifier: (Apache-2.0) -# Standard library -import math import os import shutil import time @@ -25,12 +23,11 @@ import torch.nn.functional as F from distconv import DCTensor from torch import optim -from torch.distributed.tensor import DTensor from torch.utils.data import DataLoader from tqdm import tqdm from ScaFFold.utils.checkpointing import CheckpointManager -from ScaFFold.utils.data_loading import FractalDataset +from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec from ScaFFold.utils.dice_score import SpatialAllReduce, compute_sharded_dice from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size @@ -69,9 +66,14 @@ def __init__(self, model, config, device, log): self.criterion = None self.global_step = 0 self.start_epoch = -1 - self.ps = None # DistConv ParallelStrategy + self.ps = getattr(self.config, "_parallel_strategy", None) self.spatial_mesh = None # Spatial mesh for use w/ DistConv - self.ddp_placements = None # DDP placements for use w/ DistConv + self.data_num_replicas = self.world_size + self.data_replica_rank = self.world_rank + if self.ps is not None: + self.spatial_mesh = self.ps.device_mesh[self.ps.distconv_dim_names] + self.data_num_replicas = self.ps.ddp_ranks + self.data_replica_rank = self.ps.ddp_ind self.checkpoint_path_absolute = str( self.config.run_dir + "/" + self.config.checkpoint_dir @@ -98,12 +100,25 @@ def create_dataset(self): val_mask_dir = dataset_dir / "masks/validation" train_unique_masks_path = dataset_dir / "train_unique_mask_vals" val_unique_masks_path = dataset_dir / "val_unique_mask_vals" + spatial_shard_spec = None + if self.ps is not None: + spatial_shard_spec = SpatialShardSpec( + shard_dims=tuple(self.ps.shard_dim), + num_shards=tuple(self.ps.num_shards), + shard_indices=tuple(self.ps.shard_ind), + ) self.train_set = FractalDataset( - train_vol_dir, train_mask_dir, data_dir=train_unique_masks_path + train_vol_dir, + train_mask_dir, + data_dir=train_unique_masks_path, + spatial_shard_spec=spatial_shard_spec, ) self.val_set = FractalDataset( - val_vol_dir, val_mask_dir, data_dir=val_unique_masks_path + val_vol_dir, + val_mask_dir, + data_dir=val_unique_masks_path, + spatial_shard_spec=spatial_shard_spec, ) self.n_train = len(self.train_set) self.n_val = len(self.val_set) @@ -115,10 +130,15 @@ def create_sampler(self): """Create DistributedSamplers for train and validation datasets.""" if self.config.dist: self.train_sampler = torch.utils.data.distributed.DistributedSampler( - self.train_set + self.train_set, + num_replicas=self.data_num_replicas, + rank=self.data_replica_rank, ) self.val_sampler = torch.utils.data.distributed.DistributedSampler( - self.val_set, shuffle=False + self.val_set, + num_replicas=self.data_num_replicas, + rank=self.data_replica_rank, + shuffle=False, ) else: self.train_sampler = torch.utils.data.RandomSampler(self.train_set) @@ -366,18 +386,9 @@ def warmup(self): # Add a dummy channel dimension to get 5D [B, 1, D, H, W] true_masks = true_masks.unsqueeze(1) - # Data parallel sharding - images_dp = DTensor.from_local( - images, self.ps.device_mesh, placements=self.ddp_placements - ).to_local() - - true_masks_dp = DTensor.from_local( - true_masks, self.ps.device_mesh, placements=self.ddp_placements - ).to_local() - - # Spatial sharding via DistConv - images_dc = DCTensor.distribute(images_dp, self.ps) - true_masks_dc = DCTensor.distribute(true_masks_dp, self.ps) + # Inputs are already loaded as local shards by the dataset. + images_dc = DCTensor.from_shard(images, self.ps) + true_masks_dc = DCTensor.from_shard(true_masks, self.ps) self._get_memsize(images_dc, "Sharded image", self.config.verbose) with torch.autocast( @@ -421,8 +432,13 @@ def warmup(self): # Pass the spatial_mesh directly global_ce_sum = SpatialAllReduce.apply(local_ce_sum, self.spatial_mesh) - global_total_voxels = local_labels.numel() * math.prod( - self.config.dc_num_shards + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=torch.float32, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, self.spatial_mesh ) loss_ce = global_ce_sum / global_total_voxels @@ -462,7 +478,7 @@ def warmup(self): local_preds_softmax, local_labels_one_hot, ) - del loss_ce, loss_dice, loss, images_dp, true_masks_dp + del loss_ce, loss_dice, loss if self.world_rank == 0: peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) @@ -533,7 +549,7 @@ def train(self): else f"{epoch}/{self.config.epochs}" ) with tqdm( - total=self.n_train // self.world_size, + total=len(self.train_sampler), desc=f"({os.path.basename(self.config.run_dir)}) \ Epoch {estr}", unit="img", @@ -560,23 +576,10 @@ def train(self): # Add a dummy channel dimension to get 5D [B, 1, D, H, W] true_masks = true_masks.unsqueeze(1) - # Data parallel sharding - images_dp = DTensor.from_local( - images, self.ps.device_mesh, placements=self.ddp_placements - ).to_local() - - true_masks_dp = DTensor.from_local( - true_masks, - self.ps.device_mesh, - placements=self.ddp_placements, - ).to_local() - - # Delete source tensors immediately after use to keep memory down + # Inputs are already loaded as local shards by the dataset. + images_dc = DCTensor.from_shard(images, self.ps) + true_masks_dc = DCTensor.from_shard(true_masks, self.ps) del images, true_masks - - # Spatial sharding via DistConv - images_dc = DCTensor.distribute(images_dp, self.ps) - true_masks_dc = DCTensor.distribute(true_masks_dp, self.ps) self._get_memsize( images_dc, "Sharded image", self.config.verbose ) @@ -634,8 +637,13 @@ def train(self): local_ce_sum, self.spatial_mesh ) - global_total_voxels = local_labels.numel() * math.prod( - self.config.dc_num_shards + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=torch.float32, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, self.spatial_mesh ) loss_ce = global_ce_sum / global_total_voxels From f66fff1299c7df97d2bef23058a82af3dff7854c Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 23 Apr 2026 10:21:05 -0700 Subject: [PATCH 30/43] Missing Import (#57) --- ScaFFold/utils/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index e8752fd..a5d4355 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -12,6 +12,7 @@ # # SPDX-License-Identifier: (Apache-2.0) +import math import os import shutil import time From 3194f6c7f011336232a41642d99ee8f553e6a778 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 23 Apr 2026 14:59:26 -0700 Subject: [PATCH 31/43] Consistently ignore background class in loss/dice (#59) * ignore background class * .item() --- ScaFFold/utils/evaluate.py | 14 +++++++++----- ScaFFold/utils/trainer.py | 22 +++++++++++++++------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 372aa70..56198cc 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -29,6 +29,12 @@ def evaluate( net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy ): + + def foreground_dice_mean(dice_scores): + if dice_scores.size(1) > 1: + return dice_scores[:, 1:].mean().item() + return dice_scores.mean().item() + net.eval() num_val_batches = len(dataloader) total_dice_score = 0.0 @@ -109,14 +115,12 @@ def evaluate( dice_score_probs = compute_sharded_dice( mask_pred_probs, mask_true_onehot, spatial_mesh ) - dice_loss_curr = 1.0 - dice_score_probs.mean() - # Eval metric (excluding background class 0) - # dice_score_probs shape is [Batch, Channels]. We slice [:, 1:] to drop background - batch_dice_score = dice_score_probs[:, 1:].mean() + # dice_score_probs shape is [Batch, Channels]. + batch_dice_score = foreground_dice_mean(dice_score_probs) # --- Combine and Accumulate --- - loss = CE_loss + dice_loss_curr + loss = CE_loss + (1.0 - batch_dice_score) val_loss_epoch += loss.item() total_dice_score += batch_dice_score.item() processed_batches += 1 diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index a5d4355..27482ab 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -207,6 +207,13 @@ def setup_training_components(self): f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, Gradient Scaler Enabled: {self.config.torch_amp}" ) + @staticmethod + def _foreground_dice_mean(dice_scores): + """Match optimization to the reported validation metric by excluding background.""" + if dice_scores.size(1) > 1: + return dice_scores[:, 1:].mean().item() + return dice_scores.mean().item() + class PyTorchTrainer(BaseTrainer): """ @@ -453,10 +460,10 @@ def warmup(self): dice_scores = compute_sharded_dice( local_preds_softmax, local_labels_one_hot, self.spatial_mesh ) - loss_dice = 1.0 - dice_scores.mean() + batch_dice_score = self._foreground_dice_mean(dice_scores) # 3. Combine Loss - loss = loss_ce + loss_dice + loss = loss_ce + (1.0 - batch_dice_score) self.log.debug( " warmup: loss calculation complete. Proceeding to backward pass" @@ -479,7 +486,7 @@ def warmup(self): local_preds_softmax, local_labels_one_hot, ) - del loss_ce, loss_dice, loss + del loss_ce, loss if self.world_rank == 0: peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) @@ -665,11 +672,11 @@ def train(self): local_labels_one_hot, self.spatial_mesh, ) - loss_dice = 1.0 - dice_scores.mean() + batch_dice_score = self._foreground_dice_mean(dice_scores) # 3. Combine Loss - loss = loss_ce + loss_dice - train_dice_total += dice_scores[:, 1:].mean().item() + loss = loss_ce + (1.0 - batch_dice_score) + train_dice_total += batch_dice_score end_code_region("calculate_loss") @@ -745,7 +752,8 @@ def train(self): self.log.info( f" epoch {epoch} \ | train_dice_loss {train_dice:.6f} (type {type(train_dice)}) \ - | val_dice_score {val_score:.6f}" + | val_dice_score {val_score:.6f} \ + | lr {self.config.learning_rate:.8f}" ) self.log.debug(f" writing to csv at {self.outfile_path}") if self.world_rank == 0: From 699cf5f3c80858ad4eb71ae1f71b146ddf4b722d Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 30 Apr 2026 15:00:19 -0700 Subject: [PATCH 32/43] Improve AMP stability (#60) * bf16 and more fp32 sections for dice * Refactor * ruff * fix merge artifact * Update trainer.py * Update trainer.py * Update trainer.py * Refactor * mv .item() --- ScaFFold/utils/data_types.py | 4 + ScaFFold/utils/evaluate.py | 71 +++++++------- ScaFFold/utils/trainer.py | 182 ++++++++++++++++++----------------- 3 files changed, 134 insertions(+), 123 deletions(-) diff --git a/ScaFFold/utils/data_types.py b/ScaFFold/utils/data_types.py index 90186db..b555811 100644 --- a/ScaFFold/utils/data_types.py +++ b/ScaFFold/utils/data_types.py @@ -13,9 +13,13 @@ # SPDX-License-Identifier: (Apache-2.0) import numpy as np +import torch DEFAULT_NP_DTYPE = np.float64 # Masks are values 0 <= x <= n_categories MASK_DTYPE = np.uint16 # Volumes/img are 0 <= x <= 1 VOLUME_DTYPE = np.float32 + +# Shared AMP dtype selection for torch.autocast. +AMP_DTYPE = torch.bfloat16 diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 56198cc..62d0fdf 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -17,6 +17,7 @@ from distconv import DCTensor from tqdm import tqdm +from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE from ScaFFold.utils.dice_score import ( SpatialAllReduce, compute_sharded_dice, @@ -29,13 +30,16 @@ def evaluate( net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy ): - def foreground_dice_mean(dice_scores): if dice_scores.size(1) > 1: return dice_scores[:, 1:].mean().item() return dice_scores.mean().item() net.eval() + autocast_device_type = device.type if device.type != "mps" else "cpu" + autocast_kwargs = {"device_type": autocast_device_type, "enabled": amp} + if amp: + autocast_kwargs["dtype"] = AMP_DTYPE num_val_batches = len(dataloader) total_dice_score = 0.0 processed_batches = 0 @@ -47,7 +51,7 @@ def foreground_dice_mean(dice_scores): f"[eval] ps.shard_dim={parallel_strategy.shard_dim} num_shards={parallel_strategy.num_shards}" ) - with torch.autocast(device.type if device.type != "mps" else "cpu", enabled=amp): + with torch.autocast(**autocast_kwargs): val_loss_epoch = 0.0 for batch in tqdm( dataloader, @@ -85,44 +89,39 @@ def foreground_dice_mean(dice_scores): if local_preds.size(0) == 0 or local_labels.size(0) == 0: continue - # --- 1. Sharded CE Loss --- - with torch.autocast( - device.type if device.type != "mps" else "cpu", enabled=False - ): + # Calculate CE and Dice loss in single precision for numerical stability. + with torch.autocast(device_type=autocast_device_type, enabled=False): + # Compute global CE loss from sharded CE loss local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum" ) - global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) - - # Divide by the actual global voxel count to handle uneven shards. - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=torch.float32, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, spatial_mesh - ) - CE_loss = global_ce_sum / global_total_voxels - - # --- 2. Format Predictions & Labels (Strictly Multiclass) --- - mask_pred_probs = F.softmax(local_preds, dim=1).float() - mask_true_onehot = ( - F.one_hot(local_labels, n_categories + 1).permute(0, 4, 1, 2, 3).float() - ) + global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=VOLUME_DTYPE, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, spatial_mesh + ) + CE_loss = global_ce_sum / global_total_voxels + + # Compute global dice loss from sharded dice loss + mask_pred_probs = F.softmax(local_preds.float(), dim=1) + mask_true_onehot = ( + F.one_hot(local_labels, n_categories + 1) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_score_probs = compute_sharded_dice( + mask_pred_probs, mask_true_onehot, spatial_mesh + ) + batch_dice_score = foreground_dice_mean(dice_score_probs) - # Dice loss uses probabilities - dice_score_probs = compute_sharded_dice( - mask_pred_probs, mask_true_onehot, spatial_mesh - ) - # Eval metric (excluding background class 0) - # dice_score_probs shape is [Batch, Channels]. - batch_dice_score = foreground_dice_mean(dice_score_probs) - - # --- Combine and Accumulate --- - loss = CE_loss + (1.0 - batch_dice_score) - val_loss_epoch += loss.item() - total_dice_score += batch_dice_score.item() + # Sum global CE Loss and Dice loss + loss = CE_loss + (1.0 - batch_dice_score) + val_loss_epoch += loss.item() + total_dice_score += batch_dice_score processed_batches += 1 net.train() diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 27482ab..b680746 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -29,7 +29,11 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec -from ScaFFold.utils.dice_score import SpatialAllReduce, compute_sharded_dice +from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE +from ScaFFold.utils.dice_score import ( + SpatialAllReduce, + compute_sharded_dice, +) from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size # Local @@ -48,6 +52,9 @@ def __init__(self, model, config, device, log): self.config = config self.device = device self.log = log + self.amp_device_type = self.device.type if self.device.type != "mps" else "cpu" + self.amp_dtype = AMP_DTYPE + self.use_grad_scaler = False self.world_size = get_world_size(required=self.config.dist) self.world_rank = get_world_rank(required=self.config.dist) self.local_rank = get_local_rank(required=self.config.dist) @@ -194,7 +201,11 @@ def setup_training_components(self): ) # Set up gradient scaler for AMP (Automatic Mixed Precision) - self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.config.torch_amp) + # bfloat does not need grad scaler + self.use_grad_scaler = ( + self.config.torch_amp and self.amp_dtype != torch.bfloat16 + ) + self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_grad_scaler) # Set up loss function self.criterion = ( @@ -204,15 +215,24 @@ def setup_training_components(self): ) self.log.info( - f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, Gradient Scaler Enabled: {self.config.torch_amp}" + f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, AMP dtype: {self.amp_dtype}, Gradient Scaler Enabled: {self.use_grad_scaler}" ) + def _autocast_kwargs(self, enabled=None): + if enabled is None: + enabled = self.config.torch_amp + + kwargs = {"device_type": self.amp_device_type, "enabled": enabled} + if enabled: + kwargs["dtype"] = self.amp_dtype + return kwargs + @staticmethod def _foreground_dice_mean(dice_scores): """Match optimization to the reported validation metric by excluding background.""" if dice_scores.size(1) > 1: - return dice_scores[:, 1:].mean().item() - return dice_scores.mean().item() + return dice_scores[:, 1:].mean() + return dice_scores.mean() class PyTorchTrainer(BaseTrainer): @@ -399,10 +419,7 @@ def warmup(self): true_masks_dc = DCTensor.from_shard(true_masks, self.ps) self._get_memsize(images_dc, "Sharded image", self.config.verbose) - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=self.config.torch_amp, - ): + with torch.autocast(**self._autocast_kwargs()): # Forward on DCTensor self.log.debug(" warmup: running forward pass") masks_pred_dc = self.model(images_dc) @@ -428,42 +445,41 @@ def warmup(self): f" warmup: Calculating sharded loss. Mem: {current_mem:.2f} GB." ) - # 1. Sharded Cross Entropy - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=False, - ): + # Calculate CE and Dice loss in single precision for numerical stability. + with torch.autocast(**self._autocast_kwargs(enabled=False)): + # Compute global CE loss from sharded CE loss local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum" ) + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, self.spatial_mesh + ) + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=VOLUME_DTYPE, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, self.spatial_mesh + ) + loss_ce = global_ce_sum / global_total_voxels - # Pass the spatial_mesh directly - global_ce_sum = SpatialAllReduce.apply(local_ce_sum, self.spatial_mesh) - - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=torch.float32, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, self.spatial_mesh - ) - loss_ce = global_ce_sum / global_total_voxels - - # 2. Sharded Dice Loss - local_preds_softmax = F.softmax(local_preds, dim=1).float() - local_labels_one_hot = ( - F.one_hot(local_labels, num_classes=self.config.n_categories + 1) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_scores = compute_sharded_dice( - local_preds_softmax, local_labels_one_hot, self.spatial_mesh - ) - batch_dice_score = self._foreground_dice_mean(dice_scores) + # Compute global dice loss from sharded dice loss + local_preds_softmax = F.softmax(local_preds.float(), dim=1) + local_labels_one_hot = ( + F.one_hot( + local_labels, num_classes=self.config.n_categories + 1 + ) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_scores = compute_sharded_dice( + local_preds_softmax, local_labels_one_hot, self.spatial_mesh + ) + batch_dice_score = self._foreground_dice_mean(dice_scores) - # 3. Combine Loss - loss = loss_ce + (1.0 - batch_dice_score) + # Sum global CE Loss and Dice loss + loss = loss_ce + (1.0 - batch_dice_score) self.log.debug( " warmup: loss calculation complete. Proceeding to backward pass" @@ -592,10 +608,7 @@ def train(self): images_dc, "Sharded image", self.config.verbose ) - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=self.config.torch_amp, - ): + with torch.autocast(**self._autocast_kwargs()): # Predict on this batch torch.cuda.reset_peak_memory_stats() gather_and_print_mem(self.log, "pre_forward") @@ -627,56 +640,51 @@ def train(self): f"Calculating sharded loss. Mem: {current_mem:.2f} GB." ) - # 1. Sharded Cross Entropy - with torch.autocast( - self.device.type - if self.device.type != "mps" - else "cpu", - enabled=False, - ): + # Calculate CE and Dice loss in single precision for numerical stability. + with torch.autocast(**self._autocast_kwargs(enabled=False)): + # Compute global CE loss from sharded CE loss local_ce_sum = F.cross_entropy( local_preds.float(), local_labels, reduction="sum", ) - - # Pass the spatial_mesh directly - global_ce_sum = SpatialAllReduce.apply( - local_ce_sum, self.spatial_mesh - ) - - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=torch.float32, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, self.spatial_mesh - ) - loss_ce = global_ce_sum / global_total_voxels - - # 2. Sharded Dice Loss - local_preds_softmax = F.softmax(local_preds, dim=1).float() - local_labels_one_hot = ( - F.one_hot( - local_labels, - num_classes=self.config.n_categories + 1, + global_ce_sum = SpatialAllReduce.apply( + local_ce_sum, self.spatial_mesh ) - .permute(0, 4, 1, 2, 3) - .float() - ) + local_voxel_count = torch.tensor( + float(local_labels.numel()), + device=local_labels.device, + dtype=VOLUME_DTYPE, + ) + global_total_voxels = SpatialAllReduce.apply( + local_voxel_count, self.spatial_mesh + ) + loss_ce = global_ce_sum / global_total_voxels - # Compute sharded dice using new function - dice_scores = compute_sharded_dice( - local_preds_softmax, - local_labels_one_hot, - self.spatial_mesh, - ) - batch_dice_score = self._foreground_dice_mean(dice_scores) + # Compute global dice loss from sharded dice loss + local_preds_softmax = F.softmax( + local_preds.float(), dim=1 + ) + local_labels_one_hot = ( + F.one_hot( + local_labels, + num_classes=self.config.n_categories + 1, + ) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_scores = compute_sharded_dice( + local_preds_softmax, + local_labels_one_hot, + self.spatial_mesh, + ) + batch_dice_score = self._foreground_dice_mean( + dice_scores + ) - # 3. Combine Loss - loss = loss_ce + (1.0 - batch_dice_score) - train_dice_total += batch_dice_score + # Sum global CE Loss and Dice loss + loss = loss_ce + (1.0 - batch_dice_score) + train_dice_total += batch_dice_score end_code_region("calculate_loss") @@ -748,7 +756,7 @@ def train(self): # # Write out data for this epoch to train stats csv # - train_dice = float(train_dice_total / len(self.train_loader)) + train_dice = float(train_dice_total.item() / len(self.train_loader)) self.log.info( f" epoch {epoch} \ | train_dice_loss {train_dice:.6f} (type {type(train_dice)}) \ From 7ba9de82ccf2d03b891c1798cbbbaf85c80ef82e Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Wed, 6 May 2026 08:17:22 -0700 Subject: [PATCH 33/43] Implement CosineAnnealingWarmRestarts LR Scheduler (#61) * bf16 and more fp32 sections for dice * Refactor * ruff * fix merge artifact * cosine sched * Validation needs smaller batch size otherwise val_dice can be 0. * config * Update benchmark_default.yml * lint * undo val_batch_size * Fix dtypes * README --- README.md | 5 ++- ScaFFold/cli.py | 22 ++++++++++ ScaFFold/configs/benchmark_default.yml | 11 +++-- ScaFFold/configs/benchmark_testing.yml | 5 ++- ScaFFold/utils/config_utils.py | 5 ++- ScaFFold/utils/evaluate.py | 27 +++++++++--- ScaFFold/utils/trainer.py | 61 +++++++++++++++----------- ScaFFold/worker.py | 12 +++-- 8 files changed, 105 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 04bef50..9438b1d 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,10 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 25 # In percent. epochs: 100 # Number of training epochs. -learning_rate: .0001 # Learning rate for training. +starting_learning_rate: .01 # Initial learning rate for training. +min_learning_rate: .001 # Minimum learning rate for CosineAnnealingWarmRestarts. +T_0: 100 # Epochs in the first cosine restart cycle. +T_mult: 2 # Restart cycle growth factor. disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 30512a6..461c338 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -177,6 +177,28 @@ def main(): type=int, help="Number of training epochs.", ) + benchmark_parser.add_argument( + "--starting-learning-rate", + type=float, + help="Initial learning rate for training.", + ) + benchmark_parser.add_argument( + "--min-learning-rate", + type=float, + help="Minimum learning rate for CosineAnnealingWarmRestarts.", + ) + benchmark_parser.add_argument( + "--T-0", + dest="T_0", + type=int, + help="Epochs in the first cosine restart cycle.", + ) + benchmark_parser.add_argument( + "--T-mult", + dest="T_mult", + type=int, + help="Restart cycle growth factor.", + ) comm = MPI.COMM_WORLD rank = comm.Get_rank() diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 0bc715d..ac54d54 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -7,8 +7,8 @@ n_instances_used_per_fractal: 145 # Number of unique instances to pull from eac problem_scale: 7 # Determines dataset resolution and number of unet layers. Default is 6. unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. seed: 42 # Random seed. -batch_size: 1 # Batch sizes for each vol size. -dataloader_num_workers: 4 # Number of DataLoader worker processes per rank. +batch_size: 1 # Batch sizes for each vol size per rank. +dataloader_num_workers: 1 # Number of DataLoader worker processes per rank. More workers will use more memory optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. dc_num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard @@ -19,8 +19,11 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 25 # In percent. epochs: -1 # Number of training epochs. -learning_rate: .0001 # Learning rate for training. -disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. +starting_learning_rate: 0.1 # Initial learning rate for training. +min_learning_rate: 0.001 # Minimum learning rate for CosineAnnealingWarmRestarts. +T_0: 100 # Epochs in the first cosine restart cycle. +T_mult: 2 # Restart cycle growth factor. +disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off. diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index 8fbd52c..8b72435 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -19,7 +19,10 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 25 # In percent. epochs: 10 # Number of training epochs. -learning_rate: .0001 # Learning rate for training. +starting_learning_rate: .0001 # Initial learning rate for training. +min_learning_rate: .0001 # Minimum learning rate for CosineAnnealingWarmRestarts. +T_0: 10 # Epochs in the first cosine restart cycle. +T_mult: 1 # Restart cycle growth factor. disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 4684779..2c73a6c 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -61,7 +61,10 @@ def __init__(self, config_dict): self.seed = config_dict["seed"] self.dist = bool(config_dict["dist"]) self.framework = config_dict["framework"] - self.learning_rate = config_dict["learning_rate"] + self.starting_learning_rate = config_dict["starting_learning_rate"] + self.min_learning_rate = config_dict["min_learning_rate"] + self.T_0 = config_dict["T_0"] + self.T_mult = config_dict["T_mult"] self.variance_threshold = config_dict["variance_threshold"] self.torch_amp = bool(config_dict["torch_amp"]) self.loss_freq = config_dict["loss_freq"] diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 62d0fdf..2fd3eb1 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -30,10 +30,12 @@ def evaluate( net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy ): - def foreground_dice_mean(dice_scores): + def foreground_dice_stats(dice_scores): if dice_scores.size(1) > 1: - return dice_scores[:, 1:].mean().item() - return dice_scores.mean().item() + per_sample_scores = dice_scores[:, 1:].mean(dim=1) + else: + per_sample_scores = dice_scores.mean(dim=1) + return per_sample_scores.sum().item(), per_sample_scores.numel() net.eval() autocast_device_type = device.type if device.type != "mps" else "cpu" @@ -43,6 +45,7 @@ def foreground_dice_mean(dice_scores): num_val_batches = len(dataloader) total_dice_score = 0.0 processed_batches = 0 + processed_samples = 0 spatial_mesh = parallel_strategy.device_mesh[parallel_strategy.distconv_dim_names] @@ -116,19 +119,29 @@ def foreground_dice_mean(dice_scores): dice_score_probs = compute_sharded_dice( mask_pred_probs, mask_true_onehot, spatial_mesh ) - batch_dice_score = foreground_dice_mean(dice_score_probs) + batch_dice_sum, batch_sample_count = foreground_dice_stats( + dice_score_probs + ) + batch_dice_score = batch_dice_sum / max(batch_sample_count, 1) # Sum global CE Loss and Dice loss loss = CE_loss + (1.0 - batch_dice_score) val_loss_epoch += loss.item() - total_dice_score += batch_dice_score + total_dice_score += batch_dice_sum processed_batches += 1 + processed_samples += batch_sample_count net.train() val_loss_avg = val_loss_epoch / max(processed_batches, 1) if primary: print( - f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}" + f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}, num_val_samples={processed_samples}" ) - return total_dice_score, val_loss_epoch, val_loss_avg, processed_batches + return ( + total_dice_score, + val_loss_epoch, + val_loss_avg, + processed_batches, + processed_samples, + ) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index b680746..654cc19 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -130,7 +130,7 @@ def create_dataset(self): ) self.n_train = len(self.train_set) self.n_val = len(self.val_set) - self.log.debug( + self.log.info( f"Datasets created with n_train={self.n_train}, n_val={self.n_val}" ) @@ -173,8 +173,15 @@ def create_dataloaders(self): self.train_set, sampler=self.train_sampler, **loader_args ) self.val_loader = DataLoader( - self.val_set, sampler=self.val_sampler, drop_last=True, **loader_args + self.val_set, sampler=self.val_sampler, drop_last=False, **loader_args ) + if len(self.val_loader) == 0: + raise ValueError( + "Validation DataLoader has zero batches. " + f"n_val={self.n_val}, batch_size={self.config.batch_size}, " + f"data_num_replicas={self.data_num_replicas}. " + "Reduce batch_size or adjust validation sharding." + ) def setup_training_components(self): """Set up the optimizer, scheduler, gradient scaler, and loss function.""" @@ -182,22 +189,27 @@ def setup_training_components(self): if self.config.optimizer == "ADAM": self.log.info("Using ADAM optimizer.") self.optimizer = optim.Adam( - self.model.parameters(), lr=self.config.learning_rate + self.model.parameters(), lr=self.config.starting_learning_rate ) elif self.config.optimizer == "SGD": self.log.info("Using SGD optimizer.") self.optimizer = optim.SGD( - self.model.parameters(), lr=self.config.learning_rate + self.model.parameters(), lr=self.config.starting_learning_rate ) else: self.log.info("Using RMSprop optimizer.") self.optimizer = optim.RMSprop( - self.model.parameters(), lr=self.config.learning_rate, foreach=True + self.model.parameters(), + lr=self.config.starting_learning_rate, + foreach=True, ) # Set up learning rate scheduler - self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer, "max", patience=25 + self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( + self.optimizer, + T_0=self.config.T_0, + T_mult=self.config.T_mult, + eta_min=self.config.min_learning_rate, ) # Set up gradient scaler for AMP (Automatic Mixed Precision) @@ -234,6 +246,11 @@ def _foreground_dice_mean(dice_scores): return dice_scores[:, 1:].mean() return dice_scores.mean() + def _current_learning_rate(self): + if self.optimizer is None or not self.optimizer.param_groups: + return self.config.starting_learning_rate + return self.optimizer.param_groups[0]["lr"] + class PyTorchTrainer(BaseTrainer): """ @@ -403,7 +420,7 @@ def warmup(self): images = images.to( device=self.device, - dtype=torch.float32, + dtype=VOLUME_DTYPE, memory_format=torch.channels_last_3d, non_blocking=True, ) @@ -587,7 +604,7 @@ def train(self): begin_code_region("image_to_device") images = images.to( device=self.device, - dtype=torch.float32, + dtype=VOLUME_DTYPE, memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first) non_blocking=True, ) @@ -720,7 +737,13 @@ def train(self): # # Evaluate model on validation set, update LR if necessary # - dice_sum, val_loss_epoch, val_loss_avg, numbatch = evaluate( + ( + dice_sum, + val_loss_epoch, + val_loss_avg, + numbatch, + numsamples, + ) = evaluate( self.model, self.val_loader, self.device, @@ -730,7 +753,7 @@ def train(self): self.config.n_categories, self.config._parallel_strategy, ) - dice_info = torch.tensor([dice_sum, numbatch]) + dice_info = torch.tensor([dice_sum, numsamples], dtype=VOLUME_DTYPE) if self.config.dist: dice_info = dice_info.to(device=self.device) torch.distributed.all_reduce( @@ -738,16 +761,7 @@ def train(self): ) val_score = dice_info[0].item() / max(dice_info[1].item(), 1) if not self.config.disable_scheduler: - # The following is true when trying to overfit, - # in which case we only care about train loss - if self.n_train == 1 or "overfit" in self.outfile_path: - self.log.debug( - "WARNING: scheduler step by overall_loss, \ - not val_score (n_train==1 or overfit in outfile_path)" - ) - self.scheduler.step(overall_loss) - else: # Otherwise, we're really trying to optimize for validation dice score - self.scheduler.step(val_score) + self.scheduler.step() else: self.log.debug("scheduler disabled, no LR update this step") @@ -758,10 +772,7 @@ def train(self): # train_dice = float(train_dice_total.item() / len(self.train_loader)) self.log.info( - f" epoch {epoch} \ - | train_dice_loss {train_dice:.6f} (type {type(train_dice)}) \ - | val_dice_score {val_score:.6f} \ - | lr {self.config.learning_rate:.8f}" + f" epoch {epoch} | train_dice_score {train_dice:.6f} | val_dice_score {val_score:.6f} | lr {self._current_learning_rate():.8f}" ) self.log.debug(f" writing to csv at {self.outfile_path}") if self.world_rank == 0: diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index d2818c1..16087c8 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -217,14 +217,18 @@ def main(kwargs_dict: dict = {}): trainer.spatial_mesh = ps.device_mesh[ps.distconv_dim_names] num_spatial_dims = len(ps.shard_dim) trainer.ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims - global_batch_size = config.batch_size * ( - world_size // math.prod(config.dc_num_shards) - ) + total_shards = math.prod(config.dc_num_shards) + global_batch_size = config.batch_size * (world_size // total_shards) + ddp_ranks = world_size // total_shards if rank == 0: log.info( f"Effective global batch size = {global_batch_size} " f"(batch_size={config.batch_size} * " - f"(world_size={world_size} / prod(dc_num_shards)={math.prod(config.dc_num_shards)}))" + f"(world_size={world_size} / prod(dc_num_shards)={total_shards}))" + ) + log.info( + f"DDP ranks = {ddp_ranks} " + f"world_size={world_size} // prod(dc_num_shards)={total_shards}" ) if global_batch_size > trainer.n_train: raise ValueError( From 9cb7b71ad1aeedd4e3edb0cb0f829f32e423a645 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Wed, 6 May 2026 10:25:44 -0700 Subject: [PATCH 34/43] Fix usage of `fract_base_dir` and `val_split` (#63) * bf16 and more fp32 sections for dice * Refactor * ruff * fix merge artifact * cosine sched * Validation needs smaller batch size otherwise val_dice can be 0. * config * Update benchmark_default.yml * fix fractal base dir * fix config * lint * Update get_dataset.py --- ScaFFold/cli.py | 10 ++++++++++ ScaFFold/configs/benchmark_default.yml | 6 +++--- ScaFFold/datagen/category_search.py | 7 ++++--- ScaFFold/datagen/get_dataset.py | 1 + ScaFFold/datagen/instance.py | 7 +++---- ScaFFold/datagen/volumegen.py | 5 ++--- ScaFFold/utils/config_utils.py | 3 +++ ScaFFold/worker.py | 10 ++++++++-- 8 files changed, 34 insertions(+), 15 deletions(-) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 461c338..469c71e 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -116,6 +116,11 @@ def main(): benchmark_parser.add_argument( "--base-run-dir", type=str, help="Subfolder of $(pwd) in which to run jobs." ) + benchmark_parser.add_argument( + "--fract-base-dir", + type=str, + help="Base directory for fractal IFS and instances.", + ) benchmark_parser.add_argument( "--n-categories", type=int, @@ -240,6 +245,11 @@ def main(): Path(combined_config["dataset_dir"]).resolve() ) + if "fract_base_dir" in combined_config and combined_config["fract_base_dir"]: + combined_config["fract_base_dir"] = str( + Path(combined_config["fract_base_dir"]).resolve() + ) + # Calculate these variables after override combined_config["vol_size"] = pow(2, combined_config["problem_scale"]) combined_config["point_num"] = int(combined_config["vol_size"] ** 3 / 256) diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index ac54d54..4e44d2e 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -17,11 +17,11 @@ checkpoint_interval: -1 # Checkpoint every C epochs; set to -1 to dis # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. -val_split: 25 # In percent. +val_split: 30 # In percent. epochs: -1 # Number of training epochs. -starting_learning_rate: 0.1 # Initial learning rate for training. +starting_learning_rate: 0.01 # Initial learning rate for training. min_learning_rate: 0.001 # Minimum learning rate for CosineAnnealingWarmRestarts. -T_0: 100 # Epochs in the first cosine restart cycle. +T_0: 100 # Epochs in the first cosine restart cycle. T_mult: 2 # Restart cycle growth factor. disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. diff --git a/ScaFFold/datagen/category_search.py b/ScaFFold/datagen/category_search.py index 1730bd1..a7dbc7a 100644 --- a/ScaFFold/datagen/category_search.py +++ b/ScaFFold/datagen/category_search.py @@ -195,9 +195,10 @@ def main(config: Config) -> None: print(f"MPI size = {size}") # Setup directories - repo_src_path = config.library_root - fracts_sub_dir = f"/var{config.variance_threshold}" - fracts_write_dir = f"{repo_src_path}/fractals{fracts_sub_dir}/3DIFS_param" + fracts_sub_dir = f"var{config.variance_threshold}" + fracts_write_dir = os.path.join( + config.fract_base_dir, fracts_sub_dir, "3DIFS_param" + ) if rank == 0: print(f"Writing fractals to {fracts_write_dir}") if os.path.exists(fracts_write_dir) and config.datagen_from_scratch: diff --git a/ScaFFold/datagen/get_dataset.py b/ScaFFold/datagen/get_dataset.py index 885377b..c7ffaf9 100644 --- a/ScaFFold/datagen/get_dataset.py +++ b/ScaFFold/datagen/get_dataset.py @@ -38,6 +38,7 @@ "seed", "variance_threshold", "n_fracts_per_vol", + "val_split", ] diff --git a/ScaFFold/datagen/instance.py b/ScaFFold/datagen/instance.py index f8cf651..a14c2fb 100644 --- a/ScaFFold/datagen/instance.py +++ b/ScaFFold/datagen/instance.py @@ -80,11 +80,10 @@ def main(config: Config): print(f"MPI size = {size}") # Setup directories - repo_src_path = config.library_root fracts_sub_dir = f"var{config.variance_threshold}" - fracts_read_dir = f"{repo_src_path}/fractals/{fracts_sub_dir}/3DIFS_param" - instance_write_dir = ( - f"{repo_src_path}fractals/{fracts_sub_dir}/instances/np{config.point_num}" + fracts_read_dir = os.path.join(config.fract_base_dir, fracts_sub_dir, "3DIFS_param") + instance_write_dir = os.path.join( + config.fract_base_dir, fracts_sub_dir, "instances", f"np{config.point_num}" ) if rank == 0: print( diff --git a/ScaFFold/datagen/volumegen.py b/ScaFFold/datagen/volumegen.py index efee01e..b268aa7 100644 --- a/ScaFFold/datagen/volumegen.py +++ b/ScaFFold/datagen/volumegen.py @@ -165,7 +165,7 @@ def main(config: Dict): fractal_colors = np.random.rand(max(config.n_categories, n_fracts_per_vol), 3) grid_size = math.floor(config.vol_size * config.scale) - library_root = str(config.library_root) + fract_base_dir = str(config.fract_base_dir) # Generation loop start_time = time.time() @@ -197,8 +197,7 @@ def main(config: Dict): ) point_cloud_path = os.path.join( - library_root, - "fractals", + fract_base_dir, instances_dir, f"{curr_category:06d}", f"{curr_category:06d}_{curr_instance:04d}.npy", diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 2c73a6c..744004a 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -32,6 +32,9 @@ def __init__(self, config_dict): self.dataset_dir = str( Path(config_dict.get("dataset_dir", "datasets/")).resolve() ) + self.fract_base_dir = str( + Path(config_dict.get("fract_base_dir", "fractals/")).resolve() + ) self.job_name = config_dict.get("job_name", "benchmark") self.n_categories = config_dict["n_categories"] self.problem_scale = config_dict["problem_scale"] diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 16087c8..168bd85 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -230,10 +230,16 @@ def main(kwargs_dict: dict = {}): f"DDP ranks = {ddp_ranks} " f"world_size={world_size} // prod(dc_num_shards)={total_shards}" ) + too_small_splits = [] if global_batch_size > trainer.n_train: + too_small_splits.append(f"training n_train={trainer.n_train}") + if global_batch_size > trainer.n_val: + too_small_splits.append(f"validation n_val={trainer.n_val}") + if too_small_splits: raise ValueError( - "Effective global batch size exceeds available training samples: " - f"global_batch_size={global_batch_size}, n_train={trainer.n_train}, " + "Effective global batch size exceeds available samples: " + f"global_batch_size={global_batch_size}, " + f"{', '.join(too_small_splits)}, " f"batch_size={config.batch_size}, world_size={world_size}, " f"dc_num_shards={config.dc_num_shards}" ) From 23dc95f24db0e7ede5ec6df042487f8ce73f8574 Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Wed, 6 May 2026 12:49:52 -0700 Subject: [PATCH 35/43] Fix class imbalance in CE loss (#56) * use class weights in CE loss to make background less dominant; calc weights at trainer init * ruff * missing import * ruff * ruff * fix missing volume dtype * remove default ce_weight_num_samples Co-authored-by: Michael McKinsey * remove default ce_weight_num_samples in trainer Co-authored-by: Michael McKinsey * move ce loss helpers to losses.py * sample by fraction of total rather than hard number * ruff --------- Co-authored-by: Patrick Miles Co-authored-by: Michael McKinsey --- ScaFFold/configs/benchmark_default.yml | 1 + ScaFFold/configs/benchmark_testing.yml | 1 + ScaFFold/utils/config_utils.py | 3 + ScaFFold/utils/evaluate.py | 30 ++--- ScaFFold/utils/losses.py | 166 +++++++++++++++++++++++++ ScaFFold/utils/trainer.py | 80 ++++++------ 6 files changed, 220 insertions(+), 61 deletions(-) create mode 100644 ScaFFold/utils/losses.py diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 4e44d2e..541d664 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -34,5 +34,6 @@ checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpo loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter warmup_batches: 5 # How many warmup batches per rank to run before training. +ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. target_dice: 0.95 diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index 8b72435..5167de1 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -34,5 +34,6 @@ checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpo loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter warmup_batches: 5 # How many warmup batches per rank to run before training. +ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. target_dice: 0.95 diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 744004a..36f1603 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -74,6 +74,9 @@ def __init__(self, config_dict): self.checkpoint_dir = config_dict["checkpoint_dir"] self.normalize = config_dict["normalize"] self.warmup_batches = config_dict.get("warmup_batches") + self.ce_weight_sample_fraction = config_dict.get( + "ce_weight_sample_fraction", 0.1 + ) self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 2fd3eb1..67b01da 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -17,11 +17,9 @@ from distconv import DCTensor from tqdm import tqdm -from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE -from ScaFFold.utils.dice_score import ( - SpatialAllReduce, - compute_sharded_dice, -) +from ScaFFold.utils.data_types import AMP_DTYPE +from ScaFFold.utils.dice_score import compute_sharded_dice +from ScaFFold.utils.losses import compute_sharded_cross_entropy_loss from ScaFFold.utils.perf_measure import annotate @@ -56,6 +54,7 @@ def foreground_dice_stats(dice_scores): with torch.autocast(**autocast_kwargs): val_loss_epoch = 0.0 + class_weights = getattr(criterion, "weight", None) for batch in tqdm( dataloader, total=num_val_batches, @@ -94,22 +93,15 @@ def foreground_dice_stats(dice_scores): # Calculate CE and Dice loss in single precision for numerical stability. with torch.autocast(device_type=autocast_device_type, enabled=False): - # Compute global CE loss from sharded CE loss - local_ce_sum = F.cross_entropy( - local_preds.float(), local_labels, reduction="sum" + CE_loss = compute_sharded_cross_entropy_loss( + local_preds, + local_labels, + spatial_mesh, + parallel_strategy.num_shards, + autocast_device_type, + class_weights, ) - global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=VOLUME_DTYPE, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, spatial_mesh - ) - CE_loss = global_ce_sum / global_total_voxels - # Compute global dice loss from sharded dice loss mask_pred_probs = F.softmax(local_preds.float(), dim=1) mask_true_onehot = ( F.one_hot(local_labels, n_categories + 1) diff --git a/ScaFFold/utils/losses.py b/ScaFFold/utils/losses.py new file mode 100644 index 0000000..3869b8b --- /dev/null +++ b/ScaFFold/utils/losses.py @@ -0,0 +1,166 @@ +# Copyright (c) 2014-2026, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LBANN/ScaFFold. +# +# SPDX-License-Identifier: (Apache-2.0) + +import math + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from ScaFFold.utils.dice_score import SpatialAllReduce + + +def _sample_ce_weight_indices(n_train, sample_fraction): + """Pick a small, deterministic subset of masks to estimate CE weights.""" + if n_train <= 0: + return [] + + if sample_fraction is None: + sample_fraction = 0.1 + + sample_count = min( + max(math.ceil(n_train * float(sample_fraction)), 1), + n_train, + ) + if sample_count == n_train: + return list(range(n_train)) + + return torch.linspace(0, n_train - 1, steps=sample_count).long().tolist() + + +def _compute_ce_class_weights( + train_set, + n_train, + n_categories, + device, + sample_fraction=0.1, + dist_enabled=False, + world_rank=0, + log=None, +): + """ + Estimate background vs foreground CE weights from a few training masks. + + Background keeps its own inverse-frequency weight, and every non-zero + fractal class shares the foreground weight derived from the aggregate + non-empty voxel count. + """ + + num_classes = n_categories + 1 + class_weights = torch.ones(num_classes, device=device, dtype=torch.float32) + + if n_train == 0: + if log is not None: + log.warning( + "Training set is empty while computing CE class weights. Falling back to uniform weights." + ) + return class_weights + + sample_indices = _sample_ce_weight_indices(n_train, sample_fraction) + sampled_class_counts = torch.zeros(num_classes, dtype=torch.long) + + for sample_idx in sample_indices: + mask = train_set[sample_idx]["mask"] + sampled_class_counts += torch.bincount(mask.reshape(-1), minlength=num_classes) + + # The dataset may already return only this rank's local spatial shard, + # so combine per-rank counts before deriving the global CE weights. + sampled_class_counts = sampled_class_counts.to(device=device) + if dist_enabled: + dist.all_reduce(sampled_class_counts, op=dist.ReduceOp.SUM) + + background_voxels = int(sampled_class_counts[0].item()) + foreground_voxels = int(sampled_class_counts[1:].sum().item()) + + if background_voxels > 0 and foreground_voxels > 0: + total_voxels = background_voxels + foreground_voxels + class_weights[0] = total_voxels / background_voxels + class_weights[1:] = total_voxels / foreground_voxels + elif log is not None: + log.warning( + "Sampled masks did not contain both background and foreground voxels. Falling back to uniform CE weights." + ) + + if log is not None and (not dist_enabled or world_rank == 0): + log.info( + f"CE weights estimated from {len(sample_indices)} training masks " + f"(sample_fraction={sample_fraction}, indices={sample_indices}): " + f"background_voxels={background_voxels} " + f"foreground_voxels={foreground_voxels} " + f"weights={class_weights.detach().cpu().tolist()}" + ) + + return class_weights + + +def compute_sharded_cross_entropy_loss( + local_preds, + local_labels, + spatial_mesh, + _num_shards, + device_type, + class_weights=None, +): + """ + Compute the CE loss for a spatially sharded volume. + + Each rank only sees a local spatial shard, so we cannot use the local + `reduction="mean"` result directly. Instead we: + 1. compute the local CE numerator with `reduction="sum"`, + 2. build the correct global denominator, + 3. all-reduce across the spatial mesh, and + 4. divide to recover the same value we would get from a non-sharded tensor. + + When `class_weights` is provided, PyTorch's CE "mean" divides by the sum of + the target weights, not the raw voxel count, so we reproduce that behavior + explicitly here. + """ + + autocast_device = device_type if device_type != "mps" else "cpu" + with torch.autocast(autocast_device, enabled=False): + # Accumulate CE in full precision. Using reduction="sum" gives us the + # numerator of the final global mean; if class weights are present, + # PyTorch applies the target-class weight to each voxel here. + local_ce_sum = F.cross_entropy( + local_preds.float(), + local_labels, + weight=class_weights, + reduction="sum", + ) + + if class_weights is None: + # Sum the actual local voxel counts across spatial shards. We use + # an all-reduced count instead of numel()*num_shards because shard + # sizes can differ at chunk boundaries. + local_voxel_count = local_ce_sum.new_tensor(float(local_labels.numel())) + global_normalizer = SpatialAllReduce.apply(local_voxel_count, spatial_mesh) + else: + # Weighted CE divides by sum(weight[target_i]) over all voxels. + # Build that denominator from the local label histogram, then + # all-reduce it across the spatial mesh. + local_class_counts = torch.bincount( + local_labels.reshape(-1), minlength=class_weights.numel() + ).to(dtype=local_ce_sum.dtype) + local_weight_sum = torch.dot( + local_class_counts, class_weights.to(dtype=local_ce_sum.dtype) + ) + global_normalizer = SpatialAllReduce.apply(local_weight_sum, spatial_mesh) + + # Sum the local CE numerators from each spatial shard to get the global CE + # numerator, then divide by the matching global denominator. + global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) + # Clamp to avoid a divide-by-zero in degenerate cases. + return global_ce_sum / global_normalizer.clamp_min( + torch.finfo(global_ce_sum.dtype).eps + ) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 654cc19..1a1d2e0 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -12,6 +12,7 @@ # # SPDX-License-Identifier: (Apache-2.0) +# Standard library import math import os import shutil @@ -30,14 +31,15 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE -from ScaFFold.utils.dice_score import ( - SpatialAllReduce, - compute_sharded_dice, -) +from ScaFFold.utils.dice_score import compute_sharded_dice from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size # Local from ScaFFold.utils.evaluate import evaluate +from ScaFFold.utils.losses import ( + _compute_ce_class_weights, + compute_sharded_cross_entropy_loss, +) from ScaFFold.utils.perf_measure import adiak_value, begin_code_region, end_code_region from ScaFFold.utils.utils import gather_and_print_mem @@ -72,6 +74,7 @@ def __init__(self, model, config, device, log): self.scheduler = None self.grad_scaler = None self.criterion = None + self.ce_class_weights = None self.global_step = 0 self.start_epoch = -1 self.ps = getattr(self.config, "_parallel_strategy", None) @@ -220,11 +223,24 @@ def setup_training_components(self): self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_grad_scaler) # Set up loss function - self.criterion = ( - nn.CrossEntropyLoss() - if self.config.n_categories + 1 > 1 - else nn.BCEWithLogitsLoss() - ) + if self.config.n_categories + 1 > 1: + ce_class_weights = _compute_ce_class_weights( + train_set=self.train_set, + n_train=self.n_train, + n_categories=self.config.n_categories, + device=self.device, + sample_fraction=self.config.ce_weight_sample_fraction, + dist_enabled=self.config.dist, + world_rank=self.world_rank, + log=self.log, + ) + self.criterion = nn.CrossEntropyLoss(weight=ce_class_weights).to( + self.device + ) + else: + self.criterion = nn.BCEWithLogitsLoss().to(self.device) + if isinstance(self.criterion, nn.CrossEntropyLoss): + self.ce_class_weights = self.criterion.weight self.log.info( f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, AMP dtype: {self.amp_dtype}, Gradient Scaler Enabled: {self.use_grad_scaler}" @@ -464,24 +480,15 @@ def warmup(self): # Calculate CE and Dice loss in single precision for numerical stability. with torch.autocast(**self._autocast_kwargs(enabled=False)): - # Compute global CE loss from sharded CE loss - local_ce_sum = F.cross_entropy( - local_preds.float(), local_labels, reduction="sum" + loss_ce = compute_sharded_cross_entropy_loss( + local_preds, + local_labels, + self.spatial_mesh, + self.config.dc_num_shards, + self.amp_device_type, + self.ce_class_weights, ) - global_ce_sum = SpatialAllReduce.apply( - local_ce_sum, self.spatial_mesh - ) - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=VOLUME_DTYPE, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, self.spatial_mesh - ) - loss_ce = global_ce_sum / global_total_voxels - # Compute global dice loss from sharded dice loss local_preds_softmax = F.softmax(local_preds.float(), dim=1) local_labels_one_hot = ( F.one_hot( @@ -659,26 +666,15 @@ def train(self): # Calculate CE and Dice loss in single precision for numerical stability. with torch.autocast(**self._autocast_kwargs(enabled=False)): - # Compute global CE loss from sharded CE loss - local_ce_sum = F.cross_entropy( - local_preds.float(), + loss_ce = compute_sharded_cross_entropy_loss( + local_preds, local_labels, - reduction="sum", - ) - global_ce_sum = SpatialAllReduce.apply( - local_ce_sum, self.spatial_mesh - ) - local_voxel_count = torch.tensor( - float(local_labels.numel()), - device=local_labels.device, - dtype=VOLUME_DTYPE, - ) - global_total_voxels = SpatialAllReduce.apply( - local_voxel_count, self.spatial_mesh + self.spatial_mesh, + self.config.dc_num_shards, + self.amp_device_type, + self.ce_class_weights, ) - loss_ce = global_ce_sum / global_total_voxels - # Compute global dice loss from sharded dice loss local_preds_softmax = F.softmax( local_preds.float(), dim=1 ) From 1db5916cb84ce1ecb3257e5c393d3d6de771d90b Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Mon, 11 May 2026 16:16:30 -0700 Subject: [PATCH 36/43] fix dtypes for torch (#65) --- ScaFFold/utils/data_types.py | 5 ++++- ScaFFold/utils/trainer.py | 10 ++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/ScaFFold/utils/data_types.py b/ScaFFold/utils/data_types.py index b555811..ef1515d 100644 --- a/ScaFFold/utils/data_types.py +++ b/ScaFFold/utils/data_types.py @@ -19,7 +19,10 @@ # Masks are values 0 <= x <= n_categories MASK_DTYPE = np.uint16 # Volumes/img are 0 <= x <= 1 -VOLUME_DTYPE = np.float32 +VOLUME_DTYPE_NAME = "float32" +VOLUME_NP_DTYPE = getattr(np, VOLUME_DTYPE_NAME) +VOLUME_TORCH_DTYPE = getattr(torch, VOLUME_DTYPE_NAME) +VOLUME_DTYPE = VOLUME_NP_DTYPE # Shared AMP dtype selection for torch.autocast. AMP_DTYPE = torch.bfloat16 diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 1a1d2e0..3add912 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -30,7 +30,7 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec -from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE +from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_TORCH_DTYPE from ScaFFold.utils.dice_score import compute_sharded_dice from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size @@ -436,7 +436,7 @@ def warmup(self): images = images.to( device=self.device, - dtype=VOLUME_DTYPE, + dtype=VOLUME_TORCH_DTYPE, memory_format=torch.channels_last_3d, non_blocking=True, ) @@ -611,7 +611,7 @@ def train(self): begin_code_region("image_to_device") images = images.to( device=self.device, - dtype=VOLUME_DTYPE, + dtype=VOLUME_TORCH_DTYPE, memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first) non_blocking=True, ) @@ -749,7 +749,9 @@ def train(self): self.config.n_categories, self.config._parallel_strategy, ) - dice_info = torch.tensor([dice_sum, numsamples], dtype=VOLUME_DTYPE) + dice_info = torch.tensor( + [dice_sum, numsamples], dtype=VOLUME_TORCH_DTYPE + ) if self.config.dist: dice_info = dice_info.to(device=self.device) torch.distributed.all_reduce( From d3f386a81a34607f4249dda24e67d327332439bf Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Wed, 27 May 2026 09:06:13 -0700 Subject: [PATCH 37/43] Enable rocm/7.2.1 (#67) * working rocm/7.2.1 * Change wheel source and use torch2.12 * Update install-tuolumne-torchpypi.sh * Update scaffold-tuolumne-torchpypi.job --- pyproject.toml | 2 +- scripts/install-rccl.sh | 31 ------------------------- scripts/install-tuolumne-torchpypi.sh | 7 +++--- scripts/install-tuolumne.sh | 2 +- scripts/scaffold-tuolumne-torchpypi.job | 11 ++++----- scripts/scaffold-tuolumne.job | 6 ++--- 6 files changed, 14 insertions(+), 45 deletions(-) delete mode 100644 scripts/install-rccl.sh diff --git a/pyproject.toml b/pyproject.toml index 4f8ed54..70eb0d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ cuda = [ "mpi4py==4.1.1", ] rocm = [ - "torch==2.10.0+rocm7.1", + "torch==2.12.0+rocm7.2", "mpi4py==4.1.1", ] rocmwci = [ diff --git a/scripts/install-rccl.sh b/scripts/install-rccl.sh deleted file mode 100644 index a84add3..0000000 --- a/scripts/install-rccl.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -# Exit if target directory already exists -if [ -d "aws-ofi-nccl.git" ]; then - echo "Directory 'aws-ofi-nccl.git' already exists. Exiting to avoid overwrite." - return 1 2>/dev/null || exit 1 -fi - -rocm_version=7.1.1 - -module swap PrgEnv-cray PrgEnv-gnu -module load rocm/$rocm_version - -git clone --recursive --branch v1.18.0 https://github.com/aws/aws-ofi-nccl.git aws-ofi-nccl.git - -cd aws-ofi-nccl.git - -installdir=$(pwd)/install - -./autogen.sh - -export LD_LIBRARY_PATH=$PWD/../rccl/install/lib:/opt/rocm-$rocm_version/lib:$LD_LIBRARY_PATH - -#CC=hipcc CXX=hipcc CFLAGS=-I$PWD/../rccl/install/include/rccl ./configure \ -./configure \ - --with-libfabric=/opt/cray/libfabric/2.1 \ - --with-rocm=$ROCM_PATH \ - --prefix=$installdir - -make -make install \ No newline at end of file diff --git a/scripts/install-tuolumne-torchpypi.sh b/scripts/install-tuolumne-torchpypi.sh index 26e7a22..91876c4 100644 --- a/scripts/install-tuolumne-torchpypi.sh +++ b/scripts/install-tuolumne-torchpypi.sh @@ -1,4 +1,5 @@ -. install-rccl.sh ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv-tuo-pypi && source .venvs/scaffoldvenv-tuo-pypi/bin/activate && pip install --upgrade pip -ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi -pip install -e .[rocm] --prefix=.venvs/scaffoldvenv-tuo-pypi --extra-index-url https://download.pytorch.org/whl/rocm7.1 2>&1 | tee install.log +ml cce/21.0.1 cray-mpich/9.1.0 rocm/7.2.1 rccl/fast-env-slows-mpi +pip install -e .[rocm] --find-links https://download.pytorch.org/whl/torch/ --find-links https://download.pytorch.org/whl/triton-rocm/ 2>&1 | tee install.log +# libmpi.so.12 does not exist => ls /opt/cray/pe/lib64/ | grep libmpi +patchelf --replace-needed libmpi.so.12 libmpi_gnu.so.12 .venvs/scaffoldvenv-tuo-pypi/lib/python3.11/site-packages/mpi4py/MPI.mpich.cpython-311-x86_64-linux-gnu.so diff --git a/scripts/install-tuolumne.sh b/scripts/install-tuolumne.sh index 339fd8f..4b026f8 100644 --- a/scripts/install-tuolumne.sh +++ b/scripts/install-tuolumne.sh @@ -1,6 +1,6 @@ ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv-tuo && source .venvs/scaffoldvenv-tuo/bin/activate && pip install --upgrade pip ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi -pip install -e .[rocmwci] --prefix=.venvs/scaffoldvenv-tuo 2>&1 | tee install.log +pip install -e .[rocmwci] 2>&1 | tee install.log # Needed until new wheel exists for torch using mpich 9.1.0 TORCH_LIB_DIR=".venvs/scaffoldvenv-tuo/lib/python3.11/site-packages/torch/lib" OLD="libmpi_gnu_112.so.12" diff --git a/scripts/scaffold-tuolumne-torchpypi.job b/scripts/scaffold-tuolumne-torchpypi.job index c0b0780..0387e5e 100644 --- a/scripts/scaffold-tuolumne-torchpypi.job +++ b/scripts/scaffold-tuolumne-torchpypi.job @@ -5,21 +5,20 @@ # flux: -g=1 # flux: -t 60m # flux: -qpdebug -# flux: -B fractale +# flux: -B flask -ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi +ml cce/21.0.1 cray-mpich/9.1.0 rocm/7.2.1 rccl/fast-env-slows-mpi . .venvs/scaffoldvenv-tuo-pypi/bin/activate -# Use ccl plugin that we manually built with install-rccl.sh -export NCCL_NET_PLUGIN=../aws-ofi-nccl.git/install/lib/librccl-net.so +export NCCL_NET_PLUGIN=/collab/usr/global/tools/rccl/toss_4_x86_64_ib_cray/rocm-7.2.0/install/lib/librccl-net.so # Disable direct convolution benchmarking (should speedup warmup by a significant amount, does the below three options together) # export MIOPEN_DEBUG_CONV_DIRECT=0 # Disable direct naive convolution benchmarking (naive_conv_ab_nonpacked_fwd_ndhwc_half_double_half.kd) export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD=0 # Disable naive_conv_ab_nonpacked_bwd_ndhwc_half_double_half.kd -# export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD=0 +export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD=0 # Disable naive_conv_ab_nonpacked_wrw_ndhwc_half_double_half.kd export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_WRW=0 @@ -29,4 +28,4 @@ torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/co #export PROFILE_TORCH=ON torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml -#torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml +# torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml diff --git a/scripts/scaffold-tuolumne.job b/scripts/scaffold-tuolumne.job index d3b9e05..bbbba33 100644 --- a/scripts/scaffold-tuolumne.job +++ b/scripts/scaffold-tuolumne.job @@ -5,7 +5,7 @@ # flux: -g=1 # flux: -t 60m # flux: -qpdebug -# flux: -B fractale +# flux: -B flask ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi @@ -20,7 +20,7 @@ export LD_PRELOAD="/opt/rocm-7.1.1/llvm/lib/libomp.so /opt/cray/pe/mpich/9.1.0/o # Disable direct naive convolution benchmarking (naive_conv_ab_nonpacked_fwd_ndhwc_half_double_half.kd) export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD=0 # Disable naive_conv_ab_nonpacked_bwd_ndhwc_half_double_half.kd -# export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD=0 +export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD=0 # Disable naive_conv_ab_nonpacked_wrw_ndhwc_half_double_half.kd export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_WRW=0 @@ -30,4 +30,4 @@ torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/co #export PROFILE_TORCH=ON torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml -#torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml +# torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml From 2e8a160163364b9191e7c98ebd6b1fad00f31551 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 28 May 2026 11:25:02 -0700 Subject: [PATCH 38/43] Add metadata (#70) * Add adiak metadata * Update worker.py --- ScaFFold/worker.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 168bd85..fde4582 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -38,6 +38,7 @@ initialize_dist, ) from ScaFFold.utils.perf_measure import ( + adiak_value, annotate, begin_code_region, end_code_region, @@ -220,6 +221,10 @@ def main(kwargs_dict: dict = {}): total_shards = math.prod(config.dc_num_shards) global_batch_size = config.batch_size * (world_size // total_shards) ddp_ranks = world_size // total_shards + adiak_value("global_batch_size", global_batch_size) + adiak_value("ddp_ranks", ddp_ranks) + adiak_value("total_shards", total_shards) + adiak_value("num_spatial_dims", num_spatial_dims) if rank == 0: log.info( f"Effective global batch size = {global_batch_size} " From b362f33deb6a76127bdf03da2b01671e220aecf7 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 28 May 2026 11:25:49 -0700 Subject: [PATCH 39/43] Enable timing minibatch (#66) * fix dtypes for torch * Add per minibatch timer * cleanup --- ScaFFold/utils/trainer.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 3add912..e981e8a 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -349,18 +349,18 @@ def cleanup_or_resume(self): with open(self.outfile_path, "a", newline="") as outfile: outfile.write(",".join(headers) + "\n") - def _truncate_stats_file(self, start_epoch): + def _truncate_stats_file(self, start_epoch, path=None): """ Scans the stats file and truncates it at the first occurrence of an epoch >= start_epoch. This is O(1) memory and safe for large logs. """ - self.log.info( - f"Truncating {self.outfile_path} to remove epochs >= {start_epoch}" - ) + if path is None: + path = self.outfile_path + self.log.info(f"Truncating {path} to remove epochs >= {start_epoch}") try: # Open in read+update mode ('r+') to allow seeking and truncating - with open(self.outfile_path, "r+") as f: + with open(path, "r+") as f: header = f.readline() if not header: return @@ -401,7 +401,7 @@ def _truncate_stats_file(self, start_epoch): pass except Exception as e: - self.log.warning(f"Failed to truncate stats file: {e}") + self.log.warning(f"Failed to truncate stats file {path}: {e}") def _get_memsize(self, tensor, tensor_label: str, verbosity: int = 0): """Log size of tensor in memory""" @@ -604,7 +604,11 @@ def train(self): disable=True if self.world_rank != 0 else False, ) as pbar: begin_code_region("batch_loop") - for batch in self.train_loader: + for batch_idx, batch in enumerate(self.train_loader): + time_minibatch = batch_idx == 0 and self.world_rank == 0 + if time_minibatch: + minibatch_start_time = time.perf_counter() + # Load initial samples and labels images, true_masks = batch["image"], batch["mask"] @@ -724,6 +728,13 @@ def train(self): self.global_step += 1 # Stay on GPU epoch_loss += loss.detach() + if time_minibatch: + # This sync has some potential performance impact + # TODO: Would be better to measure this with Caliper, which uses CUDA events. + torch.cuda.synchronize(self.device) + minibatch_time_s = ( + time.perf_counter() - minibatch_start_time + ) end_code_region("update_loss") end_code_region("batch_loop") @@ -791,7 +802,7 @@ def train(self): ) outfile.flush() print( - f"Epoch {epoch} completed in {epoch_duration} seconds. Total train time so far: {time.time() - start}" + f"Epoch {epoch} completed in {epoch_duration} seconds. Total train time so far: {time.time() - start}. Rank 0 first batch minibatch_time_s={minibatch_time_s:.6f}." ) # From 362282f6cdcfa477a7e7fdb6b5b0164cd2fe9dd7 Mon Sep 17 00:00:00 2001 From: Patrick R Miles <78748866+PatrickRMiles@users.noreply.github.com> Date: Thu, 28 May 2026 15:08:41 -0700 Subject: [PATCH 40/43] fix restart epoch bug in trainer (#72) Co-authored-by: Patrick Miles --- ScaFFold/utils/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index e981e8a..1a710ff 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -568,7 +568,7 @@ def train(self): Execute model training """ - epoch = 1 + epoch = self.start_epoch dice_score_train = 0 with open(self.outfile_path, "a", newline="") as outfile: start = time.time() From 5231a8f430c143a5424eac92998cca578690f525 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 4 Jun 2026 10:13:32 -0700 Subject: [PATCH 41/43] Update scaffold-tuolumne.job (#74) --- scripts/scaffold-tuolumne.job | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/scaffold-tuolumne.job b/scripts/scaffold-tuolumne.job index bbbba33..a22d8c6 100644 --- a/scripts/scaffold-tuolumne.job +++ b/scripts/scaffold-tuolumne.job @@ -13,7 +13,8 @@ ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi # (1) Avoid libmagma error # (2) Removing libmpi may cause segfault on mpi4py import -export LD_PRELOAD="/opt/rocm-7.1.1/llvm/lib/libomp.so /opt/cray/pe/mpich/9.1.0/ofi/gnu/11.2/lib/libmpi_gnu.so.12" +# (3-5) undefined symbol: cblas_gemm_f16f16f32 +export LD_PRELOAD="/opt/rocm-7.1.1/llvm/lib/libomp.so /opt/cray/pe/mpich/9.1.0/ofi/gnu/11.2/lib/libmpi_gnu.so.12 /opt/intel/oneapi/mkl/2024.2/lib/libmkl_core.so.2 /opt/intel/oneapi/mkl/2024.2/lib/libmkl_gnu_thread.so.2 /opt/intel/oneapi/mkl/2024.2/lib/libmkl_intel_lp64.so.2" # Disable direct convolution benchmarking (should speedup warmup by a significant amount, does the below three options together) # export MIOPEN_DEBUG_CONV_DIRECT=0 From 55292fda743e015ce9e09802a945e0ccdac8a4ad Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 4 Jun 2026 14:55:42 -0700 Subject: [PATCH 42/43] Use snapshot to prevent warmup from affecting training and refactor warmup (#68) * Use snapshot to prevent warmup from influencing training * Fix validation warmup and increase default warmup * Refactor shared trainer logic in warmup * better default --- ScaFFold/configs/benchmark_default.yml | 6 +- ScaFFold/utils/checkpointing.py | 61 +++- ScaFFold/utils/evaluate.py | 30 +- ScaFFold/utils/trainer.py | 414 +++++++++++-------------- 4 files changed, 258 insertions(+), 253 deletions(-) diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 541d664..1b0310c 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -19,8 +19,8 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 30 # In percent. epochs: -1 # Number of training epochs. -starting_learning_rate: 0.01 # Initial learning rate for training. -min_learning_rate: 0.001 # Minimum learning rate for CosineAnnealingWarmRestarts. +starting_learning_rate: 0.001 # Initial learning rate for training. +min_learning_rate: 0.0001 # Minimum learning rate for CosineAnnealingWarmRestarts. T_0: 100 # Epochs in the first cosine restart cycle. T_mult: 2 # Restart cycle growth factor. disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR. @@ -33,7 +33,7 @@ framework: "torch" # The DL framework to train with. Only valid checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter -warmup_batches: 5 # How many warmup batches per rank to run before training. +warmup_batches: 64 # How many warmup batches per rank to run before training. ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. target_dice: 0.95 diff --git a/ScaFFold/utils/checkpointing.py b/ScaFFold/utils/checkpointing.py index 5a65b09..2a06a3c 100644 --- a/ScaFFold/utils/checkpointing.py +++ b/ScaFFold/utils/checkpointing.py @@ -105,6 +105,46 @@ def wait_for_save(self): self._log(f"Background save failed with error: {e}") self.future = None + def snapshot_training_state(self) -> Dict[str, Any]: + """Capture mutable in-memory training state without writing a checkpoint.""" + model_ref = self.model.module if hasattr(self.model, "module") else self.model + return { + "model_state_dict": self._clone_state_dict(model_ref.state_dict()), + "optimizer_state_dict": self._clone_state_dict(self.optimizer.state_dict()) + if self.optimizer + else None, + "scheduler_state_dict": self._clone_state_dict(self.scheduler.state_dict()) + if self.scheduler + else None, + "grad_scaler_state_dict": self._clone_state_dict( + self.grad_scaler.state_dict() + ) + if self.grad_scaler + else None, + "model_training": model_ref.training, + **self._get_rng_snapshot(), + } + + def restore_training_state(self, snapshot: Dict[str, Any]) -> None: + """Restore an in-memory training snapshot.""" + model_ref = self.model.module if hasattr(self.model, "module") else self.model + model_ref.load_state_dict(snapshot["model_state_dict"]) + + if self.optimizer and snapshot.get("optimizer_state_dict") is not None: + self.optimizer.load_state_dict(snapshot["optimizer_state_dict"]) + + if self.scheduler and snapshot.get("scheduler_state_dict") is not None: + self.scheduler.load_state_dict(snapshot["scheduler_state_dict"]) + + if self.grad_scaler and snapshot.get("grad_scaler_state_dict") is not None: + self.grad_scaler.load_state_dict(snapshot["grad_scaler_state_dict"]) + + self._restore_rng(snapshot) + model_ref.train(snapshot.get("model_training", True)) + + if self.optimizer: + self.optimizer.zero_grad(set_to_none=True) + def load_from_checkpoint(self) -> int: """Load the latest checkpoint. Returns start_epoch (default 1).""" self.wait_for_save() # Safety: don't load while writing @@ -285,6 +325,19 @@ def _transfer_dict_to_cpu(self, obj): else: return obj + def _clone_state_dict(self, obj): + """Recursively clone tensors so in-memory snapshots are isolated.""" + if torch.is_tensor(obj): + return obj.detach().clone() + elif isinstance(obj, dict): + return {k: self._clone_state_dict(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._clone_state_dict(v) for v in obj] + elif isinstance(obj, tuple): + return tuple(self._clone_state_dict(v) for v in obj) + else: + return obj + def _barrier(self): if self.dist_enabled: dist.barrier() @@ -305,7 +358,7 @@ def _log(self, msg): def _get_rng_snapshot(self) -> Dict[str, Any]: snap = {"rng_state_pytorch": torch.get_rng_state()} if torch.cuda.is_available(): - snap["rng_state_pytorch_cuda"] = torch.cuda.get_rng_state() + snap["rng_state_pytorch_cuda"] = torch.cuda.get_rng_state_all() try: snap["rng_state_numpy"] = np.random.get_state() except ImportError: @@ -321,7 +374,11 @@ def _restore_rng(self, snap: Dict[str, Any]): if "rng_state_pytorch" in snap: torch.set_rng_state(snap["rng_state_pytorch"]) if "rng_state_pytorch_cuda" in snap and torch.cuda.is_available(): - torch.cuda.set_rng_state(snap["rng_state_pytorch_cuda"]) + cuda_state = snap["rng_state_pytorch_cuda"] + if isinstance(cuda_state, list): + torch.cuda.set_rng_state_all(cuda_state) + else: + torch.cuda.set_rng_state(cuda_state) if "rng_state_numpy" in snap: np.random.set_state(snap["rng_state_numpy"]) if "rng_state_python" in snap: diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 67b01da..6b23b15 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -26,7 +26,15 @@ @annotate() @torch.inference_mode() def evaluate( - net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy + net, + dataloader, + device, + amp, + primary, + criterion, + n_categories, + parallel_strategy, + max_batches=None, ): def foreground_dice_stats(dice_scores): if dice_scores.size(1) > 1: @@ -41,6 +49,8 @@ def foreground_dice_stats(dice_scores): if amp: autocast_kwargs["dtype"] = AMP_DTYPE num_val_batches = len(dataloader) + if max_batches is not None: + num_val_batches = min(num_val_batches, max_batches) total_dice_score = 0.0 processed_batches = 0 processed_samples = 0 @@ -55,14 +65,18 @@ def foreground_dice_stats(dice_scores): with torch.autocast(**autocast_kwargs): val_loss_epoch = 0.0 class_weights = getattr(criterion, "weight", None) - for batch in tqdm( - dataloader, - total=num_val_batches, - desc="Validation round", - unit="batch", - leave=False, - disable=not primary, + for batch_idx, batch in enumerate( + tqdm( + dataloader, + total=num_val_batches, + desc="Validation round", + unit="batch", + leave=False, + disable=not primary, + ) ): + if batch_idx >= num_val_batches: + break image, mask_true = batch["image"], batch["mask"] image = image.to( diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 1a710ff..a1f77f5 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -412,6 +412,137 @@ def _get_memsize(self, tensor, tensor_label: str, verbosity: int = 0): tensor_memory_gb = tensor_memory_bytes / (1024**3) self.log.info(f"{tensor_label} size on GPU: {tensor_memory_gb:.2f} GB") + def _run_training_batch( + self, + batch, + *, + log_prefix="", + gather_mem_stats=False, + log_peak_mem=False, + ): + """Run one training batch and return batch size, detached loss, and dice.""" + images, true_masks = batch["image"], batch["mask"] + + begin_code_region("image_to_device") + images = images.to( + device=self.device, + dtype=VOLUME_TORCH_DTYPE, + memory_format=torch.channels_last_3d, + non_blocking=True, + ) + true_masks = true_masks.to( + device=self.device, dtype=torch.long, non_blocking=True + ).contiguous() + end_code_region("image_to_device") + if gather_mem_stats: + gather_and_print_mem(self.log, "after_batch_to_device") + + # Add a dummy channel dimension to get 5D [B, 1, D, H, W] + true_masks = true_masks.unsqueeze(1) + + # Inputs are already loaded as local shards by the dataset. + images_dc = DCTensor.from_shard(images, self.ps) + true_masks_dc = DCTensor.from_shard(true_masks, self.ps) + del images, true_masks + self._get_memsize(images_dc, "Sharded image", self.config.verbose) + + with torch.autocast(**self._autocast_kwargs()): + if gather_mem_stats: + torch.cuda.reset_peak_memory_stats() + gather_and_print_mem(self.log, "pre_forward") + begin_code_region("predict") + self.log.debug(f" {log_prefix}running forward pass") + masks_pred_dc = self.model(images_dc) + end_code_region("predict") + if gather_mem_stats: + gather_and_print_mem(self.log, "post_forward") + self.log.debug(f" {log_prefix}forward pass complete") + + # Extract the underlying PyTorch local tensors + local_preds = masks_pred_dc + local_labels_5d = true_masks_dc + + # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] + local_labels = local_labels_5d.squeeze(1) + if self.world_rank == 0: + self.log.debug(f" {log_prefix}Local Preds Shape: {local_preds.shape}") + self.log.debug( + f" {log_prefix}Local Labels Shape: {local_labels.shape}" + ) + + begin_code_region("calculate_loss") + current_mem = torch.cuda.memory_allocated() / (1024**3) + self.log.debug( + f" {log_prefix}Calculating sharded loss. Mem: {current_mem:.2f} GB." + ) + + # Calculate CE and Dice loss in single precision for numerical stability. + with torch.autocast(**self._autocast_kwargs(enabled=False)): + loss_ce = compute_sharded_cross_entropy_loss( + local_preds, + local_labels, + self.spatial_mesh, + self.config.dc_num_shards, + self.amp_device_type, + self.ce_class_weights, + ) + + local_preds_softmax = F.softmax(local_preds.float(), dim=1) + local_labels_one_hot = ( + F.one_hot(local_labels, num_classes=self.config.n_categories + 1) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_scores = compute_sharded_dice( + local_preds_softmax, + local_labels_one_hot, + self.spatial_mesh, + ) + batch_dice_score = self._foreground_dice_mean(dice_scores) + + # Sum global CE Loss and Dice loss + loss = loss_ce + (1.0 - batch_dice_score) + end_code_region("calculate_loss") + + self.log.debug( + f" {log_prefix}loss calculation complete. Proceeding to backward pass" + ) + if gather_mem_stats: + gather_and_print_mem(self.log, "pre_backward") + begin_code_region("backward") + self.grad_scaler.scale(loss).backward() + end_code_region("backward") + if gather_mem_stats: + gather_and_print_mem(self.log, "post_backward") + + begin_code_region("step_and_update") + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + self.log.debug(f" {log_prefix}backward pass complete. Stepping optimizer") + self.grad_scaler.step(self.optimizer) + if gather_mem_stats: + gather_and_print_mem(self.log, "after_optim_step") + self.grad_scaler.update() + self.optimizer.zero_grad(set_to_none=False) + end_code_region("step_and_update") + + batch_size = images_dc.shape[0] + detached_loss = loss.detach() + + # Free memory aggressively + del images_dc, true_masks_dc, masks_pred_dc + del local_preds, local_labels, local_preds_softmax, local_labels_one_hot + del loss_ce, loss + + if log_peak_mem and self.world_rank == 0: + peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) + peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) + self.log.debug( + f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", + ) + + return batch_size, detached_loss, batch_dice_score + def warmup(self): """Run warmup iterations before the main training loop.""" warmup_batches = self.config.warmup_batches @@ -421,143 +552,52 @@ def warmup(self): if self.config.dist: self.train_loader.sampler.set_epoch(0) - # Match the main training path as closely as possible. - self.model.train() - self.optimizer.zero_grad(set_to_none=False) start_warmup = time.time() max_batches = min(warmup_batches, len(self.train_loader)) - self.log.info(f"Running {max_batches} warmup batch(es) per rank") - - for batch_idx, batch in enumerate(self.train_loader): - if batch_idx >= max_batches: - break + max_val_batches = min(warmup_batches, len(self.val_loader)) + self.log.info( + f"Running {max_batches} training warmup batch(es) and {max_val_batches} validation warmup batch(es) per rank" + ) + snapshot = self.checkpoint_manager.snapshot_training_state() - images, true_masks = batch["image"], batch["mask"] + # Match the main training path as closely as possible, but roll back all + # mutable state so warmup does not affect convergence. + self.model.train() + self.optimizer.zero_grad(set_to_none=False) - images = images.to( - device=self.device, - dtype=VOLUME_TORCH_DTYPE, - memory_format=torch.channels_last_3d, - non_blocking=True, - ) - true_masks = true_masks.to( - device=self.device, dtype=torch.long, non_blocking=True - ).contiguous() - - # Add a dummy channel dimension to get 5D [B, 1, D, H, W] - true_masks = true_masks.unsqueeze(1) - - # Inputs are already loaded as local shards by the dataset. - images_dc = DCTensor.from_shard(images, self.ps) - true_masks_dc = DCTensor.from_shard(true_masks, self.ps) - self._get_memsize(images_dc, "Sharded image", self.config.verbose) - - with torch.autocast(**self._autocast_kwargs()): - # Forward on DCTensor - self.log.debug(" warmup: running forward pass") - masks_pred_dc = self.model(images_dc) - self.log.debug(" warmup: forward pass complete") - - # Extract the underlying PyTorch local tensors - local_preds = masks_pred_dc - local_labels_5d = true_masks_dc - - # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] - local_labels = local_labels_5d.squeeze(1) - if self.world_rank == 0: - self.log.debug(f" warmup: Local Preds Shape: {local_preds.shape}") - # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 - self.log.debug( - f" warmup: Local Labels Shape: {local_labels.shape}" - ) - # Should be something like [1, 128, 128, 64] + try: + for batch_idx, batch in enumerate(self.train_loader): + if batch_idx >= max_batches: + break - # --- SHARDED LOSS CALCULATION --- - current_mem = torch.cuda.memory_allocated() / (1024**3) - self.log.debug( - f" warmup: Calculating sharded loss. Mem: {current_mem:.2f} GB." + self._run_training_batch( + batch, + log_prefix="warmup: ", + log_peak_mem=True, ) - - # Calculate CE and Dice loss in single precision for numerical stability. - with torch.autocast(**self._autocast_kwargs(enabled=False)): - loss_ce = compute_sharded_cross_entropy_loss( - local_preds, - local_labels, - self.spatial_mesh, - self.config.dc_num_shards, - self.amp_device_type, - self.ce_class_weights, - ) - - local_preds_softmax = F.softmax(local_preds.float(), dim=1) - local_labels_one_hot = ( - F.one_hot( - local_labels, num_classes=self.config.n_categories + 1 - ) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_scores = compute_sharded_dice( - local_preds_softmax, local_labels_one_hot, self.spatial_mesh - ) - batch_dice_score = self._foreground_dice_mean(dice_scores) - - # Sum global CE Loss and Dice loss - loss = loss_ce + (1.0 - batch_dice_score) - - self.log.debug( - " warmup: loss calculation complete. Proceeding to backward pass" - ) - - # Backward pass - self.grad_scaler.scale(loss).backward() - self.grad_scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - self.log.debug(" warmup: backward pass complete. Stepping optimizer") - - self.grad_scaler.step(self.optimizer) - self.grad_scaler.update() - - # Free memory aggressively - del images_dc, true_masks_dc, masks_pred_dc - del ( - local_preds, - local_labels, - local_preds_softmax, - local_labels_one_hot, - ) - del loss_ce, loss - - if self.world_rank == 0: - peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) - peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) + batch_t_end = time.time() self.log.debug( - f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", + f" warmup: batch {batch_idx} completed in {batch_t_end - start_warmup} seconds" ) - batch_t_end = time.time() - self.log.debug( - f" warmup: batch {batch_idx} completed in {batch_t_end - start_warmup} seconds" - ) - # Nuke any accumulated grads so the first real step starts clean - for p in self.model.parameters(): - p.grad = None - self.optimizer.zero_grad(set_to_none=True) + if self.config.dist: + self.val_loader.sampler.set_epoch(0) - if self.config.dist: - self.val_loader.sampler.set_epoch(0) - - evaluate( - self.model, - self.val_loader, - self.device, - self.config.torch_amp, - self.world_rank == 0, - self.criterion, - self.config.n_categories, - self.config._parallel_strategy, - ) - self.model.train() + if max_val_batches > 0: + self.log.debug(" warmup: running validation warmup pass") + evaluate( + self.model, + self.val_loader, + self.device, + self.config.torch_amp, + False, + self.criterion, + self.config.n_categories, + self.config._parallel_strategy, + max_batches=max_val_batches, + ) + finally: + self.checkpoint_manager.restore_training_state(snapshot) if self.config.dist: torch.distributed.barrier() @@ -608,126 +648,20 @@ def train(self): time_minibatch = batch_idx == 0 and self.world_rank == 0 if time_minibatch: minibatch_start_time = time.perf_counter() - - # Load initial samples and labels - images, true_masks = batch["image"], batch["mask"] - - begin_code_region("image_to_device") - images = images.to( - device=self.device, - dtype=VOLUME_TORCH_DTYPE, - memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first) - non_blocking=True, - ) - true_masks = true_masks.to( - device=self.device, dtype=torch.long, non_blocking=True - ).contiguous() # masks no channels NDHW, but ensure continuity. - end_code_region("image_to_device") - gather_and_print_mem(self.log, "after_batch_to_device") - - # Add a dummy channel dimension to get 5D [B, 1, D, H, W] - true_masks = true_masks.unsqueeze(1) - - # Inputs are already loaded as local shards by the dataset. - images_dc = DCTensor.from_shard(images, self.ps) - true_masks_dc = DCTensor.from_shard(true_masks, self.ps) - del images, true_masks - self._get_memsize( - images_dc, "Sharded image", self.config.verbose - ) - - with torch.autocast(**self._autocast_kwargs()): - # Predict on this batch - torch.cuda.reset_peak_memory_stats() - gather_and_print_mem(self.log, "pre_forward") - begin_code_region("predict") - masks_pred_dc = self.model(images_dc) - end_code_region("predict") - gather_and_print_mem(self.log, "post_forward") - - # Extract the underlying PyTorch local tensors - local_preds = masks_pred_dc - local_labels_5d = true_masks_dc - - # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] - local_labels = local_labels_5d.squeeze(1) - if self.world_rank == 0: - self.log.debug( - f"Local Preds Shape: {local_preds.shape}" - ) - # Should be something like [1, 6, 128, 128, 64] if sharding Width by 2 - self.log.debug( - f"Local Labels Shape: {local_labels.shape}" - ) - # Should be something like [1, 128, 128, 64] - - begin_code_region("calculate_loss") - # --- SHARDED LOSS CALCULATION --- - current_mem = torch.cuda.memory_allocated() / (1024**3) - self.log.debug( - f"Calculating sharded loss. Mem: {current_mem:.2f} GB." + batch_size, batch_loss, batch_dice_score = ( + self._run_training_batch( + batch, + gather_mem_stats=True, ) - - # Calculate CE and Dice loss in single precision for numerical stability. - with torch.autocast(**self._autocast_kwargs(enabled=False)): - loss_ce = compute_sharded_cross_entropy_loss( - local_preds, - local_labels, - self.spatial_mesh, - self.config.dc_num_shards, - self.amp_device_type, - self.ce_class_weights, - ) - - local_preds_softmax = F.softmax( - local_preds.float(), dim=1 - ) - local_labels_one_hot = ( - F.one_hot( - local_labels, - num_classes=self.config.n_categories + 1, - ) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_scores = compute_sharded_dice( - local_preds_softmax, - local_labels_one_hot, - self.spatial_mesh, - ) - batch_dice_score = self._foreground_dice_mean( - dice_scores - ) - - # Sum global CE Loss and Dice loss - loss = loss_ce + (1.0 - batch_dice_score) - train_dice_total += batch_dice_score - - end_code_region("calculate_loss") - - gather_and_print_mem(self.log, "pre_backward") - begin_code_region("backward") - self.grad_scaler.scale(loss).backward() - end_code_region("backward") - gather_and_print_mem(self.log, "post_backward") - - begin_code_region("step_and_update") - self.grad_scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), max_norm=1.0 ) - self.grad_scaler.step(self.optimizer) - gather_and_print_mem(self.log, "after_optim_step") - self.grad_scaler.update() - self.optimizer.zero_grad(set_to_none=False) - end_code_region("step_and_update") + train_dice_total += batch_dice_score # Update the loss begin_code_region("update_loss") - pbar.update(images_dc.shape[0]) + pbar.update(batch_size) self.global_step += 1 # Stay on GPU - epoch_loss += loss.detach() + epoch_loss += batch_loss if time_minibatch: # This sync has some potential performance impact # TODO: Would be better to measure this with Caliper, which uses CUDA events. From bf7a1359c558150c42fb589429b2a13575c1a289 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 10 Jun 2026 15:19:54 -0700 Subject: [PATCH 43/43] shrink gitobjects and update gitignore --- .gitignore | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index 1d407a7..e371e6a 100644 --- a/.gitignore +++ b/.gitignore @@ -129,6 +129,11 @@ venv/ ENV/ env.bak/ venv.bak/ +.venvs/ +scaffoldvenv*/ + +# Data files +*.npy # Spyder project settings .spyderproject