From 4753e1a050677077ca2e80761a12d21bcc1c253d Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Fri, 5 Jun 2026 13:14:57 -0700 Subject: [PATCH 01/10] fix fractalgen bug: computing acceptance criteria variance across wrong axis --- ScaFFold/datagen/category_search.py | 65 ++++++++++++++++++----------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/ScaFFold/datagen/category_search.py b/ScaFFold/datagen/category_search.py index a7dbc7a..8b819e7 100644 --- a/ScaFFold/datagen/category_search.py +++ b/ScaFFold/datagen/category_search.py @@ -50,8 +50,8 @@ def generate_single_category(config: Config) -> tuple[bool, np.array, bool, bool A bool for whether a valid category was found on this attempt. params : np.array A numpy array containing IFS parameters for this category attempt, if attempt was valid. - (not nan_check_pass) : bool - A bool for whether this attempt passed the NaN check. + (not value_check_pass) : bool + A bool for whether this attempt passed the NaN/non-finite check. (not variance_check_pass) : bool A bool for whether this attempt passed the variance check. (not runaway_check_pass) : bool @@ -80,31 +80,40 @@ def generate_single_category(config: Config) -> tuple[bool, np.array, bool, bool ), ) - # Sum number of NaNs + # Sum number of NaNs and reject infinities before normalization. nan_count = np.isnan(points).sum() - nan_check_pass = nan_count == 0 + value_check_pass = nan_count == 0 and np.isfinite(points).all() variance_check_pass = False - if nan_check_pass: + if value_check_pass: # Normalize + center mins = points.min(axis=0) maxs = points.max(axis=0) means = points.mean(axis=0) - scales = (2 * config.normalize) / (maxs - mins) - points = (points - means) * scales - - # Calc dimension-wise variance and compare to threshold - points_variance = np.var(points, axis=1) - variance_check_pass = np.all(points_variance > config.variance_threshold) - if variance_check_pass and nan_check_pass and runaway_check_pass: + with np.errstate(over="ignore", invalid="ignore"): + ranges = maxs - mins + value_check_pass = np.all(np.isfinite(ranges)) and np.all(ranges > 0) + if value_check_pass: + scales = (2 * config.normalize) / ranges + with np.errstate(over="ignore", invalid="ignore"): + points = (points - means) * scales + + value_check_pass = np.isfinite(points).all() + if value_check_pass: + # Calc dimension-wise variance and compare to threshold + points_variance = np.var(points, axis=0) + variance_check_pass = np.all( + points_variance > config.variance_threshold + ) + if variance_check_pass and value_check_pass and runaway_check_pass: valid = True # Return result return ( valid, params, - not nan_check_pass, - not variance_check_pass, + bool(not value_check_pass), + bool(value_check_pass and not variance_check_pass), not runaway_check_pass, ) @@ -129,7 +138,7 @@ def generate_categories_batch( params : np.array A numpy array containing IFS parameters for this category attempt, if attempt was valid. failed_nan_check_count : int - The number of attempts in this batch which failed the nan check. + The number of attempts in this batch which failed the NaN/non-finite check. failed_var_check_count : int The number of attempts in this batch which failed the var check. runaway_failure_count : int @@ -186,7 +195,11 @@ def main(config: Config) -> None: rank = comm.Get_rank() size = comm.Get_size() - datagen_batch_size = 10000 + datagen_batch_size = int(getattr(config, "datagen_batch_size", 10000)) + if datagen_batch_size < 1: + raise ValueError( + f"datagen_batch_size must be positive, got {datagen_batch_size}" + ) # FIXME anything else to ensure determinism? np.random.seed(config.seed + rank) @@ -224,7 +237,7 @@ def main(config: Config) -> None: var_fail_count = 0 runaway_fail_count = 0 while categories_remaining > 0: - attempts += size + attempts += datagen_batch_size * size # Each rank attempts to generate datagen_batch_size categories ( @@ -245,12 +258,15 @@ def main(config: Config) -> None: # Process IFS params one at a time, writing each to a CSV if rank == 0: params_valid = [item for sublist in gathered_params for item in sublist] - if attempts % 10000 * size / datagen_batch_size == 0: + print( + f"cat_remaining = {categories_remaining} | total attempts = {attempts} | stats for rank 0: invalid_value_fail_count = {nan_fail_count}, var_fail_count = {var_fail_count}, runaway_fail_count = {runaway_fail_count}", + flush=True, + ) + if len(params_valid) > 0: print( - f"cat_remaining = {categories_remaining} | total attempts = {attempts} | stats for rank 0: nan_fail_count = {nan_fail_count}, var_fail_count = {var_fail_count}, runaway_fail_count = {runaway_fail_count}" + f"Processing {len(params_valid)} valid param sets from this batch", + flush=True, ) - if len(params_valid) > 0: - print(f"Processing {len(params_valid)} param sets from this attempt") for p in params_valid: # Ensure we don't save more categories than needed if categories_remaining > 0: @@ -284,14 +300,15 @@ def main(config: Config) -> None: global_runaway_fail_count = comm.reduce(runaway_fail_count, op=MPI.SUM, root=0) if rank == 0 and attempts > 0: + generated_categories = config.n_categories - existing_categories print( - f"Generated {config.n_categories - existing_categories} new categories in {attempts * datagen_batch_size} total attempts | {attempts * datagen_batch_size / (config.n_categories - existing_categories)} Attempts per category | Total categories is now {config.n_categories}" + f"Generated {generated_categories} new categories in {attempts} total attempts | {attempts / generated_categories} Attempts per category | Total categories is now {config.n_categories}" ) print( - f"Failures experienced: {global_nan_fail_count} nan attempts, {100 * global_nan_fail_count / (attempts * datagen_batch_size):.4f}% of all attempts, {global_var_fail_count} var fail attempts, {100 * global_var_fail_count / (attempts * datagen_batch_size):.4f}% of all attempts, {global_runaway_fail_count} runaway attempts, {100 * global_runaway_fail_count / (attempts * datagen_batch_size):.4f}% of all attempts" + f"Failures experienced: {global_nan_fail_count} invalid-value attempts, {100 * global_nan_fail_count / attempts:.4f}% of all attempts, {global_var_fail_count} var fail attempts, {100 * global_var_fail_count / attempts:.4f}% of all attempts, {global_runaway_fail_count} runaway attempts, {100 * global_runaway_fail_count / attempts:.4f}% of all attempts" ) print( - f"Rank 0 wall time = {rank_total_time:.2f} | Total CPU time = {global_sum_time:.2f} | Avg wall time per rank {global_sum_time / size:.2f} | {attempts * datagen_batch_size / rank_total_time:.2f} total attempts per wall second | {attempts * datagen_batch_size / rank_total_time / size:.2f} attempts per wall second per rank" + f"Rank 0 wall time = {rank_total_time:.2f} | Total CPU time = {global_sum_time:.2f} | Avg wall time per rank {global_sum_time / size:.2f} | {attempts / rank_total_time:.2f} total attempts per wall second | {attempts / rank_total_time / size:.2f} attempts per wall second per rank" ) return 0 From a407644af786ea530501faa56cde6980456a994d Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Fri, 5 Jun 2026 13:23:03 -0700 Subject: [PATCH 02/10] pass absolute fractal dir path to scaffold calls to ensure existing fractals found --- scripts/scaffold-matrix.job | 13 ++++++++++--- scripts/scaffold-tuolumne-torchpypi.job | 13 ++++++++++--- scripts/scaffold-tuolumne.job | 13 ++++++++++--- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/scripts/scaffold-matrix.job b/scripts/scaffold-matrix.job index d194aa8..f143fcd 100644 --- a/scripts/scaffold-matrix.job +++ b/scripts/scaffold-matrix.job @@ -12,10 +12,17 @@ export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH -scaffold generate_fractals -c ScaFFold/configs/benchmark_default.yml +CONFIG_PATH="$(pwd)/ScaFFold/configs/benchmark_default.yml" +FRACT_BASE_DIR="${FRACT_BASE_DIR:-$(pwd)/ScaFFold/fractals}" + +scaffold generate_fractals \ + -c "$CONFIG_PATH" \ + --fract-base-dir "$FRACT_BASE_DIR" # Uncomment if you want torch profiling #export PROFILE_TORCH=ON -torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml -#torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml +torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark \ + -c "$CONFIG_PATH" \ + --fract-base-dir "$FRACT_BASE_DIR" +#torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c "$CONFIG_PATH" --fract-base-dir "$FRACT_BASE_DIR" diff --git a/scripts/scaffold-tuolumne-torchpypi.job b/scripts/scaffold-tuolumne-torchpypi.job index 0387e5e..c78af96 100644 --- a/scripts/scaffold-tuolumne-torchpypi.job +++ b/scripts/scaffold-tuolumne-torchpypi.job @@ -22,10 +22,17 @@ export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD=0 # Disable naive_conv_ab_nonpacked_wrw_ndhwc_half_double_half.kd export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_WRW=0 -torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/configs/benchmark_default.yml +CONFIG_PATH="$(pwd)/ScaFFold/configs/benchmark_default.yml" +FRACT_BASE_DIR="${FRACT_BASE_DIR:-$(pwd)/ScaFFold/fractals}" + +torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals \ + -c "$CONFIG_PATH" \ + --fract-base-dir "$FRACT_BASE_DIR" # Uncomment if you want torch profiling #export PROFILE_TORCH=ON -torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml -# torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml +torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark \ + -c "$CONFIG_PATH" \ + --fract-base-dir "$FRACT_BASE_DIR" +# torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c "$CONFIG_PATH" --fract-base-dir "$FRACT_BASE_DIR" diff --git a/scripts/scaffold-tuolumne.job b/scripts/scaffold-tuolumne.job index a22d8c6..1ae88b0 100644 --- a/scripts/scaffold-tuolumne.job +++ b/scripts/scaffold-tuolumne.job @@ -25,10 +25,17 @@ export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD=0 # Disable naive_conv_ab_nonpacked_wrw_ndhwc_half_double_half.kd export MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_WRW=0 -torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals -c $(pwd)/ScaFFold/configs/benchmark_default.yml +CONFIG_PATH="$(pwd)/ScaFFold/configs/benchmark_default.yml" +FRACT_BASE_DIR="${FRACT_BASE_DIR:-$(pwd)/ScaFFold/fractals}" + +torchrun-hpc -N 1 -n 1 $(which scaffold) generate_fractals \ + -c "$CONFIG_PATH" \ + --fract-base-dir "$FRACT_BASE_DIR" # Uncomment if you want torch profiling #export PROFILE_TORCH=ON -torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml -# torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c $(pwd)/ScaFFold/configs/benchmark_default.yml +torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark \ + -c "$CONFIG_PATH" \ + --fract-base-dir "$FRACT_BASE_DIR" +# torchrun-hpc -N 2 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c "$CONFIG_PATH" --fract-base-dir "$FRACT_BASE_DIR" From 77626bee39435430b4083636238f7590f7eb4122 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 27 May 2026 08:07:25 -0700 Subject: [PATCH 03/10] revert more complex changes, focus on simpler implementation: dataset sharding matches distconv spec --- ScaFFold/cli.py | 6 + ScaFFold/configs/benchmark_testing.yml | 4 +- ScaFFold/datagen/get_dataset.py | 170 ++++++++--- ScaFFold/datagen/volumegen.py | 400 ++++++++++++++++--------- ScaFFold/utils/config_utils.py | 11 + ScaFFold/utils/data_loading.py | 194 ++++++++++-- ScaFFold/utils/spatial_sharding.py | 127 ++++++++ 7 files changed, 699 insertions(+), 213 deletions(-) create mode 100644 ScaFFold/utils/spatial_sharding.py diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 469c71e..d5766bb 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -177,6 +177,12 @@ def main(): nargs=3, help="DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum", ) + benchmark_parser.add_argument( + "--dc-shard-dims", + type=int, + nargs=3, + help="DistConv param: dimensions on which to shard.", + ) benchmark_parser.add_argument( "--epochs", type=int, diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index 5167de1..bf741e5 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -10,8 +10,8 @@ seed: 42 # Random seed. batch_size: 1 # Batch sizes for each vol size. dataloader_num_workers: 4 # Number of DataLoader worker processes per rank. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. -num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum -shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard +dc_num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum +dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard checkpoint_interval: -1 # Checkpoint every C epochs; set to -1 to disable checkpointing entirely. # Internal/dev use only diff --git a/ScaFFold/datagen/get_dataset.py b/ScaFFold/datagen/get_dataset.py index c7ffaf9..511cdf0 100644 --- a/ScaFFold/datagen/get_dataset.py +++ b/ScaFFold/datagen/get_dataset.py @@ -21,7 +21,7 @@ import time from argparse import Namespace from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Optional import yaml from mpi4py import MPI @@ -29,8 +29,9 @@ from ScaFFold.datagen import volumegen META_FILENAME = "meta.yaml" -DATASET_FORMAT_VERSION = 2 -INCLUDE_KEYS = [ +DATASET_FORMAT_VERSION = 3 +V2_DATASET_FORMAT_VERSION = 2 +V2_INCLUDE_KEYS = [ "dataset_format_version", "n_categories", "n_instances_used_per_fractal", @@ -40,6 +41,10 @@ "n_fracts_per_vol", "val_split", ] +INCLUDE_KEYS = V2_INCLUDE_KEYS + [ + "dc_num_shards", + "dc_shard_dims", +] def canonicalize(input): @@ -76,6 +81,26 @@ def _hash_volume_config(volume_config: Dict[str, Any]) -> str: return hashlib.sha256(s).hexdigest()[:12] +def _volume_config_for_version(config_dict, dataset_format_version): + versioned_config = config_dict.copy() + versioned_config["dataset_format_version"] = dataset_format_version + if dataset_format_version == DATASET_FORMAT_VERSION: + include_keys = INCLUDE_KEYS + else: + include_keys = V2_INCLUDE_KEYS + return _get_required_keys_dict( + config=versioned_config, + include_keys=include_keys, + ) + + +def _requested_unsharded_layout(config_dict: Dict[str, Any]) -> bool: + total_shards = 1 + for value in config_dict["dc_num_shards"]: + total_shards *= int(value) + return total_shards == 1 + + def _git_commit_short() -> str: try: return ( @@ -98,6 +123,36 @@ def _git_commit_short() -> str: return "no-commit-id" +def _find_reusable_dataset( + root: Path, + config_id: str, + dataset_format_version: int, + commit: str, + require_commit: bool, +) -> Optional[Path]: + base = root / config_id + if not base.exists(): + return None + + candidates = sorted( + (p for p in base.iterdir() if p.is_dir()), key=lambda p: p.name, reverse=True + ) + for dataset_path in candidates: + meta_path = dataset_path / META_FILENAME + if not meta_path.exists(): + continue + meta = yaml.safe_load(meta_path.read_text()) or {} + if meta.get("config_id") != config_id: + continue + if meta.get("dataset_format_version", 1) != dataset_format_version: + continue + if require_commit and meta.get("code_commit") != commit: + continue + return dataset_path + + return None + + def get_dataset( config: Namespace, require_commit: bool = False, # default: ignore commit mismatches for reuse @@ -118,39 +173,49 @@ def get_dataset( root = Path(config.dataset_dir) root.mkdir(exist_ok=True) - # Get dict of required keys and compute config_id + # V3 is the current physical-shard format. The physical dataset layout is + # defined by dc_num_shards/dc_shard_dims, matching the DistConv layout. config_dict = vars(config).copy() - config_dict["dataset_format_version"] = DATASET_FORMAT_VERSION - volume_config = _get_required_keys_dict( - config=config_dict, include_keys=INCLUDE_KEYS - ) + volume_config = _volume_config_for_version(config_dict, DATASET_FORMAT_VERSION) config_id = _hash_volume_config(volume_config) + v2_volume_config = _volume_config_for_version(config_dict, V2_DATASET_FORMAT_VERSION) + v2_config_id = _hash_volume_config(v2_volume_config) commit = _git_commit_short() base = root / config_id base.mkdir(parents=True, exist_ok=True) - # Try to reuse latest candidate dataset - candidates = sorted( - (p for p in base.iterdir() if p.is_dir()), key=lambda p: p.name, reverse=True + # Prefer a matching V3 physical-shard dataset. + dataset_path = _find_reusable_dataset( + root, + config_id, + DATASET_FORMAT_VERSION, + commit, + require_commit, ) - for dataset_path in candidates: - meta_path = dataset_path / META_FILENAME - if not meta_path.exists(): - continue - meta = yaml.safe_load(meta_path.read_text()) - if meta.get("config_id") != config_id: - continue - if meta.get("dataset_format_version", 1) != DATASET_FORMAT_VERSION: - continue - if require_commit and meta.get("code_commit") != commit: - continue - # If we pass the above checks, this dataset can be reused + if dataset_path is not None: print( - "Valid existing dataset found. Reusing this dataset..." + "Valid existing v3 sharded dataset found. Reusing this dataset..." ) # FIXME replace with updated logging return dataset_path + # V2 datasets are full-volume files without shard suffixes. Reuse them only + # for unsharded requests so sharded generation never silently returns a + # cache that lacks the requested shard files. + if _requested_unsharded_layout(config_dict): + dataset_path = _find_reusable_dataset( + root, + v2_config_id, + V2_DATASET_FORMAT_VERSION, + commit, + require_commit, + ) + if dataset_path is not None: + print( + "Valid existing v2 full-volume dataset found. Reusing this dataset..." + ) # FIXME replace with updated logging + return dataset_path + # Otherwise, generate a new dataset print(f"No valid existing dataset found at {base}. Generating new dataset...") if rank == 0: @@ -176,33 +241,52 @@ def get_dataset( # Check that all ranks succeeded in volumegen, then sync all_ok = comm.allreduce(1 if ok else 0, op=MPI.MIN) == 1 - comm.Barrier() + errs = comm.gather(err, root=0) - # rank 0 has file write + move + failure_msg = None if rank == 0: if not all_ok: try: shutil.rmtree(tmp, ignore_errors=True) except Exception: pass - # collect & raise a representative error - errs = comm.gather(err, root=0) msgs = "; ".join(e for e in errs if e) - raise RuntimeError(f"dataset generation failed: {msgs or 'unknown error'}") - - # Write to tmp, then move, so readers never see half-written dataset - meta = { - "config_id": config_id, - "dataset_format_version": DATASET_FORMAT_VERSION, - "config_subset": volume_config, - "include_keys": INCLUDE_KEYS, - "code_commit": commit, - "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), - } - (tmp / META_FILENAME).write_text( - yaml.safe_dump(meta, sort_keys=True, default_flow_style=False) - ) - tmp.rename(dest) + failure_msg = f"dataset generation failed: {msgs or 'unknown error'}" + + failure_msg = comm.bcast(failure_msg, root=0) + if failure_msg: + raise RuntimeError(failure_msg) + + # rank 0 has file write + move + finalize_err = "" + if rank == 0: + try: + # Write to tmp, then move, so readers never see half-written dataset + meta = { + "config_id": config_id, + "dataset_format_version": DATASET_FORMAT_VERSION, + "config_subset": volume_config, + "include_keys": INCLUDE_KEYS, + "code_commit": commit, + "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + } + (tmp / META_FILENAME).write_text( + yaml.safe_dump(meta, sort_keys=True, default_flow_style=False) + ) + tmp.rename(dest) + except Exception as e: + finalize_err = ( + f"dataset finalization failed: rank 0: {type(e).__name__}: {e}" + ) + + finalize_err = comm.bcast(finalize_err, root=0) + if finalize_err: + if rank == 0: + try: + shutil.rmtree(tmp, ignore_errors=True) + except Exception: + pass + raise RuntimeError(finalize_err) # ensure the rename is visible everywhere before returning comm.Barrier() diff --git a/ScaFFold/datagen/volumegen.py b/ScaFFold/datagen/volumegen.py index b268aa7..ffeec44 100644 --- a/ScaFFold/datagen/volumegen.py +++ b/ScaFFold/datagen/volumegen.py @@ -17,16 +17,22 @@ import os import pickle import random -import sys import time from math import ceil -from typing import Dict +from typing import Callable, Dict import numpy as np from mpi4py import MPI from ScaFFold.utils.config_utils import Config from ScaFFold.utils.data_types import DEFAULT_NP_DTYPE, MASK_DTYPE, VOLUME_DTYPE +from ScaFFold.utils.spatial_sharding import ( + normalize_sharding, + shard_file_suffix, + shard_id_to_indices, + spatial_slices, + total_shards, +) def load_np_ptcloud(path: str) -> np.ndarray: @@ -44,7 +50,24 @@ def points_to_voxelgrid( Convert an (N,3) float64 point cloud directly into a boolean voxel grid of shape (grid_size, grid_size, grid_size). """ - # 1) Axis‐aligned bounding box in float64 + idx = points_to_voxel_indices(points, grid_size, eps=eps) + + # Scatter into a boolean grid + grid = np.zeros((grid_size, grid_size, grid_size), dtype=bool) + grid[idx[:, 0], idx[:, 1], idx[:, 2]] = True + + return grid + + +def points_to_voxel_indices( + points: np.ndarray, grid_size: int, eps: float = 1e-6 +) -> np.ndarray: + """ + Convert an (N,3) point cloud into global voxel indices using the same + math as points_to_voxelgrid(), without allocating a full boolean grid. + """ + + # 1) Axis-aligned bounding box in float64 mins = points.min(axis=0) maxs = points.max(axis=0) @@ -53,16 +76,159 @@ def points_to_voxelgrid( # 3) Map points into [0,grid_size) indices scaled = (points - mins) / voxel_size - idx = np.floor(scaled).astype(int) + np.floor(scaled, out=scaled) + idx = scaled.astype(int) # 4) Clip to valid range - idx = np.clip(idx, 0, grid_size - 1) + np.clip(idx, 0, grid_size - 1, out=idx) - # 5) Scatter into a boolean grid - grid = np.zeros((grid_size, grid_size, grid_size), dtype=bool) - grid[idx[:, 0], idx[:, 1], idx[:, 2]] = True + return idx - return grid + +def _build_volumes_contents(config, n_fracts_per_vol): + # Force n_instances_used_per_fractal to be multiple of n_fracts_per_vol + if config.n_instances_used_per_fractal % n_fracts_per_vol != 0: + print( + f"volumegen.py: WARNING: n_instances_used_per_fractal ({config.n_instances_used_per_fractal}) \n" + f"NOT multiple of n_fracts_per_vol={n_fracts_per_vol}. Rounding down." + ) + config.n_instances_used_per_fractal = ( + config.n_instances_used_per_fractal + // n_fracts_per_vol + * n_fracts_per_vol + ) + + # Randomly select n_instances_used_per_fractal instances from each fractal class. + instances_list = [] + for category in range(config.n_categories): + instances_remaining = config.n_instances_used_per_fractal + random_instances = [] + while instances_remaining > 0: + random_instances.extend( + random.sample(range(145), min(145, instances_remaining)) + ) + instances_remaining -= min(145, instances_remaining) + + category_instance_pairs = [[category, instance] for instance in random_instances] + instances_list.extend(category_instance_pairs) + + instances_list = np.array(instances_list, dtype=int) + np.random.shuffle(instances_list) + + volumes_contents = instances_list.reshape(-1, 2 * n_fracts_per_vol) + + indices = np.arange(volumes_contents.shape[0]).reshape(-1, 1) + return np.hstack([indices, volumes_contents]) + + +def _validation_indices(num_volumes: int, config) -> set[int]: + random.seed(config.seed) + return set( + random.sample(range(num_volumes), int(num_volumes * config.val_split / 100)) + ) + + +def _fractal_colors(config, n_fracts_per_vol): + np.random.seed(config.seed) + return np.random.rand(max(config.n_categories, n_fracts_per_vol), 3) + + +def _point_cloud_path(config, curr_category: int, curr_instance: int) -> str: + instances_dir = f"var{config.variance_threshold}/instances/np{config.point_num}" + return os.path.join( + str(config.fract_base_dir), + instances_dir, + f"{curr_category:06d}", + f"{curr_category:06d}_{curr_instance:04d}.npy", + ) + + +def _local_shape(slices): + return tuple(s.stop - s.start for s in slices) + + +def _physical_sharding(config): + return normalize_sharding(config.dc_num_shards, config.dc_shard_dims) + + +def _validate_generation_config(config): + num_shards, shard_dims = _physical_sharding(config) + n_total_shards = total_shards(num_shards) + + grid_size = math.floor(config.vol_size * config.scale) + if grid_size != config.vol_size: + raise ValueError( + "Sharded volume generation currently requires config.scale == 1 so shard files tile the full volume" + ) + + return num_shards, shard_dims, n_total_shards, grid_size + + +def generate_volume_shard( + config, + curr_vol: np.ndarray, + shard_id: int, + fractal_colors: np.ndarray, + point_cloud_loader: Callable[[str], np.ndarray] = load_np_ptcloud, +): + """ + Generate one physical shard for one logical volume. + + Voxel indices are computed in the full-volume coordinate system first, then + filtered to the shard. This preserves bitwise reconstruction across + different shard layouts. + """ + + n_fracts_per_vol = config.n_fracts_per_vol + num_shards, shard_dims = _physical_sharding(config) + shard_indices = shard_id_to_indices(shard_id, num_shards) + slices = spatial_slices( + (config.vol_size, config.vol_size, config.vol_size), + shard_dims, + num_shards, + shard_indices, + ) + local_shape = _local_shape(slices) + + volume = np.full((3, *local_shape), 0, dtype=VOLUME_DTYPE) + mask = np.full(local_shape, 0, dtype=MASK_DTYPE) + grid_size = math.floor(config.vol_size * config.scale) + + for curr_fract in range(n_fracts_per_vol): + curr_category = int(curr_vol[1 + 2 * curr_fract]) + curr_instance = int(curr_vol[1 + 2 * curr_fract + 1]) + fractal_color = fractal_colors[curr_category] + + point_cloud_path = _point_cloud_path(config, curr_category, curr_instance) + if point_cloud_loader is load_np_ptcloud and not os.path.exists( + point_cloud_path + ): + raise FileNotFoundError( + f"File {point_cloud_path} does not exist. Ensure you have run 'scaffold generate_fractals ...'" + ) + + points = point_cloud_loader(point_cloud_path) + idx = points_to_voxel_indices(points, grid_size) + keep = np.ones(idx.shape[0], dtype=bool) + for axis, axis_slice in enumerate(slices): + keep &= idx[:, axis] >= axis_slice.start + keep &= idx[:, axis] < axis_slice.stop + + if not np.any(keep): + continue + + local_idx = idx[keep] + local_idx[:, 0] -= slices[0].start + local_idx[:, 1] -= slices[1].start + local_idx[:, 2] -= slices[2].start + d = local_idx[:, 0] + h = local_idx[:, 1] + w = local_idx[:, 2] + + volume[:, d, h, w] = fractal_color[:, None] + mask[d, h, w] = curr_category + 1 + + return volume, mask def main(config: Dict): @@ -80,168 +246,120 @@ def main(config: Dict): volumes_contents_path = os.path.join(dataset_dir, "volumes_contents.csv") n_fracts_per_vol = config.n_fracts_per_vol + _, _, n_total_shards, _ = _validate_generation_config(config) random.seed(config.seed) # Python np.random.seed(config.seed) # NumPy # Set up directories and select instances from each category volumes_contents = None + setup_err = "" if rank == 0: - if not os.path.exists(dataset_dir): - os.makedirs(dataset_dir) - for subdir in ["training", "validation"]: - os.makedirs(os.path.join(vol_path, subdir), exist_ok=True) - os.makedirs(os.path.join(mask_path, subdir), exist_ok=True) - - # Force n_instances_used_per_fractal to be multiple of n_fracts_per_vol - if config.n_instances_used_per_fractal % n_fracts_per_vol != 0: - print( - f"volumegen.py: WARNING: n_instances_used_per_fractal ({config.n_instances_used_per_fractal}) \n" - f"NOT multiple of n_fracts_per_vol={n_fracts_per_vol}. Rounding down." - ) - config.n_instances_used_per_fractal = ( - config.n_instances_used_per_fractal - // n_fracts_per_vol - * n_fracts_per_vol - ) - - # Randomly select n_instances_used_per_fractal instances from each fractal class. - instances_list = [] - for category in range(config.n_categories): - instances_remaining = config.n_instances_used_per_fractal - random_instances = [] - while instances_remaining > 0: - random_instances.extend( - random.sample(range(145), min(145, instances_remaining)) - ) - instances_remaining -= min(145, instances_remaining) - - category_instance_pairs = [ - [category, instance] for instance in random_instances - ] - instances_list.extend(category_instance_pairs) - - instances_list = np.array(instances_list, dtype=int) - np.random.shuffle(instances_list) - - volumes_contents = instances_list.reshape(-1, 2 * n_fracts_per_vol) + try: + if not os.path.exists(dataset_dir): + os.makedirs(dataset_dir) + for subdir in ["training", "validation"]: + os.makedirs(os.path.join(vol_path, subdir), exist_ok=True) + os.makedirs(os.path.join(mask_path, subdir), exist_ok=True) - indices = np.arange(volumes_contents.shape[0]).reshape(-1, 1) - volumes_contents = np.hstack([indices, volumes_contents]) + volumes_contents = _build_volumes_contents(config, n_fracts_per_vol) - with open(volumes_contents_path, "wb") as f: - np.savetxt(f, volumes_contents.astype(int), fmt="%i", delimiter=",") - print( - f"volumegen.py({rank}): finished writing volumes_contents (shape = {volumes_contents.shape})" - ) + with open(volumes_contents_path, "wb") as f: + np.savetxt(f, volumes_contents.astype(int), fmt="%i", delimiter=",") + print( + f"volumegen.py({rank}): finished writing volumes_contents (shape = {volumes_contents.shape})" + ) + except Exception as e: + setup_err = f"setup failed: rank {rank}: {type(e).__name__}: {e}" # Broadcast to all ranks - volumes_contents = comm.bcast(volumes_contents, root=0) + volumes_contents, setup_err = comm.bcast((volumes_contents, setup_err), root=0) + if setup_err: + raise RuntimeError(setup_err) # Determine train/val split globally so all ranks know where to save num_volumes = len(volumes_contents) - random.seed(config.seed) # Reset seed to ensure all ranks get same split - val_indices = set( - random.sample(range(num_volumes), int(num_volumes * config.val_split / 100)) - ) + val_indices = _validation_indices(num_volumes, config) # Work distribution - num_volumes = len(volumes_contents) - stride = ceil(num_volumes / size) + total_tasks = num_volumes * n_total_shards + stride = ceil(total_tasks / size) start_idx = rank * stride - end_idx = min(((rank + 1) * stride), num_volumes) - - if start_idx >= end_idx: - logging.info(f"Rank {rank} given no volumes to generate") - - else: - volumes_contents_subset = volumes_contents[start_idx:end_idx] - print( - f"rank {rank} responsible for volumes {volumes_contents_subset[0][0]} through {volumes_contents_subset[-1][0]}" - ) - - np.random.seed(config.seed) - fractal_colors = np.random.rand(max(config.n_categories, n_fracts_per_vol), 3) + end_idx = min(((rank + 1) * stride), total_tasks) - grid_size = math.floor(config.vol_size * config.scale) - fract_base_dir = str(config.fract_base_dir) + generation_err = "" + try: + if start_idx >= end_idx: + logging.info(f"Rank {rank} given no volume shards to generate") - # Generation loop - start_time = time.time() - for i, curr_vol in enumerate(volumes_contents_subset): - if i % 10 == 0: - logging.info(f"Rank {rank} processing local volume {i}...") - - volume = np.full( - (config.vol_size, config.vol_size, config.vol_size, 3), - 0, - dtype=VOLUME_DTYPE, - ) - mask = np.full( - (config.vol_size, config.vol_size, config.vol_size), 0, dtype=MASK_DTYPE + else: + task_ids = range(start_idx, end_idx) + print( + f"rank {rank} responsible for volume-shard tasks {start_idx} through {end_idx - 1}" ) - global_vol_idx = curr_vol[0] - vol_seed = config.seed + int(global_vol_idx) - random.seed(vol_seed) - np.random.seed(vol_seed) - - for curr_fract in range(n_fracts_per_vol): - curr_category = curr_vol[1 + 2 * curr_fract] - curr_instance = curr_vol[1 + 2 * curr_fract + 1] - fractal_color = fractal_colors[curr_category] - - instances_dir = ( - f"var{config.variance_threshold}/instances/np{config.point_num}" - ) + fractal_colors = _fractal_colors(config, n_fracts_per_vol) - point_cloud_path = os.path.join( - fract_base_dir, - instances_dir, - f"{curr_category:06d}", - f"{curr_category:06d}_{curr_instance:04d}.npy", - ) - - if not os.path.exists(point_cloud_path): - print( - f"File {point_cloud_path} does not exist. Ensure you have run 'scaffold generate_fractals ...'" + # Generation loop + start_time = time.time() + for i, task_id in enumerate(task_ids): + if i % 10 == 0: + logging.info( + f"Rank {rank} processing local volume-shard task {i}..." ) - sys.exit(1) - - points = load_np_ptcloud(point_cloud_path) - mask3d = points_to_voxelgrid(points, grid_size) - assert mask3d.shape == volume.shape[:3], ( - f"mask3d {mask3d.shape} != volume spatial dims {volume.shape[:3]}" + volume_idx = task_id // n_total_shards + shard_id = task_id % n_total_shards + curr_vol = volumes_contents[volume_idx] + global_vol_idx = curr_vol[0] + vol_seed = config.seed + int(global_vol_idx) + random.seed(vol_seed) + np.random.seed(vol_seed) + + volume_to_save, mask_to_save = generate_volume_shard( + config, + curr_vol, + shard_id, + fractal_colors, ) - volume[mask3d] = fractal_color - mask[mask3d] = curr_category + 1 + # Determine destination folder + subdir = "validation" if global_vol_idx in val_indices else "training" + shard_suffix = shard_file_suffix(shard_id) - # Determine destination folder - subdir = "validation" if global_vol_idx in val_indices else "training" - # Tensors must logically be channels-first, later we will change striding/storage to channels-last on GPU (metadata will always stay channels-first). - volume_channels_first = volume.transpose((3, 0, 1, 2)) - volume_to_save = np.ascontiguousarray( - volume_channels_first, dtype=VOLUME_DTYPE - ) - mask_to_save = np.ascontiguousarray(mask, dtype=MASK_DTYPE) - - vol_file = os.path.join(vol_path, subdir, f"{global_vol_idx}.npy") - with open(vol_file, "wb") as f: - np.save(f, volume_to_save) + vol_file = os.path.join( + vol_path, subdir, f"{global_vol_idx}{shard_suffix}.npy" + ) + with open(vol_file, "wb") as f: + np.save(f, volume_to_save) - mask_file = os.path.join(mask_path, subdir, f"{global_vol_idx}_mask.npy") - with open(mask_file, "wb") as f: - np.save(f, mask_to_save) + mask_file = os.path.join( + mask_path, subdir, f"{global_vol_idx}{shard_suffix}_mask.npy" + ) + with open(mask_file, "wb") as f: + np.save(f, mask_to_save) + + end_time = time.time() + total_time = end_time - start_time + if rank == 0: + print( + f"Rank 0 generated {end_idx - start_idx} volume shards in {total_time:.2f} seconds | {(end_idx - start_idx) / total_time:.2f} shards per second" + ) + except Exception as e: + generation_err = ( + f"volume shard generation failed: rank {rank}: {type(e).__name__}: {e}" + ) - end_time = time.time() - total_time = end_time - start_time - if rank == 0: - print( - f"Rank 0 generated {len(volumes_contents_subset)} volumes in {total_time:.2f} seconds | {len(volumes_contents_subset) / total_time:.2f} volumes per second" - ) + all_generated = comm.allreduce(1 if not generation_err else 0, op=MPI.MIN) == 1 + generation_errs = comm.gather(generation_err, root=0) + generation_failure = "" + if rank == 0 and not all_generated: + msgs = "; ".join(e for e in generation_errs if e) + generation_failure = msgs or "unknown volume shard generation error" + generation_failure = comm.bcast(generation_failure, root=0) + if generation_failure: + raise RuntimeError(generation_failure) # Barrier to ensure all ranks are finished writing comm.Barrier() diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 36f1603..0d03358 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -85,6 +85,17 @@ def __init__(self, config_dict): self.dc_num_shards = config_dict["dc_num_shards"] self.dc_shard_dims = config_dict["dc_shard_dims"] self.dc_total_shards = math.prod(self.dc_num_shards) + unsupported_dataset_keys = [ + key + for key in ("dataset_num_shards", "dataset_shard_dims") + if key in config_dict + ] + if unsupported_dataset_keys: + raise ValueError( + "Configuration Mismatch: dataset_num_shards/dataset_shard_dims " + "are not supported. Use dc_num_shards/dc_shard_dims for the " + "v3 physical dataset layout." + ) # Safety Check: Length mismatch if len(self.dc_num_shards) != len(self.dc_shard_dims): raise ValueError( diff --git a/ScaFFold/utils/data_loading.py b/ScaFFold/utils/data_loading.py index 688f329..c655c2b 100644 --- a/ScaFFold/utils/data_loading.py +++ b/ScaFFold/utils/data_loading.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: (Apache-2.0) import pickle +import re from dataclasses import dataclass from os import listdir from os.path import isfile, join, splitext @@ -25,9 +26,18 @@ from torch.utils.data import Dataset from ScaFFold.utils.data_types import MASK_DTYPE, VOLUME_DTYPE +from ScaFFold.utils.spatial_sharding import ( + chunk_slice, + normalize_sharding, + shard_file_suffix, + shard_indices_to_id, + total_shards, +) from ScaFFold.utils.utils import customlog -DATASET_FORMAT_VERSION = 2 +DATASET_FORMAT_VERSION = 3 +FULL_VOLUME_DATASET_FORMAT_VERSION = 2 +SHARDED_DATASET_FORMAT_VERSION = 3 LEGACY_DATASET_FORMAT_VERSION = 1 META_FILENAME = "meta.yaml" @@ -69,14 +79,11 @@ def __post_init__(self): def _chunk_slice(size: int, num_shards: int, shard_index: int) -> slice: """Match torch.chunk-style uneven shard boundaries.""" - chunk_size = (size + num_shards - 1) // num_shards - start = shard_index * chunk_size - if start >= size: - raise ValueError( - f"Empty local shard: dim size {size}, num_shards {num_shards}, shard_index {shard_index}" - ) - stop = min(size, start + chunk_size) - return slice(start, stop) + return chunk_slice(size, num_shards, shard_index) + + @property + def shard_id(self) -> int: + return shard_indices_to_id(self.shard_indices, self.num_shards) def slice_array( self, array: np.ndarray, axis_map: Dict[int, int], array_label: str @@ -116,13 +123,30 @@ def __init__( self.mask_suffix = mask_suffix self.spatial_shard_spec = spatial_shard_spec self.dataset_root = self.images_dir.parents[1] - self.dataset_format_version = self._load_dataset_format_version() + self.dataset_meta = self._load_dataset_metadata() + self.dataset_format_version = int( + self.dataset_meta.get( + "dataset_format_version", LEGACY_DATASET_FORMAT_VERSION + ) + ) + if self.dataset_format_version > DATASET_FORMAT_VERSION: + raise RuntimeError( + f"Unsupported dataset format version {self.dataset_format_version}; " + f"expected <= {DATASET_FORMAT_VERSION}" + ) + self.physical_shards = ( + self.dataset_format_version >= SHARDED_DATASET_FORMAT_VERSION + ) + self.physical_num_shards, self.physical_shard_dims = ( + self._load_physical_sharding() + ) + self.physical_total_shards = ( + total_shards(self.physical_num_shards) if self.physical_shards else 1 + ) + self.shard_id = self._select_physical_shard_id() + self.shard_suffix = shard_file_suffix(self.shard_id) - self.ids = [ - splitext(file)[0] - for file in listdir(images_dir) - if isfile(join(images_dir, file)) and not file.startswith(".") - ] + self.ids = self._list_ids(images_dir) if not self.ids: raise RuntimeError( f"No input file found in {images_dir}, make sure you put your images there" @@ -136,6 +160,12 @@ def __init__( self.mask_values = data["mask_values"] customlog(f"Unique mask values: {self.mask_values}") customlog(f"Dataset format version: {self.dataset_format_version}") + if self.physical_shards: + customlog( + f"Loading physical shard files with suffix {self.shard_suffix}; " + f"dc_num_shards={self.physical_num_shards}, " + f"dc_shard_dims={self.physical_shard_dims}" + ) def __len__(self): return len(self.ids) @@ -144,21 +174,104 @@ def __len__(self): def _load_numpy_array(path, mmap_mode=None): return np.load(path, allow_pickle=False, mmap_mode=mmap_mode) - def _load_dataset_format_version(self): + def _list_ids(self, images_dir): + if not self.physical_shards: + return sorted( + [ + splitext(file)[0] + for file in listdir(images_dir) + if isfile(join(images_dir, file)) and not file.startswith(".") + ] + ) + + pattern = re.compile(rf"^(?P.+){re.escape(self.shard_suffix)}\.npy$") + ids = [] + for file in listdir(images_dir): + if file.startswith(".") or not isfile(join(images_dir, file)): + continue + match = pattern.match(file) + if match is not None: + ids.append(match.group("id")) + return sorted(ids) + + def _load_dataset_metadata(self): meta_path = self.dataset_root / META_FILENAME if not meta_path.exists(): - return LEGACY_DATASET_FORMAT_VERSION + return {"dataset_format_version": LEGACY_DATASET_FORMAT_VERSION} try: with open(meta_path, "r") as meta_file: - meta = yaml.safe_load(meta_file) or {} + return yaml.safe_load(meta_file) or {} except Exception as exc: customlog( f"Failed to read dataset metadata from {meta_path}: {exc}. Falling back to legacy loader." ) - return LEGACY_DATASET_FORMAT_VERSION + return {"dataset_format_version": LEGACY_DATASET_FORMAT_VERSION} + + def _load_physical_sharding(self): + if not self.physical_shards: + return (), () + + config_subset = self.dataset_meta.get("config_subset") or {} + num_shards = config_subset.get("dc_num_shards") + shard_dims = config_subset.get("dc_shard_dims") + if num_shards is None or shard_dims is None: + raise RuntimeError( + "Physical dataset is missing shard metadata. Expected " + "config_subset.dc_num_shards/config_subset.dc_shard_dims in meta.yaml." + ) - return int(meta.get("dataset_format_version", LEGACY_DATASET_FORMAT_VERSION)) + return normalize_sharding(num_shards, shard_dims) + + @staticmethod + def _layout_by_dim(num_shards, shard_dims): + return {int(dim): int(num) for num, dim in zip(num_shards, shard_dims)} + + def _physical_layout_matches_spatial_spec(self): + if self.spatial_shard_spec is None: + return False + return self._layout_by_dim( + self.physical_num_shards, self.physical_shard_dims + ) == self._layout_by_dim( + self.spatial_shard_spec.num_shards, + self.spatial_shard_spec.shard_dims, + ) + + def _physical_shard_id_for_spatial_spec(self): + spec_indices_by_dim = { + int(dim): int(index) + for dim, index in zip( + self.spatial_shard_spec.shard_dims, + self.spatial_shard_spec.shard_indices, + ) + } + shard_indices = tuple( + spec_indices_by_dim[int(dim)] for dim in self.physical_shard_dims + ) + return shard_indices_to_id(shard_indices, self.physical_num_shards) + + def _select_physical_shard_id(self): + if not self.physical_shards: + return 0 + if self.spatial_shard_spec is None: + if self.physical_total_shards == 1: + return 0 + raise RuntimeError( + "Physical dataset has multiple shard files, but no SpatialShardSpec " + "was provided. Use a DistConv layout matching the v3 dataset." + ) + if not self._physical_layout_matches_spatial_spec(): + raise RuntimeError( + "V3 physical dataset shard layout does not match the requested " + "DistConv layout. V3 requires physical dataset layout and " + "DistConv layout to match. " + f"dataset dc_num_shards={self.physical_num_shards}, " + f"dataset dc_shard_dims={self.physical_shard_dims}, " + f"dc_num_shards={self.spatial_shard_spec.num_shards}, " + f"dc_shard_dims={self.spatial_shard_spec.shard_dims}" + ) + + return self._physical_shard_id_for_spatial_spec() @staticmethod def _prepare_legacy_image(img): @@ -188,8 +301,10 @@ def _prepare_optimized_mask(mask): def _slice_image_array(self, img): if self.spatial_shard_spec is None: return img + if self.physical_shards: + return img - if self.dataset_format_version >= DATASET_FORMAT_VERSION: + if self.dataset_format_version >= FULL_VOLUME_DATASET_FORMAT_VERSION: axis_map = {2: 1, 3: 2, 4: 3} else: axis_map = {2: 0, 3: 1, 4: 2} @@ -198,12 +313,26 @@ def _slice_image_array(self, img): def _slice_mask_array(self, mask): if self.spatial_shard_spec is None: return mask + if self.physical_shards: + return mask axis_map = {2: 0, 3: 1, 4: 2} return self.spatial_shard_spec.slice_array(mask, axis_map, "mask") - def __getitem__(self, idx): - name = self.ids[idx] + def _resolve_sample_files(self, name): + if self.physical_shards: + img_file = self.images_dir / f"{name}{self.shard_suffix}.npy" + mask_file = ( + self.mask_dir / f"{name}{self.shard_suffix}{self.mask_suffix}.npy" + ) + assert img_file.is_file(), ( + f"No image found for ID {name}, shard {self.shard_id}: {img_file}" + ) + assert mask_file.is_file(), ( + f"No mask found for ID {name}, shard {self.shard_id}: {mask_file}" + ) + return img_file, mask_file + mask_file = list(self.mask_dir.glob(name + self.mask_suffix + ".*")) img_file = list(self.images_dir.glob(name + ".*")) @@ -213,15 +342,26 @@ def __getitem__(self, idx): assert len(mask_file) == 1, ( f"Either no mask or multiple masks found for the ID {name}: {mask_file}" ) - mmap_mode = "r" if self.spatial_shard_spec is not None else None + return img_file[0], mask_file[0] + + def __getitem__(self, idx): + name = self.ids[idx] + img_file, mask_file = self._resolve_sample_files(name) + + mmap_mode = ( + "r" + if self.spatial_shard_spec is not None + and not self.physical_shards + else None + ) # Memmap lets each rank slice out just its local shard without eagerly # reading the full sample into process memory first. - mask = self._load_numpy_array(mask_file[0], mmap_mode=mmap_mode) - img = self._load_numpy_array(img_file[0], mmap_mode=mmap_mode) + mask = self._load_numpy_array(mask_file, mmap_mode=mmap_mode) + img = self._load_numpy_array(img_file, mmap_mode=mmap_mode) mask = self._slice_mask_array(mask) img = self._slice_image_array(img) - if self.dataset_format_version >= DATASET_FORMAT_VERSION: + if self.dataset_format_version >= FULL_VOLUME_DATASET_FORMAT_VERSION: img = self._prepare_optimized_image(img) mask = self._prepare_optimized_mask(mask) else: diff --git a/ScaFFold/utils/spatial_sharding.py b/ScaFFold/utils/spatial_sharding.py new file mode 100644 index 0000000..a32695b --- /dev/null +++ b/ScaFFold/utils/spatial_sharding.py @@ -0,0 +1,127 @@ +# Copyright (c) 2014-2026, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LBANN/ScaFFold. +# +# SPDX-License-Identifier: (Apache-2.0) + +from math import prod +from typing import Iterable, Tuple + + +def normalize_sharding(num_shards: Iterable[int], shard_dims: Iterable[int]): + """Validate and normalize spatial sharding config.""" + + num_shards = tuple(int(x) for x in num_shards) + shard_dims = tuple(int(x) for x in shard_dims) + + if len(num_shards) != len(shard_dims): + raise ValueError( + f"num_shards {num_shards} must have same length as shard_dims {shard_dims}" + ) + if len(set(shard_dims)) != len(shard_dims): + raise ValueError(f"Shard dimensions must be unique: {shard_dims}") + + for num_shards_i, shard_dim_i in zip(num_shards, shard_dims): + if num_shards_i < 1: + raise ValueError(f"Invalid num_shards value {num_shards_i}") + if shard_dim_i not in (2, 3, 4): + raise ValueError( + f"Invalid shard_dim {shard_dim_i}: only 3D spatial dimensions 2, 3, and 4 are supported" + ) + + return num_shards, shard_dims + + +def total_shards(num_shards: Iterable[int]) -> int: + return prod(tuple(int(x) for x in num_shards)) + + +def shard_id_to_indices(shard_id: int, num_shards: Iterable[int]) -> Tuple[int, ...]: + """Convert row-major linear shard id to multi-dimensional shard indices.""" + + num_shards = tuple(int(x) for x in num_shards) + total = total_shards(num_shards) + if shard_id < 0 or shard_id >= total: + raise ValueError(f"shard_id {shard_id} out of range for num_shards={num_shards}") + + indices = [] + linear_idx = int(shard_id) + stride = total + for num_shards_i in num_shards: + stride //= num_shards_i + indices.append(linear_idx // stride) + linear_idx %= stride + return tuple(indices) + + +def shard_indices_to_id( + shard_indices: Iterable[int], num_shards: Iterable[int] +) -> int: + """Convert multi-dimensional shard indices to row-major linear shard id.""" + + shard_indices = tuple(int(x) for x in shard_indices) + num_shards = tuple(int(x) for x in num_shards) + if len(shard_indices) != len(num_shards): + raise ValueError( + f"shard_indices {shard_indices} must match num_shards {num_shards}" + ) + + shard_id = 0 + stride = 1 + for shard_index_i, num_shards_i in zip(reversed(shard_indices), reversed(num_shards)): + if shard_index_i < 0 or shard_index_i >= num_shards_i: + raise ValueError( + f"Invalid shard index {shard_index_i} for num_shards={num_shards}" + ) + shard_id += shard_index_i * stride + stride *= num_shards_i + return shard_id + + +def chunk_slice(size: int, num_shards: int, shard_index: int) -> slice: + """Match torch.chunk-style uneven shard boundaries.""" + + chunk_size = (size + num_shards - 1) // num_shards + start = shard_index * chunk_size + if start >= size: + raise ValueError( + f"Empty local shard: dim size {size}, num_shards {num_shards}, shard_index {shard_index}" + ) + stop = min(size, start + chunk_size) + return slice(start, stop) + + +def spatial_slices( + spatial_shape: Iterable[int], + shard_dims: Iterable[int], + num_shards: Iterable[int], + shard_indices: Iterable[int], +) -> Tuple[slice, slice, slice]: + """Return local D/H/W slices for DistConv spatial dims 2/3/4.""" + + spatial_shape = tuple(int(x) for x in spatial_shape) + if len(spatial_shape) != 3: + raise ValueError(f"Expected 3D spatial shape, got {spatial_shape}") + + slices = [slice(0, size) for size in spatial_shape] + for shard_dim, num_shards_i, shard_index_i in zip( + shard_dims, num_shards, shard_indices + ): + spatial_axis = int(shard_dim) - 2 + slices[spatial_axis] = chunk_slice( + spatial_shape[spatial_axis], int(num_shards_i), int(shard_index_i) + ) + + return tuple(slices) + + +def shard_file_suffix(shard_id: int) -> str: + return f"_shard{int(shard_id):06d}" From b7dc67d2d69b80e6f57dfb866edb41c0b6c725a6 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 27 May 2026 14:56:32 -0700 Subject: [PATCH 04/10] sharded volgen WIP --- ScaFFold/datagen/volumegen.py | 108 ++++++++++++++++++++++----------- ScaFFold/utils/config_utils.py | 11 ---- 2 files changed, 71 insertions(+), 48 deletions(-) diff --git a/ScaFFold/datagen/volumegen.py b/ScaFFold/datagen/volumegen.py index ffeec44..2b1f24d 100644 --- a/ScaFFold/datagen/volumegen.py +++ b/ScaFFold/datagen/volumegen.py @@ -179,20 +179,24 @@ def generate_volume_shard( different shard layouts. """ - n_fracts_per_vol = config.n_fracts_per_vol - num_shards, shard_dims = _physical_sharding(config) - shard_indices = shard_id_to_indices(shard_id, num_shards) - slices = spatial_slices( - (config.vol_size, config.vol_size, config.vol_size), - shard_dims, - num_shards, - shard_indices, + voxelized_fractals = _voxelized_fractals_for_volume( + config, + curr_vol, + fractal_colors, + point_cloud_loader=point_cloud_loader, ) - local_shape = _local_shape(slices) + return _render_volume_shard(config, voxelized_fractals, shard_id) - volume = np.full((3, *local_shape), 0, dtype=VOLUME_DTYPE) - mask = np.full(local_shape, 0, dtype=MASK_DTYPE) + +def _voxelized_fractals_for_volume( + config, + curr_vol: np.ndarray, + fractal_colors: np.ndarray, + point_cloud_loader: Callable[[str], np.ndarray] = load_np_ptcloud, +): + n_fracts_per_vol = config.n_fracts_per_vol grid_size = math.floor(config.vol_size * config.scale) + voxelized_fractals = [] for curr_fract in range(n_fracts_per_vol): curr_category = int(curr_vol[1 + 2 * curr_fract]) @@ -209,6 +213,26 @@ def generate_volume_shard( points = point_cloud_loader(point_cloud_path) idx = points_to_voxel_indices(points, grid_size) + voxelized_fractals.append((curr_category, fractal_color, idx)) + + return voxelized_fractals + + +def _render_volume_shard(config, voxelized_fractals, shard_id: int): + num_shards, shard_dims = _physical_sharding(config) + shard_indices = shard_id_to_indices(shard_id, num_shards) + slices = spatial_slices( + (config.vol_size, config.vol_size, config.vol_size), + shard_dims, + num_shards, + shard_indices, + ) + local_shape = _local_shape(slices) + + volume = np.full((3, *local_shape), 0, dtype=VOLUME_DTYPE) + mask = np.full(local_shape, 0, dtype=MASK_DTYPE) + + for curr_category, fractal_color, idx in voxelized_fractals: keep = np.ones(idx.shape[0], dtype=bool) for axis, axis_slice in enumerate(slices): keep &= idx[:, axis] >= axis_slice.start @@ -283,7 +307,7 @@ def main(config: Dict): val_indices = _validation_indices(num_volumes, config) # Work distribution - total_tasks = num_volumes * n_total_shards + total_tasks = num_volumes stride = ceil(total_tasks / size) start_idx = rank * stride end_idx = min(((rank + 1) * stride), total_tasks) @@ -291,60 +315,70 @@ def main(config: Dict): generation_err = "" try: if start_idx >= end_idx: - logging.info(f"Rank {rank} given no volume shards to generate") + logging.info(f"Rank {rank} given no logical volumes to generate") else: - task_ids = range(start_idx, end_idx) + volume_ids = range(start_idx, end_idx) print( - f"rank {rank} responsible for volume-shard tasks {start_idx} through {end_idx - 1}" + f"rank {rank} responsible for logical volume tasks " + f"{start_idx} through {end_idx - 1}" ) fractal_colors = _fractal_colors(config, n_fracts_per_vol) # Generation loop start_time = time.time() - for i, task_id in enumerate(task_ids): + n_generated_shards = 0 + for i, volume_idx in enumerate(volume_ids): if i % 10 == 0: - logging.info( - f"Rank {rank} processing local volume-shard task {i}..." - ) + logging.info(f"Rank {rank} processing local volume task {i}...") - volume_idx = task_id // n_total_shards - shard_id = task_id % n_total_shards curr_vol = volumes_contents[volume_idx] global_vol_idx = curr_vol[0] vol_seed = config.seed + int(global_vol_idx) random.seed(vol_seed) np.random.seed(vol_seed) - volume_to_save, mask_to_save = generate_volume_shard( + voxelized_fractals = _voxelized_fractals_for_volume( config, curr_vol, - shard_id, fractal_colors, ) - # Determine destination folder - subdir = "validation" if global_vol_idx in val_indices else "training" - shard_suffix = shard_file_suffix(shard_id) + for shard_id in range(n_total_shards): + volume_to_save, mask_to_save = _render_volume_shard( + config, + voxelized_fractals, + shard_id, + ) - vol_file = os.path.join( - vol_path, subdir, f"{global_vol_idx}{shard_suffix}.npy" - ) - with open(vol_file, "wb") as f: - np.save(f, volume_to_save) + # Determine destination folder + subdir = ( + "validation" if global_vol_idx in val_indices else "training" + ) + shard_suffix = shard_file_suffix(shard_id) - mask_file = os.path.join( - mask_path, subdir, f"{global_vol_idx}{shard_suffix}_mask.npy" - ) - with open(mask_file, "wb") as f: - np.save(f, mask_to_save) + vol_file = os.path.join( + vol_path, subdir, f"{global_vol_idx}{shard_suffix}.npy" + ) + with open(vol_file, "wb") as f: + np.save(f, volume_to_save) + + mask_file = os.path.join( + mask_path, subdir, f"{global_vol_idx}{shard_suffix}_mask.npy" + ) + with open(mask_file, "wb") as f: + np.save(f, mask_to_save) + n_generated_shards += 1 end_time = time.time() total_time = end_time - start_time if rank == 0: + shard_rate = n_generated_shards / total_time print( - f"Rank 0 generated {end_idx - start_idx} volume shards in {total_time:.2f} seconds | {(end_idx - start_idx) / total_time:.2f} shards per second" + f"Rank 0 generated {n_generated_shards} volume shards " + f"from {end_idx - start_idx} logical volumes in " + f"{total_time:.2f} seconds | {shard_rate:.2f} shards per second" ) except Exception as e: generation_err = ( diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 0d03358..36f1603 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -85,17 +85,6 @@ def __init__(self, config_dict): self.dc_num_shards = config_dict["dc_num_shards"] self.dc_shard_dims = config_dict["dc_shard_dims"] self.dc_total_shards = math.prod(self.dc_num_shards) - unsupported_dataset_keys = [ - key - for key in ("dataset_num_shards", "dataset_shard_dims") - if key in config_dict - ] - if unsupported_dataset_keys: - raise ValueError( - "Configuration Mismatch: dataset_num_shards/dataset_shard_dims " - "are not supported. Use dc_num_shards/dc_shard_dims for the " - "v3 physical dataset layout." - ) # Safety Check: Length mismatch if len(self.dc_num_shards) != len(self.dc_shard_dims): raise ValueError( From d0eda7934f35b89bb660fa597423ac1e29ba56e8 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 27 May 2026 21:35:54 -0700 Subject: [PATCH 05/10] normalize shard layout in dataset config hashing so equivalent shard layouts hash the same way --- ScaFFold/datagen/get_dataset.py | 39 ++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/ScaFFold/datagen/get_dataset.py b/ScaFFold/datagen/get_dataset.py index 511cdf0..d51d662 100644 --- a/ScaFFold/datagen/get_dataset.py +++ b/ScaFFold/datagen/get_dataset.py @@ -76,22 +76,50 @@ def _get_required_keys_dict( return canonicalize(required) +def _canonicalize_v3_shard_layout(volume_config: Dict[str, Any]) -> Dict[str, Any]: + """Normalize shard layout ordering so equivalent v3 layouts share cache IDs.""" + + canonical_config = volume_config.copy() + num_shards = canonical_config["dc_num_shards"] + shard_dims = canonical_config["dc_shard_dims"] + if len(num_shards) != len(shard_dims): + raise ValueError( + f"dc_num_shards {num_shards} must have same length as dc_shard_dims {shard_dims}" + ) + + shard_layout = sorted( + (int(shard_dim), int(num_shard)) + for num_shard, shard_dim in zip(num_shards, shard_dims) + ) + canonical_config["dc_shard_dims"] = [shard_dim for shard_dim, _ in shard_layout] + canonical_config["dc_num_shards"] = [num_shard for _, num_shard in shard_layout] + return canonical_config + + def _hash_volume_config(volume_config: Dict[str, Any]) -> str: s = json.dumps(volume_config, separators=(",", ":"), sort_keys=True).encode() return hashlib.sha256(s).hexdigest()[:12] -def _volume_config_for_version(config_dict, dataset_format_version): +def _volume_config_for_version( + config_dict, dataset_format_version, canonicalize_v3_shard_layout=True +): versioned_config = config_dict.copy() versioned_config["dataset_format_version"] = dataset_format_version if dataset_format_version == DATASET_FORMAT_VERSION: include_keys = INCLUDE_KEYS else: include_keys = V2_INCLUDE_KEYS - return _get_required_keys_dict( + volume_config = _get_required_keys_dict( config=versioned_config, include_keys=include_keys, ) + if ( + dataset_format_version == DATASET_FORMAT_VERSION + and canonicalize_v3_shard_layout + ): + volume_config = _canonicalize_v3_shard_layout(volume_config) + return volume_config def _requested_unsharded_layout(config_dict: Dict[str, Any]) -> bool: @@ -177,6 +205,11 @@ def get_dataset( # defined by dc_num_shards/dc_shard_dims, matching the DistConv layout. config_dict = vars(config).copy() volume_config = _volume_config_for_version(config_dict, DATASET_FORMAT_VERSION) + metadata_volume_config = _volume_config_for_version( + config_dict, + DATASET_FORMAT_VERSION, + canonicalize_v3_shard_layout=False, + ) config_id = _hash_volume_config(volume_config) v2_volume_config = _volume_config_for_version(config_dict, V2_DATASET_FORMAT_VERSION) v2_config_id = _hash_volume_config(v2_volume_config) @@ -265,7 +298,7 @@ def get_dataset( meta = { "config_id": config_id, "dataset_format_version": DATASET_FORMAT_VERSION, - "config_subset": volume_config, + "config_subset": metadata_volume_config, "include_keys": INCLUDE_KEYS, "code_commit": commit, "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), From 9244c767c94989e69ea43fe793eea80ca884bb0e Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 28 May 2026 08:05:27 -0700 Subject: [PATCH 06/10] small tweaks to shrink diff --- ScaFFold/datagen/get_dataset.py | 5 +- ScaFFold/datagen/volumegen.py | 118 ++++++++++++-------------------- ScaFFold/utils/data_loading.py | 12 +--- 3 files changed, 46 insertions(+), 89 deletions(-) diff --git a/ScaFFold/datagen/get_dataset.py b/ScaFFold/datagen/get_dataset.py index d51d662..5af28e8 100644 --- a/ScaFFold/datagen/get_dataset.py +++ b/ScaFFold/datagen/get_dataset.py @@ -215,9 +215,6 @@ def get_dataset( v2_config_id = _hash_volume_config(v2_volume_config) commit = _git_commit_short() - base = root / config_id - base.mkdir(parents=True, exist_ok=True) - # Prefer a matching V3 physical-shard dataset. dataset_path = _find_reusable_dataset( root, @@ -250,8 +247,10 @@ def get_dataset( return dataset_path # Otherwise, generate a new dataset + base = root / config_id print(f"No valid existing dataset found at {base}. Generating new dataset...") if rank == 0: + base.mkdir(parents=True, exist_ok=True) ts = time.strftime("%Y%m%d-%H%M%S") dest = base / f"{ts}__{commit}" tmp = base / f".tmp_{ts}" diff --git a/ScaFFold/datagen/volumegen.py b/ScaFFold/datagen/volumegen.py index 2b1f24d..76fccc4 100644 --- a/ScaFFold/datagen/volumegen.py +++ b/ScaFFold/datagen/volumegen.py @@ -85,54 +85,6 @@ def points_to_voxel_indices( return idx -def _build_volumes_contents(config, n_fracts_per_vol): - # Force n_instances_used_per_fractal to be multiple of n_fracts_per_vol - if config.n_instances_used_per_fractal % n_fracts_per_vol != 0: - print( - f"volumegen.py: WARNING: n_instances_used_per_fractal ({config.n_instances_used_per_fractal}) \n" - f"NOT multiple of n_fracts_per_vol={n_fracts_per_vol}. Rounding down." - ) - config.n_instances_used_per_fractal = ( - config.n_instances_used_per_fractal - // n_fracts_per_vol - * n_fracts_per_vol - ) - - # Randomly select n_instances_used_per_fractal instances from each fractal class. - instances_list = [] - for category in range(config.n_categories): - instances_remaining = config.n_instances_used_per_fractal - random_instances = [] - while instances_remaining > 0: - random_instances.extend( - random.sample(range(145), min(145, instances_remaining)) - ) - instances_remaining -= min(145, instances_remaining) - - category_instance_pairs = [[category, instance] for instance in random_instances] - instances_list.extend(category_instance_pairs) - - instances_list = np.array(instances_list, dtype=int) - np.random.shuffle(instances_list) - - volumes_contents = instances_list.reshape(-1, 2 * n_fracts_per_vol) - - indices = np.arange(volumes_contents.shape[0]).reshape(-1, 1) - return np.hstack([indices, volumes_contents]) - - -def _validation_indices(num_volumes: int, config) -> set[int]: - random.seed(config.seed) - return set( - random.sample(range(num_volumes), int(num_volumes * config.val_split / 100)) - ) - - -def _fractal_colors(config, n_fracts_per_vol): - np.random.seed(config.seed) - return np.random.rand(max(config.n_categories, n_fracts_per_vol), 3) - - def _point_cloud_path(config, curr_category: int, curr_instance: int) -> str: instances_dir = f"var{config.variance_threshold}/instances/np{config.point_num}" return os.path.join( @@ -164,30 +116,6 @@ def _validate_generation_config(config): return num_shards, shard_dims, n_total_shards, grid_size -def generate_volume_shard( - config, - curr_vol: np.ndarray, - shard_id: int, - fractal_colors: np.ndarray, - point_cloud_loader: Callable[[str], np.ndarray] = load_np_ptcloud, -): - """ - Generate one physical shard for one logical volume. - - Voxel indices are computed in the full-volume coordinate system first, then - filtered to the shard. This preserves bitwise reconstruction across - different shard layouts. - """ - - voxelized_fractals = _voxelized_fractals_for_volume( - config, - curr_vol, - fractal_colors, - point_cloud_loader=point_cloud_loader, - ) - return _render_volume_shard(config, voxelized_fractals, shard_id) - - def _voxelized_fractals_for_volume( config, curr_vol: np.ndarray, @@ -287,7 +215,41 @@ def main(config: Dict): os.makedirs(os.path.join(vol_path, subdir), exist_ok=True) os.makedirs(os.path.join(mask_path, subdir), exist_ok=True) - volumes_contents = _build_volumes_contents(config, n_fracts_per_vol) + # Force n_instances_used_per_fractal to be multiple of n_fracts_per_vol + if config.n_instances_used_per_fractal % n_fracts_per_vol != 0: + print( + f"volumegen.py: WARNING: n_instances_used_per_fractal ({config.n_instances_used_per_fractal}) \n" + f"NOT multiple of n_fracts_per_vol={n_fracts_per_vol}. Rounding down." + ) + config.n_instances_used_per_fractal = ( + config.n_instances_used_per_fractal + // n_fracts_per_vol + * n_fracts_per_vol + ) + + # Randomly select n_instances_used_per_fractal instances from each fractal class. + instances_list = [] + for category in range(config.n_categories): + instances_remaining = config.n_instances_used_per_fractal + random_instances = [] + while instances_remaining > 0: + random_instances.extend( + random.sample(range(145), min(145, instances_remaining)) + ) + instances_remaining -= min(145, instances_remaining) + + category_instance_pairs = [ + [category, instance] for instance in random_instances + ] + instances_list.extend(category_instance_pairs) + + instances_list = np.array(instances_list, dtype=int) + np.random.shuffle(instances_list) + + volumes_contents = instances_list.reshape(-1, 2 * n_fracts_per_vol) + + indices = np.arange(volumes_contents.shape[0]).reshape(-1, 1) + volumes_contents = np.hstack([indices, volumes_contents]) with open(volumes_contents_path, "wb") as f: np.savetxt(f, volumes_contents.astype(int), fmt="%i", delimiter=",") @@ -304,7 +266,10 @@ def main(config: Dict): # Determine train/val split globally so all ranks know where to save num_volumes = len(volumes_contents) - val_indices = _validation_indices(num_volumes, config) + random.seed(config.seed) + val_indices = set( + random.sample(range(num_volumes), int(num_volumes * config.val_split / 100)) + ) # Work distribution total_tasks = num_volumes @@ -324,7 +289,10 @@ def main(config: Dict): f"{start_idx} through {end_idx - 1}" ) - fractal_colors = _fractal_colors(config, n_fracts_per_vol) + np.random.seed(config.seed) + fractal_colors = np.random.rand( + max(config.n_categories, n_fracts_per_vol), 3 + ) # Generation loop start_time = time.time() diff --git a/ScaFFold/utils/data_loading.py b/ScaFFold/utils/data_loading.py index c655c2b..faa40e7 100644 --- a/ScaFFold/utils/data_loading.py +++ b/ScaFFold/utils/data_loading.py @@ -75,16 +75,6 @@ def __post_init__(self): f"Invalid shard_index {shard_index} for shard_dim {shard_dim} with {num_shards} shards" ) - @staticmethod - def _chunk_slice(size: int, num_shards: int, shard_index: int) -> slice: - """Match torch.chunk-style uneven shard boundaries.""" - - return chunk_slice(size, num_shards, shard_index) - - @property - def shard_id(self) -> int: - return shard_indices_to_id(self.shard_indices, self.num_shards) - def slice_array( self, array: np.ndarray, axis_map: Dict[int, int], array_label: str ) -> np.ndarray: @@ -104,7 +94,7 @@ def slice_array( raise ValueError( f"Axis {axis} out of range for {array_label} with shape {array.shape}" ) - slices[axis] = self._chunk_slice(array.shape[axis], num_shards, shard_index) + slices[axis] = chunk_slice(array.shape[axis], num_shards, shard_index) return array[tuple(slices)] From f852415d0965ff7c771548c00f10f8825621fc97 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 28 May 2026 08:05:54 -0700 Subject: [PATCH 07/10] ruff --- ScaFFold/datagen/get_dataset.py | 4 +++- ScaFFold/utils/data_loading.py | 3 +-- ScaFFold/utils/spatial_sharding.py | 12 +++++++----- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/ScaFFold/datagen/get_dataset.py b/ScaFFold/datagen/get_dataset.py index 5af28e8..ab757c1 100644 --- a/ScaFFold/datagen/get_dataset.py +++ b/ScaFFold/datagen/get_dataset.py @@ -211,7 +211,9 @@ def get_dataset( canonicalize_v3_shard_layout=False, ) config_id = _hash_volume_config(volume_config) - v2_volume_config = _volume_config_for_version(config_dict, V2_DATASET_FORMAT_VERSION) + v2_volume_config = _volume_config_for_version( + config_dict, V2_DATASET_FORMAT_VERSION + ) v2_config_id = _hash_volume_config(v2_volume_config) commit = _git_commit_short() diff --git a/ScaFFold/utils/data_loading.py b/ScaFFold/utils/data_loading.py index faa40e7..c4f3f18 100644 --- a/ScaFFold/utils/data_loading.py +++ b/ScaFFold/utils/data_loading.py @@ -340,8 +340,7 @@ def __getitem__(self, idx): mmap_mode = ( "r" - if self.spatial_shard_spec is not None - and not self.physical_shards + if self.spatial_shard_spec is not None and not self.physical_shards else None ) # Memmap lets each rank slice out just its local shard without eagerly diff --git a/ScaFFold/utils/spatial_sharding.py b/ScaFFold/utils/spatial_sharding.py index a32695b..fac1e06 100644 --- a/ScaFFold/utils/spatial_sharding.py +++ b/ScaFFold/utils/spatial_sharding.py @@ -50,7 +50,9 @@ def shard_id_to_indices(shard_id: int, num_shards: Iterable[int]) -> Tuple[int, num_shards = tuple(int(x) for x in num_shards) total = total_shards(num_shards) if shard_id < 0 or shard_id >= total: - raise ValueError(f"shard_id {shard_id} out of range for num_shards={num_shards}") + raise ValueError( + f"shard_id {shard_id} out of range for num_shards={num_shards}" + ) indices = [] linear_idx = int(shard_id) @@ -62,9 +64,7 @@ def shard_id_to_indices(shard_id: int, num_shards: Iterable[int]) -> Tuple[int, return tuple(indices) -def shard_indices_to_id( - shard_indices: Iterable[int], num_shards: Iterable[int] -) -> int: +def shard_indices_to_id(shard_indices: Iterable[int], num_shards: Iterable[int]) -> int: """Convert multi-dimensional shard indices to row-major linear shard id.""" shard_indices = tuple(int(x) for x in shard_indices) @@ -76,7 +76,9 @@ def shard_indices_to_id( shard_id = 0 stride = 1 - for shard_index_i, num_shards_i in zip(reversed(shard_indices), reversed(num_shards)): + for shard_index_i, num_shards_i in zip( + reversed(shard_indices), reversed(num_shards) + ): if shard_index_i < 0 or shard_index_i >= num_shards_i: raise ValueError( f"Invalid shard index {shard_index_i} for num_shards={num_shards}" From d2099195ae596ef4ed3d19958681dea18ba60dde Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Thu, 28 May 2026 08:10:50 -0700 Subject: [PATCH 08/10] add docstrings --- ScaFFold/datagen/get_dataset.py | 6 ++++++ ScaFFold/datagen/volumegen.py | 12 ++++++++++++ ScaFFold/utils/data_loading.py | 16 ++++++++++++++++ ScaFFold/utils/spatial_sharding.py | 4 ++++ 4 files changed, 38 insertions(+) diff --git a/ScaFFold/datagen/get_dataset.py b/ScaFFold/datagen/get_dataset.py index ab757c1..3adc646 100644 --- a/ScaFFold/datagen/get_dataset.py +++ b/ScaFFold/datagen/get_dataset.py @@ -104,6 +104,8 @@ def _hash_volume_config(volume_config: Dict[str, Any]) -> str: def _volume_config_for_version( config_dict, dataset_format_version, canonicalize_v3_shard_layout=True ): + """Build the hashable dataset config subset for a format version.""" + versioned_config = config_dict.copy() versioned_config["dataset_format_version"] = dataset_format_version if dataset_format_version == DATASET_FORMAT_VERSION: @@ -123,6 +125,8 @@ def _volume_config_for_version( def _requested_unsharded_layout(config_dict: Dict[str, Any]) -> bool: + """Return whether the requested layout has exactly one physical shard.""" + total_shards = 1 for value in config_dict["dc_num_shards"]: total_shards *= int(value) @@ -158,6 +162,8 @@ def _find_reusable_dataset( commit: str, require_commit: bool, ) -> Optional[Path]: + """Find the newest reusable dataset matching format, config, and commit.""" + base = root / config_id if not base.exists(): return None diff --git a/ScaFFold/datagen/volumegen.py b/ScaFFold/datagen/volumegen.py index 76fccc4..f7d746a 100644 --- a/ScaFFold/datagen/volumegen.py +++ b/ScaFFold/datagen/volumegen.py @@ -86,6 +86,8 @@ def points_to_voxel_indices( def _point_cloud_path(config, curr_category: int, curr_instance: int) -> str: + """Return the input point-cloud path for a fractal instance.""" + instances_dir = f"var{config.variance_threshold}/instances/np{config.point_num}" return os.path.join( str(config.fract_base_dir), @@ -96,14 +98,20 @@ def _point_cloud_path(config, curr_category: int, curr_instance: int) -> str: def _local_shape(slices): + """Return the local spatial shape described by shard slices.""" + return tuple(s.stop - s.start for s in slices) def _physical_sharding(config): + """Return normalized physical sharding from the generation config.""" + return normalize_sharding(config.dc_num_shards, config.dc_shard_dims) def _validate_generation_config(config): + """Validate sharded generation settings and return normalized layout data.""" + num_shards, shard_dims = _physical_sharding(config) n_total_shards = total_shards(num_shards) @@ -122,6 +130,8 @@ def _voxelized_fractals_for_volume( fractal_colors: np.ndarray, point_cloud_loader: Callable[[str], np.ndarray] = load_np_ptcloud, ): + """Load and voxelize all fractals needed for one logical volume.""" + n_fracts_per_vol = config.n_fracts_per_vol grid_size = math.floor(config.vol_size * config.scale) voxelized_fractals = [] @@ -147,6 +157,8 @@ def _voxelized_fractals_for_volume( def _render_volume_shard(config, voxelized_fractals, shard_id: int): + """Render one physical shard from precomputed global voxel indices.""" + num_shards, shard_dims = _physical_sharding(config) shard_indices = shard_id_to_indices(shard_id, num_shards) slices = spatial_slices( diff --git a/ScaFFold/utils/data_loading.py b/ScaFFold/utils/data_loading.py index c4f3f18..13d5d41 100644 --- a/ScaFFold/utils/data_loading.py +++ b/ScaFFold/utils/data_loading.py @@ -165,6 +165,8 @@ def _load_numpy_array(path, mmap_mode=None): return np.load(path, allow_pickle=False, mmap_mode=mmap_mode) def _list_ids(self, images_dir): + """List logical sample IDs visible to this dataset instance.""" + if not self.physical_shards: return sorted( [ @@ -185,6 +187,8 @@ def _list_ids(self, images_dir): return sorted(ids) def _load_dataset_metadata(self): + """Load dataset metadata, falling back to legacy defaults.""" + meta_path = self.dataset_root / META_FILENAME if not meta_path.exists(): return {"dataset_format_version": LEGACY_DATASET_FORMAT_VERSION} @@ -199,6 +203,8 @@ def _load_dataset_metadata(self): return {"dataset_format_version": LEGACY_DATASET_FORMAT_VERSION} def _load_physical_sharding(self): + """Load and normalize the physical shard layout from metadata.""" + if not self.physical_shards: return (), () @@ -215,9 +221,13 @@ def _load_physical_sharding(self): @staticmethod def _layout_by_dim(num_shards, shard_dims): + """Map each sharded dimension to its shard count.""" + return {int(dim): int(num) for num, dim in zip(num_shards, shard_dims)} def _physical_layout_matches_spatial_spec(self): + """Return whether dataset shards match the requested spatial layout.""" + if self.spatial_shard_spec is None: return False return self._layout_by_dim( @@ -228,6 +238,8 @@ def _physical_layout_matches_spatial_spec(self): ) def _physical_shard_id_for_spatial_spec(self): + """Return the physical shard id selected by the spatial shard spec.""" + spec_indices_by_dim = { int(dim): int(index) for dim, index in zip( @@ -241,6 +253,8 @@ def _physical_shard_id_for_spatial_spec(self): return shard_indices_to_id(shard_indices, self.physical_num_shards) def _select_physical_shard_id(self): + """Select the physical shard file this dataset instance should read.""" + if not self.physical_shards: return 0 if self.spatial_shard_spec is None: @@ -310,6 +324,8 @@ def _slice_mask_array(self, mask): return self.spatial_shard_spec.slice_array(mask, axis_map, "mask") def _resolve_sample_files(self, name): + """Resolve image and mask file paths for a logical sample ID.""" + if self.physical_shards: img_file = self.images_dir / f"{name}{self.shard_suffix}.npy" mask_file = ( diff --git a/ScaFFold/utils/spatial_sharding.py b/ScaFFold/utils/spatial_sharding.py index fac1e06..91664dc 100644 --- a/ScaFFold/utils/spatial_sharding.py +++ b/ScaFFold/utils/spatial_sharding.py @@ -41,6 +41,8 @@ def normalize_sharding(num_shards: Iterable[int], shard_dims: Iterable[int]): def total_shards(num_shards: Iterable[int]) -> int: + """Return the total number of shards in a multi-dimensional layout.""" + return prod(tuple(int(x) for x in num_shards)) @@ -126,4 +128,6 @@ def spatial_slices( def shard_file_suffix(shard_id: int) -> str: + """Return the filename suffix for a physical shard id.""" + return f"_shard{int(shard_id):06d}" From 1a1d33768b9878b635f41a4cb1a2f57d3ac51415 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 10 Jun 2026 14:19:34 -0700 Subject: [PATCH 09/10] update README --- README.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9438b1d..c9f3fd6 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,12 @@ The model is trained from a random initialization until convergence, which is de 1. Once fractal generation completes, run the benchmark: `torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c ScaFFold/configs/benchmark_default.yml` +### Dataset cache and sharded datagen + +`benchmark` creates or reuses datasets under `dataset_dir`. New datasets are written in the v3 format, which stores one volume and mask file per logical sample per physical shard. The physical layout is controlled by `dc_num_shards` and `dc_shard_dims`; for example, `dc_num_shards: [1, 1, 2]` writes two physical shards per logical volume, with filenames such as `120_shard000000.npy` and `120_shard000001.npy`. Datasets are generated with the same sharding configuration used for model training. + +Unsharded runs use `dc_num_shards: [1, 1, 1]`. For those runs, ScaFFold can still reuse an existing v2 full-volume dataset cache. Sharded runs require a matching v3 cache or generate a new v3 dataset. + `benchmark` creates a folder for the benchmark run(s) at `base_run_dir` set in the config file. For reproducibility, we store a copy of the benchmark run config yml. Within each run subfolder, `benchmark` creates a yml config for that specific run. After each run completes, statistics from the run are stored in `train_stats.csv`. Additionally, users can inspect plots of the training and validation losses over time in ` bottleneck layer of size 8. seed: 42 # Random seed. -batch_size: 1 # Batch sizes for each vol size. +batch_size: 1 # Batch size per rank. +dataloader_num_workers: 1 # Number of DataLoader worker processes per rank. optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp. +dc_num_shards: [1, 1, 1] # Physical data shards per sample for DistConv. +dc_shard_dims: [2, 3, 4] # Tensor dimensions used for physical sharding. # Internal/dev use only variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15. @@ -97,6 +107,7 @@ framework: "torch" # The DL framework to train with. Only valid checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. checkpoint_interval: 1 # Number of epochs between saving training checkpoints. loss_freq: 1 # Number of epochs between logging the overall loss. +warmup_batches: 64 # Training and validation warmup batches per DDP rank. ``` ## How the benchmark works @@ -194,6 +205,8 @@ For n  in n_volumes: 3. Save volume and mask  to files ``` +In the current v3 dataset format, this save step writes each logical sample as one or more physical shard files, matching the requested `dc_num_shards` layout. The dataloader then reads only the shard file needed by the current DistConv rank instead of loading a full volume and slicing it locally. + ### 1. Profiling with the PyTorch Profiler Set `PROFILE_TORCH=ON` to generate a PyTorch profiling trace that can be read into [Perfetto](https://ui.perfetto.dev/). From a6697cf25e3d491cc08a7ed204e3fc549c89c4b8 Mon Sep 17 00:00:00 2001 From: Patrick Miles Date: Wed, 10 Jun 2026 14:22:11 -0700 Subject: [PATCH 10/10] volumegen: every rank ensures needed dirs exist --- ScaFFold/datagen/volumegen.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ScaFFold/datagen/volumegen.py b/ScaFFold/datagen/volumegen.py index f7d746a..570f73c 100644 --- a/ScaFFold/datagen/volumegen.py +++ b/ScaFFold/datagen/volumegen.py @@ -276,6 +276,16 @@ def main(config: Dict): if setup_err: raise RuntimeError(setup_err) + # Rank 0 creates shared metadata above; wait before local writer setup. + comm.Barrier() + + for subdir in ["training", "validation"]: + os.makedirs(os.path.join(vol_path, subdir), exist_ok=True) + os.makedirs(os.path.join(mask_path, subdir), exist_ok=True) + + # Wait until every rank has ensured the writer directories exist. + comm.Barrier() + # Determine train/val split globally so all ranks know where to save num_volumes = len(volumes_contents) random.seed(config.seed)