-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss #8669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -22,14 +22,13 @@ | |||||||||||||||||||
|
|
||||||||||||||||||||
| class AsymmetricFocalTverskyLoss(_Loss): | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. | ||||||||||||||||||||
| AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that prioritizes the foreground classes. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Actually, it's only supported for binary image segmentation now. | ||||||||||||||||||||
| It supports both binary and multi-class segmentation. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Reimplementation of the Asymmetric Focal Tversky Loss described in: | ||||||||||||||||||||
|
|
||||||||||||||||||||
| - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", | ||||||||||||||||||||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||||||||||||||||||||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||||||||||||||||||||
| """ | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def __init__( | ||||||||||||||||||||
|
|
@@ -39,119 +38,200 @@ def __init__( | |||||||||||||||||||
| gamma: float = 0.75, | ||||||||||||||||||||
| epsilon: float = 1e-7, | ||||||||||||||||||||
| reduction: LossReduction | str = LossReduction.MEAN, | ||||||||||||||||||||
| use_softmax: bool = False, | ||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Args: | ||||||||||||||||||||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||||||||||||||||||||
| delta : weight of the background. Defaults to 0.7. | ||||||||||||||||||||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. | ||||||||||||||||||||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||||||||||||||||||||
| delta: weight of the background class (used in the Tversky index denominator). Defaults to 0.7. | ||||||||||||||||||||
| gamma: focal exponent value to down-weight easy foreground examples. Defaults to 0.75. | ||||||||||||||||||||
| epsilon: a small value to prevent division by zero. Defaults to 1e-7. | ||||||||||||||||||||
| reduction: {``"none"``, ``"mean"``, ``"sum"``} | ||||||||||||||||||||
| Specifies the reduction to apply to the output. Defaults to ``"mean"``. | ||||||||||||||||||||
| use_softmax: whether to use softmax to transform original logits into probabilities. | ||||||||||||||||||||
| If True, softmax is used (for multi-class). If False, sigmoid is used (for binary/multi-label). | ||||||||||||||||||||
| Defaults to False. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| super().__init__(reduction=LossReduction(reduction).value) | ||||||||||||||||||||
| self.to_onehot_y = to_onehot_y | ||||||||||||||||||||
| self.delta = delta | ||||||||||||||||||||
| self.gamma = gamma | ||||||||||||||||||||
| self.epsilon = epsilon | ||||||||||||||||||||
| self.use_softmax = use_softmax | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Args: | ||||||||||||||||||||
| y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims). | ||||||||||||||||||||
| y_true: ground truth labels. Shape should match y_pred. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
|
Comment on lines
62
to
+67
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Complete the forward method docstring. Missing Returns and Raises sections. Per coding guidelines, describe return value and raised exceptions. 📝 Enhanced docstringdef forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
y_true: ground truth labels. Shape should match y_pred.
Returns:
torch.Tensor: Computed loss. Shape depends on reduction:
- "none": (B, C)
- "mean" or "sum": scalar
Raises:
ValueError: When ground truth shape differs from prediction shape.
"""🤖 Prompt for AI Agents |
||||||||||||||||||||
|
|
||||||||||||||||||||
| # Auto-handle single channel input (binary segmentation case) | ||||||||||||||||||||
| if y_pred.shape[1] == 1 and not self.use_softmax: | ||||||||||||||||||||
| y_pred = torch.sigmoid(y_pred) | ||||||||||||||||||||
| y_pred = torch.cat([1 - y_pred, y_pred], dim=1) | ||||||||||||||||||||
| is_already_prob = True | ||||||||||||||||||||
| if y_true.shape[1] == 1: | ||||||||||||||||||||
| y_true = one_hot(y_true, num_classes=2) | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| is_already_prob = False | ||||||||||||||||||||
|
|
||||||||||||||||||||
| n_pred_ch = y_pred.shape[1] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if self.to_onehot_y: | ||||||||||||||||||||
| if n_pred_ch == 1: | ||||||||||||||||||||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||||||||||||||||||||
| if y_true.shape[1] != n_pred_ch: | ||||||||||||||||||||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if y_true.shape != y_pred.shape: | ||||||||||||||||||||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # clip the prediction to avoid NaN | ||||||||||||||||||||
| # Convert logits to probabilities if not already done | ||||||||||||||||||||
| if not is_already_prob: | ||||||||||||||||||||
| if self.use_softmax: | ||||||||||||||||||||
| y_pred = torch.softmax(y_pred, dim=1) | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| y_pred = torch.sigmoid(y_pred) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Clip the prediction to avoid NaN | ||||||||||||||||||||
| y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| axis = list(range(2, len(y_pred.shape))) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Calculate true positives (tp), false negatives (fn) and false positives (fp) | ||||||||||||||||||||
| tp = torch.sum(y_true * y_pred, dim=axis) | ||||||||||||||||||||
| fn = torch.sum(y_true * (1 - y_pred), dim=axis) | ||||||||||||||||||||
| fp = torch.sum((1 - y_true) * y_pred, dim=axis) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Calculate losses separately for each class, enhancing both classes | ||||||||||||||||||||
| # Calculate losses separately for each class | ||||||||||||||||||||
| # Background: Standard Dice Loss | ||||||||||||||||||||
| back_dice = 1 - dice_class[:, 0] | ||||||||||||||||||||
| fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Average class scores | ||||||||||||||||||||
| loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) | ||||||||||||||||||||
| return loss | ||||||||||||||||||||
| # Foreground: Focal Tversky Loss | ||||||||||||||||||||
| fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma) | ||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: The Focal Tversky loss is defined as (1 − TI)^γ, where the Tversky index (TI) for a class is Sources:
🏁 Script executed: # Find and examine the unified_focal_loss.py file
fd unified_focal_loss.py -type fRepository: Project-MONAI/MONAI Length of output: 182 🏁 Script executed: # Read the file to see the actual implementation around line 116
cat -n monai/losses/unified_focal_loss.py | sed -n '100,130p'Repository: Project-MONAI/MONAI Length of output: 182 🏁 Script executed: # Check the full context of the AsymmetricFocalTverskyLoss class to understand gamma usage
cat -n monai/losses/unified_focal_loss.py | sed -n '1,150p'Repository: Project-MONAI/MONAI Length of output: 182 Fix focal loss exponent formula. Line 116 uses 🤖 Prompt for AI Agents |
||||||||||||||||||||
|
|
||||||||||||||||||||
| # Concatenate background and foreground losses | ||||||||||||||||||||
| # back_dice needs unsqueeze to match dimensions: (B,) -> (B, 1) | ||||||||||||||||||||
| all_losses = torch.cat([back_dice.unsqueeze(1), fore_dice], dim=1) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Apply reduction | ||||||||||||||||||||
| if self.reduction == LossReduction.MEAN.value: | ||||||||||||||||||||
| return torch.mean(all_losses) | ||||||||||||||||||||
| if self.reduction == LossReduction.SUM.value: | ||||||||||||||||||||
| return torch.sum(all_losses) | ||||||||||||||||||||
| if self.reduction == LossReduction.NONE.value: | ||||||||||||||||||||
| return all_losses | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return torch.mean(all_losses) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| class AsymmetricFocalLoss(_Loss): | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. | ||||||||||||||||||||
| AsymmetricFocalLoss is a variant of Focal Loss that treats background and foreground differently. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Actually, it's only supported for binary image segmentation now. | ||||||||||||||||||||
| It supports both binary and multi-class segmentation. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Reimplementation of the Asymmetric Focal Loss described in: | ||||||||||||||||||||
|
|
||||||||||||||||||||
| - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", | ||||||||||||||||||||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||||||||||||||||||||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||||||||||||||||||||
| """ | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def __init__( | ||||||||||||||||||||
| self, | ||||||||||||||||||||
| to_onehot_y: bool = False, | ||||||||||||||||||||
| delta: float = 0.7, | ||||||||||||||||||||
| gamma: float = 2, | ||||||||||||||||||||
| gamma: float = 2.0, | ||||||||||||||||||||
| epsilon: float = 1e-7, | ||||||||||||||||||||
| reduction: LossReduction | str = LossReduction.MEAN, | ||||||||||||||||||||
| use_softmax: bool = False, | ||||||||||||||||||||
| ): | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Args: | ||||||||||||||||||||
| to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||||||||||||||||||||
| delta : weight of the background. Defaults to 0.7. | ||||||||||||||||||||
| gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. | ||||||||||||||||||||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||||||||||||||||||||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||||||||||||||||||||
| delta: weight for the foreground classes. Defaults to 0.7. | ||||||||||||||||||||
| gamma: focusing parameter for the background class (to down-weight easy background examples). Defaults to 2.0. | ||||||||||||||||||||
| epsilon: a small value to prevent calculation errors. Defaults to 1e-7. | ||||||||||||||||||||
| reduction: {``"none"``, ``"mean"``, ``"sum"``} | ||||||||||||||||||||
| use_softmax: whether to use softmax to transform logits. Defaults to False. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
ytl0623 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
| super().__init__(reduction=LossReduction(reduction).value) | ||||||||||||||||||||
| self.to_onehot_y = to_onehot_y | ||||||||||||||||||||
| self.delta = delta | ||||||||||||||||||||
| self.gamma = gamma | ||||||||||||||||||||
| self.epsilon = epsilon | ||||||||||||||||||||
| self.use_softmax = use_softmax | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Args: | ||||||||||||||||||||
| y_pred: prediction logits or probabilities. | ||||||||||||||||||||
| y_true: ground truth labels. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
|
Comment on lines
168
to
+173
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Complete forward docstring. Missing Returns and Raises sections. 📝 Enhanced docstringdef forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred: prediction logits or probabilities.
y_true: ground truth labels.
Returns:
torch.Tensor: Computed loss. Shape depends on reduction:
- "none": (B, C, spatial_dims...)
- "mean" or "sum": scalar
Raises:
ValueError: When ground truth shape differs from prediction shape.
"""🤖 Prompt for AI Agents |
||||||||||||||||||||
|
|
||||||||||||||||||||
| if y_pred.shape[1] == 1 and not self.use_softmax: | ||||||||||||||||||||
| y_pred = torch.sigmoid(y_pred) | ||||||||||||||||||||
| y_pred = torch.cat([1 - y_pred, y_pred], dim=1) | ||||||||||||||||||||
| is_already_prob = True | ||||||||||||||||||||
| if y_true.shape[1] == 1: | ||||||||||||||||||||
| y_true = one_hot(y_true, num_classes=2) | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| is_already_prob = False | ||||||||||||||||||||
|
|
||||||||||||||||||||
| n_pred_ch = y_pred.shape[1] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if self.to_onehot_y: | ||||||||||||||||||||
| if n_pred_ch == 1: | ||||||||||||||||||||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||||||||||||||||||||
| if y_true.shape[1] != n_pred_ch: | ||||||||||||||||||||
| y_true = one_hot(y_true, num_classes=n_pred_ch) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if y_true.shape != y_pred.shape: | ||||||||||||||||||||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if not is_already_prob: | ||||||||||||||||||||
| if self.use_softmax: | ||||||||||||||||||||
| y_pred = torch.softmax(y_pred, dim=1) | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| y_pred = torch.sigmoid(y_pred) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| cross_entropy = -y_true * torch.log(y_pred) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Background (Channel 0): Focal Loss | ||||||||||||||||||||
| back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] | ||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gamma=0 causes division by zero in Tversky loss. While this line uses 🤖 Prompt for AI Agents |
||||||||||||||||||||
| back_ce = (1 - self.delta) * back_ce | ||||||||||||||||||||
|
|
||||||||||||||||||||
| fore_ce = cross_entropy[:, 1] | ||||||||||||||||||||
| # Foreground (Channels 1+): Standard Cross Entropy | ||||||||||||||||||||
| fore_ce = cross_entropy[:, 1:] | ||||||||||||||||||||
| fore_ce = self.delta * fore_ce | ||||||||||||||||||||
|
|
||||||||||||||||||||
| loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) | ||||||||||||||||||||
| return loss | ||||||||||||||||||||
| # Concatenate losses | ||||||||||||||||||||
| all_ce = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Sum over classes (dim=1) to get total loss per pixel | ||||||||||||||||||||
| total_loss = torch.sum(all_ce, dim=1) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| class AsymmetricUnifiedFocalLoss(_Loss): | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| AsymmetricUnifiedFocalLoss is a variant of Focal Loss. | ||||||||||||||||||||
| # Apply reduction | ||||||||||||||||||||
| if self.reduction == LossReduction.MEAN.value: | ||||||||||||||||||||
| return torch.mean(total_loss) | ||||||||||||||||||||
| if self.reduction == LossReduction.SUM.value: | ||||||||||||||||||||
| return torch.sum(total_loss) | ||||||||||||||||||||
| if self.reduction == LossReduction.NONE.value: | ||||||||||||||||||||
| return total_loss | ||||||||||||||||||||
| return torch.mean(total_loss) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Actually, it's only supported for binary image segmentation now | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: | ||||||||||||||||||||
| class AsymmetricUnifiedFocalLoss(_Loss): | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| AsymmetricUnifiedFocalLoss is a wrapper that combines AsymmetricFocalLoss and AsymmetricFocalTverskyLoss. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", | ||||||||||||||||||||
| Michael Yeung, Computerized Medical Imaging and Graphics | ||||||||||||||||||||
| This unified loss allows for simultaneously optimizing distribution-based (CE) and region-based (Dice) metrics, | ||||||||||||||||||||
| while handling class imbalance through asymmetric weighting. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def __init__( | ||||||||||||||||||||
|
|
@@ -162,51 +242,53 @@ def __init__( | |||||||||||||||||||
| gamma: float = 0.5, | ||||||||||||||||||||
| delta: float = 0.7, | ||||||||||||||||||||
| reduction: LossReduction | str = LossReduction.MEAN, | ||||||||||||||||||||
| use_softmax: bool = False, | ||||||||||||||||||||
| ): | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Args: | ||||||||||||||||||||
| to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. | ||||||||||||||||||||
| num_classes : number of classes, it only supports 2 now. Defaults to 2. | ||||||||||||||||||||
| delta : weight of the background. Defaults to 0.7. | ||||||||||||||||||||
| gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. | ||||||||||||||||||||
| epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. | ||||||||||||||||||||
| weight : weight for each loss function, if it's none it's 0.5. Defaults to None. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Example: | ||||||||||||||||||||
ytl0623 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
| >>> import torch | ||||||||||||||||||||
| >>> from monai.losses import AsymmetricUnifiedFocalLoss | ||||||||||||||||||||
| >>> pred = torch.ones((1,1,32,32), dtype=torch.float32) | ||||||||||||||||||||
| >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) | ||||||||||||||||||||
| >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) | ||||||||||||||||||||
| >>> fl(pred, grnd) | ||||||||||||||||||||
| to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. | ||||||||||||||||||||
| num_classes: number of classes. Defaults to 2. | ||||||||||||||||||||
| weight: weight factor to balance between Focal Loss and Tversky Loss. | ||||||||||||||||||||
| Loss = weight * FocalLoss + (1-weight) * TverskyLoss. Defaults to 0.5. | ||||||||||||||||||||
| gamma: focal exponent. Defaults to 0.5. | ||||||||||||||||||||
| delta: background/foreground balancing weight. Defaults to 0.7. | ||||||||||||||||||||
| reduction: specifies the reduction to apply to the output. Defaults to "mean". | ||||||||||||||||||||
| use_softmax: whether to use softmax for probability conversion. Defaults to False. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| super().__init__(reduction=LossReduction(reduction).value) | ||||||||||||||||||||
| self.to_onehot_y = to_onehot_y | ||||||||||||||||||||
| self.num_classes = num_classes | ||||||||||||||||||||
| self.gamma = gamma | ||||||||||||||||||||
| self.delta = delta | ||||||||||||||||||||
| self.weight: float = weight | ||||||||||||||||||||
| self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) | ||||||||||||||||||||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) | ||||||||||||||||||||
| self.weight = weight | ||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing validation: weight should be in [0, 1].
✅ Proposed fix self.weight = weight
+if not 0 <= self.weight <= 1:
+ raise ValueError(f"weight must be in [0, 1], got {self.weight}")
self.use_softmax = use_softmax📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||
| self.use_softmax = use_softmax | ||||||||||||||||||||
|
|
||||||||||||||||||||
| self.asy_focal_loss = AsymmetricFocalLoss( | ||||||||||||||||||||
| gamma=self.gamma, | ||||||||||||||||||||
| delta=self.delta, | ||||||||||||||||||||
| use_softmax=self.use_softmax, | ||||||||||||||||||||
| to_onehot_y=to_onehot_y, | ||||||||||||||||||||
| reduction=reduction, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss( | ||||||||||||||||||||
| gamma=self.gamma, | ||||||||||||||||||||
| delta=self.delta, | ||||||||||||||||||||
| use_softmax=self.use_softmax, | ||||||||||||||||||||
| to_onehot_y=to_onehot_y, | ||||||||||||||||||||
| reduction=reduction, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
Comment on lines
+266
to
+279
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential double one-hot conversion. Lines 297-299 convert single-channel inputs to one-hot, then lines 305-309 may convert again if This could lead to redundant conversions or shape mismatches. Consider:
🔧 Recommended fixOption 1: Handle all conversions in wrapper self.asy_focal_loss = AsymmetricFocalLoss(
gamma=self.gamma,
delta=self.delta,
use_softmax=self.use_softmax,
- to_onehot_y=to_onehot_y,
+ to_onehot_y=False, # Wrapper handles one-hot conversion
reduction=reduction,
)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
gamma=self.gamma,
delta=self.delta,
use_softmax=self.use_softmax,
- to_onehot_y=to_onehot_y,
+ to_onehot_y=False, # Wrapper handles one-hot conversion
reduction=reduction,
)Option 2: Let sub-losses handle conversions if y_pred.shape != y_true.shape:
is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
if not self.to_onehot_y and not is_binary_logits:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
- if y_pred.shape[1] == 1:
- y_pred = one_hot(y_pred, num_classes=self.num_classes)
- y_true = one_hot(y_true, num_classes=self.num_classes)
-
- if torch.max(y_true) != self.num_classes - 1:
- raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
-
- n_pred_ch = y_pred.shape[1]
- if self.to_onehot_y:
- if n_pred_ch == 1:
- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
- else:
- y_true = one_hot(y_true, num_classes=n_pred_ch)
-
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)Also applies to: 297-309 🤖 Prompt for AI Agents |
||||||||||||||||||||
|
|
||||||||||||||||||||
| # TODO: Implement this function to support multiple classes segmentation | ||||||||||||||||||||
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Args: | ||||||||||||||||||||
| y_pred : the shape should be BNH[WD], where N is the number of classes. | ||||||||||||||||||||
| It only supports binary segmentation. | ||||||||||||||||||||
| The input should be the original logits since it will be transformed by | ||||||||||||||||||||
| a sigmoid in the forward function. | ||||||||||||||||||||
| y_true : the shape should be BNH[WD], where N is the number of classes. | ||||||||||||||||||||
| It only supports binary segmentation. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Raises: | ||||||||||||||||||||
ytl0623 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
| ValueError: When input and target are different shape | ||||||||||||||||||||
| ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 | ||||||||||||||||||||
| ValueError: When num_classes | ||||||||||||||||||||
| ValueError: When the number of classes entered does not match the expected number | ||||||||||||||||||||
| y_pred: Prediction logits. Shape: (B, C, H, W, [D]). | ||||||||||||||||||||
| Supports binary (C=1 or C=2) and multi-class (C>2) segmentation. | ||||||||||||||||||||
| y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True). | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| if y_pred.shape != y_true.shape: | ||||||||||||||||||||
| is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax | ||||||||||||||||||||
| if not self.to_onehot_y and not is_binary_logits: | ||||||||||||||||||||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||||||||||||||||||||
| raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") | ||||||||||||||||||||
|
Comment on lines
288
to
292
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix duplicate raise statement. Line 292 unconditionally raises the same ValueError as line 291, making the conditional logic on lines 289-291 pointless. This prevents binary logits from being processed correctly. 🐛 Proposed fix if y_pred.shape != y_true.shape:
is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
if not self.to_onehot_y and not is_binary_logits:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
- raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")This allows binary logits (single-channel with 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.14.10)291-291: Avoid specifying long messages outside the exception class (TRY003) 292-292: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||||||||||||||||||||
|
|
||||||||||||||||||||
| if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: | ||||||||||||||||||||
|
|
@@ -229,12 +311,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | |||||||||||||||||||
| asy_focal_loss = self.asy_focal_loss(y_pred, y_true) | ||||||||||||||||||||
| asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss | ||||||||||||||||||||
| loss = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if self.reduction == LossReduction.SUM.value: | ||||||||||||||||||||
| return torch.sum(loss) # sum over the batch and channel dims | ||||||||||||||||||||
| if self.reduction == LossReduction.NONE.value: | ||||||||||||||||||||
| return loss # returns [N, num_classes] losses | ||||||||||||||||||||
| if self.reduction == LossReduction.MEAN.value: | ||||||||||||||||||||
| return torch.mean(loss) | ||||||||||||||||||||
| raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') | ||||||||||||||||||||
| return loss | ||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing validation: delta should be in valid range.
deltais a weight parameter but lacks range validation. Add validation to ensure0 <= delta <= 1.✅ Proposed fix
Apply similar validation in AsymmetricFocalLoss (line 201).
📝 Committable suggestion
🤖 Prompt for AI Agents