Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ScaFFold/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ def main():
type=int,
help="Number of warmup batches to run per rank before training.",
)
benchmark_parser.add_argument(
"--group-norm-groups",
type=int,
help="Number of groups used by GroupNorm in the UNet blocks.",
)
benchmark_parser.add_argument(
"--dataloader-num-workers",
type=int,
Expand Down
1 change: 1 addition & 0 deletions ScaFFold/configs/benchmark_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ framework: "torch" # The DL framework to train with. Only valid
checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints.
loss_freq: 1 # Number of epochs between logging the overall loss.
normalize: 1 # Cateogry search normalization parameter
group_norm_groups: 8 # Number of groups used by GroupNorm in the UNet blocks.
warmup_batches: 64 # How many warmup batches per rank to run before training.
ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights.
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
Expand Down
1 change: 1 addition & 0 deletions ScaFFold/configs/benchmark_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ framework: "torch" # The DL framework to train with. Only valid
checkpoint_dir: "checkpoints" # Subfolder in which to save training checkpoints.
loss_freq: 1 # Number of epochs between logging the overall loss.
normalize: 1 # Cateogry search normalization parameter
group_norm_groups: 8 # Number of groups used by GroupNorm in the UNet blocks.
warmup_batches: 5 # How many warmup batches per rank to run before training.
ce_weight_sample_fraction: 0.1 # Fraction of training masks to sample when estimating background vs foreground CE weights.
dataset_reuse_enforce_commit_id: 0 # Enforce matching commit IDs for dataset reuse.
Expand Down
42 changes: 36 additions & 6 deletions ScaFFold/unet/unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,33 +41,63 @@ class UNet(nn.Module):
"""

@_unet_annotate
def __init__(self, n_channels, n_classes, trilinear=False, layers=4):
def __init__(
self,
n_channels,
n_classes,
trilinear=False,
layers=4,
group_norm_groups=8,
):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.trilinear = trilinear
self.layers = layers
self.group_norm_groups = group_norm_groups
factor = 2 if trilinear else 1

self.down_list = nn.ModuleList([])
layer_channels = 64
self.down_list.append(DoubleConv(n_channels, layer_channels))
self.down_list.append(
DoubleConv(n_channels, layer_channels, self.group_norm_groups)
)

for i in range(self.layers - 1):
self.down_list.append(Down(layer_channels, layer_channels * 2))
self.down_list.append(
Down(layer_channels, layer_channels * 2, self.group_norm_groups)
)
layer_channels *= 2

self.down_list.append(Down(layer_channels, (layer_channels * 2) // factor))
self.down_list.append(
Down(
layer_channels,
(layer_channels * 2) // factor,
self.group_norm_groups,
)
)
layer_channels *= 2

self.up_list = nn.ModuleList([])
for i in range(self.layers - 1):
self.up_list.append(
Up(layer_channels, (layer_channels // 2) // factor, trilinear)
Up(
layer_channels,
(layer_channels // 2) // factor,
self.group_norm_groups,
trilinear,
)
)
layer_channels //= 2

self.up_list.append(Up(layer_channels, layer_channels // 2, trilinear))
self.up_list.append(
Up(
layer_channels,
layer_channels // 2,
self.group_norm_groups,
trilinear,
)
)
layer_channels //= 2

self.up_list.append(OutConv(layer_channels, n_classes))
Expand Down
32 changes: 23 additions & 9 deletions ScaFFold/unet/unet_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,27 @@
_outconv_annotate = annotate(fmt="OutConv.{}")


def _group_norm(num_groups, num_channels):
if num_channels % num_groups != 0:
raise ValueError(
f"group_norm_groups={num_groups} must evenly divide num_channels={num_channels}"
)
return nn.GroupNorm(num_groups, num_channels)


class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
"""(convolution => GroupNorm => ReLU) * 2"""

def __init__(self, in_channels, out_channels, mid_channels=None):
def __init__(self, in_channels, out_channels, group_norm_groups, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm3d(mid_channels),
_group_norm(group_norm_groups, mid_channels),
nn.ReLU(inplace=True),
nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm3d(out_channels),
_group_norm(group_norm_groups, out_channels),
nn.ReLU(inplace=True),
)

Expand All @@ -50,10 +58,11 @@ def forward(self, x):
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""

def __init__(self, in_channels, out_channels):
def __init__(self, in_channels, out_channels, group_norm_groups):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool3d(2), DoubleConv(in_channels, out_channels)
nn.MaxPool3d(2),
DoubleConv(in_channels, out_channels, group_norm_groups),
)

@_down_annotate
Expand All @@ -64,18 +73,23 @@ def forward(self, x):
class Up(nn.Module):
"""Upscaling then double conv"""

def __init__(self, in_channels, out_channels, trilinear=True):
def __init__(self, in_channels, out_channels, group_norm_groups, trilinear=True):
super().__init__()

# if trilinear, use the normal convolutions to reduce the number of channels
if trilinear:
self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
self.conv = DoubleConv(
in_channels,
out_channels,
group_norm_groups,
in_channels // 2,
)
else:
self.up = nn.ConvTranspose3d(
in_channels, in_channels // 2, kernel_size=2, stride=2
)
self.conv = DoubleConv(in_channels, out_channels)
self.conv = DoubleConv(in_channels, out_channels, group_norm_groups)

@_up_annotate
def forward(self, x1, x2):
Expand Down
1 change: 1 addition & 0 deletions ScaFFold/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, config_dict):
self.loss_freq = config_dict["loss_freq"]
self.checkpoint_dir = config_dict["checkpoint_dir"]
self.normalize = config_dict["normalize"]
self.group_norm_groups = config_dict.get("group_norm_groups", 8)
self.warmup_batches = config_dict.get("warmup_batches")
self.ce_weight_sample_fraction = config_dict.get(
"ce_weight_sample_fraction", 0.1
Expand Down
1 change: 1 addition & 0 deletions ScaFFold/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def main(kwargs_dict: dict = {}):
n_classes=config.n_categories + 1,
trilinear=False,
layers=config.unet_layers,
group_norm_groups=config.group_norm_groups,
)
if config.dist:
# DDP + DistConv setup
Expand Down
Loading