Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 149 additions & 73 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@

class AsymmetricFocalTverskyLoss(_Loss):
"""
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that prioritizes the foreground classes.

Actually, it's only supported for binary image segmentation now.
It supports both binary and multi-class segmentation.

Reimplementation of the Asymmetric Focal Tversky Loss described in:

- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
Michael Yeung, Computerized Medical Imaging and Graphics
"""

def __init__(
Expand All @@ -39,119 +38,200 @@ def __init__(
gamma: float = 0.75,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
) -> None:
"""
Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
delta: weight of the background class (used in the Tversky index denominator). Defaults to 0.7.
gamma: focal exponent value to down-weight easy foreground examples. Defaults to 0.75.
epsilon: a small value to prevent division by zero. Defaults to 1e-7.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
use_softmax: whether to use softmax to transform original logits into probabilities.
If True, softmax is used (for multi-class). If False, sigmoid is used (for binary/multi-label).
Defaults to False.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing validation: delta should be in valid range.

delta is a weight parameter but lacks range validation. Add validation to ensure 0 <= delta <= 1.

✅ Proposed fix
 self.delta = delta
+if not 0 <= self.delta <= 1:
+    raise ValueError(f"delta must be in [0, 1], got {self.delta}")
 self.gamma = gamma

Apply similar validation in AsymmetricFocalLoss (line 201).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.delta = delta
self.delta = delta
if not 0 <= self.delta <= 1:
raise ValueError(f"delta must be in [0, 1], got {self.delta}")
🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py at line 62, The delta assignment in
UnifiedFocalLoss currently lacks validation; add a check in the UnifiedFocalLoss
initializer after receiving delta to ensure 0 <= delta <= 1 and raise a
ValueError with a clear message if out of range; follow the same pattern used in
AsymmetricFocalLoss (the delta validation at/around line 201) so both classes
validate delta consistently.

self.gamma = gamma
self.epsilon = epsilon
self.use_softmax = use_softmax

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
y_true: ground truth labels. Shape should match y_pred.
"""
Comment on lines 62 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Complete the forward method docstring.

Missing Returns and Raises sections. Per coding guidelines, describe return value and raised exceptions.

📝 Enhanced docstring
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    """
    Args:
        y_pred: prediction logits or probabilities. Shape should be (B, C, spatial_dims).
        y_true: ground truth labels. Shape should match y_pred.

    Returns:
        torch.Tensor: Computed loss. Shape depends on reduction:
            - "none": (B, C)
            - "mean" or "sum": scalar

    Raises:
        ValueError: When ground truth shape differs from prediction shape.
    """
🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py around lines 62 - 67, The forward
method's docstring in unified_focal_loss.py is missing Returns and Raises
sections; update the forward(self, y_pred: torch.Tensor, y_true: torch.Tensor)
-> torch.Tensor docstring to describe the returned torch.Tensor (including shape
behavior for reductions: "none" -> (B, C), "mean" or "sum" -> scalar) and
document that a ValueError is raised when y_true shape differs from y_pred;
place these descriptions under "Returns:" and "Raises:" in the forward docstring
for clarity.


# Auto-handle single channel input (binary segmentation case)
if y_pred.shape[1] == 1 and not self.use_softmax:
y_pred = torch.sigmoid(y_pred)
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
is_already_prob = True
if y_true.shape[1] == 1:
y_true = one_hot(y_true, num_classes=2)
else:
is_already_prob = False

n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)
if y_true.shape[1] != n_pred_ch:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

# clip the prediction to avoid NaN
# Convert logits to probabilities if not already done
if not is_already_prob:
if self.use_softmax:
y_pred = torch.softmax(y_pred, dim=1)
else:
y_pred = torch.sigmoid(y_pred)

# Clip the prediction to avoid NaN
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)

axis = list(range(2, len(y_pred.shape)))

# Calculate true positives (tp), false negatives (fn) and false positives (fp)
tp = torch.sum(y_true * y_pred, dim=axis)
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
fp = torch.sum((1 - y_true) * y_pred, dim=axis)

dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)

# Calculate losses separately for each class, enhancing both classes
# Calculate losses separately for each class
# Background: Standard Dice Loss
back_dice = 1 - dice_class[:, 0]
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)

# Average class scores
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
return loss
# Foreground: Focal Tversky Loss
fore_dice = torch.pow(1 - dice_class[:, 1:], 1 / self.gamma)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

Unified Focal Loss Generalising Dice and Cross Entropy-based Losses Handle Class Imbalanced Medical Image Segmentation Michael Yeung focal Tversky formula

💡 Result:

The Focal Tversky loss is defined as (1 − TI)^γ, where the Tversky index (TI) for a class is
TI = TP / (TP + α·FP + β·FN),
with TP, FP, FN the true positives, false positives and false negatives (often α+β=1; e.g. α=0.3, β=0.7). Thus the Focal Tversky loss:
FTL = (1 − TI)^γ,
with γ controlling focus on hard examples (γ=1 → Tversky loss). [1][2]

Sources:

  • Yeung et al., "Unified Focal loss…" (PMC), Eq.12 and surrounding text. [1]
  • Abraham & Khan, "A novel focal Tversky loss…" (original Focal Tversky proposal). [2]

🏁 Script executed:

# Find and examine the unified_focal_loss.py file
fd unified_focal_loss.py -type f

Repository: Project-MONAI/MONAI

Length of output: 182


🏁 Script executed:

# Read the file to see the actual implementation around line 116
cat -n monai/losses/unified_focal_loss.py | sed -n '100,130p'

Repository: Project-MONAI/MONAI

Length of output: 182


🏁 Script executed:

# Check the full context of the AsymmetricFocalTverskyLoss class to understand gamma usage
cat -n monai/losses/unified_focal_loss.py | sed -n '1,150p'

Repository: Project-MONAI/MONAI

Length of output: 182


Fix focal loss exponent formula.

Line 116 uses torch.pow(1 - dice_class[:, 1:], 1 / self.gamma) but should use gamma directly as the exponent, not its reciprocal. Per Yeung et al.'s "Unified Focal Loss" (Eq. 12), the Focal Tversky loss is defined as (1 − TI)^γ, where γ controls focus on hard examples.

🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py at line 116, The focal exponent is
inverted: replace the use of 1/self.gamma with self.gamma so that fore_dice is
computed as (1 - dice_class[:, 1:])**self.gamma; update the expression where
fore_dice is assigned (variable fore_dice, using dice_class and self.gamma) to
raise (1 - dice_class[:, 1:]) to the power self.gamma instead of 1/self.gamma.


# Concatenate background and foreground losses
# back_dice needs unsqueeze to match dimensions: (B,) -> (B, 1)
all_losses = torch.cat([back_dice.unsqueeze(1), fore_dice], dim=1)

# Apply reduction
if self.reduction == LossReduction.MEAN.value:
return torch.mean(all_losses)
if self.reduction == LossReduction.SUM.value:
return torch.sum(all_losses)
if self.reduction == LossReduction.NONE.value:
return all_losses

return torch.mean(all_losses)


class AsymmetricFocalLoss(_Loss):
"""
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
AsymmetricFocalLoss is a variant of Focal Loss that treats background and foreground differently.

Actually, it's only supported for binary image segmentation now.
It supports both binary and multi-class segmentation.

Reimplementation of the Asymmetric Focal Loss described in:

- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
Michael Yeung, Computerized Medical Imaging and Graphics
"""

def __init__(
self,
to_onehot_y: bool = False,
delta: float = 0.7,
gamma: float = 2,
gamma: float = 2.0,
epsilon: float = 1e-7,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
delta: weight for the foreground classes. Defaults to 0.7.
gamma: focusing parameter for the background class (to down-weight easy background examples). Defaults to 2.0.
epsilon: a small value to prevent calculation errors. Defaults to 1e-7.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
use_softmax: whether to use softmax to transform logits. Defaults to False.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.delta = delta
self.gamma = gamma
self.epsilon = epsilon
self.use_softmax = use_softmax

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred: prediction logits or probabilities.
y_true: ground truth labels.
"""
Comment on lines 168 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Complete forward docstring.

Missing Returns and Raises sections.

📝 Enhanced docstring
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    """
    Args:
        y_pred: prediction logits or probabilities.
        y_true: ground truth labels.

    Returns:
        torch.Tensor: Computed loss. Shape depends on reduction:
            - "none": (B, C, spatial_dims...)
            - "mean" or "sum": scalar

    Raises:
        ValueError: When ground truth shape differs from prediction shape.
    """
🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py around lines 169 - 174, Update the
forward method's docstring in unified_focal_loss.py to include explicit Returns
and Raises sections: under Returns describe the torch.Tensor loss and its shape
behavior for reductions ("none" -> (B, C, spatial_dims...), "mean"/"sum" ->
scalar), and under Raises document a ValueError when y_true shape differs from
y_pred; keep the existing Args description and ensure the docstring remains in
the same triple-quoted block for the forward(self, y_pred: torch.Tensor, y_true:
torch.Tensor) method.


if y_pred.shape[1] == 1 and not self.use_softmax:
y_pred = torch.sigmoid(y_pred)
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
is_already_prob = True
if y_true.shape[1] == 1:
y_true = one_hot(y_true, num_classes=2)
else:
is_already_prob = False

n_pred_ch = y_pred.shape[1]

if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
y_true = one_hot(y_true, num_classes=n_pred_ch)
if y_true.shape[1] != n_pred_ch:
y_true = one_hot(y_true, num_classes=n_pred_ch)

if y_true.shape != y_pred.shape:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

if not is_already_prob:
if self.use_softmax:
y_pred = torch.softmax(y_pred, dim=1)
else:
y_pred = torch.sigmoid(y_pred)

y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)

cross_entropy = -y_true * torch.log(y_pred)

# Background (Channel 0): Focal Loss
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Gamma=0 causes division by zero in Tversky loss.

While this line uses torch.pow multiplicatively (safe with gamma=0), the AsymmetricFocalTverskyLoss uses 1/gamma which will fail. For consistency and to prevent confusion, validate gamma > 0 in both classes.

🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py at line 275, Add a guard that enforces
gamma > 0 in the constructors for both the UnifiedFocalLoss-related class that
uses self.gamma (referenced where back_ce = torch.pow(1 - y_pred[:, 0],
self.gamma) * ...) and the AsymmetricFocalTverskyLoss class (which uses
1/gamma). In each __init__ (look for the class constructors named around
UnifiedFocalLoss/AsymmetricFocalTverskyLoss), raise a ValueError with a clear
message if gamma <= 0 so callers cannot pass zero or negative gamma; keep the
rest of logic unchanged.

back_ce = (1 - self.delta) * back_ce

fore_ce = cross_entropy[:, 1]
# Foreground (Channels 1+): Standard Cross Entropy
fore_ce = cross_entropy[:, 1:]
fore_ce = self.delta * fore_ce

loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
return loss
# Concatenate losses
all_ce = torch.cat([back_ce.unsqueeze(1), fore_ce], dim=1)

# Sum over classes (dim=1) to get total loss per pixel
total_loss = torch.sum(all_ce, dim=1)

class AsymmetricUnifiedFocalLoss(_Loss):
"""
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
# Apply reduction
if self.reduction == LossReduction.MEAN.value:
return torch.mean(total_loss)
if self.reduction == LossReduction.SUM.value:
return torch.sum(total_loss)
if self.reduction == LossReduction.NONE.value:
return total_loss
return torch.mean(total_loss)

Actually, it's only supported for binary image segmentation now

Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
class AsymmetricUnifiedFocalLoss(_Loss):
"""
AsymmetricUnifiedFocalLoss is a wrapper that combines AsymmetricFocalLoss and AsymmetricFocalTverskyLoss.

- "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
Michael Yeung, Computerized Medical Imaging and Graphics
This unified loss allows for simultaneously optimizing distribution-based (CE) and region-based (Dice) metrics,
while handling class imbalance through asymmetric weighting.
"""

def __init__(
Expand All @@ -162,51 +242,53 @@ def __init__(
gamma: float = 0.5,
delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
):
"""
Args:
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
num_classes : number of classes, it only supports 2 now. Defaults to 2.
delta : weight of the background. Defaults to 0.7.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.

Example:
>>> import torch
>>> from monai.losses import AsymmetricUnifiedFocalLoss
>>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
>>> fl(pred, grnd)
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
num_classes: number of classes. Defaults to 2.
weight: weight factor to balance between Focal Loss and Tversky Loss.
Loss = weight * FocalLoss + (1-weight) * TverskyLoss. Defaults to 0.5.
gamma: focal exponent. Defaults to 0.5.
delta: background/foreground balancing weight. Defaults to 0.7.
reduction: specifies the reduction to apply to the output. Defaults to "mean".
use_softmax: whether to use softmax for probability conversion. Defaults to False.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.to_onehot_y = to_onehot_y
self.num_classes = num_classes
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
self.weight = weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing validation: weight should be in [0, 1].

weight balances two losses but lacks range validation.

✅ Proposed fix
 self.weight = weight
+if not 0 <= self.weight <= 1:
+    raise ValueError(f"weight must be in [0, 1], got {self.weight}")
 self.use_softmax = use_softmax
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.weight = weight
self.weight = weight
if not 0 <= self.weight <= 1:
raise ValueError(f"weight must be in [0, 1], got {self.weight}")
🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py at line 348, The initializer sets
self.weight without validating its range; add a check in the UnifiedFocalLoss
(or its __init__) where self.weight is assigned to ensure weight is a numeric
value between 0 and 1 inclusive, and raise a ValueError with a clear message
like "weight must be between 0 and 1" if it is outside that range or not a
number; keep the assignment to self.weight only after the validation passes.

self.use_softmax = use_softmax

self.asy_focal_loss = AsymmetricFocalLoss(
gamma=self.gamma,
delta=self.delta,
use_softmax=self.use_softmax,
to_onehot_y=to_onehot_y,
reduction=reduction,
)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
gamma=self.gamma,
delta=self.delta,
use_softmax=self.use_softmax,
to_onehot_y=to_onehot_y,
reduction=reduction,
)
Comment on lines +266 to +279
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential double one-hot conversion.

Lines 297-299 convert single-channel inputs to one-hot, then lines 305-309 may convert again if to_onehot_y=True. Additionally, the sub-losses on lines 266-279 are initialized with to_onehot_y, so they will also attempt one-hot conversion.

This could lead to redundant conversions or shape mismatches. Consider:

  1. Performing one-hot conversion only in the wrapper (lines 297-309), then passing to_onehot_y=False to sub-losses, or
  2. Letting sub-losses handle all conversions and removing lines 297-309
🔧 Recommended fix

Option 1: Handle all conversions in wrapper

         self.asy_focal_loss = AsymmetricFocalLoss(
             gamma=self.gamma,
             delta=self.delta,
             use_softmax=self.use_softmax,
-            to_onehot_y=to_onehot_y,
+            to_onehot_y=False,  # Wrapper handles one-hot conversion
             reduction=reduction,
         )
         self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
             gamma=self.gamma,
             delta=self.delta,
             use_softmax=self.use_softmax,
-            to_onehot_y=to_onehot_y,
+            to_onehot_y=False,  # Wrapper handles one-hot conversion
             reduction=reduction,
         )

Option 2: Let sub-losses handle conversions

         if y_pred.shape != y_true.shape:
             is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
             if not self.to_onehot_y and not is_binary_logits:
                 raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
 
         if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
             raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
 
-        if y_pred.shape[1] == 1:
-            y_pred = one_hot(y_pred, num_classes=self.num_classes)
-            y_true = one_hot(y_true, num_classes=self.num_classes)
-
-        if torch.max(y_true) != self.num_classes - 1:
-            raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}")
-
-        n_pred_ch = y_pred.shape[1]
-        if self.to_onehot_y:
-            if n_pred_ch == 1:
-                warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
-            else:
-                y_true = one_hot(y_true, num_classes=n_pred_ch)
-
         asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
         asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)

Also applies to: 297-309

🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py around lines 266 - 279, The wrapper
UnifiedFocalLoss is performing an explicit one-hot conversion in its forward
path while the child losses (asy_focal_loss and asy_focal_tversky_loss) are also
constructed with to_onehot_y=True, causing double conversion and shape errors;
fix by centralizing conversion in the wrapper: keep the wrapper's one-hot
conversion and change the initialization of AsymmetricFocalLoss and
AsymmetricFocalTverskyLoss to to_onehot_y=False (so they assume inputs are
already one-hot), or alternatively remove the wrapper conversion and leave the
sub-losses with to_onehot_y=True—prefer the first option (do conversion once in
UnifiedFocalLoss and set the sub-losses' to_onehot_y=False).


# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
y_true : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.

Raises:
ValueError: When input and target are different shape
ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
ValueError: When num_classes
ValueError: When the number of classes entered does not match the expected number
y_pred: Prediction logits. Shape: (B, C, H, W, [D]).
Supports binary (C=1 or C=2) and multi-class (C>2) segmentation.
y_true: Ground truth labels. Shape should match y_pred (or be indices if to_onehot_y is True).
"""
if y_pred.shape != y_true.shape:
is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
if not self.to_onehot_y and not is_binary_logits:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
Comment on lines 288 to 292
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix duplicate raise statement.

Line 292 unconditionally raises the same ValueError as line 291, making the conditional logic on lines 289-291 pointless. This prevents binary logits from being processed correctly.

🐛 Proposed fix
         if y_pred.shape != y_true.shape:
             is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
             if not self.to_onehot_y and not is_binary_logits:
                 raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
-            raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

This allows binary logits (single-channel with use_softmax=False) and cases where to_onehot_y=True to proceed despite shape mismatch.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if y_pred.shape != y_true.shape:
is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
if not self.to_onehot_y and not is_binary_logits:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
if y_pred.shape != y_true.shape:
is_binary_logits = y_pred.shape[1] == 1 and not self.use_softmax
if not self.to_onehot_y and not is_binary_logits:
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
🧰 Tools
🪛 Ruff (0.14.10)

291-291: Avoid specifying long messages outside the exception class

(TRY003)


292-292: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In @monai/losses/unified_focal_loss.py around lines 288 - 292, The code
currently always raises a ValueError when y_pred.shape != y_true.shape due to a
duplicate raise; update the conditional in the shape check inside the unified
focal loss logic so that it only raises when neither binary-logits nor
to-onehot_Y handling applies: compute is_binary_logits = (y_pred.shape[1] == 1
and not self.use_softmax) and then if not self.to_onehot_y and not
is_binary_logits: raise the ValueError, otherwise allow execution to continue to
handle binary logits or one-hot conversion paths; reference the y_pred, y_true,
is_binary_logits, to_onehot_y, and use_softmax symbols when locating the change.


if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
Expand All @@ -229,12 +311,6 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)

loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss
loss = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss

if self.reduction == LossReduction.SUM.value:
return torch.sum(loss) # sum over the batch and channel dims
if self.reduction == LossReduction.NONE.value:
return loss # returns [N, num_classes] losses
if self.reduction == LossReduction.MEAN.value:
return torch.mean(loss)
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
return loss
Loading
Loading