Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 114 additions & 17 deletions src/neuron_proofreader/machine_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,35 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from neuron_proofreader.utils import img_util, ml_util, util


class FocalLoss(nn.Module):
"""Binary focal loss for imbalanced classification.

Downweights easy examples (high confidence, correct) so training
concentrates on the hard cases that drive false positives and false
negatives. Alpha upweights the positive class; gamma sharpens the
focus (gamma=0 reduces to standard BCE).
"""

def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma

def forward(self, logits, targets):
bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
pt = torch.exp(-bce)
alpha_t = targets * self.alpha + (1 - targets) * (1 - self.alpha)
focal_weight = alpha_t * (1 - pt) ** self.gamma
return (focal_weight * bce).mean()

logger = logging.getLogger(__name__)
_LOG_EVERY = 100 # batches between progress log lines
_LOG_EVERY = 1000 # batches between progress log lines


class Trainer:
Expand Down Expand Up @@ -83,6 +106,8 @@ def __init__(
warmup_epochs=5,
scheduler_type="cosine",
pos_weight=None,
focal_gamma=None,
focal_alpha=0.25,
save_val_logits=False,
save_mistake_mips=False,
on_best_model_saved=None,
Expand Down Expand Up @@ -124,6 +149,7 @@ def __init__(
# Instance attributes
self.best_f1 = 0
self.best_val_loss = float("inf")
self.best_f1_at_95recall = 0.0
self.device = device
self.log_dir = log_dir
self.max_epochs = max_epochs
Expand All @@ -138,11 +164,15 @@ def __init__(
self.save_mistake_mips = save_mistake_mips
self.on_best_model_saved = on_best_model_saved

if pos_weight is None:
if focal_gamma is not None:
self.criterion = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
print(f"Loss: FocalLoss(alpha={focal_alpha}, gamma={focal_gamma})")
elif pos_weight is None:
self.criterion = nn.BCEWithLogitsLoss()
else:
pos_weight_tensor = torch.tensor([pos_weight], device=device)
self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
print(f"Loss: BCEWithLogitsLoss(pos_weight={pos_weight})")
self.model = model.to(device)
self.optimizer = optim.AdamW(
self._build_param_groups(self.model, lr, head_lr),
Expand Down Expand Up @@ -256,16 +286,29 @@ def run(self, train_dataloader, val_dataloader):
# Train-Validate
train_stats = self.train_step(train_dataloader, epoch)
val_stats = self.validate_step(val_dataloader, epoch)

new_best_loss = val_stats["loss"] < self.best_val_loss
if new_best_loss:
self.best_val_loss = val_stats["loss"]

f1_95 = val_stats.get("f1_at_95recall", 0.0)
new_best_f1_95 = f1_95 > self.best_f1_at_95recall
if new_best_f1_95:
self.best_f1_at_95recall = f1_95

# Checkpoint: use F1@95recall once the model achieves it; fall back
# to val loss before that threshold is first reached.
if new_best_f1_95:
self.save_model(epoch, tag="best_f1_at_95recall")
if self.save_val_logits:
self._save_val_logits(
val_dataloader, self._last_val_y, self._last_val_hat_y, epoch
)
elif new_best_loss and self.best_f1_at_95recall == 0.0:
self.save_model(epoch, tag="best_loss")
if self.save_val_logits:
self._save_val_logits(
val_dataloader,
self._last_val_y,
self._last_val_hat_y,
epoch,
val_dataloader, self._last_val_y, self._last_val_hat_y, epoch
)

# Log learning rate
Expand All @@ -274,7 +317,12 @@ def run(self, train_dataloader, val_dataloader):
self.writer.add_scalar("lr", current_lr, epoch)

# Report results
print(f"\nEpoch {epoch}: " + ("New Best!" if new_best_loss else " "))
is_new_best = new_best_f1_95 or (new_best_loss and self.best_f1_at_95recall == 0.0)
criterion_label = (
f"F1@95R={f1_95:.4f}" if self.best_f1_at_95recall > 0.0
else f"loss={val_stats['loss']:.4f}"
)
print(f"\nEpoch {epoch}: " + (f"New Best! ({criterion_label})" if is_new_best else ""))
self.report_stats(train_stats, is_train=True)
self.report_stats(val_stats, is_train=False)

Expand All @@ -289,8 +337,8 @@ def run(self, train_dataloader, val_dataloader):
if new != old:
print(f" LR reduced: group {i} {old:.2e} -> {new:.2e}")

# Early stopping check
if new_best_loss:
# Early stopping: track whichever criterion is active
if is_new_best:
self.epochs_without_improvement = 0
else:
self.epochs_without_improvement += 1
Expand Down Expand Up @@ -485,6 +533,27 @@ def forward_pass(self, x, y):
return hat_y, loss

# --- Helpers ---
@staticmethod
def _f1_at_recall_target(y, hat_y_logits, recall_target=0.95):
"""Return the best F1 achievable at >= recall_target recall.

Sweeps 200 probability thresholds and returns the maximum F1 among
those where recall >= recall_target. Returns 0.0 if the model never
achieves the target recall at any threshold.
"""
y_arr = np.array(y, dtype=int)
probs = 1.0 / (1.0 + np.exp(-np.array(hat_y_logits)))
thresholds = np.unique(np.percentile(probs, np.linspace(0, 100, 200)))
best_f1 = 0.0
for t in thresholds:
preds = (probs >= t).astype(int)
r = recall_score(y_arr, preds, zero_division=0)
if r >= recall_target:
p = precision_score(y_arr, preds, zero_division=0)
f1 = 2 * p * r / max(p + r, 1e-8)
best_f1 = max(best_f1, f1)
return best_f1

@staticmethod
def compute_stats(y, hat_y):
"""
Expand Down Expand Up @@ -515,8 +584,10 @@ def compute_stats(y, hat_y):
avg_recall = recall_score(y, hat_y, zero_division=np.nan)
avg_f1 = 2 * avg_prec * avg_recall / max((avg_prec + avg_recall), 1e-8)
avg_acc = accuracy_score(y, hat_y)
f1_at_95recall = Trainer._f1_at_recall_target(y, hat_y_arr)
stats = {
"f1": avg_f1,
"f1_at_95recall": f1_at_95recall,
"precision": avg_prec,
"recall": avg_recall,
"accuracy": avg_acc,
Expand Down Expand Up @@ -800,6 +871,8 @@ def __init__(
warmup_epochs=5,
scheduler_type="cosine",
pos_weight=None,
focal_gamma=None,
focal_alpha=0.25,
save_val_logits=False,
save_mistake_mips=False
):
Expand Down Expand Up @@ -876,6 +949,7 @@ def __init__(
# Now initialize parent class attributes without creating directories
self.best_f1 = 0
self.best_val_loss = float("inf")
self.best_f1_at_95recall = 0.0
self.device = device
self.log_dir = log_dir
self.max_epochs = max_epochs
Expand All @@ -889,11 +963,17 @@ def __init__(
self.save_val_logits = save_val_logits
self.save_mistake_mips = save_mistake_mips

if pos_weight is None:
if focal_gamma is not None:
self.criterion = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
if self.rank == 0:
print(f"Loss: FocalLoss(alpha={focal_alpha}, gamma={focal_gamma})")
elif pos_weight is None:
self.criterion = nn.BCEWithLogitsLoss()
else:
pos_weight_tensor = torch.tensor([pos_weight], device=device)
self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
if self.rank == 0:
print(f"Loss: BCEWithLogitsLoss(pos_weight={pos_weight})")
self.model = model.to(device)
self.scaler = torch.cuda.amp.GradScaler(enabled=True)

Expand Down Expand Up @@ -1092,15 +1172,32 @@ def run(self, train_dataloader, val_dataloader):
new_best_loss = val_stats["loss"] < self.best_val_loss
if new_best_loss:
self.best_val_loss = val_stats["loss"]

f1_95 = val_stats.get("f1_at_95recall", 0.0)
new_best_f1_95 = f1_95 > self.best_f1_at_95recall
if new_best_f1_95:
self.best_f1_at_95recall = f1_95

# Checkpoint: F1@95recall once achieved, val loss as fallback
if new_best_f1_95:
self.save_model(epoch, tag="best_f1_at_95recall")
if self.save_val_logits:
self._save_val_logits(
val_dataloader, self._last_val_y, self._last_val_hat_y, epoch
)
elif new_best_loss and self.best_f1_at_95recall == 0.0:
self.save_model(epoch, tag="best_loss")
if self.save_val_logits:
self._save_val_logits(
val_dataloader,
self._last_val_y,
self._last_val_hat_y,
epoch,
val_dataloader, self._last_val_y, self._last_val_hat_y, epoch
)
print(f"\nEpoch {epoch}: ", "New Best!" if new_best_loss else "")

is_new_best = new_best_f1_95 or (new_best_loss and self.best_f1_at_95recall == 0.0)
criterion_label = (
f"F1@95R={f1_95:.4f}" if self.best_f1_at_95recall > 0.0
else f"loss={val_stats['loss']:.4f}"
)
print(f"\nEpoch {epoch}: " + (f"New Best! ({criterion_label})" if is_new_best else ""))
self.report_stats(train_stats, is_train=True)
self.report_stats(val_stats, is_train=False)

Expand All @@ -1117,8 +1214,8 @@ def run(self, train_dataloader, val_dataloader):
print(f" LR reduced: group {i} {old:.2e} -> {new:.2e}")

if rank == 0:
# Early stopping check
if new_best_loss:
# Early stopping: track whichever criterion is active
if is_new_best:
self.epochs_without_improvement = 0
else:
self.epochs_without_improvement += 1
Expand Down
11 changes: 8 additions & 3 deletions src/neuron_proofreader/machine_learning/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,12 @@ def __init__(
self.encoder_dim = self.model.encoder_dim
self.n_prefix_tokens = 1 + encoder.n_register_tokens
self.grid_size = tuple(int(g) for g in encoder.grid_size)
self.pool_power = pool_power
# Learnable pooling power γ (log-parameterized so it stays positive).
# exp(log(pool_power)) = pool_power at init, so skeleton tokens start
# pool_power× heavier than segment tokens and background stays zero.
self.pool_log_power = nn.Parameter(
torch.tensor(float(pool_power)).log()
)

# Dual-stream classifier: [CLS, mask-pooled] → 1
self.classifier = nn.Sequential(
Expand All @@ -203,8 +208,8 @@ def forward(self, x):
weights = F.adaptive_max_pool3d(mask, self.grid_size)
weights = weights.reshape(weights.shape[0], -1) # (B, n_patches)

# Power-scale: skeleton=1.0, segment=0.25, bg=0.0 (with power=2)
weights = weights ** self.pool_power
# Power-scale with learned γ; exp keeps γ strictly positive.
weights = weights ** self.pool_log_power.exp()

# Normalize weights to sum to 1
weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-8)
Expand Down
13 changes: 13 additions & 0 deletions src/neuron_proofreader/merge_proofreading/merge_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,19 @@ def load_images(
segmentation_prefixes = util.read_json(segmentation_prefixes_path)
brain_ids = get_brain_ids(dataset.merge_sites_df, is_test)

# Filter to brains present in both prefix maps. A missing key would
# KeyError inside the dict comprehension below before any futures are
# submitted, bypassing the try/except that guards future.result().
loadable = []
for bid in brain_ids:
if bid not in img_prefixes:
logger.warning("No image prefix for brain %s — skipping image load", bid)
elif bid not in segmentation_prefixes:
logger.warning("No seg prefix for brain %s — skipping image load", bid)
else:
loadable.append(bid)
brain_ids = loadable

_log_ram("before images")
logger.info("Loading images (%d brains, %d workers)", len(brain_ids), max_workers)
completed = 0
Expand Down
Loading