diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 0b93bcd..ebe4149 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -20,4 +20,4 @@ jobs: # Check format. To fix, run "ruff format ." ruff format --diff . # Check PEP8 violations, logic/correctness errors, and sort imports. To fix, run "ruff check --fix ." - ruff check --diff . + ruff check . diff --git a/.gitignore b/.gitignore index 1d407a7..e371e6a 100644 --- a/.gitignore +++ b/.gitignore @@ -129,6 +129,11 @@ venv/ ENV/ env.bak/ venv.bak/ +.venvs/ +scaffoldvenv*/ + +# Data files +*.npy # Spyder project settings .spyderproject diff --git a/README.md b/README.md index cdfead8..9438b1d 100644 --- a/README.md +++ b/README.md @@ -29,25 +29,24 @@ The model is trained from a random initialization until convergence, which is de 1. Clone the repository: `git clone https://github.com/LBANN/ScaFFold.git && cd ScaFFold` +1. Build the ccl plugin (if not using WCI wheel) + `. scripts/install-rccl.sh` + 1. Create and activate a python venv for running the benchmark: `ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv && source .venvs/scaffoldvenv/bin/activate && pip install --upgrade pip` 1. Necessary LLNL settings: - CUDA (matrix): - 1. `ml cuda/12.6.0 gcc/12.1.1 mvapich2/2.3.7` + 1. `ml cuda/12.9.1 gcc/13.3.1 mvapich2/2.3.7` 1. `export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH` - ROCm (elcap): - 1. `ml load rocm/6.4.2 rccl/fast-env-slows-mpi` - - If using generic wheel: - 1. `export LD_LIBRARY_PATH=/opt/cray/pe/cce/20.0.0/cce/x86_64/lib:$LD_LIBRARY_PATH` - 1. `export LD_LIBRARY_PATH=/collab/usr/global/tools/rccl/toss_4_x86_64_ib_cray/rocm-6.4.1/install/lib/:$LD_LIBRARY_PATH` # Necessary to use libfabric plugin (Only necessary if using generic install, wci already links correctly) + 1. `ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi` - If using WCI wheel: - 1. `export LD_LIBRARY_PATH=/opt/cray/pe/cce/20.0.0/cce-clang/x86_64/lib/:$LD_LIBRARY_PATH` # for libomp.so - 1. `export SPINDLE_FLUXOPT=off` # Avoid spindle error + 1. `export LD_PRELOAD=/opt/rocm-7.1.1/llvm/lib/libomp.so` # for libomp.so 1. Install the benchmark in the python venv: - - CUDA: `pip install --no-binary=mpi4py .[cuda] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/cu126 2>&1 | tee install.log` - - ROCm (generic): `pip install --no-binary=mpi4py .[rocm] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/rocm6.4 2>&1 | tee install.log` + - CUDA: `pip install --no-binary=mpi4py .[cuda] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/cu129 2>&1 | tee install.log` + - ROCm (generic): `pip install --no-binary=mpi4py .[rocm] --prefix=.venvs/scaffoldvenv --extra-index-url https://download.pytorch.org/whl/rocm7.1 2>&1 | tee install.log` - ROCm (LLNL): `pip install .[rocmwci] --prefix=.venvs/scaffoldvenv 2>&1 | tee install.log` @@ -84,7 +83,10 @@ variance_threshold: 0.15 # Variance threshold for valid fractals. Defa n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. val_split: 25 # In percent. epochs: 100 # Number of training epochs. -learning_rate: .0001 # Learning rate for training. +starting_learning_rate: .01 # Initial learning rate for training. +min_learning_rate: .001 # Minimum learning rate for CosineAnnealingWarmRestarts. +T_0: 100 # Epochs in the first cosine restart cycle. +T_mult: 2 # Restart cycle growth factor. disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. @@ -227,8 +229,8 @@ make && make install git clone https://github.com/LLNL/Caliper.git cd Caliper mkdir pybuild && cd pybuild -ml rocm/6.4.0 -ml cuda/12.6.0 +ml rocm/7.1.1 +ml cuda/12.9.1 cmake -DWITH_PYTHON_BINDINGS=ON \ -DWITH_ROCPROFILER=ON \ -DWITH_CUPTI=ON \ diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 2caf981..469c71e 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -78,6 +78,11 @@ def main(): type=int, help="Determines dataset resolution and number of UNet layers.", ) + generate_fractals_parser.add_argument( + "--n-categories", + type=int, + help="Number of fractal categories present in the dataset.", + ) generate_fractals_parser.add_argument( "--fract-base-dir", type=str, @@ -111,10 +116,14 @@ def main(): benchmark_parser.add_argument( "--base-run-dir", type=str, help="Subfolder of $(pwd) in which to run jobs." ) + benchmark_parser.add_argument( + "--fract-base-dir", + type=str, + help="Base directory for fractal IFS and instances.", + ) benchmark_parser.add_argument( "--n-categories", type=int, - nargs="+", help="Number of fractal categories present in the dataset.", ) benchmark_parser.add_argument( @@ -134,7 +143,17 @@ def main(): ) benchmark_parser.add_argument("--seed", type=int, help="Random seed.") benchmark_parser.add_argument( - "--batch-size", type=int, nargs="+", help="Batch sizes for each volume size." + "--batch-size", type=int, help="Batch sizes for each volume size." + ) + benchmark_parser.add_argument( + "--warmup-batches", + type=int, + help="Number of warmup batches to run per rank before training.", + ) + benchmark_parser.add_argument( + "--dataloader-num-workers", + type=int, + help="Number of DataLoader worker processes per rank.", ) benchmark_parser.add_argument( "--optimizer", @@ -152,6 +171,39 @@ def main(): type=str, help="Resume execution in this specific directory. Overrides --base-run-dir.", ) + benchmark_parser.add_argument( + "--dc-num-shards", + type=int, + nargs=3, + help="DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum", + ) + benchmark_parser.add_argument( + "--epochs", + type=int, + help="Number of training epochs.", + ) + benchmark_parser.add_argument( + "--starting-learning-rate", + type=float, + help="Initial learning rate for training.", + ) + benchmark_parser.add_argument( + "--min-learning-rate", + type=float, + help="Minimum learning rate for CosineAnnealingWarmRestarts.", + ) + benchmark_parser.add_argument( + "--T-0", + dest="T_0", + type=int, + help="Epochs in the first cosine restart cycle.", + ) + benchmark_parser.add_argument( + "--T-mult", + dest="T_mult", + type=int, + help="Restart cycle growth factor.", + ) comm = MPI.COMM_WORLD rank = comm.Get_rank() @@ -177,12 +229,33 @@ def main(): print(f"Overriding '{key}={combined_config[key]}' with '{key}={value}'") combined_config[key] = value + # Recalculate unet_layers to capture any CLI overrides + combined_config["unet_layers"] = ( + combined_config["problem_scale"] - combined_config["unet_bottleneck_dim"] + ) + + # Resolve paths to absolute, matching Config() behavior + if "base_run_dir" in combined_config and combined_config["base_run_dir"]: + combined_config["base_run_dir"] = str( + Path(combined_config["base_run_dir"]).resolve() + ) + + if "dataset_dir" in combined_config and combined_config["dataset_dir"]: + combined_config["dataset_dir"] = str( + Path(combined_config["dataset_dir"]).resolve() + ) + + if "fract_base_dir" in combined_config and combined_config["fract_base_dir"]: + combined_config["fract_base_dir"] = str( + Path(combined_config["fract_base_dir"]).resolve() + ) + # Calculate these variables after override combined_config["vol_size"] = pow(2, combined_config["problem_scale"]) combined_config["point_num"] = int(combined_config["vol_size"] ** 3 / 256) # Handle Restart / Resume logic - if hasattr(args, "restart") and args.restart == True: + if hasattr(args, "restart") and args.restart: print("Restart flag detected: Forcing train_from_scratch = False") combined_config["train_from_scratch"] = False combined_config["restart"] = True diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 9ba4bc3..1b0310c 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -4,21 +4,26 @@ dataset_dir: "datasets" # Directory in which to store and query for d fract_base_dir: "fractals" # Base directory for fractal IFS and instances. n_categories: 5 # Number of fractal categories present in the dataset. n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances. -problem_scale: 6 # Determines dataset resolution and number of unet layers. Default is 6. +problem_scale: 7 # Determines dataset resolution and number of unet layers. Default is 6. unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. seed: 42 # Random seed. -batch_size: 1 # Batch sizes for each vol size. +batch_size: 1 # Batch sizes for each vol size per rank. +dataloader_num_workers: 1 # Number of DataLoader worker processes per rank. More workers will use more memory optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: 2 # DistConv param: number of shards to divide the tensor into -shard_dim: 2 # DistConv param: dimension on which to shard +dc_num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard +checkpoint_interval: -1 # Checkpoint every C epochs; set to -1 to disable checkpointing entirely. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. -val_split: 25 # In percent. -epochs: 2000 # Number of training epochs. -learning_rate: .0001 # Learning rate for training. -disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. +val_split: 30 # In percent. +epochs: -1 # Number of training epochs. +starting_learning_rate: 0.001 # Initial learning rate for training. +min_learning_rate: 0.0001 # Minimum learning rate for CosineAnnealingWarmRestarts. +T_0: 100 # Epochs in the first cosine restart cycle. +T_mult: 2 # Restart cycle growth factor. +disable_scheduler: 0 # If 1, disable scheduler during training to use constant LR. more_determinism: 0 # If 1, improve model training determinism. datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off. @@ -28,5 +33,7 @@ framework: "torch" # The DL framework to train with. Only valid checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter -warmup_epochs: 1 # How many warmup epochs before training -dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. \ No newline at end of file +warmup_batches: 64 # How many warmup batches per rank to run before training. +ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights. +dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. +target_dice: 0.95 diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml new file mode 100644 index 0000000..5167de1 --- /dev/null +++ b/ScaFFold/configs/benchmark_testing.yml @@ -0,0 +1,39 @@ +# External/user-facing +base_run_dir: "benchmark_runs" # Subfolder of $(pwd) in which to run jobs. +dataset_dir: "datasets" # Directory in which to store and query for datasets. +fract_base_dir: "fractals" # Base directory for fractal IFS and instances. +n_categories: 5 # Number of fractal categories present in the dataset. +n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances. +problem_scale: 6 # Determines dataset resolution and number of unet layers. Default is 6. +unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. +seed: 42 # Random seed. +batch_size: 1 # Batch sizes for each vol size. +dataloader_num_workers: 4 # Number of DataLoader worker processes per rank. +optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. +num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard +checkpoint_interval: -1 # Checkpoint every C epochs; set to -1 to disable checkpointing entirely. + +# Internal/dev use only +variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. +n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. +val_split: 25 # In percent. +epochs: 10 # Number of training epochs. +starting_learning_rate: .0001 # Initial learning rate for training. +min_learning_rate: .0001 # Minimum learning rate for CosineAnnealingWarmRestarts. +T_0: 10 # Epochs in the first cosine restart cycle. +T_mult: 1 # Restart cycle growth factor. +disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. +more_determinism: 0 # If 1, improve model training determinism. +datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. +train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off. +dist: 1 # If 1, use torch DDP. +torch_amp: 1 # If 1, use mixed precision in training. +framework: "torch" # The DL framework to train with. Only valid option for now is "torch". +checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. +loss_freq: 1 # Number of epochs between logging the overall loss. +normalize: 1 # Cateogry search normalization parameter +warmup_batches: 5 # How many warmup batches per rank to run before training. +ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights. +dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. +target_dice: 0.95 diff --git a/ScaFFold/datagen/category_search.py b/ScaFFold/datagen/category_search.py index 1730bd1..a7dbc7a 100644 --- a/ScaFFold/datagen/category_search.py +++ b/ScaFFold/datagen/category_search.py @@ -195,9 +195,10 @@ def main(config: Config) -> None: print(f"MPI size = {size}") # Setup directories - repo_src_path = config.library_root - fracts_sub_dir = f"/var{config.variance_threshold}" - fracts_write_dir = f"{repo_src_path}/fractals{fracts_sub_dir}/3DIFS_param" + fracts_sub_dir = f"var{config.variance_threshold}" + fracts_write_dir = os.path.join( + config.fract_base_dir, fracts_sub_dir, "3DIFS_param" + ) if rank == 0: print(f"Writing fractals to {fracts_write_dir}") if os.path.exists(fracts_write_dir) and config.datagen_from_scratch: diff --git a/ScaFFold/datagen/get_dataset.py b/ScaFFold/datagen/get_dataset.py index 65bedd8..c7ffaf9 100644 --- a/ScaFFold/datagen/get_dataset.py +++ b/ScaFFold/datagen/get_dataset.py @@ -29,14 +29,16 @@ from ScaFFold.datagen import volumegen META_FILENAME = "meta.yaml" +DATASET_FORMAT_VERSION = 2 INCLUDE_KEYS = [ + "dataset_format_version", "n_categories", "n_instances_used_per_fractal", "problem_scale", - "unet_bottleneck_dim", "seed", "variance_threshold", "n_fracts_per_vol", + "val_split", ] @@ -86,12 +88,12 @@ def _git_commit_short() -> str: ) except subprocess.CalledProcessError: print( - f"Tried to get git commit id in non-git repo. No commit id will be enforced for dataset reuse." + "Tried to get git commit id in non-git repo. No commit id will be enforced for dataset reuse." ) return "no-commit-id" except Exception: print( - f"Exception when trying to get git commit for dataset. No commit id will be enforced for dataset reuse." + "Exception when trying to get git commit for dataset. No commit id will be enforced for dataset reuse." ) return "no-commit-id" @@ -117,8 +119,10 @@ def get_dataset( root.mkdir(exist_ok=True) # Get dict of required keys and compute config_id + config_dict = vars(config).copy() + config_dict["dataset_format_version"] = DATASET_FORMAT_VERSION volume_config = _get_required_keys_dict( - config=vars(config), include_keys=INCLUDE_KEYS + config=config_dict, include_keys=INCLUDE_KEYS ) config_id = _hash_volume_config(volume_config) commit = _git_commit_short() @@ -137,6 +141,8 @@ def get_dataset( meta = yaml.safe_load(meta_path.read_text()) if meta.get("config_id") != config_id: continue + if meta.get("dataset_format_version", 1) != DATASET_FORMAT_VERSION: + continue if require_commit and meta.get("code_commit") != commit: continue # If we pass the above checks, this dataset can be reused @@ -187,6 +193,7 @@ def get_dataset( # Write to tmp, then move, so readers never see half-written dataset meta = { "config_id": config_id, + "dataset_format_version": DATASET_FORMAT_VERSION, "config_subset": volume_config, "include_keys": INCLUDE_KEYS, "code_commit": commit, diff --git a/ScaFFold/datagen/instance.py b/ScaFFold/datagen/instance.py index 91067ee..a14c2fb 100644 --- a/ScaFFold/datagen/instance.py +++ b/ScaFFold/datagen/instance.py @@ -25,7 +25,6 @@ from pathlib import Path import numpy as np -import open3d from mpi4py import MPI from ScaFFold.datagen.generate_fractal_points import generate_fractal_points @@ -81,11 +80,10 @@ def main(config: Config): print(f"MPI size = {size}") # Setup directories - repo_src_path = config.library_root fracts_sub_dir = f"var{config.variance_threshold}" - fracts_read_dir = f"{repo_src_path}/fractals/{fracts_sub_dir}/3DIFS_param" - instance_write_dir = ( - f"{repo_src_path}fractals/{fracts_sub_dir}/instances/np{config.point_num}" + fracts_read_dir = os.path.join(config.fract_base_dir, fracts_sub_dir, "3DIFS_param") + instance_write_dir = os.path.join( + config.fract_base_dir, fracts_sub_dir, "instances", f"np{config.point_num}" ) if rank == 0: print( @@ -113,7 +111,7 @@ def main(config: Config): existing_instances = [ int(path_str.split("_")[-1].split(".")[0]) for path_str in glob.glob( - f"{instance_write_dir}/{category:06d}/[0-9][0-9][0-9][0-9][0-9][0-9]_[0-9][0-9][0-9][0-9].ply" + f"{instance_write_dir}/{category:06d}/[0-9][0-9][0-9][0-9][0-9][0-9]_[0-9][0-9][0-9][0-9].npy" ) ] category_instance_pairs = [ @@ -170,33 +168,18 @@ def main(config: Config): # Generate points points = generate_single_instance(config.point_num, params) - # Force point_data to be contiguous -- prevents possible segfaults in later Vector3dVector call from non-contiguous arrays + # Force point_data to be contiguous points_contiguous = np.ascontiguousarray(points, dtype=DEFAULT_NP_DTYPE) - # Create o3d PointCloud object - pointcloud = open3d.geometry.PointCloud() - - # Populate PointCloud points attribute with point_data - pointcloud.points = open3d.utility.Vector3dVector(points_contiguous) - - # Construct the long path and short filename + # Construct the output path out_dir = Path(instance_write_dir) / f"{category:06d}" - filename = f"{category:06d}_{instance:04d}.ply" + filename = f"{category:06d}_{instance:04d}.npy" # Ensure parent directory exists out_dir.mkdir(parents=True, exist_ok=True) - # Save current working directory so we can restore it later - cwd = os.getcwd() - try: - # Change to the output directory - os.chdir(out_dir) - - # Write with short relative path - open3d.io.write_point_cloud(filename, pointcloud) - finally: - # Always restore the working directory - os.chdir(cwd) + # Save array to out_dir + np.save(out_dir / filename, points_contiguous) end_time = time.time() total_time = end_time - start_time diff --git a/ScaFFold/datagen/volumegen.py b/ScaFFold/datagen/volumegen.py index d1120ab..b268aa7 100644 --- a/ScaFFold/datagen/volumegen.py +++ b/ScaFFold/datagen/volumegen.py @@ -23,20 +23,17 @@ from typing import Dict import numpy as np -import open3d as o3d from mpi4py import MPI from ScaFFold.utils.config_utils import Config +from ScaFFold.utils.data_types import DEFAULT_NP_DTYPE, MASK_DTYPE, VOLUME_DTYPE -DEFAULT_NP_DTYPE = np.float64 - -def load_ply_with_open3d(path: str) -> np.ndarray: +def load_np_ptcloud(path: str) -> np.ndarray: """ - Read a .ply via Open3D and return an (N,3) array of dtype float64. + Read a .npy file and return an (N,3) array of dtype float64. """ - pcd = o3d.io.read_point_cloud(path) - pts = np.asarray(pcd.points) + pts = np.load(path) return pts.astype(DEFAULT_NP_DTYPE, copy=False) @@ -168,7 +165,7 @@ def main(config: Dict): fractal_colors = np.random.rand(max(config.n_categories, n_fracts_per_vol), 3) grid_size = math.floor(config.vol_size * config.scale) - library_root = str(config.library_root) + fract_base_dir = str(config.fract_base_dir) # Generation loop start_time = time.time() @@ -179,10 +176,10 @@ def main(config: Dict): volume = np.full( (config.vol_size, config.vol_size, config.vol_size, 3), 0, - dtype=np.float32, + dtype=VOLUME_DTYPE, ) mask = np.full( - (config.vol_size, config.vol_size, config.vol_size), 0, dtype=np.short + (config.vol_size, config.vol_size, config.vol_size), 0, dtype=MASK_DTYPE ) global_vol_idx = curr_vol[0] @@ -200,11 +197,10 @@ def main(config: Dict): ) point_cloud_path = os.path.join( - library_root, - "fractals", + fract_base_dir, instances_dir, f"{curr_category:06d}", - f"{curr_category:06d}_{curr_instance:04d}.ply", + f"{curr_category:06d}_{curr_instance:04d}.npy", ) if not os.path.exists(point_cloud_path): @@ -213,7 +209,7 @@ def main(config: Dict): ) sys.exit(1) - points = load_ply_with_open3d(point_cloud_path) + points = load_np_ptcloud(point_cloud_path) mask3d = points_to_voxelgrid(points, grid_size) assert mask3d.shape == volume.shape[:3], ( @@ -225,14 +221,20 @@ def main(config: Dict): # Determine destination folder subdir = "validation" if global_vol_idx in val_indices else "training" + # Tensors must logically be channels-first, later we will change striding/storage to channels-last on GPU (metadata will always stay channels-first). + volume_channels_first = volume.transpose((3, 0, 1, 2)) + volume_to_save = np.ascontiguousarray( + volume_channels_first, dtype=VOLUME_DTYPE + ) + mask_to_save = np.ascontiguousarray(mask, dtype=MASK_DTYPE) vol_file = os.path.join(vol_path, subdir, f"{global_vol_idx}.npy") with open(vol_file, "wb") as f: - np.save(f, volume) + np.save(f, volume_to_save) mask_file = os.path.join(mask_path, subdir, f"{global_vol_idx}_mask.npy") with open(mask_file, "wb") as f: - np.save(f, mask) + np.save(f, mask_to_save) end_time = time.time() total_time = end_time - start_time diff --git a/ScaFFold/utils/checkpointing.py b/ScaFFold/utils/checkpointing.py index 0bab949..2a06a3c 100644 --- a/ScaFFold/utils/checkpointing.py +++ b/ScaFFold/utils/checkpointing.py @@ -12,11 +12,9 @@ # # SPDX-License-Identifier: (Apache-2.0) -import copy import math import random import shutil -import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Dict, Optional @@ -107,6 +105,46 @@ def wait_for_save(self): self._log(f"Background save failed with error: {e}") self.future = None + def snapshot_training_state(self) -> Dict[str, Any]: + """Capture mutable in-memory training state without writing a checkpoint.""" + model_ref = self.model.module if hasattr(self.model, "module") else self.model + return { + "model_state_dict": self._clone_state_dict(model_ref.state_dict()), + "optimizer_state_dict": self._clone_state_dict(self.optimizer.state_dict()) + if self.optimizer + else None, + "scheduler_state_dict": self._clone_state_dict(self.scheduler.state_dict()) + if self.scheduler + else None, + "grad_scaler_state_dict": self._clone_state_dict( + self.grad_scaler.state_dict() + ) + if self.grad_scaler + else None, + "model_training": model_ref.training, + **self._get_rng_snapshot(), + } + + def restore_training_state(self, snapshot: Dict[str, Any]) -> None: + """Restore an in-memory training snapshot.""" + model_ref = self.model.module if hasattr(self.model, "module") else self.model + model_ref.load_state_dict(snapshot["model_state_dict"]) + + if self.optimizer and snapshot.get("optimizer_state_dict") is not None: + self.optimizer.load_state_dict(snapshot["optimizer_state_dict"]) + + if self.scheduler and snapshot.get("scheduler_state_dict") is not None: + self.scheduler.load_state_dict(snapshot["scheduler_state_dict"]) + + if self.grad_scaler and snapshot.get("grad_scaler_state_dict") is not None: + self.grad_scaler.load_state_dict(snapshot["grad_scaler_state_dict"]) + + self._restore_rng(snapshot) + model_ref.train(snapshot.get("model_training", True)) + + if self.optimizer: + self.optimizer.zero_grad(set_to_none=True) + def load_from_checkpoint(self) -> int: """Load the latest checkpoint. Returns start_epoch (default 1).""" self.wait_for_save() # Safety: don't load while writing @@ -238,7 +276,7 @@ def save_checkpoint( self.best_ckpt_path, is_best, ) - self._log(f"Async checkpoint offloaded to background thread.") + self._log("Async checkpoint offloaded to background thread.") else: # Synchronous Save self._write_to_disk( @@ -257,14 +295,22 @@ def save_checkpoint( def _write_to_disk(state_dict, last_path, best_path, is_best): """Worker function to perform actual disk I/O.""" # Save 'last' - torch.save(state_dict, last_path) + try: + torch.save(state_dict, last_path) + except Exception as e: + print("Saving checkpoint failed. Continuing...") + print(e) # Save 'best' (copy logic) if is_best: # Copy is often faster than re-serializing if last_path.exists(): shutil.copyfile(last_path, best_path) else: - torch.save(state_dict, best_path) + try: + torch.save(state_dict, best_path) + except Exception as e: + print("Saving checkpoint failed. Continuing...") + print(e) def _transfer_dict_to_cpu(self, obj): """Recursively move tensors to CPU.""" @@ -279,6 +325,19 @@ def _transfer_dict_to_cpu(self, obj): else: return obj + def _clone_state_dict(self, obj): + """Recursively clone tensors so in-memory snapshots are isolated.""" + if torch.is_tensor(obj): + return obj.detach().clone() + elif isinstance(obj, dict): + return {k: self._clone_state_dict(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._clone_state_dict(v) for v in obj] + elif isinstance(obj, tuple): + return tuple(self._clone_state_dict(v) for v in obj) + else: + return obj + def _barrier(self): if self.dist_enabled: dist.barrier() @@ -299,14 +358,14 @@ def _log(self, msg): def _get_rng_snapshot(self) -> Dict[str, Any]: snap = {"rng_state_pytorch": torch.get_rng_state()} if torch.cuda.is_available(): - snap["rng_state_pytorch_cuda"] = torch.cuda.get_rng_state() + snap["rng_state_pytorch_cuda"] = torch.cuda.get_rng_state_all() try: snap["rng_state_numpy"] = np.random.get_state() except ImportError: pass try: snap["rng_state_python"] = random.getstate() - except: + except Exception: pass return snap @@ -315,7 +374,11 @@ def _restore_rng(self, snap: Dict[str, Any]): if "rng_state_pytorch" in snap: torch.set_rng_state(snap["rng_state_pytorch"]) if "rng_state_pytorch_cuda" in snap and torch.cuda.is_available(): - torch.cuda.set_rng_state(snap["rng_state_pytorch_cuda"]) + cuda_state = snap["rng_state_pytorch_cuda"] + if isinstance(cuda_state, list): + torch.cuda.set_rng_state_all(cuda_state) + else: + torch.cuda.set_rng_state(cuda_state) if "rng_state_numpy" in snap: np.random.set_state(snap["rng_state_numpy"]) if "rng_state_python" in snap: diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 06ee76d..36f1603 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -32,6 +32,9 @@ def __init__(self, config_dict): self.dataset_dir = str( Path(config_dict.get("dataset_dir", "datasets/")).resolve() ) + self.fract_base_dir = str( + Path(config_dict.get("fract_base_dir", "fractals/")).resolve() + ) self.job_name = config_dict.get("job_name", "benchmark") self.n_categories = config_dict["n_categories"] self.problem_scale = config_dict["problem_scale"] @@ -45,11 +48,12 @@ def __init__(self, config_dict): ) self.problem_scale = math.floor(self.problem_scale) self.unet_bottleneck_dim = config_dict["unet_bottleneck_dim"] - self.unet_layers = self.problem_scale - self.unet_bottleneck_dim + 1 + self.unet_layers = self.problem_scale - self.unet_bottleneck_dim self.n_fracts_per_vol = config_dict["n_fracts_per_vol"] self.n_instances_used_per_fractal = config_dict["n_instances_used_per_fractal"] self.scale = 1 self.batch_size = config_dict["batch_size"] + self.dataloader_num_workers = config_dict["dataloader_num_workers"] self.epochs = config_dict["epochs"] self.optimizer = config_dict["optimizer"] self.disable_scheduler = bool(config_dict["disable_scheduler"]) @@ -60,18 +64,33 @@ def __init__(self, config_dict): self.seed = config_dict["seed"] self.dist = bool(config_dict["dist"]) self.framework = config_dict["framework"] - self.learning_rate = config_dict["learning_rate"] + self.starting_learning_rate = config_dict["starting_learning_rate"] + self.min_learning_rate = config_dict["min_learning_rate"] + self.T_0 = config_dict["T_0"] + self.T_mult = config_dict["T_mult"] self.variance_threshold = config_dict["variance_threshold"] self.torch_amp = bool(config_dict["torch_amp"]) self.loss_freq = config_dict["loss_freq"] self.checkpoint_dir = config_dict["checkpoint_dir"] self.normalize = config_dict["normalize"] - self.warmup_epochs = config_dict["warmup_epochs"] - self.num_shards = config_dict["num_shards"] - self.shard_dim = config_dict["shard_dim"] + self.warmup_batches = config_dict.get("warmup_batches") + self.ce_weight_sample_fraction = config_dict.get( + "ce_weight_sample_fraction", 0.1 + ) self.dataset_reuse_enforce_commit_id = config_dict[ "dataset_reuse_enforce_commit_id" ] + self.target_dice = config_dict["target_dice"] + self.checkpoint_interval = config_dict["checkpoint_interval"] + self.dc_num_shards = config_dict["dc_num_shards"] + self.dc_shard_dims = config_dict["dc_shard_dims"] + self.dc_total_shards = math.prod(self.dc_num_shards) + # Safety Check: Length mismatch + if len(self.dc_num_shards) != len(self.dc_shard_dims): + raise ValueError( + f"Configuration Mismatch: num_shards {self.dc_num_shards} " + f"must have same length as shard_dim {self.dc_shard_dims}" + ) class RunConfig(Config): diff --git a/ScaFFold/utils/create_restart_script.py b/ScaFFold/utils/create_restart_script.py index 4994205..206eae7 100644 --- a/ScaFFold/utils/create_restart_script.py +++ b/ScaFFold/utils/create_restart_script.py @@ -20,7 +20,7 @@ import stat import sys from pathlib import Path -from typing import List, Literal, Optional, Union +from typing import List, Literal, Union import torch @@ -39,7 +39,7 @@ def _rewrite_config_and_add_restart(cli_args: List[str]) -> List[str]: new_args = [] skip_next = False - # Args to strip because they trigger new directory creation + # Args to strip because they trigger new directory creation or shouldn't change args_to_remove = {"--base-run-dir", "--job-name"} for i, tok in enumerate(cli_args): @@ -90,61 +90,29 @@ def _bash_array(var_name: str, argv: List[str], var_subs: dict[str, str]) -> str def _get_env_setup() -> str: - """Return the bash block that sets up the environment (modules, LD_PRELOAD, etc).""" - # Dynamically determine the current virtualenv path + """Return the bash block that sets up the environment based on your stable configuration.""" + # Dynamically determine the current virtualenv path to reuse the active one venv_path = sys.prefix return f""" # --- Begin Environment Setup --- # Load Modules if command -v module &> /dev/null; then - module load rocm/6.4.2 rccl/fast-env-slows-mpi libfabric + ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi fi # Activate Virtual Environment +# (Using the one active when this script was generated) if [ -f "{venv_path}/bin/activate" ]; then source "{venv_path}/bin/activate" else echo "WARNING: Could not find venv activate script at {venv_path}/bin/activate" fi -# 1. Define the path to the ROCm LLVM OpenMP library -ROCM_OMP_LIB="/opt/rocm-6.4.2/llvm/lib/libomp.so" +# Environment variables +export SPINDLE_FLUXOPT=off +export LD_PRELOAD=/opt/rocm-7.1.1/llvm/lib/libomp.so -# 2. Check if it exists before proceeding -if [ ! -f "$ROCM_OMP_LIB" ]; then - echo "ERROR: Could not find OpenMP at $ROCM_OMP_LIB" - # Fallback search if the standard path is wrong - ROCM_OMP_LIB=$(find /opt/rocm-6.4.2 -name libomp.so | head -n 1) - echo "Found alternative at: $ROCM_OMP_LIB" -fi -if [ -z "$ROCM_OMP_LIB" ]; then - echo "CRITICAL: Unable to find libomp.so in /opt/rocm-6.4.2. Aborting." - exit 1 -fi - -# 3. Force the dynamic linker to load this specific library first -echo "Forcing Preload of: $ROCM_OMP_LIB" -export LD_PRELOAD=$ROCM_OMP_LIB - -# Setup Torch Library Path -SITE_PACKAGES=$(python3 -c "import sysconfig; print(sysconfig.get_path('purelib'))") -TORCH_LIB_PATH="$SITE_PACKAGES/torch/lib" -export LD_LIBRARY_PATH=$TORCH_LIB_PATH:$LD_LIBRARY_PATH - -# Setup System Libfabric -SYSTEM_LIBFABRIC=$(ls /opt/cray/libfabric/2.1/lib64/libfabric.so.1 | head -n 1) - -if [ -z "$SYSTEM_LIBFABRIC" ]; then - echo "Error: Could not find system libfabric!" - exit 1 -fi - -echo "Forcing preload of system Libfabric: $SYSTEM_LIBFABRIC" -export LD_PRELOAD=$SYSTEM_LIBFABRIC:$LD_PRELOAD - -export NCCL_NET=Socket -export NCCL_SOCKET_IFNAME=hsi0 export PROFILE_TORCH=ON # --- End Environment Setup --- """ @@ -180,20 +148,25 @@ def _render_torchrun_hpc_restart( # Additional torchrun-hpc arguments (e.g. --launcher-args for specific scheduler flags) LAUNCHER_ADDITIONAL_ARGS='' -LAUNCHER_ARGS="-N $NODES -n $TASKS_PER_NODE --gpus-per-proc $GPUS_PER_PROC $LAUNCHER_ADDITIONAL_ARGS" - -IFS=' ' read -r -a LAUNCHER_ARR <<< "$LAUNCHER_ARGS" +# Use a proper Bash array for arguments to handle paths with spaces safely +LAUNCHER_ARGS=( + -l "$RUN_DIR" + -N "$NODES" + -n "$TASKS_PER_NODE" + --gpus-per-proc "$GPUS_PER_PROC" + $LAUNCHER_ADDITIONAL_ARGS +) # Exact Python command to rerun the CLI {py_array_decl} echo "Restarting in $RUN_DIR via torchrun-hpc:" -echo " torchrun-hpc $LAUNCHER_ARGS ..." +echo " torchrun-hpc ${{LAUNCHER_ARGS[*]}} ..." printf ' python cmd: '; printf '%q ' "${{PY[@]}}"; echo cd "$RUN_DIR" # Invoking torchrun-hpc to handle scheduler interaction (Flux/Slurm) -exec torchrun-hpc "${{LAUNCHER_ARR[@]}}" "${{PY[@]}}" +exec torchrun-hpc "${{LAUNCHER_ARGS[@]}}" "${{PY[@]}}" """ diff --git a/ScaFFold/utils/data_loading.py b/ScaFFold/utils/data_loading.py index 725854c..688f329 100644 --- a/ScaFFold/utils/data_loading.py +++ b/ScaFFold/utils/data_loading.py @@ -13,24 +13,110 @@ # SPDX-License-Identifier: (Apache-2.0) import pickle +from dataclasses import dataclass from os import listdir from os.path import isfile, join, splitext from pathlib import Path +from typing import Dict, Optional, Tuple import numpy as np import torch +import yaml from torch.utils.data import Dataset +from ScaFFold.utils.data_types import MASK_DTYPE, VOLUME_DTYPE from ScaFFold.utils.utils import customlog +DATASET_FORMAT_VERSION = 2 +LEGACY_DATASET_FORMAT_VERSION = 1 +META_FILENAME = "meta.yaml" + + +@dataclass(frozen=True) +class SpatialShardSpec: + """Describe the local spatial shard owned by the current rank.""" + + shard_dims: Tuple[int, ...] + num_shards: Tuple[int, ...] + shard_indices: Tuple[int, ...] + + def __post_init__(self): + if not ( + len(self.shard_dims) == len(self.num_shards) == len(self.shard_indices) + ): + raise ValueError( + "shard_dims, num_shards, and shard_indices must have matching lengths" + ) + if len(set(self.shard_dims)) != len(self.shard_dims): + raise ValueError(f"Shard dimensions must be unique: {self.shard_dims}") + for shard_dim, num_shards, shard_index in zip( + self.shard_dims, self.num_shards, self.shard_indices + ): + if shard_dim < 2: + raise ValueError( + f"Invalid shard_dim {shard_dim}: only spatial dimensions are supported" + ) + if num_shards < 1: + raise ValueError( + f"Invalid num_shards {num_shards} for shard_dim {shard_dim}" + ) + if shard_index < 0 or shard_index >= num_shards: + raise ValueError( + f"Invalid shard_index {shard_index} for shard_dim {shard_dim} with {num_shards} shards" + ) + + @staticmethod + def _chunk_slice(size: int, num_shards: int, shard_index: int) -> slice: + """Match torch.chunk-style uneven shard boundaries.""" + + chunk_size = (size + num_shards - 1) // num_shards + start = shard_index * chunk_size + if start >= size: + raise ValueError( + f"Empty local shard: dim size {size}, num_shards {num_shards}, shard_index {shard_index}" + ) + stop = min(size, start + chunk_size) + return slice(start, stop) + + def slice_array( + self, array: np.ndarray, axis_map: Dict[int, int], array_label: str + ) -> np.ndarray: + if not self.shard_dims: + return array + + slices = [slice(None)] * array.ndim + for shard_dim, num_shards, shard_index in zip( + self.shard_dims, self.num_shards, self.shard_indices + ): + if shard_dim not in axis_map: + raise ValueError( + f"No axis mapping defined for {array_label} shard_dim {shard_dim}" + ) + axis = axis_map[shard_dim] + if axis >= array.ndim: + raise ValueError( + f"Axis {axis} out of range for {array_label} with shape {array.shape}" + ) + slices[axis] = self._chunk_slice(array.shape[axis], num_shards, shard_index) + + return array[tuple(slices)] + class BasicDataset(Dataset): def __init__( - self, images_dir: str, mask_dir: str, mask_suffix: str = "", data_dir: str = "" + self, + images_dir: str, + mask_dir: str, + mask_suffix: str = "", + data_dir: str = "", + spatial_shard_spec: Optional[SpatialShardSpec] = None, ): self.images_dir = Path(images_dir) self.mask_dir = Path(mask_dir) self.mask_suffix = mask_suffix + self.spatial_shard_spec = spatial_shard_spec + self.dataset_root = self.images_dir.parents[1] + self.dataset_format_version = self._load_dataset_format_version() self.ids = [ splitext(file)[0] @@ -49,26 +135,73 @@ def __init__( data = pickle.load(data_file) self.mask_values = data["mask_values"] customlog(f"Unique mask values: {self.mask_values}") + customlog(f"Dataset format version: {self.dataset_format_version}") def __len__(self): return len(self.ids) @staticmethod - def preprocess(mask_values, img, is_mask): - if is_mask: - mask = np.zeros((img.shape[0], img.shape[1], img.shape[2]), dtype=np.short) - for i, v in enumerate(mask_values): - if img.ndim == 3: - mask[img == v] = i - else: - mask[(img == v).all(-1)] = i + def _load_numpy_array(path, mmap_mode=None): + return np.load(path, allow_pickle=False, mmap_mode=mmap_mode) + + def _load_dataset_format_version(self): + meta_path = self.dataset_root / META_FILENAME + if not meta_path.exists(): + return LEGACY_DATASET_FORMAT_VERSION + + try: + with open(meta_path, "r") as meta_file: + meta = yaml.safe_load(meta_file) or {} + except Exception as exc: + customlog( + f"Failed to read dataset metadata from {meta_path}: {exc}. Falling back to legacy loader." + ) + return LEGACY_DATASET_FORMAT_VERSION - return mask + return int(meta.get("dataset_format_version", LEGACY_DATASET_FORMAT_VERSION)) - else: - img = img.transpose((3, 0, 1, 2)) + @staticmethod + def _prepare_legacy_image(img): + return np.ascontiguousarray(img.transpose((3, 0, 1, 2)), dtype=VOLUME_DTYPE) + + @staticmethod + def _prepare_legacy_mask(mask_values, mask): + remapped = np.zeros( + (mask.shape[0], mask.shape[1], mask.shape[2]), dtype=MASK_DTYPE + ) + for i, value in enumerate(mask_values): + if mask.ndim == 3: + remapped[mask == value] = i + else: + remapped[(mask == value).all(-1)] = i + + return remapped + + @staticmethod + def _prepare_optimized_image(img): + return np.array(img, dtype=VOLUME_DTYPE, copy=True, order="C") + + @staticmethod + def _prepare_optimized_mask(mask): + return np.array(mask, dtype=MASK_DTYPE, copy=True, order="C") + + def _slice_image_array(self, img): + if self.spatial_shard_spec is None: return img + if self.dataset_format_version >= DATASET_FORMAT_VERSION: + axis_map = {2: 1, 3: 2, 4: 3} + else: + axis_map = {2: 0, 3: 1, 4: 2} + return self.spatial_shard_spec.slice_array(img, axis_map, "image") + + def _slice_mask_array(self, mask): + if self.spatial_shard_spec is None: + return mask + + axis_map = {2: 0, 3: 1, 4: 2} + return self.spatial_shard_spec.slice_array(mask, axis_map, "mask") + def __getitem__(self, idx): name = self.ids[idx] mask_file = list(self.mask_dir.glob(name + self.mask_suffix + ".*")) @@ -80,22 +213,39 @@ def __getitem__(self, idx): assert len(mask_file) == 1, ( f"Either no mask or multiple masks found for the ID {name}: {mask_file}" ) - with open(mask_file[0], "rb") as f: - mask = np.load(f) - f.close() - with open(img_file[0], "rb") as f: - img = np.load(f) - f.close() - - img = self.preprocess(self.mask_values, img, is_mask=False) - mask = self.preprocess(self.mask_values, mask, is_mask=True) + mmap_mode = "r" if self.spatial_shard_spec is not None else None + # Memmap lets each rank slice out just its local shard without eagerly + # reading the full sample into process memory first. + mask = self._load_numpy_array(mask_file[0], mmap_mode=mmap_mode) + img = self._load_numpy_array(img_file[0], mmap_mode=mmap_mode) + mask = self._slice_mask_array(mask) + img = self._slice_image_array(img) + + if self.dataset_format_version >= DATASET_FORMAT_VERSION: + img = self._prepare_optimized_image(img) + mask = self._prepare_optimized_mask(mask) + else: + img = self._prepare_legacy_image(img) + mask = self._prepare_legacy_mask(self.mask_values, mask) return { - "image": torch.as_tensor(img.copy()).float().contiguous(), - "mask": torch.as_tensor(mask.copy()).long().contiguous(), + "image": torch.from_numpy(img).contiguous().float(), + "mask": torch.from_numpy(mask).contiguous().long(), } class FractalDataset(BasicDataset): - def __init__(self, images_dir, mask_dir, data_dir): - super().__init__(images_dir, mask_dir, mask_suffix="_mask", data_dir=data_dir) + def __init__( + self, + images_dir, + mask_dir, + data_dir, + spatial_shard_spec: Optional[SpatialShardSpec] = None, + ): + super().__init__( + images_dir, + mask_dir, + mask_suffix="_mask", + data_dir=data_dir, + spatial_shard_spec=spatial_shard_spec, + ) diff --git a/ScaFFold/utils/data_types.py b/ScaFFold/utils/data_types.py new file mode 100644 index 0000000..ef1515d --- /dev/null +++ b/ScaFFold/utils/data_types.py @@ -0,0 +1,28 @@ +# Copyright (c) 2014-2026, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LBANN/ScaFFold. +# +# SPDX-License-Identifier: (Apache-2.0) + +import numpy as np +import torch + +DEFAULT_NP_DTYPE = np.float64 +# Masks are values 0 <= x <= n_categories +MASK_DTYPE = np.uint16 +# Volumes/img are 0 <= x <= 1 +VOLUME_DTYPE_NAME = "float32" +VOLUME_NP_DTYPE = getattr(np, VOLUME_DTYPE_NAME) +VOLUME_TORCH_DTYPE = getattr(torch, VOLUME_DTYPE_NAME) +VOLUME_DTYPE = VOLUME_NP_DTYPE + +# Shared AMP dtype selection for torch.autocast. +AMP_DTYPE = torch.bfloat16 diff --git a/ScaFFold/utils/dice_score.py b/ScaFFold/utils/dice_score.py index 6536345..ed60fd4 100644 --- a/ScaFFold/utils/dice_score.py +++ b/ScaFFold/utils/dice_score.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: (Apache-2.0) import torch +import torch.distributed as dist from torch import Tensor from ScaFFold.utils.perf_measure import annotate @@ -59,3 +60,56 @@ def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): # Dice loss (objective to minimize) between 0 and 1 fn = multiclass_dice_coeff if multiclass else dice_coeff return 1 - fn(input, target, reduce_batch_first=True) + + +class SpatialAllReduce(torch.autograd.Function): + @staticmethod + def forward(ctx, input, spatial_mesh): + output = input.clone() + for mesh_dim in range(spatial_mesh.ndim): + pg = spatial_mesh.get_group(mesh_dim) + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=pg) + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +@annotate() +def compute_sharded_dice( + preds: torch.Tensor, + targets: torch.Tensor, + spatial_mesh, + epsilon: float = 1e-6, +): + """ + Computes the globally sharded Dice score. + Returns the raw score tensor of shape [Batch, Channels]. + """ + assert preds.size() == targets.size(), ( + f"Shape mismatch: {preds.size()} vs {targets.size()}" + ) + assert preds.dim() == 5, f"Expected 5D tensor, got {preds.dim()}D" + + sum_dim = (-1, -2, -3) # D, H, W + + local_inter = 2.0 * (preds * targets).sum(dim=sum_dim) + local_sets_sum_raw = preds.sum(dim=sum_dim) + targets.sum(dim=sum_dim) + + packed = torch.stack([local_inter, local_sets_sum_raw]) + + # Global reduce across spatial mesh + packed_global = SpatialAllReduce.apply(packed, spatial_mesh) + + global_inter = packed_global[0] + global_sets_sum_raw = packed_global[1] + + global_sets_sum = torch.where( + global_sets_sum_raw == 0, global_inter, global_sets_sum_raw + ) + + # Calculate score + dice_score = (global_inter + epsilon) / (global_sets_sum + epsilon) + + return dice_score diff --git a/ScaFFold/utils/distributed.py b/ScaFFold/utils/distributed.py index 2ca1c71..5e1f92e 100644 --- a/ScaFFold/utils/distributed.py +++ b/ScaFFold/utils/distributed.py @@ -98,7 +98,7 @@ def force_cuda_visible_devices(force: bool = False) -> None: other GPUs. """ - print(f"force_cuda_visible_devices is deprecated. Skipping...") + print("force_cuda_visible_devices is deprecated. Skipping...") def get_device() -> torch.device: diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index c2d0672..6b23b15 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -15,127 +15,139 @@ import torch import torch.nn.functional as F from distconv import DCTensor -from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor from tqdm import tqdm -from ScaFFold.utils.dice_score import dice_coeff, dice_loss, multiclass_dice_coeff +from ScaFFold.utils.data_types import AMP_DTYPE +from ScaFFold.utils.dice_score import compute_sharded_dice +from ScaFFold.utils.losses import compute_sharded_cross_entropy_loss from ScaFFold.utils.perf_measure import annotate @annotate() @torch.inference_mode() def evaluate( - net, dataloader, device, amp, primary, criterion, n_categories, parallel_strategy + net, + dataloader, + device, + amp, + primary, + criterion, + n_categories, + parallel_strategy, + max_batches=None, ): + def foreground_dice_stats(dice_scores): + if dice_scores.size(1) > 1: + per_sample_scores = dice_scores[:, 1:].mean(dim=1) + else: + per_sample_scores = dice_scores.mean(dim=1) + return per_sample_scores.sum().item(), per_sample_scores.numel() + net.eval() + autocast_device_type = device.type if device.type != "mps" else "cpu" + autocast_kwargs = {"device_type": autocast_device_type, "enabled": amp} + if amp: + autocast_kwargs["dtype"] = AMP_DTYPE num_val_batches = len(dataloader) - dice_score = 0.0 + if max_batches is not None: + num_val_batches = min(num_val_batches, max_batches) + total_dice_score = 0.0 processed_batches = 0 + processed_samples = 0 + + spatial_mesh = parallel_strategy.device_mesh[parallel_strategy.distconv_dim_names] - # For reference, dc sharding happens on this spatial dim: 2=D, 3=H, 4=W if primary: print( f"[eval] ps.shard_dim={parallel_strategy.shard_dim} num_shards={parallel_strategy.num_shards}" ) - with torch.autocast(device.type if device.type != "mps" else "cpu", enabled=amp): + with torch.autocast(**autocast_kwargs): val_loss_epoch = 0.0 - for batch in tqdm( - dataloader, - total=num_val_batches, - desc="Validation round", - unit="batch", - leave=False, - disable=not primary, + class_weights = getattr(criterion, "weight", None) + for batch_idx, batch in enumerate( + tqdm( + dataloader, + total=num_val_batches, + desc="Validation round", + unit="batch", + leave=False, + disable=not primary, + ) ): + if batch_idx >= num_val_batches: + break image, mask_true = batch["image"], batch["mask"] - # move images and labels to correct device and type image = image.to( device=device, dtype=torch.float32, - memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first) + memory_format=torch.channels_last_3d, ) - mask_true = mask_true.to( - device=device, dtype=torch.long - ).contiguous() # masks no channels NDHW, but ensure cotinuity. - - # Shard batch across ddp mesh, replicate across dc mesh - image_dp = distribute_tensor( - image, parallel_strategy.device_mesh, placements=[Shard(0), Replicate()] - ).to_local() - mask_true_dp = distribute_tensor( - mask_true, - parallel_strategy.device_mesh, - placements=[Shard(0), Replicate()], - ).to_local() - - # Spatially shard images along the dc mesh and run the model - dcx = DCTensor.distribute(image_dp, parallel_strategy) - dcy = net(dcx) + mask_true = mask_true.to(device=device, dtype=torch.long).contiguous() - # Replicate predictions across dc to get full spatial result on each dc rank - mask_pred = dcy.to_replicate() + # Dummy channel dimension [B, 1, D, H, W] + mask_true = mask_true.unsqueeze(1) - # Use labels that are replicated across dc and sharded across ddp, like predictions - mask_true_ddp = mask_true_dp + # Inputs are already loaded as local shards by the dataset. + dcx = DCTensor.from_shard(image, parallel_strategy) + mask_true_dc = DCTensor.from_shard(mask_true, parallel_strategy) - # Skip if this ddp rank has an empty local batch - if mask_pred.size(0) == 0 or mask_true_ddp.size(0) == 0: - continue - - # Loss - CE_loss = criterion(mask_pred, mask_true_ddp) + # Forward pass on sharded data + dcy = net(dcx) - # Dice loss - mask_pred_softmax = F.softmax(mask_pred, dim=1).float() - mask_true_onehot = ( - F.one_hot(mask_true_ddp, n_categories + 1) - .permute(0, 4, 1, 2, 3) - .float() - ) - dice_loss_curr = dice_loss( - mask_pred_softmax, - mask_true_onehot, - multiclass=True, - ) + # Extract underlying local tensors (STAY SHARDED) + local_preds = dcy + local_labels_5d = mask_true_dc + local_labels = local_labels_5d.squeeze(1) - # Combined validation loss - loss = CE_loss + dice_loss_curr - val_loss_epoch += loss.item() - processed_batches += 1 + # Skip empty batches + if local_preds.size(0) == 0 or local_labels.size(0) == 0: + continue - # Dice score - if net.module.n_classes == 1: - assert mask_true_ddp.min() >= 0 and mask_true_ddp.max() <= 1, ( - "True mask indices should be in [0, 1]" - ) - mask_pred_bin = (F.sigmoid(mask_pred) > 0.5).float() - dice_score += dice_coeff( - mask_pred_bin, mask_true_ddp, reduce_batch_first=False + # Calculate CE and Dice loss in single precision for numerical stability. + with torch.autocast(device_type=autocast_device_type, enabled=False): + CE_loss = compute_sharded_cross_entropy_loss( + local_preds, + local_labels, + spatial_mesh, + parallel_strategy.num_shards, + autocast_device_type, + class_weights, ) - else: - assert ( - mask_true_ddp.min() >= 0 - and mask_true_ddp.max() < net.module.n_classes - ), "True mask indices should be in [0, n_classes]" - mask_pred_processed = F.softmax(mask_pred, dim=1).float() - mask_true_onehot_mc = ( - F.one_hot(mask_true_ddp, net.module.n_classes) + + mask_pred_probs = F.softmax(local_preds.float(), dim=1) + mask_true_onehot = ( + F.one_hot(local_labels, n_categories + 1) .permute(0, 4, 1, 2, 3) .float() ) - dice_score += multiclass_dice_coeff( - mask_pred_processed[:, 1:], - mask_true_onehot_mc[:, 1:], - reduce_batch_first=True, + dice_score_probs = compute_sharded_dice( + mask_pred_probs, mask_true_onehot, spatial_mesh + ) + batch_dice_sum, batch_sample_count = foreground_dice_stats( + dice_score_probs ) + batch_dice_score = batch_dice_sum / max(batch_sample_count, 1) + + # Sum global CE Loss and Dice loss + loss = CE_loss + (1.0 - batch_dice_score) + val_loss_epoch += loss.item() + total_dice_score += batch_dice_sum + processed_batches += 1 + processed_samples += batch_sample_count net.train() val_loss_avg = val_loss_epoch / max(processed_batches, 1) if primary: print( - f"evaluate.py: dice_score={dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}" + f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}, num_val_samples={processed_samples}" ) - return dice_score, val_loss_epoch, val_loss_avg, processed_batches + return ( + total_dice_score, + val_loss_epoch, + val_loss_avg, + processed_batches, + processed_samples, + ) diff --git a/ScaFFold/utils/losses.py b/ScaFFold/utils/losses.py new file mode 100644 index 0000000..3869b8b --- /dev/null +++ b/ScaFFold/utils/losses.py @@ -0,0 +1,166 @@ +# Copyright (c) 2014-2026, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LBANN/ScaFFold. +# +# SPDX-License-Identifier: (Apache-2.0) + +import math + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from ScaFFold.utils.dice_score import SpatialAllReduce + + +def _sample_ce_weight_indices(n_train, sample_fraction): + """Pick a small, deterministic subset of masks to estimate CE weights.""" + if n_train <= 0: + return [] + + if sample_fraction is None: + sample_fraction = 0.1 + + sample_count = min( + max(math.ceil(n_train * float(sample_fraction)), 1), + n_train, + ) + if sample_count == n_train: + return list(range(n_train)) + + return torch.linspace(0, n_train - 1, steps=sample_count).long().tolist() + + +def _compute_ce_class_weights( + train_set, + n_train, + n_categories, + device, + sample_fraction=0.1, + dist_enabled=False, + world_rank=0, + log=None, +): + """ + Estimate background vs foreground CE weights from a few training masks. + + Background keeps its own inverse-frequency weight, and every non-zero + fractal class shares the foreground weight derived from the aggregate + non-empty voxel count. + """ + + num_classes = n_categories + 1 + class_weights = torch.ones(num_classes, device=device, dtype=torch.float32) + + if n_train == 0: + if log is not None: + log.warning( + "Training set is empty while computing CE class weights. Falling back to uniform weights." + ) + return class_weights + + sample_indices = _sample_ce_weight_indices(n_train, sample_fraction) + sampled_class_counts = torch.zeros(num_classes, dtype=torch.long) + + for sample_idx in sample_indices: + mask = train_set[sample_idx]["mask"] + sampled_class_counts += torch.bincount(mask.reshape(-1), minlength=num_classes) + + # The dataset may already return only this rank's local spatial shard, + # so combine per-rank counts before deriving the global CE weights. + sampled_class_counts = sampled_class_counts.to(device=device) + if dist_enabled: + dist.all_reduce(sampled_class_counts, op=dist.ReduceOp.SUM) + + background_voxels = int(sampled_class_counts[0].item()) + foreground_voxels = int(sampled_class_counts[1:].sum().item()) + + if background_voxels > 0 and foreground_voxels > 0: + total_voxels = background_voxels + foreground_voxels + class_weights[0] = total_voxels / background_voxels + class_weights[1:] = total_voxels / foreground_voxels + elif log is not None: + log.warning( + "Sampled masks did not contain both background and foreground voxels. Falling back to uniform CE weights." + ) + + if log is not None and (not dist_enabled or world_rank == 0): + log.info( + f"CE weights estimated from {len(sample_indices)} training masks " + f"(sample_fraction={sample_fraction}, indices={sample_indices}): " + f"background_voxels={background_voxels} " + f"foreground_voxels={foreground_voxels} " + f"weights={class_weights.detach().cpu().tolist()}" + ) + + return class_weights + + +def compute_sharded_cross_entropy_loss( + local_preds, + local_labels, + spatial_mesh, + _num_shards, + device_type, + class_weights=None, +): + """ + Compute the CE loss for a spatially sharded volume. + + Each rank only sees a local spatial shard, so we cannot use the local + `reduction="mean"` result directly. Instead we: + 1. compute the local CE numerator with `reduction="sum"`, + 2. build the correct global denominator, + 3. all-reduce across the spatial mesh, and + 4. divide to recover the same value we would get from a non-sharded tensor. + + When `class_weights` is provided, PyTorch's CE "mean" divides by the sum of + the target weights, not the raw voxel count, so we reproduce that behavior + explicitly here. + """ + + autocast_device = device_type if device_type != "mps" else "cpu" + with torch.autocast(autocast_device, enabled=False): + # Accumulate CE in full precision. Using reduction="sum" gives us the + # numerator of the final global mean; if class weights are present, + # PyTorch applies the target-class weight to each voxel here. + local_ce_sum = F.cross_entropy( + local_preds.float(), + local_labels, + weight=class_weights, + reduction="sum", + ) + + if class_weights is None: + # Sum the actual local voxel counts across spatial shards. We use + # an all-reduced count instead of numel()*num_shards because shard + # sizes can differ at chunk boundaries. + local_voxel_count = local_ce_sum.new_tensor(float(local_labels.numel())) + global_normalizer = SpatialAllReduce.apply(local_voxel_count, spatial_mesh) + else: + # Weighted CE divides by sum(weight[target_i]) over all voxels. + # Build that denominator from the local label histogram, then + # all-reduce it across the spatial mesh. + local_class_counts = torch.bincount( + local_labels.reshape(-1), minlength=class_weights.numel() + ).to(dtype=local_ce_sum.dtype) + local_weight_sum = torch.dot( + local_class_counts, class_weights.to(dtype=local_ce_sum.dtype) + ) + global_normalizer = SpatialAllReduce.apply(local_weight_sum, spatial_mesh) + + # Sum the local CE numerators from each spatial shard to get the global CE + # numerator, then divide by the matching global denominator. + global_ce_sum = SpatialAllReduce.apply(local_ce_sum, spatial_mesh) + # Clamp to avoid a divide-by-zero in degenerate cases. + return global_ce_sum / global_normalizer.clamp_min( + torch.finfo(global_ce_sum.dtype).eps + ) diff --git a/ScaFFold/utils/perf_measure.py b/ScaFFold/utils/perf_measure.py index e0ac1a2..5af8d5b 100644 --- a/ScaFFold/utils/perf_measure.py +++ b/ScaFFold/utils/perf_measure.py @@ -27,8 +27,9 @@ from pycaliper.instrumentation import begin_region, end_region _CALI_PERF_ENABLED = True - except Exception: + except Exception as e: print("User requested Caliper annotations, but could not import Caliper") + print(f"Exception: {e}") elif ( TORCH_PERF_ENV_VAR in os.environ and os.environ.get(TORCH_PERF_ENV_VAR).lower() != "off" diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index c76646a..a1f77f5 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -13,34 +13,34 @@ # SPDX-License-Identifier: (Apache-2.0) # Standard library -import json import math import os -import random import shutil import time from pathlib import Path # Third party -import numpy as np import torch -import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from distconv import DCTensor from torch import optim -from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor from torch.utils.data import DataLoader from tqdm import tqdm from ScaFFold.utils.checkpointing import CheckpointManager -from ScaFFold.utils.data_loading import FractalDataset -from ScaFFold.utils.dice_score import dice_loss +from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec +from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_TORCH_DTYPE +from ScaFFold.utils.dice_score import compute_sharded_dice from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size # Local from ScaFFold.utils.evaluate import evaluate -from ScaFFold.utils.perf_measure import begin_code_region, end_code_region +from ScaFFold.utils.losses import ( + _compute_ce_class_weights, + compute_sharded_cross_entropy_loss, +) +from ScaFFold.utils.perf_measure import adiak_value, begin_code_region, end_code_region from ScaFFold.utils.utils import gather_and_print_mem @@ -54,6 +54,9 @@ def __init__(self, model, config, device, log): self.config = config self.device = device self.log = log + self.amp_device_type = self.device.type if self.device.type != "mps" else "cpu" + self.amp_dtype = AMP_DTYPE + self.use_grad_scaler = False self.world_size = get_world_size(required=self.config.dist) self.world_rank = get_world_rank(required=self.config.dist) self.local_rank = get_local_rank(required=self.config.dist) @@ -71,8 +74,17 @@ def __init__(self, model, config, device, log): self.scheduler = None self.grad_scaler = None self.criterion = None + self.ce_class_weights = None self.global_step = 0 self.start_epoch = -1 + self.ps = getattr(self.config, "_parallel_strategy", None) + self.spatial_mesh = None # Spatial mesh for use w/ DistConv + self.data_num_replicas = self.world_size + self.data_replica_rank = self.world_rank + if self.ps is not None: + self.spatial_mesh = self.ps.device_mesh[self.ps.distconv_dim_names] + self.data_num_replicas = self.ps.ddp_ranks + self.data_replica_rank = self.ps.ddp_ind self.checkpoint_path_absolute = str( self.config.run_dir + "/" + self.config.checkpoint_dir @@ -99,16 +111,29 @@ def create_dataset(self): val_mask_dir = dataset_dir / "masks/validation" train_unique_masks_path = dataset_dir / "train_unique_mask_vals" val_unique_masks_path = dataset_dir / "val_unique_mask_vals" + spatial_shard_spec = None + if self.ps is not None: + spatial_shard_spec = SpatialShardSpec( + shard_dims=tuple(self.ps.shard_dim), + num_shards=tuple(self.ps.num_shards), + shard_indices=tuple(self.ps.shard_ind), + ) self.train_set = FractalDataset( - train_vol_dir, train_mask_dir, data_dir=train_unique_masks_path + train_vol_dir, + train_mask_dir, + data_dir=train_unique_masks_path, + spatial_shard_spec=spatial_shard_spec, ) self.val_set = FractalDataset( - val_vol_dir, val_mask_dir, data_dir=val_unique_masks_path + val_vol_dir, + val_mask_dir, + data_dir=val_unique_masks_path, + spatial_shard_spec=spatial_shard_spec, ) self.n_train = len(self.train_set) self.n_val = len(self.val_set) - self.log.debug( + self.log.info( f"Datasets created with n_train={self.n_train}, n_val={self.n_val}" ) @@ -116,10 +141,15 @@ def create_sampler(self): """Create DistributedSamplers for train and validation datasets.""" if self.config.dist: self.train_sampler = torch.utils.data.distributed.DistributedSampler( - self.train_set + self.train_set, + num_replicas=self.data_num_replicas, + rank=self.data_replica_rank, ) self.val_sampler = torch.utils.data.distributed.DistributedSampler( - self.val_set, shuffle=False + self.val_set, + num_replicas=self.data_num_replicas, + rank=self.data_replica_rank, + shuffle=False, ) else: self.train_sampler = torch.utils.data.RandomSampler(self.train_set) @@ -130,18 +160,31 @@ def create_dataloaders(self): self.create_dataset() self.create_sampler() + num_workers = self.config.dataloader_num_workers loader_args = dict( - batch_size=self.config.batch_size, num_workers=1, pin_memory=True + batch_size=self.config.batch_size, + num_workers=num_workers, + pin_memory=True, ) + if num_workers > 0: + loader_args["persistent_workers"] = True + loader_args["prefetch_factor"] = 2 self.log.debug( - f"dataloader num_workers={loader_args['num_workers']}, os.cpu_count()={os.cpu_count()}, self.world_size={self.world_size} " + f"dataloader num_workers={loader_args['num_workers']}, prefetch_factor={loader_args.get('prefetch_factor')}, persistent_workers={loader_args.get('persistent_workers', False)}, os.cpu_count()={os.cpu_count()}, self.world_size={self.world_size} " ) self.train_loader = DataLoader( self.train_set, sampler=self.train_sampler, **loader_args ) self.val_loader = DataLoader( - self.val_set, sampler=self.val_sampler, drop_last=True, **loader_args + self.val_set, sampler=self.val_sampler, drop_last=False, **loader_args ) + if len(self.val_loader) == 0: + raise ValueError( + "Validation DataLoader has zero batches. " + f"n_val={self.n_val}, batch_size={self.config.batch_size}, " + f"data_num_replicas={self.data_num_replicas}. " + "Reduce batch_size or adjust validation sharding." + ) def setup_training_components(self): """Set up the optimizer, scheduler, gradient scaler, and loss function.""" @@ -149,38 +192,81 @@ def setup_training_components(self): if self.config.optimizer == "ADAM": self.log.info("Using ADAM optimizer.") self.optimizer = optim.Adam( - self.model.parameters(), lr=self.config.learning_rate + self.model.parameters(), lr=self.config.starting_learning_rate ) elif self.config.optimizer == "SGD": self.log.info("Using SGD optimizer.") self.optimizer = optim.SGD( - self.model.parameters(), lr=self.config.learning_rate + self.model.parameters(), lr=self.config.starting_learning_rate ) else: self.log.info("Using RMSprop optimizer.") self.optimizer = optim.RMSprop( - self.model.parameters(), lr=self.config.learning_rate, foreach=True + self.model.parameters(), + lr=self.config.starting_learning_rate, + foreach=True, ) # Set up learning rate scheduler - self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer, "max", patience=25 + self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( + self.optimizer, + T_0=self.config.T_0, + T_mult=self.config.T_mult, + eta_min=self.config.min_learning_rate, ) # Set up gradient scaler for AMP (Automatic Mixed Precision) - self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.config.torch_amp) + # bfloat does not need grad scaler + self.use_grad_scaler = ( + self.config.torch_amp and self.amp_dtype != torch.bfloat16 + ) + self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_grad_scaler) # Set up loss function - self.criterion = ( - nn.CrossEntropyLoss() - if self.config.n_categories + 1 > 1 - else nn.BCEWithLogitsLoss() - ) + if self.config.n_categories + 1 > 1: + ce_class_weights = _compute_ce_class_weights( + train_set=self.train_set, + n_train=self.n_train, + n_categories=self.config.n_categories, + device=self.device, + sample_fraction=self.config.ce_weight_sample_fraction, + dist_enabled=self.config.dist, + world_rank=self.world_rank, + log=self.log, + ) + self.criterion = nn.CrossEntropyLoss(weight=ce_class_weights).to( + self.device + ) + else: + self.criterion = nn.BCEWithLogitsLoss().to(self.device) + if isinstance(self.criterion, nn.CrossEntropyLoss): + self.ce_class_weights = self.criterion.weight self.log.info( - f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, Gradient Scaler Enabled: {self.config.torch_amp}" + f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, AMP dtype: {self.amp_dtype}, Gradient Scaler Enabled: {self.use_grad_scaler}" ) + def _autocast_kwargs(self, enabled=None): + if enabled is None: + enabled = self.config.torch_amp + + kwargs = {"device_type": self.amp_device_type, "enabled": enabled} + if enabled: + kwargs["dtype"] = self.amp_dtype + return kwargs + + @staticmethod + def _foreground_dice_mean(dice_scores): + """Match optimization to the reported validation metric by excluding background.""" + if dice_scores.size(1) > 1: + return dice_scores[:, 1:].mean() + return dice_scores.mean() + + def _current_learning_rate(self): + if self.optimizer is None or not self.optimizer.param_groups: + return self.config.starting_learning_rate + return self.optimizer.param_groups[0]["lr"] + class PyTorchTrainer(BaseTrainer): """ @@ -239,6 +325,15 @@ def cleanup_or_resume(self): "train_mask_values" ] + # If we loaded a checkpoint (start_epoch > 1), we must ensure the CSV + # matches the state of that checkpoint. + if ( + self.world_rank == 0 + and self.start_epoch > 1 + and os.path.exists(self.outfile_path) + ): + self._truncate_stats_file(self.start_epoch) + # Set up the output file headers headers = [ "epoch", @@ -254,102 +349,279 @@ def cleanup_or_resume(self): with open(self.outfile_path, "a", newline="") as outfile: outfile.write(",".join(headers) + "\n") - def train(self): + def _truncate_stats_file(self, start_epoch, path=None): """ - Execute model training + Scans the stats file and truncates it at the first occurrence of + an epoch >= start_epoch. This is O(1) memory and safe for large logs. """ + if path is None: + path = self.outfile_path + self.log.info(f"Truncating {path} to remove epochs >= {start_epoch}") + + try: + # Open in read+update mode ('r+') to allow seeking and truncating + with open(path, "r+") as f: + header = f.readline() + if not header: + return + + # Identify the index of the 'epoch' column + headers = header.strip().split(",") + try: + epoch_idx = headers.index("epoch") + except ValueError: + epoch_idx = 0 + + while True: + # Save the current file position (start of the line) + current_pos = f.tell() + line = f.readline() + + # End of file reached + if not line: + break + + parts = line.strip().split(",") + try: + row_epoch = int(float(parts[epoch_idx])) + + # If we find a row that is "from the future" (or the current restarting epoch) + if row_epoch >= start_epoch: + # Move pointer back to the start of this line + f.seek(current_pos) + # Cut the file off right here + f.truncate() + self.log.info( + f"Truncated stats file at byte {current_pos} (found epoch {row_epoch})" + ) + break + except (ValueError, IndexError): + # Skip malformed lines, or decide to stop. + # Usually safe to continue scanning. + pass - self.cleanup_or_resume() + except Exception as e: + self.log.warning(f"Failed to truncate stats file {path}: {e}") + + def _get_memsize(self, tensor, tensor_label: str, verbosity: int = 0): + """Log size of tensor in memory""" + + if verbosity < 2: + return + tensor_memory_bytes = tensor[0].element_size() * tensor[0].nelement() + tensor_memory_gb = tensor_memory_bytes / (1024**3) + self.log.info(f"{tensor_label} size on GPU: {tensor_memory_gb:.2f} GB") + + def _run_training_batch( + self, + batch, + *, + log_prefix="", + gather_mem_stats=False, + log_peak_mem=False, + ): + """Run one training batch and return batch size, detached loss, and dice.""" + images, true_masks = batch["image"], batch["mask"] + + begin_code_region("image_to_device") + images = images.to( + device=self.device, + dtype=VOLUME_TORCH_DTYPE, + memory_format=torch.channels_last_3d, + non_blocking=True, + ) + true_masks = true_masks.to( + device=self.device, dtype=torch.long, non_blocking=True + ).contiguous() + end_code_region("image_to_device") + if gather_mem_stats: + gather_and_print_mem(self.log, "after_batch_to_device") + + # Add a dummy channel dimension to get 5D [B, 1, D, H, W] + true_masks = true_masks.unsqueeze(1) + + # Inputs are already loaded as local shards by the dataset. + images_dc = DCTensor.from_shard(images, self.ps) + true_masks_dc = DCTensor.from_shard(true_masks, self.ps) + del images, true_masks + self._get_memsize(images_dc, "Sharded image", self.config.verbose) + + with torch.autocast(**self._autocast_kwargs()): + if gather_mem_stats: + torch.cuda.reset_peak_memory_stats() + gather_and_print_mem(self.log, "pre_forward") + begin_code_region("predict") + self.log.debug(f" {log_prefix}running forward pass") + masks_pred_dc = self.model(images_dc) + end_code_region("predict") + if gather_mem_stats: + gather_and_print_mem(self.log, "post_forward") + self.log.debug(f" {log_prefix}forward pass complete") + + # Extract the underlying PyTorch local tensors + local_preds = masks_pred_dc + local_labels_5d = true_masks_dc + + # Remove the dummy channel dimension so CE Loss is happy [B, D, H, W] + local_labels = local_labels_5d.squeeze(1) + if self.world_rank == 0: + self.log.debug(f" {log_prefix}Local Preds Shape: {local_preds.shape}") + self.log.debug( + f" {log_prefix}Local Labels Shape: {local_labels.shape}" + ) - warmup_epochs = self.config.warmup_epochs - if warmup_epochs > 0: - begin_code_region("warmup") - # Keep BN/Dropout from changing behavior/statistics - self.model.eval() - start_warmup = time.time() - self.log.info(f"Running {warmup_epochs} warmup epoch(s)") + begin_code_region("calculate_loss") + current_mem = torch.cuda.memory_allocated() / (1024**3) + self.log.debug( + f" {log_prefix}Calculating sharded loss. Mem: {current_mem:.2f} GB." + ) - ps = getattr(self.config, "_parallel_strategy", None) + # Calculate CE and Dice loss in single precision for numerical stability. + with torch.autocast(**self._autocast_kwargs(enabled=False)): + loss_ce = compute_sharded_cross_entropy_loss( + local_preds, + local_labels, + self.spatial_mesh, + self.config.dc_num_shards, + self.amp_device_type, + self.ce_class_weights, + ) - for _ in range(warmup_epochs): - for batch in self.train_loader: - images, true_masks = batch["image"], batch["mask"] + local_preds_softmax = F.softmax(local_preds.float(), dim=1) + local_labels_one_hot = ( + F.one_hot(local_labels, num_classes=self.config.n_categories + 1) + .permute(0, 4, 1, 2, 3) + .float() + ) + dice_scores = compute_sharded_dice( + local_preds_softmax, + local_labels_one_hot, + self.spatial_mesh, + ) + batch_dice_score = self._foreground_dice_mean(dice_scores) - images = images.to( - device=self.device, - dtype=torch.float32, - memory_format=torch.channels_last_3d, - non_blocking=False, - ) - images_dc = DCTensor.distribute(images, ps) + # Sum global CE Loss and Dice loss + loss = loss_ce + (1.0 - batch_dice_score) + end_code_region("calculate_loss") - true_masks = true_masks.to( - device=self.device, dtype=torch.long, non_blocking=True - ) + self.log.debug( + f" {log_prefix}loss calculation complete. Proceeding to backward pass" + ) + if gather_mem_stats: + gather_and_print_mem(self.log, "pre_backward") + begin_code_region("backward") + self.grad_scaler.scale(loss).backward() + end_code_region("backward") + if gather_mem_stats: + gather_and_print_mem(self.log, "post_backward") + + begin_code_region("step_and_update") + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + self.log.debug(f" {log_prefix}backward pass complete. Stepping optimizer") + self.grad_scaler.step(self.optimizer) + if gather_mem_stats: + gather_and_print_mem(self.log, "after_optim_step") + self.grad_scaler.update() + self.optimizer.zero_grad(set_to_none=False) + end_code_region("step_and_update") + + batch_size = images_dc.shape[0] + detached_loss = loss.detach() + + # Free memory aggressively + del images_dc, true_masks_dc, masks_pred_dc + del local_preds, local_labels, local_preds_softmax, local_labels_one_hot + del loss_ce, loss + + if log_peak_mem and self.world_rank == 0: + peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) + peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) + self.log.debug( + f"[MEM-PEAK] Peak alloc: {peak_alloc:.2f} GiB | Peak reserved: {peak_reserved:.2f} GiB", + ) - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=self.config.torch_amp, - ): - # Forward on DCTensor - masks_pred_dc = self.model(images_dc) - - # Convert predictions for loss - if images.size(0) < ps.num_shards: - # For small batches (e.g., N=1 with dc_num_shards=2), replicate outputs - masks_pred = masks_pred_dc.to_replicate() - labels_for_loss = true_masks - else: - # Otherwise, shard labels across batch dim to match to_ddp layout - masks_pred = masks_pred_dc.to_ddp() - dt_labels = distribute_tensor( - true_masks, - device_mesh=ps.device_mesh["dc"], - placements=[Shard(0)], - ) - labels_for_loss = dt_labels.to_local() + return batch_size, detached_loss, batch_dice_score - CE_loss = self.criterion(masks_pred, labels_for_loss) + def warmup(self): + """Run warmup iterations before the main training loop.""" + warmup_batches = self.config.warmup_batches + if warmup_batches <= 0: + return - # Calculate the train dice loss - masks_pred_softmax = F.softmax(masks_pred, dim=1).float() - true_masks_onehot = ( - F.one_hot(labels_for_loss, self.config.n_categories + 1) - .permute(0, 4, 1, 2, 3) - .float() - ) - train_dice_curr = dice_loss( - masks_pred_softmax, - true_masks_onehot, - multiclass=True, - ) - loss = CE_loss + train_dice_curr + if self.config.dist: + self.train_loader.sampler.set_epoch(0) - # Fine as long as we don't step/update - self.grad_scaler.scale(loss).backward() + start_warmup = time.time() + max_batches = min(warmup_batches, len(self.train_loader)) + max_val_batches = min(warmup_batches, len(self.val_loader)) + self.log.info( + f"Running {max_batches} training warmup batch(es) and {max_val_batches} validation warmup batch(es) per rank" + ) + snapshot = self.checkpoint_manager.snapshot_training_state() + + # Match the main training path as closely as possible, but roll back all + # mutable state so warmup does not affect convergence. + self.model.train() + self.optimizer.zero_grad(set_to_none=False) + + try: + for batch_idx, batch in enumerate(self.train_loader): + if batch_idx >= max_batches: + break + + self._run_training_batch( + batch, + log_prefix="warmup: ", + log_peak_mem=True, + ) + batch_t_end = time.time() + self.log.debug( + f" warmup: batch {batch_idx} completed in {batch_t_end - start_warmup} seconds" + ) - # Nuke any accumulated grads so the first real step starts clean - for p in self.model.parameters(): - p.grad = None + if self.config.dist: + self.val_loader.sampler.set_epoch(0) + + if max_val_batches > 0: + self.log.debug(" warmup: running validation warmup pass") + evaluate( + self.model, + self.val_loader, + self.device, + self.config.torch_amp, + False, + self.criterion, + self.config.n_categories, + self.config._parallel_strategy, + max_batches=max_val_batches, + ) + finally: + self.checkpoint_manager.restore_training_state(snapshot) + + if self.config.dist: torch.distributed.barrier() - end_code_region("warmup") - self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") + self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") + + def train(self): + """ + Execute model training + """ + epoch = self.start_epoch + dice_score_train = 0 with open(self.outfile_path, "a", newline="") as outfile: start = time.time() - for epoch in range(self.start_epoch, self.config.epochs + 1): - # DistConv ParallelStrategy - ps = getattr(self.config, "_parallel_strategy", None) - if ps is None: - raise RuntimeError( - "ParallelStrategy not found in config. Set config._parallel_strategy when wrapping model with DistConvDDP." + while dice_score_train < self.config.target_dice: + if self.config.epochs != -1 and epoch > self.config.epochs: + print( + f"Maxmimum epochs reached '{self.config.epochs}'. Concluding training early (may have not converged)." ) + break # Timer and tracking variables epoch_start_time = time.time() - train_dice_curr = 0 train_dice_total = 0 - CE_loss = 0 epoch_loss = 0 # Accumulator for per-batch losses # Set necessary modes/states @@ -357,142 +629,46 @@ def train(self): self.train_loader.sampler.set_epoch(epoch) self.val_loader.sampler.set_epoch(epoch) self.model.train() + self.optimizer.zero_grad(set_to_none=False) + estr = ( + f"{epoch}" + if self.config.epochs == -1 + else f"{epoch}/{self.config.epochs}" + ) with tqdm( - total=self.n_train // self.world_size, + total=len(self.train_sampler), desc=f"({os.path.basename(self.config.run_dir)}) \ - Epoch {epoch}/{self.config.epochs}", + Epoch {estr}", unit="img", disable=True if self.world_rank != 0 else False, ) as pbar: - batch_step = 0 - begin_code_region("batch_loop") - for batch in self.train_loader: - images, true_masks = batch["image"], batch["mask"] - - begin_code_region("image_to_device") - images = images.to( - device=self.device, - dtype=torch.float32, - memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first) - non_blocking=True, - ) - true_masks = true_masks.to( - device=self.device, dtype=torch.long, non_blocking=True - ).contiguous() # masks no channels NDHW, but ensure continuity. - end_code_region("image_to_device") - gather_and_print_mem(self.log, "after_batch_to_device") - - # Replicate batch across dc mesh, shard batch across ddp mesh. - # This ensures all dc ranks in the same ddp group see the same samples, - # and ddp ranks see disjoint samples. - images_dp = distribute_tensor( - images, ps.device_mesh, placements=[Shard(0), Replicate()] - ).to_local() - true_masks_dp = distribute_tensor( - true_masks, - ps.device_mesh, - placements=[Shard(0), Replicate()], - ).to_local() - - with torch.autocast( - self.device.type if self.device.type != "mps" else "cpu", - enabled=self.config.torch_amp, - ): - # Predict on this batch - torch.cuda.reset_peak_memory_stats() - gather_and_print_mem(self.log, "pre_forward") - begin_code_region("predict") - - # Spatially shard the chosen dimension across dc mesh - dcx = DCTensor.distribute(images_dp, ps) - dcy = self.model(dcx) - # Convert back to batch-sharded layout across the dc mesh - masks_pred = dcy.to_ddp() - - end_code_region("predict") - gather_and_print_mem(self.log, "post_forward") - - # Reshard labels across dc mesh to match masks_pred's batch partition - # Start from dc-replicated labels, then shard batch across dc - true_masks_ddp = ( - DTensor.from_local( - true_masks_dp, - device_mesh=ps.device_mesh["dc"], - placements=[Replicate()], - ) - .redistribute( - device_mesh=ps.device_mesh["dc"], - placements=[Shard(0)], - ) - .to_local() + for batch_idx, batch in enumerate(self.train_loader): + time_minibatch = batch_idx == 0 and self.world_rank == 0 + if time_minibatch: + minibatch_start_time = time.perf_counter() + batch_size, batch_loss, batch_dice_score = ( + self._run_training_batch( + batch, + gather_mem_stats=True, ) - - begin_code_region("calculate_loss") - # Calculate the loss - if self.config.n_categories + 1 == 1: - loss = self.criterion( - masks_pred.squeeze(1), true_masks_ddp.float() - ) - loss += dice_loss( - F.sigmoid(masks_pred.squeeze(1)), - true_masks_ddp.float(), - multiclass=False, - ) - else: - # Calculate the CrossEntropy loss - CE_loss = self.criterion(masks_pred, true_masks_ddp) - - # Calculate the train dice loss - masks_pred_softmax = F.softmax( - masks_pred, dim=1 - ).float() - true_masks_onehot = ( - F.one_hot( - true_masks_ddp, self.config.n_categories + 1 - ) - .permute(0, 4, 1, 2, 3) - .float() - ) - train_dice_curr = dice_loss( - masks_pred_softmax, - true_masks_onehot, - multiclass=True, - ) - - # Our loss function is CE loss + dice loss - loss = CE_loss + train_dice_curr - - # Track the train dice loss separately for debugging - train_dice_total += train_dice_curr - end_code_region("calculate_loss") - - gather_and_print_mem(self.log, "pre_backward") - begin_code_region("backward") - self.grad_scaler.scale(loss).backward() - end_code_region("backward") - gather_and_print_mem(self.log, "post_backward") - - begin_code_region("step_and_update") - if batch_step + 1 == len(self.train_loader): - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), max_norm=1.0 - ) - self.grad_scaler.step(self.optimizer) - gather_and_print_mem(self.log, "after_optim_step") - - self.grad_scaler.update() - self.optimizer.zero_grad(set_to_none=False) - end_code_region("step_and_update") + ) + train_dice_total += batch_dice_score # Update the loss begin_code_region("update_loss") - pbar.update(images_dp.shape[0]) + pbar.update(batch_size) self.global_step += 1 - batch_step += 1 # Stay on GPU - epoch_loss += loss.detach() + epoch_loss += batch_loss + if time_minibatch: + # This sync has some potential performance impact + # TODO: Would be better to measure this with Caliper, which uses CUDA events. + torch.cuda.synchronize(self.device) + minibatch_time_s = ( + time.perf_counter() - minibatch_start_time + ) end_code_region("update_loss") end_code_region("batch_loop") @@ -502,17 +678,25 @@ def train(self): # # Evaluate model on validation set, update LR if necessary # - dice_sum, val_loss_epoch, val_loss_avg, numbatch = evaluate( + ( + dice_sum, + val_loss_epoch, + val_loss_avg, + numbatch, + numsamples, + ) = evaluate( self.model, self.val_loader, self.device, self.config.torch_amp, - True if self.world_rank == 0 else False, + self.world_rank == 0, self.criterion, self.config.n_categories, self.config._parallel_strategy, ) - dice_info = torch.tensor([dice_sum, numbatch]) + dice_info = torch.tensor( + [dice_sum, numsamples], dtype=VOLUME_TORCH_DTYPE + ) if self.config.dist: dice_info = dice_info.to(device=self.device) torch.distributed.all_reduce( @@ -520,16 +704,7 @@ def train(self): ) val_score = dice_info[0].item() / max(dice_info[1].item(), 1) if not self.config.disable_scheduler: - # The following is true when trying to overfit, - # in which case we only care about train loss - if self.n_train == 1 or "overfit" in self.outfile_path: - self.log.debug( - "WARNING: scheduler step by overall_loss, \ - not val_score (n_train==1 or overfit in outfile_path)" - ) - self.scheduler.step(overall_loss) - else: # Otherwise, we're really trying to optimize for validation dice score - self.scheduler.step(val_score) + self.scheduler.step() else: self.log.debug("scheduler disabled, no LR update this step") @@ -538,11 +713,9 @@ def train(self): # # Write out data for this epoch to train stats csv # - train_dice = float(train_dice_total / len(self.train_loader)) + train_dice = float(train_dice_total.item() / len(self.train_loader)) self.log.info( - f" epoch {epoch} \ - | train_dice_loss {train_dice:.6f} (type {type(train_dice)}) \ - | val_dice_score {val_score:.6f}" + f" epoch {epoch} | train_dice_score {train_dice:.6f} | val_dice_score {val_score:.6f} | lr {self._current_learning_rate():.8f}" ) self.log.debug(f" writing to csv at {self.outfile_path}") if self.world_rank == 0: @@ -563,7 +736,7 @@ def train(self): ) outfile.flush() print( - f"Epoch {epoch} completed in {epoch_duration} seconds. Total train time so far: {time.time() - start}" + f"Epoch {epoch} completed in {epoch_duration} seconds. Total train time so far: {time.time() - start}. Rank 0 first batch minibatch_time_s={minibatch_time_s:.6f}." ) # @@ -571,13 +744,23 @@ def train(self): # begin_code_region("checkpoint") - extras = {"train_mask_values": self.train_set.mask_values} - self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras) + # A checkpoint interval of -1 disables checkpointing entirely. + if ( + self.config.checkpoint_interval > 0 + and epoch % self.config.checkpoint_interval == 0 + ): + extras = {"train_mask_values": self.train_set.mask_values} + self.checkpoint_manager.save_checkpoint(epoch, val_loss_avg, extras) end_code_region("checkpoint") - if val_score >= 0.95: - self.log.info( - f"val_score of {val_score} is > threshold of 0.95. Benchmark run complete. Wrapping up..." + dice_score_train = val_score + epoch += 1 + + # This check must exist otherwise the condition dice_score_train < self.config.target_dice will evaluate to False and incorrectly exit the training + if math.isnan(dice_score_train): + raise ValueError( + "Invalid value (NaN) encountered in dice score computation" ) - return 0 + + adiak_value("final_epochs", epoch) diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 33f8949..fde4582 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -12,6 +12,7 @@ # # SPDX-License-Identifier: (Apache-2.0) +import math import os import socket import sys @@ -23,14 +24,13 @@ import torch import torch.distributed as dist import yaml -from distconv import DCTensor, DistConvDDP, ParallelStrategy -from torch.nn.parallel import DistributedDataParallel as DDP +from distconv import DistConvDDP, ParallelStrategy +from torch.distributed.tensor import Replicate, Shard from ScaFFold.datagen.get_dataset import get_dataset from ScaFFold.unet import UNet from ScaFFold.utils.distributed import ( get_device, - get_job_id, get_local_rank, get_local_size, get_world_rank, @@ -38,6 +38,7 @@ initialize_dist, ) from ScaFFold.utils.perf_measure import ( + adiak_value, annotate, begin_code_region, end_code_region, @@ -161,10 +162,10 @@ def main(kwargs_dict: dict = {}): # Initialize model begin_code_region("init_model") - config.dc_num_shards = getattr(config, "dc_num_shards", config.num_shards) - config.dc_shard_dim = getattr(config, "dc_shard_dim", config.shard_dim) + config.dc_num_shards = getattr(config, "dc_num_shards", config.dc_num_shards) + config.dc_shard_dims = getattr(config, "dc_shard_dims", config.dc_shard_dims) log.info( - f"DistConv num_shards={config.dc_num_shards}, shard_dim={config.dc_shard_dim}" + f"DistConv num_shards={config.dc_num_shards}, shard_dim={config.dc_shard_dims}" ) device = get_device() log.info(f"Using device: {device}") @@ -176,18 +177,18 @@ def main(kwargs_dict: dict = {}): ) if config.dist: # DDP + DistConv setup - # Ensure world_size is divisible by dc_num_shards - assert dist.get_world_size() % config.dc_num_shards == 0, ( - f"world_size={dist.get_world_size()} must be divisible by dc_num_shards={config.dc_num_shards}" + # Ensure world_size is divisible by total distconv shards + assert dist.get_world_size() % math.prod(config.dc_num_shards) == 0, ( + f"world_size={dist.get_world_size()} must be divisible by total number of distconv shards = {math.prod(config.dc_num_shards)}" ) - # Select which full-tensor dim to shard: 2 + dc_shard_dim - shard_dim = 2 + int(config.dc_shard_dim) + ps = ParallelStrategy( - num_shards=int(config.dc_num_shards), - shard_dim=shard_dim, + num_shards=config.dc_num_shards, + shard_dim=config.dc_shard_dims, device_type=device.type, ) - model = model.to(device).to(memory_format=torch.contiguous_format) + + model = model.to(device, memory_format=torch.channels_last_3d) # Wrap with DistConvDDP that corrects gradient scaling for dc submesh model = DistConvDDP( model, @@ -213,6 +214,41 @@ def main(kwargs_dict: dict = {}): torch.backends.cudnn.benchmark = False torch.use_deterministic_algorithms(True, warn_only=True) trainer = PyTorchTrainer(model, config, device, log) + trainer.ps = ps + trainer.spatial_mesh = ps.device_mesh[ps.distconv_dim_names] + num_spatial_dims = len(ps.shard_dim) + trainer.ddp_placements = [Shard(0)] + [Replicate()] * num_spatial_dims + total_shards = math.prod(config.dc_num_shards) + global_batch_size = config.batch_size * (world_size // total_shards) + ddp_ranks = world_size // total_shards + adiak_value("global_batch_size", global_batch_size) + adiak_value("ddp_ranks", ddp_ranks) + adiak_value("total_shards", total_shards) + adiak_value("num_spatial_dims", num_spatial_dims) + if rank == 0: + log.info( + f"Effective global batch size = {global_batch_size} " + f"(batch_size={config.batch_size} * " + f"(world_size={world_size} / prod(dc_num_shards)={total_shards}))" + ) + log.info( + f"DDP ranks = {ddp_ranks} " + f"world_size={world_size} // prod(dc_num_shards)={total_shards}" + ) + too_small_splits = [] + if global_batch_size > trainer.n_train: + too_small_splits.append(f"training n_train={trainer.n_train}") + if global_batch_size > trainer.n_val: + too_small_splits.append(f"validation n_val={trainer.n_val}") + if too_small_splits: + raise ValueError( + "Effective global batch size exceeds available samples: " + f"global_batch_size={global_batch_size}, " + f"{', '.join(too_small_splits)}, " + f"batch_size={config.batch_size}, world_size={world_size}, " + f"dc_num_shards={config.dc_num_shards}" + ) + else: raise RuntimeError( "Invalid framework specified. Currently [torch] is the supported framework." @@ -224,6 +260,12 @@ def main(kwargs_dict: dict = {}): ranks_per_node = get_local_size() prof_ctx, TORCH_PERF_LOCAL = get_torch_context(ranks_per_node, rank) with prof_ctx as prof: + begin_code_region("cleanup_or_resume") + trainer.cleanup_or_resume() + end_code_region("cleanup_or_resume") + begin_code_region("warmup") + trainer.warmup() + end_code_region("warmup") begin_code_region("train") trainer.train() end_code_region("train") @@ -239,17 +281,22 @@ def main(kwargs_dict: dict = {}): outfile_path = trainer.outfile_path train_data = np.genfromtxt(outfile_path, dtype=float, delimiter=",", names=True) total_train_time = train_data["epoch_duration"].sum() - total_epochs = train_data["epoch"][-1] - log.info( - f"Benchmark run at scale {config.problem_scale} complete. \n\ - Trained to >= 0.95 validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs." - ) + epochs = np.atleast_1d(train_data["epoch"]) + total_epochs = int(epochs[-1]) + if config.epochs == -1: + extra_msg = f"Trained to >= {config.target_dice} validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs." + else: + extra_msg = ( + f"Completed in {total_train_time:.2f} seconds, {total_epochs} epochs." + ) + + log.info(f"Benchmark run at scale {config.problem_scale} complete. \n{extra_msg}") # # Generate plots # if rank == 0: - log.info(f"Generating figures on rank 0...") + log.info("Generating figures on rank 0...") begin_code_region("generate_figures") standard_viz.main(config) end_code_region("generate_figures") diff --git a/pyproject.toml b/pyproject.toml index 6943402..70eb0d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,14 +43,14 @@ authors = [ ] license = { file = "LICENSE" } dependencies = [ - "hpc-launcher>=1.0.3", + "hpc-launcher>=1.0.4", "matplotlib>=3.9.4", "numpy>=1.26.4", "numba>=0.60.0", "tqdm>=4.67.1", "wandb>=0.19.6", - "open3d>=0.18.0", "PyYAML>=6.0.2", + "distconv @ git+https://github.com/LBANN/DistConv.git@232cba6", ] requires-python = ">=3.9" @@ -59,16 +59,16 @@ profiling = [ "pybind11>=3.0.0" ] cuda = [ - "torch==2.8.0+cu126", - "mpi4py==4.0.2", + "torch==2.10.0+cu129", + "mpi4py==4.1.1", ] rocm = [ - "torch==2.8.0+rocm6.4", - "mpi4py==4.0.2", + "torch==2.12.0+rocm7.2", + "mpi4py==4.1.1", ] rocmwci = [ - "torch==2.8.0+rocm642", - "mpi4py==4.1.1.dev0+mpich.8.1.32", + "torch==2.10.0+rocm710", + "mpi4py==4.1.1+mpich.9.1.0", ] [project.entry-points.console_scripts] diff --git a/requirements.txt b/requirements.txt index 0f8fb78..e0165fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,10 @@ --index-url https://pypi.org/simple -hpc-launcher>=1.0.3 +hpc-launcher>=1.0.4 matplotlib>=3.9.4 numpy>=1.26.4 numba>=0.60.0 tqdm>=4.67.1 wandb>=0.19.6 -open3d>=0.18.0 PyYAML>=6.0.2 -mpi4py==4.0.2 --no-binary mpi4py - -# cuda -# torch==2.7.1+cu126 -# torchvision==0.22.1+cu126 -# torchaudio==2.7.1+cu126 -# --extra-index-url https://download.pytorch.org/whl/cu126 - -# rocm -# torch==2.8.0+rocm6. -# torchvision==0.23.0+rocm6.4 -# torchaudio==2.8.0+rocm6.4 -# --extra-index-url https://download.pytorch.org/whl/test/rocm6.4 +mpi4py==4.1.1 --no-binary mpi4py +distconv @ git+https://github.com/LBANN/DistConv.git@232cba6 diff --git a/scripts/install-matrix.sh b/scripts/install-matrix.sh index f72d2cb..15c4e6d 100644 --- a/scripts/install-matrix.sh +++ b/scripts/install-matrix.sh @@ -1,5 +1,4 @@ ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv-matrix && source .venvs/scaffoldvenv-matrix/bin/activate && pip install --upgrade pip -ml cuda/12.6.0 gcc/12.1.1 mvapich2/2.3.7 +ml cuda/12.9.1 gcc/13.3.1 mvapich2/2.3.7 export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH -git clone git@github.com:LBANN/DistConv.git -pip install --no-binary=mpi4py -e .[cuda] DistConv/ --prefix=.venvs/scaffoldvenv-matrix --extra-index-url https://download.pytorch.org/whl/cu126 2>&1 | tee install.log +pip install --no-binary=mpi4py -e .[cuda] --prefix=.venvs/scaffoldvenv-matrix --extra-index-url https://download.pytorch.org/whl/cu129 2>&1 | tee install.log diff --git a/scripts/install-tuolumne-torchpypi.sh b/scripts/install-tuolumne-torchpypi.sh new file mode 100644 index 0000000..91876c4 --- /dev/null +++ b/scripts/install-tuolumne-torchpypi.sh @@ -0,0 +1,5 @@ +ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv-tuo-pypi && source .venvs/scaffoldvenv-tuo-pypi/bin/activate && pip install --upgrade pip +ml cce/21.0.1 cray-mpich/9.1.0 rocm/7.2.1 rccl/fast-env-slows-mpi +pip install -e .[rocm] --find-links https://download.pytorch.org/whl/torch/ --find-links https://download.pytorch.org/whl/triton-rocm/ 2>&1 | tee install.log +# libmpi.so.12 does not exist => ls /opt/cray/pe/lib64/ | grep libmpi +patchelf --replace-needed libmpi.so.12 libmpi_gnu.so.12 .venvs/scaffoldvenv-tuo-pypi/lib/python3.11/site-packages/mpi4py/MPI.mpich.cpython-311-x86_64-linux-gnu.so diff --git a/scripts/install-tuolumne.sh b/scripts/install-tuolumne.sh index 8e74416..4b026f8 100644 --- a/scripts/install-tuolumne.sh +++ b/scripts/install-tuolumne.sh @@ -1,4 +1,26 @@ ml load python/3.11.5 && python3 -m venv .venvs/scaffoldvenv-tuo && source .venvs/scaffoldvenv-tuo/bin/activate && pip install --upgrade pip -ml load rocm/6.4.2 rccl/fast-env-slows-mpi libfabric -git clone git@github.com:LBANN/DistConv.git -pip install -e .[rocmwci] DistConv/ --prefix=.venvs/scaffoldvenv-tuo 2>&1 | tee install.log +ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi +pip install -e .[rocmwci] 2>&1 | tee install.log +# Needed until new wheel exists for torch using mpich 9.1.0 +TORCH_LIB_DIR=".venvs/scaffoldvenv-tuo/lib/python3.11/site-packages/torch/lib" +OLD="libmpi_gnu_112.so.12" +NEW="libmpi_gnu.so.12" +cd "$TORCH_LIB_DIR" || exit 1 +# Patch every file that has OLD in its DT_NEEDED +for f in *.so*; do + [ -f "$f" ] || continue + + if patchelf --print-needed "$f" 2>/dev/null | grep -Fxq "$OLD"; then + echo "Patching $f" + patchelf --replace-needed "$OLD" "$NEW" "$f" + fi +done +echo +echo "Verification (should show no $OLD):" +for f in *.so*; do + [ -f "$f" ] || continue + if patchelf --print-needed "$f" 2>/dev/null | grep -Fxq "$OLD"; then + echo "STILL NEEDS $OLD -> $f" + fi +done +cd - diff --git a/scripts/scaffold-matrix.job b/scripts/scaffold-matrix.job index 65f2e85..d194aa8 100644 --- a/scripts/scaffold-matrix.job +++ b/scripts/scaffold-matrix.job @@ -8,8 +8,6 @@ #SBATCH -A fractale #SBATCH -perl -ml cuda/12.6.0 gcc/12.1.1 mvapich2/2.3.7 - . .venvs/scaffoldvenv-matrix/bin/activate export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH diff --git a/scripts/scaffold-tuolumne-torchpypi.job b/scripts/scaffold-tuolumne-torchpypi.job new file mode 100644 index 0000000..0387e5e --- /dev/null +++ b/scripts/scaffold-tuolumne-torchpypi.job @@ -0,0 +1,31 @@ +#!/bin/bash + +# flux: --exclusive +# flux: -N 1 +# flux: -g=1 +# flux: -t 60m +# flux: -qpdebug +# flux: -B flask + +ml cce/21.0.1 cray-mpich/9.1.0 rocm/7.2.1 rccl/fast-env-slows-mpi + +. .venvs/scaffoldvenv-tuo-pypi/bin/activate + +export NCCL_NET_PLUGIN=/collab/usr/global/tools/rccl/toss_4_x86_64_ib_cray/rocm-7.2.0/install/lib/librccl-net.so + +# Disable direct convolution benchmarking (should speedup warmup by a significant amount, does the below three options together) +# export MIOPEN_DEBUG_CONV_DIRECT=0 +# Disable direct naive convolution benchmarking (naive_conv_ab_nonpacked_fwd_ndhwc_half_double_half.kd) +export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD=0 +# Disable naive_conv_ab_nonpacked_bwd_ndhwc_half_double_half.kd +export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD=0 +# Disable naive_conv_ab_nonpacked_wrw_ndhwc_half_double_half.kd +export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_WRW=0 + +torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/configs/benchmark_default.yml + +# Uncomment if you want torch profiling +#export PROFILE_TORCH=ON + +torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml +# torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml diff --git a/scripts/scaffold-tuolumne.job b/scripts/scaffold-tuolumne.job index 4eb3715..a22d8c6 100644 --- a/scripts/scaffold-tuolumne.job +++ b/scripts/scaffold-tuolumne.job @@ -5,17 +5,25 @@ # flux: -g=1 # flux: -t 60m # flux: -qpdebug -# flux: -B fractale +# flux: -B flask -ml rocm/6.4.2 rccl/fast-env-slows-mpi +ml cce/21.0.0 cray-mpich/9.1.0 rocm/7.1.1 rccl/fast-env-slows-mpi . .venvs/scaffoldvenv-tuo/bin/activate -# Avoid spindle error -export SPINDLE_FLUXOPT=off +# (1) Avoid libmagma error +# (2) Removing libmpi may cause segfault on mpi4py import +# (3-5) undefined symbol: cblas_gemm_f16f16f32 +export LD_PRELOAD="/opt/rocm-7.1.1/llvm/lib/libomp.so /opt/cray/pe/mpich/9.1.0/ofi/gnu/11.2/lib/libmpi_gnu.so.12 /opt/intel/oneapi/mkl/2024.2/lib/libmkl_core.so.2 /opt/intel/oneapi/mkl/2024.2/lib/libmkl_gnu_thread.so.2 /opt/intel/oneapi/mkl/2024.2/lib/libmkl_intel_lp64.so.2" -# Avoid libmagma error -export LD_PRELOAD=/opt/rocm-6.4.2/llvm/lib/libomp.so +# Disable direct convolution benchmarking (should speedup warmup by a significant amount, does the below three options together) +# export MIOPEN_DEBUG_CONV_DIRECT=0 +# Disable direct naive convolution benchmarking (naive_conv_ab_nonpacked_fwd_ndhwc_half_double_half.kd) +export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD=0 +# Disable naive_conv_ab_nonpacked_bwd_ndhwc_half_double_half.kd +export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD=0 +# Disable naive_conv_ab_nonpacked_wrw_ndhwc_half_double_half.kd +export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_WRW=0 torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/configs/benchmark_default.yml @@ -23,4 +31,4 @@ torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/co #export PROFILE_TORCH=ON torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml -#torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml +# torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml