From 51765cf1c8528e08356eeaa78057a3fd37a2d99a Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 11 Jun 2026 08:52:52 -0700 Subject: [PATCH 1/9] remove unneeded extra config --- ScaFFold/configs/benchmark_testing.yml | 39 -------------------------- 1 file changed, 39 deletions(-) delete mode 100644 ScaFFold/configs/benchmark_testing.yml diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml deleted file mode 100644 index 5167de1..0000000 --- a/ScaFFold/configs/benchmark_testing.yml +++ /dev/null @@ -1,39 +0,0 @@ -# External/user-facing -base_run_dir: "benchmark_runs" # Subfolder of $(pwd) in which to run jobs. -dataset_dir: "datasets" # Directory in which to store and query for datasets. -fract_base_dir: "fractals" # Base directory for fractal IFS and instances. -n_categories: 5 # Number of fractal categories present in the dataset. -n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances. -problem_scale: 6 # Determines dataset resolution and number of unet layers. Default is 6. -unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. -seed: 42 # Random seed. -batch_size: 1 # Batch sizes for each vol size. -dataloader_num_workers: 4 # Number of DataLoader worker processes per rank. -optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum -shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard -checkpoint_interval: -1 # Checkpoint every C epochs; set to -1 to disable checkpointing entirely. - -# Internal/dev use only -variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. -n_fracts_per_vol: 3 # Number of fractals overlaid in each volume. Default is 3. -val_split: 25 # In percent. -epochs: 10 # Number of training epochs. -starting_learning_rate: .0001 # Initial learning rate for training. -min_learning_rate: .0001 # Minimum learning rate for CosineAnnealingWarmRestarts. -T_0: 10 # Epochs in the first cosine restart cycle. -T_mult: 1 # Restart cycle growth factor. -disable_scheduler: 1 # If 1, disable scheduler during training to use constant LR. -more_determinism: 0 # If 1, improve model training determinism. -datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. -train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off. -dist: 1 # If 1, use torch DDP. -torch_amp: 1 # If 1, use mixed precision in training. -framework: "torch" # The DL framework to train with. Only valid option for now is "torch". -checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. -loss_freq: 1 # Number of epochs between logging the overall loss. -normalize: 1 # Cateogry search normalization parameter -warmup_batches: 5 # How many warmup batches per rank to run before training. -ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights. -dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. -target_dice: 0.95 From 2834d5178631487a683247f223e23118394e1e00 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 11 Jun 2026 08:53:54 -0700 Subject: [PATCH 2/9] get datagen batch size from config rather than hard coding --- ScaFFold/datagen/category_search.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ScaFFold/datagen/category_search.py b/ScaFFold/datagen/category_search.py index a7dbc7a..fefe993 100644 --- a/ScaFFold/datagen/category_search.py +++ b/ScaFFold/datagen/category_search.py @@ -186,7 +186,9 @@ def main(config: Config) -> None: rank = comm.Get_rank() size = comm.Get_size() - datagen_batch_size = 10000 + datagen_batch_size = int(getattr(config, "datagen_batch_size", 10000)) + if datagen_batch_size <= 0: + raise ValueError("datagen_batch_size must be positive") # FIXME anything else to ensure determinism? np.random.seed(config.seed + rank) From 184a0bb682e2b86be71e6862f867bc102eb77ba0 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 11 Jun 2026 08:54:50 -0700 Subject: [PATCH 3/9] CLI updates: add dc-shard-dims override, lazy imports where possible --- ScaFFold/cli.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 469c71e..6ff6215 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -21,11 +21,9 @@ import yaml from mpi4py import MPI -from ScaFFold import benchmark, generate_fractals from ScaFFold.utils import config_utils from ScaFFold.utils.collect_scheduler_info import collect_scheduler_metadata from ScaFFold.utils.create_restart_script import create_restart_script -from ScaFFold.utils.utils import customlog def main(): @@ -55,7 +53,7 @@ def main(): generate_fractals_parser = subparsers.add_parser( "generate_fractals", help="Generate fractal classes and instances.", - description="Must be ran before 'benchmark'", + description="Must be run before 'benchmark'", ) generate_fractals_parser.add_argument( "-c", @@ -143,7 +141,7 @@ def main(): ) benchmark_parser.add_argument("--seed", type=int, help="Random seed.") benchmark_parser.add_argument( - "--batch-size", type=int, help="Batch sizes for each volume size." + "--batch-size", type=int, help="Batch size per data-parallel rank." ) benchmark_parser.add_argument( "--warmup-batches", @@ -177,6 +175,12 @@ def main(): nargs=3, help="DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum", ) + benchmark_parser.add_argument( + "--dc-shard-dims", + type=int, + nargs=3, + help="DistConv param: tensor dimensions to shard.", + ) benchmark_parser.add_argument( "--epochs", type=int, @@ -272,7 +276,7 @@ def main(): f"{combined_config.get('job_name')}_%Y%m%d-%H%M%S" ) benchmark_run_dir = base_run_dir / timestamp - customlog( + print( f"benchmark_run_dir created at path {Path.resolve(benchmark_run_dir)}" ) @@ -301,8 +305,12 @@ def main(): print(f"combined_config = {combined_config}") if args.command == "benchmark": + from ScaFFold import benchmark + benchmark.main(kwargs_dict=combined_config) elif args.command == "generate_fractals": + from ScaFFold import generate_fractals + generate_fractals.main(kwargs_dict=combined_config) else: raise ValueError( From b7ddca5185052085f146893c96588f22ffef6168 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 11 Jun 2026 09:14:10 -0700 Subject: [PATCH 4/9] fix typo and cleanup deprecated function --- ScaFFold/utils/distributed.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/ScaFFold/utils/distributed.py b/ScaFFold/utils/distributed.py index 5e1f92e..524fafe 100644 --- a/ScaFFold/utils/distributed.py +++ b/ScaFFold/utils/distributed.py @@ -62,7 +62,7 @@ def get_local_size(required: bool = False) -> int: def get_world_rank(required: bool = False) -> int: - """Return the global MPI rank..""" + """Return the global MPI rank.""" if "MV2_COMM_WORLD_RANK" in os.environ: return int(os.environ["MV2_COMM_WORLD_RANK"]) if "OMPI_COMM_WORLD_RANK" in os.environ: @@ -91,16 +91,6 @@ def get_world_size(required: bool = False) -> int: return 1 -def force_cuda_visible_devices(force: bool = False) -> None: - """Set CUDA_VISIBLE_DEVICES. - - This seems to help avoid PyTorch or something else from touching - other GPUs. - - """ - print("force_cuda_visible_devices is deprecated. Skipping...") - - def get_device() -> torch.device: if torch.cuda.is_available(): torch.cuda.init() From 970491a31ba3b759dd10000a4d86b4c12c275876 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 11 Jun 2026 09:16:56 -0700 Subject: [PATCH 5/9] add psutil as explicit proj dep --- pyproject.toml | 1 + requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 70eb0d5..8e85358 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "tqdm>=4.67.1", "wandb>=0.19.6", "PyYAML>=6.0.2", + "psutil>=5.9.0", "distconv @ git+https://github.com/LBANN/DistConv.git@232cba6", ] requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index e0165fb..d76015d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,6 @@ numba>=0.60.0 tqdm>=4.67.1 wandb>=0.19.6 PyYAML>=6.0.2 +psutil>=5.9.0 mpi4py==4.1.1 --no-binary mpi4py distconv @ git+https://github.com/LBANN/DistConv.git@232cba6 From 71ec053b4ad3ac99d5033aef0a9e61cbbba0d7b1 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 11 Jun 2026 09:17:18 -0700 Subject: [PATCH 6/9] remove option for running in non-distributed mode --- README.md | 17 +++-- ScaFFold/benchmark.py | 2 +- ScaFFold/configs/benchmark_default.yml | 3 +- ScaFFold/utils/config_utils.py | 9 ++- ScaFFold/utils/trainer.py | 59 +++++++--------- ScaFFold/worker.py | 98 ++++++++++---------------- 6 files changed, 83 insertions(+), 105 deletions(-) diff --git a/README.md b/README.md index 9438b1d..f9bd4e2 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,8 @@ The model is trained from a random initialization until convergence, which is de 1. Once fractal generation completes, run the benchmark: `torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c ScaFFold/configs/benchmark_default.yml` +ScaFFold benchmark training always uses PyTorch distributed execution with DistConv spatial parallelism. For a singleton run, launch one distributed rank rather than disabling distributed execution. + `benchmark` creates a folder for the benchmark run(s) at `base_run_dir` set in the config file. For reproducibility, we store a copy of the benchmark run config yml. Within each run subfolder, `benchmark` creates a yml config for that specific run. After each run completes, statistics from the run are stored in `train_stats.csv`. Additionally, users can inspect plots of the training and validation losses over time in ` bottleneck layer of size 8. seed: 42 # Random seed. -batch_size: 1 # Batch sizes for each vol size. -optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. +batch_size: 1 # Batch size per data-parallel rank. +dataloader_num_workers: 1 # Number of DataLoader worker processes per rank. +optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defaults to RMSProp. +dc_num_shards: [1, 1, 1] # DistConv spatial shard counts. +dc_shard_dims: [2, 3, 4] # Tensor dimensions sharded by DistConv. +checkpoint_interval: -1 # Checkpoint every C epochs; set to -1 to disable checkpointing entirely. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. @@ -91,12 +98,12 @@ disable_scheduler: 1 # If 1, disable scheduler during training to more_determinism: 0 # If 1, improve model training determinism. datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off. -dist: 1 # If 1, use torch DDP. torch_amp: 1 # If 1, use mixed precision in training. framework: "torch" # The DL framework to train with. Only valid option for now is "torch". checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. -checkpoint_interval: 1 # Number of epochs between saving training checkpoints. loss_freq: 1 # Number of epochs between logging the overall loss. +warmup_batches: 64 # How many warmup batches per rank to run before training. +target_dice: 0.95 # Validation Dice score threshold for convergence when epochs is -1. ``` ## How the benchmark works diff --git a/ScaFFold/benchmark.py b/ScaFFold/benchmark.py index d0f6ad8..fb0f7e2 100644 --- a/ScaFFold/benchmark.py +++ b/ScaFFold/benchmark.py @@ -60,7 +60,7 @@ def main(kwargs_dict: dict = {}): # Get MPI information comm = MPI.COMM_WORLD - rank = get_world_rank(required=args.dist) + rank = get_world_rank(required=True) if rank == 0: print(f"args found: {args}") diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 1b0310c..0eb613d 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -9,7 +9,7 @@ unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dim seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size per rank. dataloader_num_workers: 1 # Number of DataLoader worker processes per rank. More workers will use more memory -optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. +optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defaults to RMSProp. dc_num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard checkpoint_interval: -1 # Checkpoint every C epochs; set to -1 to disable checkpointing entirely. @@ -27,7 +27,6 @@ disable_scheduler: 0 # If 1, disable scheduler during training to more_determinism: 0 # If 1, improve model training determinism. datagen_from_scratch: 0 # If 1, delete existing fractals and instances, then regenerate from scratch. train_from_scratch: 1 # If 1, delete existing train stats and checkpoint files. Keep 0 if want to restart runs where we left off. -dist: 1 # If 1, use torch DDP. torch_amp: 1 # If 1, use mixed precision in training. framework: "torch" # The DL framework to train with. Only valid option for now is "torch". checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 36f1603..8f20859 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -62,7 +62,12 @@ def __init__(self, config_dict): self.train_from_scratch = bool(config_dict["train_from_scratch"]) self.val_split = config_dict["val_split"] self.seed = config_dict["seed"] - self.dist = bool(config_dict["dist"]) + if "dist" in config_dict and not bool(config_dict["dist"]): + raise ValueError( + "The 'dist: 0' mode is no longer supported. ScaFFold benchmark " + "training always runs with distributed execution; use a one-rank " + "torchrun-hpc job for singleton runs." + ) self.framework = config_dict["framework"] self.starting_learning_rate = config_dict["starting_learning_rate"] self.min_learning_rate = config_dict["min_learning_rate"] @@ -120,5 +125,5 @@ def load_config(file_path: str, config_type: str): return RunConfig(config_dict) else: raise ValueError( - f"Invalid config type specified: {type}. Must be either 'sweep' or 'run'" + f"Invalid config type specified: {config_type}. Must be either 'sweep' or 'run'" ) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index a1f77f5..0e50044 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -57,9 +57,9 @@ def __init__(self, model, config, device, log): self.amp_device_type = self.device.type if self.device.type != "mps" else "cpu" self.amp_dtype = AMP_DTYPE self.use_grad_scaler = False - self.world_size = get_world_size(required=self.config.dist) - self.world_rank = get_world_rank(required=self.config.dist) - self.local_rank = get_local_rank(required=self.config.dist) + self.world_size = get_world_size(required=True) + self.world_rank = get_world_rank(required=True) + self.local_rank = get_local_rank(required=True) # Initialize placeholders for attributes that will be set up later self.train_set = None @@ -139,21 +139,17 @@ def create_dataset(self): def create_sampler(self): """Create DistributedSamplers for train and validation datasets.""" - if self.config.dist: - self.train_sampler = torch.utils.data.distributed.DistributedSampler( - self.train_set, - num_replicas=self.data_num_replicas, - rank=self.data_replica_rank, - ) - self.val_sampler = torch.utils.data.distributed.DistributedSampler( - self.val_set, - num_replicas=self.data_num_replicas, - rank=self.data_replica_rank, - shuffle=False, - ) - else: - self.train_sampler = torch.utils.data.RandomSampler(self.train_set) - self.val_sampler = torch.utils.data.SequentialSampler(self.val_set) + self.train_sampler = torch.utils.data.distributed.DistributedSampler( + self.train_set, + num_replicas=self.data_num_replicas, + rank=self.data_replica_rank, + ) + self.val_sampler = torch.utils.data.distributed.DistributedSampler( + self.val_set, + num_replicas=self.data_num_replicas, + rank=self.data_replica_rank, + shuffle=False, + ) def create_dataloaders(self): """Create dataloaders for training and validation.""" @@ -230,7 +226,7 @@ def setup_training_components(self): n_categories=self.config.n_categories, device=self.device, sample_fraction=self.config.ce_weight_sample_fraction, - dist_enabled=self.config.dist, + dist_enabled=True, world_rank=self.world_rank, log=self.log, ) @@ -286,7 +282,7 @@ def __init__(self, model, config, device, log): base_dir=self.checkpoint_path_absolute, log=self.log, world_rank=self.world_rank, - dist_enabled=self.config.dist, + dist_enabled=True, # Check config for async setting, default to False async_save=getattr(self.config, "async_save", False), ) @@ -549,8 +545,7 @@ def warmup(self): if warmup_batches <= 0: return - if self.config.dist: - self.train_loader.sampler.set_epoch(0) + self.train_loader.sampler.set_epoch(0) start_warmup = time.time() max_batches = min(warmup_batches, len(self.train_loader)) @@ -580,8 +575,7 @@ def warmup(self): f" warmup: batch {batch_idx} completed in {batch_t_end - start_warmup} seconds" ) - if self.config.dist: - self.val_loader.sampler.set_epoch(0) + self.val_loader.sampler.set_epoch(0) if max_val_batches > 0: self.log.debug(" warmup: running validation warmup pass") @@ -599,8 +593,7 @@ def warmup(self): finally: self.checkpoint_manager.restore_training_state(snapshot) - if self.config.dist: - torch.distributed.barrier() + torch.distributed.barrier() self.log.info(f"Done warmup. Took {int(time.time() - start_warmup)}s") def train(self): @@ -625,9 +618,8 @@ def train(self): epoch_loss = 0 # Accumulator for per-batch losses # Set necessary modes/states - if self.config.dist: - self.train_loader.sampler.set_epoch(epoch) - self.val_loader.sampler.set_epoch(epoch) + self.train_loader.sampler.set_epoch(epoch) + self.val_loader.sampler.set_epoch(epoch) self.model.train() self.optimizer.zero_grad(set_to_none=False) @@ -697,11 +689,10 @@ def train(self): dice_info = torch.tensor( [dice_sum, numsamples], dtype=VOLUME_TORCH_DTYPE ) - if self.config.dist: - dice_info = dice_info.to(device=self.device) - torch.distributed.all_reduce( - dice_info, op=torch.distributed.ReduceOp.SUM - ) + dice_info = dice_info.to(device=self.device) + torch.distributed.all_reduce( + dice_info, op=torch.distributed.ReduceOp.SUM + ) val_score = dice_info[0].item() / max(dice_info[1].item(), 1) if not self.config.disable_scheduler: self.scheduler.step() diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index fde4582..375e5c0 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -15,7 +15,6 @@ import math import os import socket -import sys import time from argparse import Namespace @@ -23,7 +22,6 @@ import psutil import torch import torch.distributed as dist -import yaml from distconv import DistConvDDP, ParallelStrategy from torch.distributed.tensor import Replicate, Shard @@ -48,11 +46,6 @@ from ScaFFold.utils.utils import set_seeds, setup_mpi_logger from ScaFFold.viz import standard_viz -if hasattr(os, "sched_getaffinity"): - _orig_affinity = os.sched_getaffinity(0) -else: - _orig_affinity = None - def check_resource_utilization(log, rank, world_size): """Check that we are properly utilizing resources""" @@ -84,26 +77,6 @@ def check_resource_utilization(log, rank, world_size): log.debug(f" Device Properties: {torch.cuda.get_device_properties(this_gpu)}") -def override_config(config) -> None: - """Override base run config if additional configs are provided.""" - if "--config" in sys.argv: - config_idx = 1 # Start at 1 to skip the base run config - while True: - try: - config_idx = sys.argv.index("--config", config_idx) + 1 - except ValueError: - break - config_file = sys.argv[config_idx] - if not os.path.isfile(config_file): - raise ValueError(f"Additional config file {config_file} does not exist") - with open(config_file) as f: - override_config = yaml.full_load(f) - for k, v in override_config.items(): - if not hasattr(config, k): - raise ValueError(f"Unknown configuration option {k}={v}") - setattr(config, k, v) - - @annotate() def main(kwargs_dict: dict = {}): # @@ -120,8 +93,8 @@ def main(kwargs_dict: dict = {}): log.debug(f"random seeds set to {config.seed}") # Get MPI information - rank = get_world_rank(required=config.dist) - world_size = get_world_size(required=config.dist) + rank = get_world_rank(required=True) + world_size = get_world_size(required=True) # Optionally enable additional determinism settings if config.more_determinism: @@ -134,14 +107,14 @@ def main(kwargs_dict: dict = {}): # Default torch.backends.cudnn.benchmark = True - # Initialize DDP + # Initialize DDP. ScaFFold always runs the benchmark as a distributed job; + # a one-rank launch is the supported singleton case. begin_code_region("init_ddp") - if config.dist: - if not dist.is_initialized(): - log.info("Initializing distributed process group...") - initialize_dist(rendezvous="env") - else: - log.info("Distributed process group already initialized by launcher.") + if not dist.is_initialized(): + log.info("Initializing distributed process group...") + initialize_dist(rendezvous="env") + else: + log.info("Distributed process group already initialized by launcher.") end_code_region("init_ddp") # More useful info @@ -150,7 +123,7 @@ def main(kwargs_dict: dict = {}): log.debug( f"Backend={dist.get_backend()}, world_size={world_size}, rank={rank}, local_rank={get_local_rank()}" ) - log.info(f"rank={rank}, world_size={world_size} test") + log.info(f"rank={rank}, world_size={world_size}") # Generate or retrieve dataset begin_code_region("get_dataset") @@ -162,8 +135,6 @@ def main(kwargs_dict: dict = {}): # Initialize model begin_code_region("init_model") - config.dc_num_shards = getattr(config, "dc_num_shards", config.dc_num_shards) - config.dc_shard_dims = getattr(config, "dc_shard_dims", config.dc_shard_dims) log.info( f"DistConv num_shards={config.dc_num_shards}, shard_dim={config.dc_shard_dims}" ) @@ -175,29 +146,33 @@ def main(kwargs_dict: dict = {}): trilinear=False, layers=config.unet_layers, ) - if config.dist: - # DDP + DistConv setup - # Ensure world_size is divisible by total distconv shards - assert dist.get_world_size() % math.prod(config.dc_num_shards) == 0, ( - f"world_size={dist.get_world_size()} must be divisible by total number of distconv shards = {math.prod(config.dc_num_shards)}" + # DDP + DistConv setup + # Ensure world_size is divisible by total distconv shards + total_distconv_shards = math.prod(config.dc_num_shards) + if world_size % total_distconv_shards != 0: + raise ValueError( + f"world_size={world_size} must be divisible by total number of " + f"distconv shards={total_distconv_shards}" ) - ps = ParallelStrategy( - num_shards=config.dc_num_shards, - shard_dim=config.dc_shard_dims, - device_type=device.type, - ) + ps = ParallelStrategy( + num_shards=config.dc_num_shards, + shard_dim=config.dc_shard_dims, + device_type=device.type, + ) - model = model.to(device, memory_format=torch.channels_last_3d) - # Wrap with DistConvDDP that corrects gradient scaling for dc submesh - model = DistConvDDP( - model, - parallel_strategy=ps, - device_ids=[get_local_rank()], - output_device=get_local_rank(), - ) - # Store ps for use in the training loop - config._parallel_strategy = ps + model = model.to(device, memory_format=torch.channels_last_3d) + ddp_device_ids = [device.index] if device.type == "cuda" else None + ddp_output_device = device.index if device.type == "cuda" else None + # Wrap with DistConvDDP that corrects gradient scaling for dc submesh + model = DistConvDDP( + model, + parallel_strategy=ps, + device_ids=ddp_device_ids, + output_device=ddp_output_device, + ) + # Store ps for use in the training loop + config._parallel_strategy = ps end_code_region("init_model") check_resource_utilization(log, rank, world_size) @@ -301,7 +276,8 @@ def main(kwargs_dict: dict = {}): standard_viz.main(config) end_code_region("generate_figures") - dist.barrier() - dist.destroy_process_group() + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() return 0 From 5d78c0e2383c85d0c4035c7f670f3ca5a7757cd2 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 11 Jun 2026 09:37:38 -0700 Subject: [PATCH 7/9] Remove deprecated 0-fractal-class use case --- README.md | 2 +- ScaFFold/cli.py | 3 +++ ScaFFold/configs/benchmark_default.yml | 2 +- ScaFFold/datagen/volumegen.py | 2 +- ScaFFold/utils/config_utils.py | 10 +++++++- ScaFFold/utils/evaluate.py | 5 +--- ScaFFold/utils/trainer.py | 34 ++++++++++---------------- 7 files changed, 29 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index f9bd4e2..0652b93 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ Parameters are set in a `.yml` config file and can be modified by the user: base_run_dir: "benchmark_runs" # Subfolder of $(pwd) in which to run jobs. dataset_dir: "datasets" # Directory in which to store and query generated datasets. fract_base_dir: "fractals" # Base directory for fractal IFS and instances. -n_categories: 5 # Number of fractal categories present in the dataset. +n_categories: 5 # Positive number of fractal categories present in the dataset. n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances. problem_scale: 7 # Determines dataset resolution and number of unet layers. unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 6ff6215..1984ba2 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -237,6 +237,9 @@ def main(): combined_config["unet_layers"] = ( combined_config["problem_scale"] - combined_config["unet_bottleneck_dim"] ) + config_utils.require_positive_int( + "n_categories", combined_config["n_categories"] + ) # Resolve paths to absolute, matching Config() behavior if "base_run_dir" in combined_config and combined_config["base_run_dir"]: diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 0eb613d..ec48cf8 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -2,7 +2,7 @@ base_run_dir: "benchmark_runs" # Subfolder of $(pwd) in which to run jobs. dataset_dir: "datasets" # Directory in which to store and query for datasets. fract_base_dir: "fractals" # Base directory for fractal IFS and instances. -n_categories: 5 # Number of fractal categories present in the dataset. +n_categories: 5 # Positive number of fractal categories present in the dataset. n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances. problem_scale: 7 # Determines dataset resolution and number of unet layers. Default is 6. unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8. diff --git a/ScaFFold/datagen/volumegen.py b/ScaFFold/datagen/volumegen.py index b268aa7..866e0ce 100644 --- a/ScaFFold/datagen/volumegen.py +++ b/ScaFFold/datagen/volumegen.py @@ -162,7 +162,7 @@ def main(config: Dict): ) np.random.seed(config.seed) - fractal_colors = np.random.rand(max(config.n_categories, n_fracts_per_vol), 3) + fractal_colors = np.random.rand(config.n_categories, 3) grid_size = math.floor(config.vol_size * config.scale) fract_base_dir = str(config.fract_base_dir) diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 8f20859..e32af3e 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -21,6 +21,12 @@ import ScaFFold.paths +def require_positive_int(name: str, value: int) -> int: + if not isinstance(value, int) or isinstance(value, bool) or value < 1: + raise ValueError(f"{name} must be a positive integer") + return value + + class Config: """ A class for storing configuration settings for a specific run. @@ -36,7 +42,9 @@ def __init__(self, config_dict): Path(config_dict.get("fract_base_dir", "fractals/")).resolve() ) self.job_name = config_dict.get("job_name", "benchmark") - self.n_categories = config_dict["n_categories"] + self.n_categories = require_positive_int( + "n_categories", config_dict["n_categories"] + ) self.problem_scale = config_dict["problem_scale"] try: assert isinstance(self.problem_scale, int), ( diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index 6b23b15..bf4d6ab 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -37,10 +37,7 @@ def evaluate( max_batches=None, ): def foreground_dice_stats(dice_scores): - if dice_scores.size(1) > 1: - per_sample_scores = dice_scores[:, 1:].mean(dim=1) - else: - per_sample_scores = dice_scores.mean(dim=1) + per_sample_scores = dice_scores[:, 1:].mean(dim=1) return per_sample_scores.sum().item(), per_sample_scores.numel() net.eval() diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 0e50044..c9c03f9 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -219,24 +219,18 @@ def setup_training_components(self): self.grad_scaler = torch.amp.GradScaler("cuda", enabled=self.use_grad_scaler) # Set up loss function - if self.config.n_categories + 1 > 1: - ce_class_weights = _compute_ce_class_weights( - train_set=self.train_set, - n_train=self.n_train, - n_categories=self.config.n_categories, - device=self.device, - sample_fraction=self.config.ce_weight_sample_fraction, - dist_enabled=True, - world_rank=self.world_rank, - log=self.log, - ) - self.criterion = nn.CrossEntropyLoss(weight=ce_class_weights).to( - self.device - ) - else: - self.criterion = nn.BCEWithLogitsLoss().to(self.device) - if isinstance(self.criterion, nn.CrossEntropyLoss): - self.ce_class_weights = self.criterion.weight + ce_class_weights = _compute_ce_class_weights( + train_set=self.train_set, + n_train=self.n_train, + n_categories=self.config.n_categories, + device=self.device, + sample_fraction=self.config.ce_weight_sample_fraction, + dist_enabled=True, + world_rank=self.world_rank, + log=self.log, + ) + self.criterion = nn.CrossEntropyLoss(weight=ce_class_weights).to(self.device) + self.ce_class_weights = self.criterion.weight self.log.info( f"Optimizer: {self.optimizer}, Scheduler: {self.scheduler}, AMP dtype: {self.amp_dtype}, Gradient Scaler Enabled: {self.use_grad_scaler}" @@ -254,9 +248,7 @@ def _autocast_kwargs(self, enabled=None): @staticmethod def _foreground_dice_mean(dice_scores): """Match optimization to the reported validation metric by excluding background.""" - if dice_scores.size(1) > 1: - return dice_scores[:, 1:].mean() - return dice_scores.mean() + return dice_scores[:, 1:].mean() def _current_learning_rate(self): if self.optimizer is None or not self.optimizer.param_groups: From a22b9bef91bd120255ef13c298533f3afccb35e6 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 11 Jun 2026 09:57:52 -0700 Subject: [PATCH 8/9] standardize on MPI logger, clean up benchmark.py, better volgen error and edge case handling; lazier imports --- ScaFFold/benchmark.py | 39 ++--------------- ScaFFold/cli.py | 29 ++++++++----- ScaFFold/datagen/category_search.py | 67 +++++++++++++++++++++-------- ScaFFold/datagen/get_dataset.py | 39 +++++++++-------- ScaFFold/datagen/instance.py | 19 +++++--- ScaFFold/datagen/volumegen.py | 56 +++++++++++++----------- ScaFFold/generate_fractals.py | 10 ++--- ScaFFold/utils/checkpointing.py | 23 +++++++--- ScaFFold/utils/distributed.py | 23 +++++++--- ScaFFold/utils/evaluate.py | 21 ++++++--- ScaFFold/utils/trainer.py | 17 ++++++-- ScaFFold/utils/utils.py | 52 ++++++++++++++-------- ScaFFold/worker.py | 6 +-- 13 files changed, 242 insertions(+), 159 deletions(-) diff --git a/ScaFFold/benchmark.py b/ScaFFold/benchmark.py index fb0f7e2..307a5f8 100644 --- a/ScaFFold/benchmark.py +++ b/ScaFFold/benchmark.py @@ -12,57 +12,26 @@ # # SPDX-License-Identifier: (Apache-2.0) -import logging import shutil from argparse import Namespace from pathlib import Path, PosixPath -import yaml from mpi4py import MPI from ScaFFold import worker from ScaFFold.utils.distributed import get_world_rank from ScaFFold.utils.perf_measure import adiak_init, adiak_value - - -def create_run_directory(base_dir, combination_index, num_runs): - """ - Create new directory for current run, named using unique combination_index - """ - run_dir = base_dir / f"param_set_{combination_index}" - for i in range(num_runs): - run_dir_with_iter = Path(f"{run_dir}/run{i}") - run_dir_with_iter.mkdir(parents=True, exist_ok=True) - return run_dir - - -def write_run_config(run_dir, iter, keys, combination): - """ - Write run config to a yaml file, and create optional override yaml - """ - run_config = {key: value for key, value in zip(keys, combination)} - run_config["run_dir"] = str( - run_dir.resolve() - ) # Add abs path to run dir as entry in dict - run_config["run_iter"] = iter # Add run_iter identifier as entry in dict - run_config_path = run_dir / "run_config.yaml" - with open(run_config_path, "w") as file: - yaml.dump(run_config, file) - return run_config_path +from ScaFFold.utils.utils import setup_mpi_logger def main(kwargs_dict: dict = {}): args = Namespace(**kwargs_dict) - - logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" - ) + log = setup_mpi_logger(__file__, args.verbose) # Get MPI information comm = MPI.COMM_WORLD rank = get_world_rank(required=True) - if rank == 0: - print(f"args found: {args}") + log.debug("args found: %s", args) kdict = None # Now set up and start benchmark run(s) @@ -88,7 +57,7 @@ def main(kwargs_dict: dict = {}): adiak_init(comm) for key, value in kdict.items(): if isinstance(value, dict): - print(f"Adiak: skipping key with dict value '{key}'") + log.debug("Adiak: skipping key with dict value '%s'", key) continue if isinstance(value, PosixPath): value = str(value) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 1984ba2..06c7323 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -24,6 +24,7 @@ from ScaFFold.utils import config_utils from ScaFFold.utils.collect_scheduler_info import collect_scheduler_metadata from ScaFFold.utils.create_restart_script import create_restart_script +from ScaFFold.utils.utils import setup_mpi_logger def main(): @@ -94,9 +95,9 @@ def main(): "benchmark", help="Run the benchmark.", description=( - "The default run method for ScaFFold." - "Users may specify lists of run parameters in the config file." - "This subcommand runs one instance of the benchmark for each parameter combination." + "The default run method for ScaFFold. " + "Users may specify lists of run parameters in the config file. " + "This subcommand runs one instance of the benchmark for each parameter combination. " "Requires path to config file." ), ) @@ -213,10 +214,11 @@ def main(): rank = comm.Get_rank() # Parse the command-line arguments. args = parser.parse_args() + log = setup_mpi_logger(__file__, args.verbose) combined_config = None if rank == 0: - print(f"args = {args}") + log.debug("args = %s", args) bench_config = config_utils.load_config(Path(args.config), "sweep") bench_config_dict = ( @@ -230,7 +232,13 @@ def main(): if key not in combined_config: combined_config[key] = value elif value is not None and key != "command": - print(f"Overriding '{key}={combined_config[key]}' with '{key}={value}'") + log.info( + "Overriding '%s=%s' with '%s=%s'", + key, + combined_config[key], + key, + value, + ) combined_config[key] = value # Recalculate unet_layers to capture any CLI overrides @@ -263,13 +271,13 @@ def main(): # Handle Restart / Resume logic if hasattr(args, "restart") and args.restart: - print("Restart flag detected: Forcing train_from_scratch = False") + log.info("Restart flag detected: forcing train_from_scratch = False") combined_config["train_from_scratch"] = False combined_config["restart"] = True # If user manually supplied --run-dir (via restart script), use it. if hasattr(args, "run_dir") and args.run_dir is not None: - print(f"Resuming in existing directory: {args.run_dir}") + log.info("Resuming in existing directory: %s", args.run_dir) benchmark_run_dir = Path(args.run_dir) # Ensure we don't accidentally wipe checkpoints even if --restart wasn't explicitly passed combined_config["train_from_scratch"] = False @@ -279,8 +287,9 @@ def main(): f"{combined_config.get('job_name')}_%Y%m%d-%H%M%S" ) benchmark_run_dir = base_run_dir / timestamp - print( - f"benchmark_run_dir created at path {Path.resolve(benchmark_run_dir)}" + log.info( + "benchmark_run_dir created at path %s", + Path.resolve(benchmark_run_dir), ) combined_config["benchmark_run_dir"] = str(benchmark_run_dir) @@ -305,7 +314,7 @@ def main(): comm.Barrier() combined_config = comm.bcast(combined_config, root=0) if rank == 0: - print(f"combined_config = {combined_config}") + log.debug("combined_config = %s", combined_config) if args.command == "benchmark": from ScaFFold import benchmark diff --git a/ScaFFold/datagen/category_search.py b/ScaFFold/datagen/category_search.py index fefe993..27c042e 100644 --- a/ScaFFold/datagen/category_search.py +++ b/ScaFFold/datagen/category_search.py @@ -27,6 +27,7 @@ from ScaFFold.datagen.generate_fractal_points import generate_fractal_points from ScaFFold.utils.config_utils import Config +from ScaFFold.utils.utils import setup_mpi_logger DEFAULT_NP_DTYPE = np.float64 @@ -185,6 +186,7 @@ def main(config: Config) -> None: comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() + log = setup_mpi_logger(__file__, getattr(config, "verbose", 0)) datagen_batch_size = int(getattr(config, "datagen_batch_size", 10000)) if datagen_batch_size <= 0: @@ -193,8 +195,7 @@ def main(config: Config) -> None: # FIXME anything else to ensure determinism? np.random.seed(config.seed + rank) - if rank == 0: - print(f"MPI size = {size}") + log.info("MPI size = %s", size) # Setup directories fracts_sub_dir = f"var{config.variance_threshold}" @@ -202,9 +203,9 @@ def main(config: Config) -> None: config.fract_base_dir, fracts_sub_dir, "3DIFS_param" ) if rank == 0: - print(f"Writing fractals to {fracts_write_dir}") + log.info("Writing fractals to %s", fracts_write_dir) if os.path.exists(fracts_write_dir) and config.datagen_from_scratch: - print("Removing existing fractals dir") + log.info("Removing existing fractals directory") shutil.rmtree(fracts_write_dir) os.makedirs(fracts_write_dir, exist_ok=True) @@ -215,8 +216,12 @@ def main(config: Config) -> None: existing_categories = len(glob.glob(f"{fracts_write_dir}/*.csv")) categories_remaining = config.n_categories - existing_categories if rank == 0: - print( - f"category_search found {existing_categories} existing fractal categories | {config.n_categories} needed | {max(0, categories_remaining)} remaining" + log.info( + "category_search found %s existing fractal categories | %s needed | " + "%s remaining", + existing_categories, + config.n_categories, + max(0, categories_remaining), ) rank_start_time = time.time() @@ -248,11 +253,17 @@ def main(config: Config) -> None: if rank == 0: params_valid = [item for sublist in gathered_params for item in sublist] if attempts % 10000 * size / datagen_batch_size == 0: - print( - f"cat_remaining = {categories_remaining} | total attempts = {attempts} | stats for rank 0: nan_fail_count = {nan_fail_count}, var_fail_count = {var_fail_count}, runaway_fail_count = {runaway_fail_count}" + log.info( + "cat_remaining = %s | total attempts = %s | stats for rank 0: " + "nan_fail_count = %s, var_fail_count = %s, runaway_fail_count = %s", + categories_remaining, + attempts, + nan_fail_count, + var_fail_count, + runaway_fail_count, ) if len(params_valid) > 0: - print(f"Processing {len(params_valid)} param sets from this attempt") + log.info("Processing %s param sets from this attempt", len(params_valid)) for p in params_valid: # Ensure we don't save more categories than needed if categories_remaining > 0: @@ -267,8 +278,9 @@ def main(config: Config) -> None: # Update categories_remaining categories_remaining -= 1 else: - print( - "Generated all fractal categories needed. Ignoring additional found valid categories..." + log.info( + "Generated all fractal categories needed. Ignoring additional " + "valid categories." ) break @@ -286,14 +298,35 @@ def main(config: Config) -> None: global_runaway_fail_count = comm.reduce(runaway_fail_count, op=MPI.SUM, root=0) if rank == 0 and attempts > 0: - print( - f"Generated {config.n_categories - existing_categories} new categories in {attempts * datagen_batch_size} total attempts | {attempts * datagen_batch_size / (config.n_categories - existing_categories)} Attempts per category | Total categories is now {config.n_categories}" + categories_generated = config.n_categories - existing_categories + total_attempts = attempts * datagen_batch_size + log.info( + "Generated %s new categories in %s total attempts | %.2f attempts per " + "category | total categories is now %s", + categories_generated, + total_attempts, + total_attempts / categories_generated, + config.n_categories, ) - print( - f"Failures experienced: {global_nan_fail_count} nan attempts, {100 * global_nan_fail_count / (attempts * datagen_batch_size):.4f}% of all attempts, {global_var_fail_count} var fail attempts, {100 * global_var_fail_count / (attempts * datagen_batch_size):.4f}% of all attempts, {global_runaway_fail_count} runaway attempts, {100 * global_runaway_fail_count / (attempts * datagen_batch_size):.4f}% of all attempts" + log.info( + "Failures experienced: %s nan attempts (%.4f%%), %s variance-fail " + "attempts (%.4f%%), %s runaway attempts (%.4f%%)", + global_nan_fail_count, + 100 * global_nan_fail_count / total_attempts, + global_var_fail_count, + 100 * global_var_fail_count / total_attempts, + global_runaway_fail_count, + 100 * global_runaway_fail_count / total_attempts, ) - print( - f"Rank 0 wall time = {rank_total_time:.2f} | Total CPU time = {global_sum_time:.2f} | Avg wall time per rank {global_sum_time / size:.2f} | {attempts * datagen_batch_size / rank_total_time:.2f} total attempts per wall second | {attempts * datagen_batch_size / rank_total_time / size:.2f} attempts per wall second per rank" + log.info( + "Rank 0 wall time = %.2f | total CPU time = %.2f | avg wall time per " + "rank = %.2f | %.2f total attempts per wall second | %.2f attempts " + "per wall second per rank", + rank_total_time, + global_sum_time, + global_sum_time / size, + total_attempts / rank_total_time, + total_attempts / rank_total_time / size, ) return 0 diff --git a/ScaFFold/datagen/get_dataset.py b/ScaFFold/datagen/get_dataset.py index c7ffaf9..bcea9c2 100644 --- a/ScaFFold/datagen/get_dataset.py +++ b/ScaFFold/datagen/get_dataset.py @@ -27,6 +27,7 @@ from mpi4py import MPI from ScaFFold.datagen import volumegen +from ScaFFold.utils.utils import setup_mpi_logger META_FILENAME = "meta.yaml" DATASET_FORMAT_VERSION = 2 @@ -76,7 +77,7 @@ def _hash_volume_config(volume_config: Dict[str, Any]) -> str: return hashlib.sha256(s).hexdigest()[:12] -def _git_commit_short() -> str: +def _git_commit_short(log=None) -> str: try: return ( subprocess.check_output( @@ -87,14 +88,18 @@ def _git_commit_short() -> str: .strip() ) except subprocess.CalledProcessError: - print( - "Tried to get git commit id in non-git repo. No commit id will be enforced for dataset reuse." - ) + if log is not None: + log.warning( + "Tried to get git commit id in non-git repo. " + "No commit id will be enforced for dataset reuse." + ) return "no-commit-id" except Exception: - print( - "Exception when trying to get git commit for dataset. No commit id will be enforced for dataset reuse." - ) + if log is not None: + log.warning( + "Exception when trying to get git commit for dataset. " + "No commit id will be enforced for dataset reuse." + ) return "no-commit-id" @@ -114,6 +119,7 @@ def get_dataset( comm = MPI.COMM_WORLD rank = comm.Get_rank() + log = setup_mpi_logger(__file__, getattr(config, "verbose", 0)) root = Path(config.dataset_dir) root.mkdir(exist_ok=True) @@ -125,7 +131,7 @@ def get_dataset( config=config_dict, include_keys=INCLUDE_KEYS ) config_id = _hash_volume_config(volume_config) - commit = _git_commit_short() + commit = _git_commit_short(log) base = root / config_id base.mkdir(parents=True, exist_ok=True) @@ -146,13 +152,11 @@ def get_dataset( if require_commit and meta.get("code_commit") != commit: continue # If we pass the above checks, this dataset can be reused - print( - "Valid existing dataset found. Reusing this dataset..." - ) # FIXME replace with updated logging + log.info("Reusing existing dataset at %s", dataset_path) return dataset_path # Otherwise, generate a new dataset - print(f"No valid existing dataset found at {base}. Generating new dataset...") + log.info("No valid existing dataset found at %s. Generating new dataset.", base) if rank == 0: ts = time.strftime("%Y%m%d-%H%M%S") dest = base / f"{ts}__{commit}" @@ -178,18 +182,19 @@ def get_dataset( all_ok = comm.allreduce(1 if ok else 0, op=MPI.MIN) == 1 comm.Barrier() - # rank 0 has file write + move - if rank == 0: - if not all_ok: + errs = comm.gather(err, root=0) if not all_ok else None + if not all_ok: + if rank == 0: try: shutil.rmtree(tmp, ignore_errors=True) except Exception: pass - # collect & raise a representative error - errs = comm.gather(err, root=0) msgs = "; ".join(e for e in errs if e) raise RuntimeError(f"dataset generation failed: {msgs or 'unknown error'}") + raise RuntimeError("dataset generation failed on another rank") + # rank 0 has file write + move + if rank == 0: # Write to tmp, then move, so readers never see half-written dataset meta = { "config_id": config_id, diff --git a/ScaFFold/datagen/instance.py b/ScaFFold/datagen/instance.py index a14c2fb..fb6ebe3 100644 --- a/ScaFFold/datagen/instance.py +++ b/ScaFFold/datagen/instance.py @@ -29,6 +29,7 @@ from ScaFFold.datagen.generate_fractal_points import generate_fractal_points from ScaFFold.utils.config_utils import Config +from ScaFFold.utils.utils import setup_mpi_logger DEFAULT_NP_DTYPE = np.float64 @@ -72,12 +73,12 @@ def main(config: Config): comm = MPI.COMM_WORLD size = comm.Get_size() rank = comm.Get_rank() + log = setup_mpi_logger(__file__, getattr(config, "verbose", 0)) # FIXME anything else to ensure determinism? np.random.seed(config.seed + rank) - if rank == 0: - print(f"MPI size = {size}") + log.info("MPI size = %s", size) # Setup directories fracts_sub_dir = f"var{config.variance_threshold}" @@ -86,11 +87,13 @@ def main(config: Config): config.fract_base_dir, fracts_sub_dir, "instances", f"np{config.point_num}" ) if rank == 0: - print( - f"Generating instances for num_points={config.point_num}, writing to {instance_write_dir}" + log.info( + "Generating instances for num_points=%s, writing to %s", + config.point_num, + instance_write_dir, ) if os.path.exists(instance_write_dir) and config.datagen_from_scratch: - print("Removing existing instances dir") + log.info("Removing existing instances directory") shutil.rmtree(instance_write_dir) os.makedirs(instance_write_dir, exist_ok=True) @@ -184,8 +187,10 @@ def main(config: Config): end_time = time.time() total_time = end_time - start_time if rank == 0: - print( - f"Generated {len(instances_to_generate)} instances in {total_time:.2f} seconds" + log.info( + "Generated %s instances in %.2f seconds", + len(instances_to_generate), + total_time, ) return 0 diff --git a/ScaFFold/datagen/volumegen.py b/ScaFFold/datagen/volumegen.py index 866e0ce..1184e28 100644 --- a/ScaFFold/datagen/volumegen.py +++ b/ScaFFold/datagen/volumegen.py @@ -12,12 +12,10 @@ # # SPDX-License-Identifier: (Apache-2.0) -import logging import math import os import pickle import random -import sys import time from math import ceil from typing import Dict @@ -27,6 +25,7 @@ from ScaFFold.utils.config_utils import Config from ScaFFold.utils.data_types import DEFAULT_NP_DTYPE, MASK_DTYPE, VOLUME_DTYPE +from ScaFFold.utils.utils import setup_mpi_logger def load_np_ptcloud(path: str) -> np.ndarray: @@ -66,12 +65,11 @@ def points_to_voxelgrid( def main(config: Dict): - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - # Initialize MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() + log = setup_mpi_logger(__file__, getattr(config, "verbose", 0)) dataset_dir = str(config.dataset_dir) @@ -96,9 +94,11 @@ def main(config: Dict): # Force n_instances_used_per_fractal to be multiple of n_fracts_per_vol if config.n_instances_used_per_fractal % n_fracts_per_vol != 0: - print( - f"volumegen.py: WARNING: n_instances_used_per_fractal ({config.n_instances_used_per_fractal}) \n" - f"NOT multiple of n_fracts_per_vol={n_fracts_per_vol}. Rounding down." + log.warning( + "n_instances_used_per_fractal (%s) is not a multiple of " + "n_fracts_per_vol=%s. Rounding down.", + config.n_instances_used_per_fractal, + n_fracts_per_vol, ) config.n_instances_used_per_fractal = ( config.n_instances_used_per_fractal @@ -132,8 +132,8 @@ def main(config: Dict): with open(volumes_contents_path, "wb") as f: np.savetxt(f, volumes_contents.astype(int), fmt="%i", delimiter=",") - print( - f"volumegen.py({rank}): finished writing volumes_contents (shape = {volumes_contents.shape})" + log.info( + "Finished writing volumes_contents with shape %s", volumes_contents.shape ) # Broadcast to all ranks @@ -153,12 +153,15 @@ def main(config: Dict): end_idx = min(((rank + 1) * stride), num_volumes) if start_idx >= end_idx: - logging.info(f"Rank {rank} given no volumes to generate") + log.debug("Rank %s given no volumes to generate", rank) else: volumes_contents_subset = volumes_contents[start_idx:end_idx] - print( - f"rank {rank} responsible for volumes {volumes_contents_subset[0][0]} through {volumes_contents_subset[-1][0]}" + log.debug( + "Rank %s responsible for volumes %s through %s", + rank, + volumes_contents_subset[0][0], + volumes_contents_subset[-1][0], ) np.random.seed(config.seed) @@ -171,7 +174,7 @@ def main(config: Dict): start_time = time.time() for i, curr_vol in enumerate(volumes_contents_subset): if i % 10 == 0: - logging.info(f"Rank {rank} processing local volume {i}...") + log.debug("Rank %s processing local volume %s", rank, i) volume = np.full( (config.vol_size, config.vol_size, config.vol_size, 3), @@ -204,10 +207,10 @@ def main(config: Dict): ) if not os.path.exists(point_cloud_path): - print( - f"File {point_cloud_path} does not exist. Ensure you have run 'scaffold generate_fractals ...'" + raise FileNotFoundError( + f"File {point_cloud_path} does not exist. " + "Ensure you have run 'scaffold generate_fractals ...'" ) - sys.exit(1) points = load_np_ptcloud(point_cloud_path) mask3d = points_to_voxelgrid(points, grid_size) @@ -239,19 +242,22 @@ def main(config: Dict): end_time = time.time() total_time = end_time - start_time if rank == 0: - print( - f"Rank 0 generated {len(volumes_contents_subset)} volumes in {total_time:.2f} seconds | {len(volumes_contents_subset) / total_time:.2f} volumes per second" + log.info( + "Rank 0 generated %s volumes in %.2f seconds | %.2f volumes per second", + len(volumes_contents_subset), + total_time, + len(volumes_contents_subset) / total_time, ) # Barrier to ensure all ranks are finished writing comm.Barrier() if rank == 0: - print(f"volumegen.py({rank}): All ranks done. Proceeding to split.") + log.info("All ranks done. Proceeding to split.") # Do the train/val split and generate lists of unique train/val masks if rank == 0: - print("volumegen.py: volume gen COMPLETE. Now generating unique mask lists") + log.info("Volume generation complete. Generating unique mask lists.") # Directories are already created at start of script @@ -259,13 +265,15 @@ def main(config: Dict): val_files = sorted(list(val_indices)) train_files = sorted(list(set(range(num_volumes)) - val_indices)) - print( - f"volumegen.py({rank}): len(val_files)={len(val_files)}, len(train_files)={len(train_files)}." + log.info( + "len(val_files)=%s, len(train_files)=%s", + len(val_files), + len(train_files), ) # Save lists of unique train and val mask values - print( - f"volumegen.py({rank}): calculating unique mask values from configuration (no file read)" + log.info( + "Calculating unique mask values from configuration without reading mask files." ) # volumes_contents layout is [vol_idx, cat1, inst1, cat2, inst2, ...] diff --git a/ScaFFold/generate_fractals.py b/ScaFFold/generate_fractals.py index 99448dc..88f8531 100644 --- a/ScaFFold/generate_fractals.py +++ b/ScaFFold/generate_fractals.py @@ -17,6 +17,7 @@ from mpi4py import MPI from ScaFFold.datagen import category_search, instance +from ScaFFold.utils.utils import setup_mpi_logger def main(kwargs_dict: dict = {}): @@ -24,9 +25,9 @@ def main(kwargs_dict: dict = {}): comm = MPI.COMM_WORLD size = comm.Get_size() rank = comm.Get_rank() + log = setup_mpi_logger(__file__, getattr(args, "verbose", 0)) - if rank == 0: - print(f"generate_fractals.py: world size = {size}") + log.info("Fractal generation world size = %s", size) comm.Barrier() @@ -38,10 +39,7 @@ def main(kwargs_dict: dict = {}): comm.Barrier() - if rank == 0: - print( - f"generate_fractals.py({rank}): Fractal and instance generation has finished. Exiting..." - ) + log.info("Fractal and instance generation has finished.") MPI.Finalize() diff --git a/ScaFFold/utils/checkpointing.py b/ScaFFold/utils/checkpointing.py index 2a06a3c..520b66b 100644 --- a/ScaFFold/utils/checkpointing.py +++ b/ScaFFold/utils/checkpointing.py @@ -275,12 +275,17 @@ def save_checkpoint( self.last_ckpt_path, self.best_ckpt_path, is_best, + self.log, ) self._log("Async checkpoint offloaded to background thread.") else: # Synchronous Save self._write_to_disk( - state_dict, self.last_ckpt_path, self.best_ckpt_path, is_best + state_dict, + self.last_ckpt_path, + self.best_ckpt_path, + is_best, + self.log, ) # Broadcast result (for logging elsewhere) @@ -292,14 +297,13 @@ def save_checkpoint( return is_best @staticmethod - def _write_to_disk(state_dict, last_path, best_path, is_best): + def _write_to_disk(state_dict, last_path, best_path, is_best, log=None): """Worker function to perform actual disk I/O.""" # Save 'last' try: torch.save(state_dict, last_path) except Exception as e: - print("Saving checkpoint failed. Continuing...") - print(e) + CheckpointManager._log_save_failure(log, e) # Save 'best' (copy logic) if is_best: # Copy is often faster than re-serializing @@ -309,8 +313,15 @@ def _write_to_disk(state_dict, last_path, best_path, is_best): try: torch.save(state_dict, best_path) except Exception as e: - print("Saving checkpoint failed. Continuing...") - print(e) + CheckpointManager._log_save_failure(log, e) + + @staticmethod + def _log_save_failure(log, exc): + if log is not None: + log.warning("Saving checkpoint failed. Continuing: %s", exc) + else: + print("Saving checkpoint failed. Continuing...") + print(exc) def _transfer_dict_to_cpu(self, obj): """Recursively move tensors to CPU.""" diff --git a/ScaFFold/utils/distributed.py b/ScaFFold/utils/distributed.py index 524fafe..b986b77 100644 --- a/ScaFFold/utils/distributed.py +++ b/ScaFFold/utils/distributed.py @@ -12,18 +12,19 @@ # # SPDX-License-Identifier: (Apache-2.0) +from __future__ import annotations + import os import os.path import socket import time from typing import Literal, Optional -import torch -import torch.distributed - def get_num_gpus() -> int: """Return the number of GPUs on this node.""" + import torch + return torch.cuda.device_count() @@ -92,6 +93,8 @@ def get_world_size(required: bool = False) -> int: def get_device() -> torch.device: + import torch + if torch.cuda.is_available(): torch.cuda.init() @@ -133,9 +136,12 @@ def get_job_id() -> Optional[str]: def initialize_dist( - init_file: Optional[str] = None, rendezvous: Literal["env", "tcp", "file"] = "env" + init_file: Optional[str] = None, + rendezvous: Literal["env", "tcp", "file"] = "env", + log=None, ) -> None: """Initialize the PyTorch distributed backend and set up NCCL.""" + import torch if rendezvous == "env": init_method = "env://" @@ -173,9 +179,12 @@ def initialize_dist( else: raise ValueError(f'Unrecognized scheme "{rendezvous}"') - print( - f"distributed.py: rank {get_world_rank()} / {get_world_size()} calling init_process_group()" - ) + if log is not None: + log.debug( + "rank %s / %s calling init_process_group()", + get_world_rank(), + get_world_size(), + ) # Initialize torch.distributed.init_process_group( diff --git a/ScaFFold/utils/evaluate.py b/ScaFFold/utils/evaluate.py index bf4d6ab..dfeb476 100644 --- a/ScaFFold/utils/evaluate.py +++ b/ScaFFold/utils/evaluate.py @@ -35,6 +35,7 @@ def evaluate( n_categories, parallel_strategy, max_batches=None, + log=None, ): def foreground_dice_stats(dice_scores): per_sample_scores = dice_scores[:, 1:].mean(dim=1) @@ -54,9 +55,11 @@ def foreground_dice_stats(dice_scores): spatial_mesh = parallel_strategy.device_mesh[parallel_strategy.distconv_dim_names] - if primary: - print( - f"[eval] ps.shard_dim={parallel_strategy.shard_dim} num_shards={parallel_strategy.num_shards}" + if primary and log is not None: + log.debug( + "[eval] ps.shard_dim=%s num_shards=%s", + parallel_strategy.shard_dim, + parallel_strategy.num_shards, ) with torch.autocast(**autocast_kwargs): @@ -137,9 +140,15 @@ def foreground_dice_stats(dice_scores): net.train() val_loss_avg = val_loss_epoch / max(processed_batches, 1) - if primary: - print( - f"evaluate.py: dice_score={total_dice_score}, val_loss_epoch={val_loss_epoch}, val_loss_avg={val_loss_avg}, num_val_batches={processed_batches}, num_val_samples={processed_samples}" + if primary and log is not None: + log.debug( + "evaluate.py: dice_score=%s, val_loss_epoch=%s, val_loss_avg=%s, " + "num_val_batches=%s, num_val_samples=%s", + total_dice_score, + val_loss_epoch, + val_loss_avg, + processed_batches, + processed_samples, ) return ( total_dice_score, diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index c9c03f9..fcbbabd 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -581,6 +581,7 @@ def warmup(self): self.config.n_categories, self.config._parallel_strategy, max_batches=max_val_batches, + log=self.log, ) finally: self.checkpoint_manager.restore_training_state(snapshot) @@ -599,8 +600,10 @@ def train(self): start = time.time() while dice_score_train < self.config.target_dice: if self.config.epochs != -1 and epoch > self.config.epochs: - print( - f"Maxmimum epochs reached '{self.config.epochs}'. Concluding training early (may have not converged)." + self.log.warning( + "Maximum epochs reached '%s'. Concluding training early " + "(may have not converged).", + self.config.epochs, ) break @@ -677,6 +680,7 @@ def train(self): self.criterion, self.config.n_categories, self.config._parallel_strategy, + log=self.log, ) dice_info = torch.tensor( [dice_sum, numsamples], dtype=VOLUME_TORCH_DTYPE @@ -718,8 +722,13 @@ def train(self): + "\n" ) outfile.flush() - print( - f"Epoch {epoch} completed in {epoch_duration} seconds. Total train time so far: {time.time() - start}. Rank 0 first batch minibatch_time_s={minibatch_time_s:.6f}." + self.log.info( + "Epoch %s completed in %.2f seconds. Total train time so " + "far: %.2f seconds. Rank 0 first batch minibatch_time_s=%.6f.", + epoch, + epoch_duration, + time.time() - start, + minibatch_time_s, ) # diff --git a/ScaFFold/utils/utils.py b/ScaFFold/utils/utils.py index c79b3de..7ba1c15 100644 --- a/ScaFFold/utils/utils.py +++ b/ScaFFold/utils/utils.py @@ -16,11 +16,6 @@ import random import sys -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.distributed as dist - from ScaFFold.utils.distributed import get_world_rank logging.basicConfig( @@ -29,6 +24,8 @@ def plot_img_and_mask(img, mask): + import matplotlib.pyplot as plt + classes = mask.max() + 1 fig, ax = plt.subplots(1, classes + 1) ax[0].set_title("Input image") @@ -42,13 +39,17 @@ def plot_img_and_mask(img, mask): def set_seeds(seed_value=42): """Set seeds for reproducibility.""" + + import numpy as np + import torch + random.seed(seed_value) # Python np.random.seed(seed_value) # NumPy torch.manual_seed(seed_value) # PyTorch torch.cuda.manual_seed_all(seed_value) # PyTorch for GPUs -def customlog(msg: str, ranks=[0], allranks=False, level=0, verbose=0): +def customlog(msg: str, ranks=(0,), allranks=False, level=0, verbose=0): rank = get_world_rank() if (rank in ranks or allranks) and level <= verbose: logging.info(f"(rank {rank}): {msg}") @@ -60,12 +61,13 @@ def __init__(self, ranks: set[int] or None = None): self.allowed_ranks = ranks def filter(self, record: logging.LogRecord) -> bool: - record.mpi_rank = get_world_rank() + rank = get_world_rank() + record.mpi_rank = rank # If no allowed ranks specified, only rank 0 logs if self.allowed_ranks is None: - return get_world_rank() == 0 + return rank == 0 # Otherwise only allow logs from the specified ranks - return record.mpi_rank in self.allowed_ranks + return rank in self.allowed_ranks def setup_mpi_logger( @@ -81,23 +83,37 @@ def setup_mpi_logger( logger = logging.getLogger(name) logger.setLevel(log_level) - # Create a StreamHandler (to stderr) and attach MPI filter - handler = logging.StreamHandler(stream=sys.stderr) - handler.setLevel(log_level) - handler.addFilter(MPIRankFilter(ranks)) - # Set formatting, including MPI rank fmt = "[%(asctime)s][%(filename)s:%(lineno)d][rank=%(mpi_rank)d][%(levelname)s] %(message)s" - handler.setFormatter(logging.Formatter(fmt, datefmt="%H:%M:%S")) + formatter = logging.Formatter(fmt, datefmt="%H:%M:%S") + + for handler in logger.handlers: + if getattr(handler, "_scaffold_mpi_handler", False): + handler.setLevel(log_level) + handler.setFormatter(formatter) + handler.filters = [ + f for f in handler.filters if not isinstance(f, MPIRankFilter) + ] + handler.addFilter(MPIRankFilter(ranks)) + break + else: + # Create a StreamHandler (to stderr) and attach MPI filter + handler = logging.StreamHandler(stream=sys.stderr) + handler._scaffold_mpi_handler = True + handler.setLevel(log_level) + handler.addFilter(MPIRankFilter(ranks)) + handler.setFormatter(formatter) + logger.addHandler(handler) - # Attach handler - logger.addHandler(handler) # Prevent duplicate logging from Basic Logger logger.propagate = False return logger def mem_stats(): + import torch + import torch.distributed as dist + dev = torch.cuda.current_device() free, total = torch.cuda.mem_get_info() # device-level (driver) view stats = torch.cuda.memory_stats(dev) # allocator internals @@ -118,6 +134,8 @@ def mem_stats(): def gather_and_print_mem(log, tag=""): + import torch.distributed as dist + if log.getEffectiveLevel() > 10: # 10 -> DEBUG return stats = mem_stats() diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 375e5c0..d445048 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -112,7 +112,7 @@ def main(kwargs_dict: dict = {}): begin_code_region("init_ddp") if not dist.is_initialized(): log.info("Initializing distributed process group...") - initialize_dist(rendezvous="env") + initialize_dist(rendezvous="env", log=log) else: log.info("Distributed process group already initialized by launcher.") end_code_region("init_ddp") @@ -183,7 +183,7 @@ def main(kwargs_dict: dict = {}): if config.framework == "torch": # Optionally enable additional determinism settings if config.more_determinism: - print( + log.info( "Enabling additional determinism settings to improve training reproducibility" ) torch.backends.cudnn.benchmark = False @@ -248,7 +248,7 @@ def main(kwargs_dict: dict = {}): hostname = socket.gethostname() tracename = f"torch-{hostname}-r{rank}-N{world_size // ranks_per_node}-n{world_size}-ps{config.problem_scale}-e{config.epochs}-nipf{config.n_instances_used_per_fractal}-{int(time.time())}.json" prof.export_chrome_trace(tracename) - print(f"Wrote PyTorch trace '{tracename}'") + log.info("Wrote PyTorch trace '%s'", tracename) # # Calculate benchmark score From 093a01abc7afc6df1d9317741ca5e72df1df81af Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 18 Jun 2026 10:24:04 -0700 Subject: [PATCH 9/9] ruff --- ScaFFold/datagen/category_search.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ScaFFold/datagen/category_search.py b/ScaFFold/datagen/category_search.py index 27c042e..43254ab 100644 --- a/ScaFFold/datagen/category_search.py +++ b/ScaFFold/datagen/category_search.py @@ -263,7 +263,9 @@ def main(config: Config) -> None: runaway_fail_count, ) if len(params_valid) > 0: - log.info("Processing %s param sets from this attempt", len(params_valid)) + log.info( + "Processing %s param sets from this attempt", len(params_valid) + ) for p in params_valid: # Ensure we don't save more categories than needed if categories_remaining > 0: