Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ The model is trained from a random initialization until convergence, which is de
1. Once fractal generation completes, run the benchmark:
`torchrun-hpc -N 1 -n 4 --gpus-per-proc 1 $(which scaffold) benchmark -c ScaFFold/configs/benchmark_default.yml`

### Dataset cache and sharded datagen

`benchmark` creates or reuses datasets under `dataset_dir`. New datasets are written in the v3 format, which stores one volume and mask file per logical sample per physical shard. The physical layout is controlled by `dc_num_shards` and `dc_shard_dims`; for example, `dc_num_shards: [1, 1, 2]` writes two physical shards per logical volume, with filenames such as `120_shard000000.npy` and `120_shard000001.npy`. Datasets are generated with the same sharding configuration used for model training.

Unsharded runs use `dc_num_shards: [1, 1, 1]`. For those runs, ScaFFold can still reuse an existing v2 full-volume dataset cache. Sharded runs require a matching v3 cache or generate a new v3 dataset.

`benchmark` creates a folder for the benchmark run(s) at `base_run_dir` set in the config file. For reproducibility, we store a copy of the benchmark run config yml. Within each run subfolder, `benchmark` creates a yml config for that specific run.

After each run completes, statistics from the run are stored in `train_stats.csv`. Additionally, users can inspect plots of the training and validation losses over time in `<base_run_dir/figures`.
Expand All @@ -69,14 +75,18 @@ Parameters are set in a `.yml` config file and can be modified by the user:
```yml
# External/user-facing
base_run_dir: "benchmark_runs" # Subfolder of $(pwd) in which to run jobs.
dataset_dir: "datasets" # Directory in which to store and query for datasets.
fract_base_dir: "fractals" # Base directory for fractal IFS and instances.
n_categories: 5 # Number of fractal categories present in the dataset.
n_instances_used_per_fractal: 145 # Number of unique instances to pull from each fractal class. There are 145 unique; exceeding this number will reuse some instances.
problem_scale: 6 # Determines dataset resolution and number of unet layers. Default is 6.
unet_bottleneck_dim: 3 # Power of 2 of the unet bottleneck layer dimension. Default of 3 -> bottleneck layer of size 8.
seed: 42 # Random seed.
batch_size: 1 # Batch sizes for each vol size.
batch_size: 1 # Batch size per rank.
dataloader_num_workers: 1 # Number of DataLoader worker processes per rank.
optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp.
dc_num_shards: [1, 1, 1] # Physical data shards per sample for DistConv.
dc_shard_dims: [2, 3, 4] # Tensor dimensions used for physical sharding.

# Internal/dev use only
variance_threshold: 0.15 # Variance threshold for valid fractals. Default is 0.15.
Expand All @@ -97,6 +107,7 @@ framework: "torch" # The DL framework to train with. Only valid
checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints.
checkpoint_interval: 1 # Number of epochs between saving training checkpoints.
loss_freq: 1 # Number of epochs between logging the overall loss.
warmup_batches: 64 # Training and validation warmup batches per DDP rank.
```

## How the benchmark works
Expand Down Expand Up @@ -194,6 +205,8 @@ For n  in n_volumes:
3. Save volume and mask  to files
```

In the current v3 dataset format, this save step writes each logical sample as one or more physical shard files, matching the requested `dc_num_shards` layout. The dataloader then reads only the shard file needed by the current DistConv rank instead of loading a full volume and slicing it locally.

### 1. Profiling with the PyTorch Profiler

Set `PROFILE_TORCH=ON` to generate a PyTorch profiling trace that can be read into [Perfetto](https://ui.perfetto.dev/).
Expand Down
6 changes: 6 additions & 0 deletions ScaFFold/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def main():
nargs=3,
help="DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum",
)
benchmark_parser.add_argument(
"--dc-shard-dims",
type=int,
nargs=3,
help="DistConv param: dimensions on which to shard.",
)
benchmark_parser.add_argument(
"--epochs",
type=int,
Expand Down
4 changes: 2 additions & 2 deletions ScaFFold/configs/benchmark_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ seed: 42 # Random seed.
batch_size: 1 # Batch sizes for each vol size.
dataloader_num_workers: 4 # Number of DataLoader worker processes per rank.
optimizer: "ADAM" # "ADAM" is preferred option, otherwise training defautls to RMSProp.
num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum
shard_dim: [2, 3, 4] # DistConv param: dimension on which to shard
dc_num_shards: [1, 1, 1] # DistConv param: number of shards to divide the tensor into. It's best to choose the fewest ranks needed to fit one sample in GPU memory, since that keeps communication at a minimum
dc_shard_dims: [2, 3, 4] # DistConv param: dimension on which to shard
checkpoint_interval: -1 # Checkpoint every C epochs; set to -1 to disable checkpointing entirely.

# Internal/dev use only
Expand Down
65 changes: 41 additions & 24 deletions ScaFFold/datagen/category_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def generate_single_category(config: Config) -> tuple[bool, np.array, bool, bool
A bool for whether a valid category was found on this attempt.
params : np.array
A numpy array containing IFS parameters for this category attempt, if attempt was valid.
(not nan_check_pass) : bool
A bool for whether this attempt passed the NaN check.
(not value_check_pass) : bool
A bool for whether this attempt passed the NaN/non-finite check.
(not variance_check_pass) : bool
A bool for whether this attempt passed the variance check.
(not runaway_check_pass) : bool
Expand Down Expand Up @@ -80,31 +80,40 @@ def generate_single_category(config: Config) -> tuple[bool, np.array, bool, bool
),
)

# Sum number of NaNs
# Sum number of NaNs and reject infinities before normalization.
nan_count = np.isnan(points).sum()
nan_check_pass = nan_count == 0
value_check_pass = nan_count == 0 and np.isfinite(points).all()
variance_check_pass = False

if nan_check_pass:
if value_check_pass:
# Normalize + center
mins = points.min(axis=0)
maxs = points.max(axis=0)
means = points.mean(axis=0)
scales = (2 * config.normalize) / (maxs - mins)
points = (points - means) * scales

# Calc dimension-wise variance and compare to threshold
points_variance = np.var(points, axis=1)
variance_check_pass = np.all(points_variance > config.variance_threshold)
if variance_check_pass and nan_check_pass and runaway_check_pass:
with np.errstate(over="ignore", invalid="ignore"):
ranges = maxs - mins
value_check_pass = np.all(np.isfinite(ranges)) and np.all(ranges > 0)
if value_check_pass:
scales = (2 * config.normalize) / ranges
with np.errstate(over="ignore", invalid="ignore"):
points = (points - means) * scales

value_check_pass = np.isfinite(points).all()
if value_check_pass:
# Calc dimension-wise variance and compare to threshold
points_variance = np.var(points, axis=0)
variance_check_pass = np.all(
points_variance > config.variance_threshold
)
if variance_check_pass and value_check_pass and runaway_check_pass:
valid = True

# Return result
return (
valid,
params,
not nan_check_pass,
not variance_check_pass,
bool(not value_check_pass),
bool(value_check_pass and not variance_check_pass),
not runaway_check_pass,
)

Expand All @@ -129,7 +138,7 @@ def generate_categories_batch(
params : np.array
A numpy array containing IFS parameters for this category attempt, if attempt was valid.
failed_nan_check_count : int
The number of attempts in this batch which failed the nan check.
The number of attempts in this batch which failed the NaN/non-finite check.
failed_var_check_count : int
The number of attempts in this batch which failed the var check.
runaway_failure_count : int
Expand Down Expand Up @@ -186,7 +195,11 @@ def main(config: Config) -> None:
rank = comm.Get_rank()
size = comm.Get_size()

datagen_batch_size = 10000
datagen_batch_size = int(getattr(config, "datagen_batch_size", 10000))
if datagen_batch_size < 1:
raise ValueError(
f"datagen_batch_size must be positive, got {datagen_batch_size}"
)

# FIXME anything else to ensure determinism?
np.random.seed(config.seed + rank)
Expand Down Expand Up @@ -224,7 +237,7 @@ def main(config: Config) -> None:
var_fail_count = 0
runaway_fail_count = 0
while categories_remaining > 0:
attempts += size
attempts += datagen_batch_size * size

# Each rank attempts to generate datagen_batch_size categories
(
Expand All @@ -245,12 +258,15 @@ def main(config: Config) -> None:
# Process IFS params one at a time, writing each to a CSV
if rank == 0:
params_valid = [item for sublist in gathered_params for item in sublist]
if attempts % 10000 * size / datagen_batch_size == 0:
print(
f"cat_remaining = {categories_remaining} | total attempts = {attempts} | stats for rank 0: invalid_value_fail_count = {nan_fail_count}, var_fail_count = {var_fail_count}, runaway_fail_count = {runaway_fail_count}",
flush=True,
)
if len(params_valid) > 0:
print(
f"cat_remaining = {categories_remaining} | total attempts = {attempts} | stats for rank 0: nan_fail_count = {nan_fail_count}, var_fail_count = {var_fail_count}, runaway_fail_count = {runaway_fail_count}"
f"Processing {len(params_valid)} valid param sets from this batch",
flush=True,
)
if len(params_valid) > 0:
print(f"Processing {len(params_valid)} param sets from this attempt")
for p in params_valid:
# Ensure we don't save more categories than needed
if categories_remaining > 0:
Expand Down Expand Up @@ -284,14 +300,15 @@ def main(config: Config) -> None:
global_runaway_fail_count = comm.reduce(runaway_fail_count, op=MPI.SUM, root=0)

if rank == 0 and attempts > 0:
generated_categories = config.n_categories - existing_categories
print(
f"Generated {config.n_categories - existing_categories} new categories in {attempts * datagen_batch_size} total attempts | {attempts * datagen_batch_size / (config.n_categories - existing_categories)} Attempts per category | Total categories is now {config.n_categories}"
f"Generated {generated_categories} new categories in {attempts} total attempts | {attempts / generated_categories} Attempts per category | Total categories is now {config.n_categories}"
)
print(
f"Failures experienced: {global_nan_fail_count} nan attempts, {100 * global_nan_fail_count / (attempts * datagen_batch_size):.4f}% of all attempts, {global_var_fail_count} var fail attempts, {100 * global_var_fail_count / (attempts * datagen_batch_size):.4f}% of all attempts, {global_runaway_fail_count} runaway attempts, {100 * global_runaway_fail_count / (attempts * datagen_batch_size):.4f}% of all attempts"
f"Failures experienced: {global_nan_fail_count} invalid-value attempts, {100 * global_nan_fail_count / attempts:.4f}% of all attempts, {global_var_fail_count} var fail attempts, {100 * global_var_fail_count / attempts:.4f}% of all attempts, {global_runaway_fail_count} runaway attempts, {100 * global_runaway_fail_count / attempts:.4f}% of all attempts"
)
print(
f"Rank 0 wall time = {rank_total_time:.2f} | Total CPU time = {global_sum_time:.2f} | Avg wall time per rank {global_sum_time / size:.2f} | {attempts * datagen_batch_size / rank_total_time:.2f} total attempts per wall second | {attempts * datagen_batch_size / rank_total_time / size:.2f} attempts per wall second per rank"
f"Rank 0 wall time = {rank_total_time:.2f} | Total CPU time = {global_sum_time:.2f} | Avg wall time per rank {global_sum_time / size:.2f} | {attempts / rank_total_time:.2f} total attempts per wall second | {attempts / rank_total_time / size:.2f} attempts per wall second per rank"
)

return 0
Loading
Loading