From 1788146cae20b617275c494766694f41bcdfd494 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Mon, 15 Jun 2026 13:52:23 -0700 Subject: [PATCH] replace batchnorm with groupnorm --- ScaFFold/cli.py | 5 +++ ScaFFold/configs/benchmark_default.yml | 1 + ScaFFold/configs/benchmark_testing.yml | 1 + ScaFFold/unet/unet_model.py | 42 ++++++++++++++++++++++---- ScaFFold/unet/unet_parts.py | 32 ++++++++++++++------ ScaFFold/utils/config_utils.py | 1 + ScaFFold/worker.py | 1 + 7 files changed, 68 insertions(+), 15 deletions(-) diff --git a/ScaFFold/cli.py b/ScaFFold/cli.py index 469c71e..9249aac 100644 --- a/ScaFFold/cli.py +++ b/ScaFFold/cli.py @@ -150,6 +150,11 @@ def main(): type=int, help="Number of warmup batches to run per rank before training.", ) + benchmark_parser.add_argument( + "--group-norm-groups", + type=int, + help="Number of groups used by GroupNorm in the UNet blocks.", + ) benchmark_parser.add_argument( "--dataloader-num-workers", type=int, diff --git a/ScaFFold/configs/benchmark_default.yml b/ScaFFold/configs/benchmark_default.yml index 1b0310c..687c785 100644 --- a/ScaFFold/configs/benchmark_default.yml +++ b/ScaFFold/configs/benchmark_default.yml @@ -33,6 +33,7 @@ framework: "torch" # The DL framework to train with. Only valid checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter +group_norm_groups: 8 # Number of groups used by GroupNorm in the UNet blocks. warmup_batches: 64 # How many warmup batches per rank to run before training. ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. diff --git a/ScaFFold/configs/benchmark_testing.yml b/ScaFFold/configs/benchmark_testing.yml index 5167de1..e12bb26 100644 --- a/ScaFFold/configs/benchmark_testing.yml +++ b/ScaFFold/configs/benchmark_testing.yml @@ -33,6 +33,7 @@ framework: "torch" # The DL framework to train with. Only valid checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints. loss_freq: 1 # Number of epochs between logging the overall loss. normalize: 1 # Cateogry search normalization parameter +group_norm_groups: 8 # Number of groups used by GroupNorm in the UNet blocks. warmup_batches: 5 # How many warmup batches per rank to run before training. ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights. dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse. diff --git a/ScaFFold/unet/unet_model.py b/ScaFFold/unet/unet_model.py index a35e232..41ff4f8 100644 --- a/ScaFFold/unet/unet_model.py +++ b/ScaFFold/unet/unet_model.py @@ -41,33 +41,63 @@ class UNet(nn.Module): """ @_unet_annotate - def __init__(self, n_channels, n_classes, trilinear=False, layers=4): + def __init__( + self, + n_channels, + n_classes, + trilinear=False, + layers=4, + group_norm_groups=8, + ): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.trilinear = trilinear self.layers = layers + self.group_norm_groups = group_norm_groups factor = 2 if trilinear else 1 self.down_list = nn.ModuleList([]) layer_channels = 64 - self.down_list.append(DoubleConv(n_channels, layer_channels)) + self.down_list.append( + DoubleConv(n_channels, layer_channels, self.group_norm_groups) + ) for i in range(self.layers - 1): - self.down_list.append(Down(layer_channels, layer_channels * 2)) + self.down_list.append( + Down(layer_channels, layer_channels * 2, self.group_norm_groups) + ) layer_channels *= 2 - self.down_list.append(Down(layer_channels, (layer_channels * 2) // factor)) + self.down_list.append( + Down( + layer_channels, + (layer_channels * 2) // factor, + self.group_norm_groups, + ) + ) layer_channels *= 2 self.up_list = nn.ModuleList([]) for i in range(self.layers - 1): self.up_list.append( - Up(layer_channels, (layer_channels // 2) // factor, trilinear) + Up( + layer_channels, + (layer_channels // 2) // factor, + self.group_norm_groups, + trilinear, + ) ) layer_channels //= 2 - self.up_list.append(Up(layer_channels, layer_channels // 2, trilinear)) + self.up_list.append( + Up( + layer_channels, + layer_channels // 2, + self.group_norm_groups, + trilinear, + ) + ) layer_channels //= 2 self.up_list.append(OutConv(layer_channels, n_classes)) diff --git a/ScaFFold/unet/unet_parts.py b/ScaFFold/unet/unet_parts.py index c363763..bd44cee 100644 --- a/ScaFFold/unet/unet_parts.py +++ b/ScaFFold/unet/unet_parts.py @@ -26,19 +26,27 @@ _outconv_annotate = annotate(fmt="OutConv.{}") +def _group_norm(num_groups, num_channels): + if num_channels % num_groups != 0: + raise ValueError( + f"group_norm_groups={num_groups} must evenly divide num_channels={num_channels}" + ) + return nn.GroupNorm(num_groups, num_channels) + + class DoubleConv(nn.Module): - """(convolution => [BN] => ReLU) * 2""" + """(convolution => GroupNorm => ReLU) * 2""" - def __init__(self, in_channels, out_channels, mid_channels=None): + def __init__(self, in_channels, out_channels, group_norm_groups, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), - nn.BatchNorm3d(mid_channels), + _group_norm(group_norm_groups, mid_channels), nn.ReLU(inplace=True), nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.BatchNorm3d(out_channels), + _group_norm(group_norm_groups, out_channels), nn.ReLU(inplace=True), ) @@ -50,10 +58,11 @@ def forward(self, x): class Down(nn.Module): """Downscaling with maxpool then double conv""" - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels, group_norm_groups): super().__init__() self.maxpool_conv = nn.Sequential( - nn.MaxPool3d(2), DoubleConv(in_channels, out_channels) + nn.MaxPool3d(2), + DoubleConv(in_channels, out_channels, group_norm_groups), ) @_down_annotate @@ -64,18 +73,23 @@ def forward(self, x): class Up(nn.Module): """Upscaling then double conv""" - def __init__(self, in_channels, out_channels, trilinear=True): + def __init__(self, in_channels, out_channels, group_norm_groups, trilinear=True): super().__init__() # if trilinear, use the normal convolutions to reduce the number of channels if trilinear: self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True) - self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + self.conv = DoubleConv( + in_channels, + out_channels, + group_norm_groups, + in_channels // 2, + ) else: self.up = nn.ConvTranspose3d( in_channels, in_channels // 2, kernel_size=2, stride=2 ) - self.conv = DoubleConv(in_channels, out_channels) + self.conv = DoubleConv(in_channels, out_channels, group_norm_groups) @_up_annotate def forward(self, x1, x2): diff --git a/ScaFFold/utils/config_utils.py b/ScaFFold/utils/config_utils.py index 36f1603..9cb632c 100644 --- a/ScaFFold/utils/config_utils.py +++ b/ScaFFold/utils/config_utils.py @@ -73,6 +73,7 @@ def __init__(self, config_dict): self.loss_freq = config_dict["loss_freq"] self.checkpoint_dir = config_dict["checkpoint_dir"] self.normalize = config_dict["normalize"] + self.group_norm_groups = config_dict.get("group_norm_groups", 8) self.warmup_batches = config_dict.get("warmup_batches") self.ce_weight_sample_fraction = config_dict.get( "ce_weight_sample_fraction", 0.1 diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index fde4582..a4e212b 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -174,6 +174,7 @@ def main(kwargs_dict: dict = {}): n_classes=config.n_categories + 1, trilinear=False, layers=config.unet_layers, + group_norm_groups=config.group_norm_groups, ) if config.dist: # DDP + DistConv setup