Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 60 additions & 14 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,22 @@ def __init__(
include_background: bool = True,
to_onehot_y: bool = False,
gamma: float = 2.0,
alpha: float | None = None,
alpha: float | Sequence[float] | None = None,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
) -> None:
"""
Args:
include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
If False, `alpha` is invalid when using softmax.
If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights).
to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
The value should be in [0, 1]. Defaults to None.
The value should be in [0, 1].
If a sequence is provided, its length must match the number of classes
(excluding the background class if `include_background=False`).
Defaults to None.
weight: weights to apply to the voxels of each class. If None no weights are applied.
The input can be a single value (same weight for all classes), a sequence of values (the length
of the sequence should be the same as the number of classes. If not ``include_background``,
Expand Down Expand Up @@ -109,9 +112,15 @@ def __init__(
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.gamma = gamma
self.alpha = alpha
self.weight = weight
self.use_softmax = use_softmax
self.alpha: float | torch.Tensor | None
if alpha is None:
self.alpha = None
elif isinstance(alpha, (float, int)):
self.alpha = float(alpha)
else:
self.alpha = torch.as_tensor(alpha)
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
Expand Down Expand Up @@ -155,13 +164,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss: torch.Tensor | None = None
input = input.float()
target = target.float()
alpha_arg = self.alpha
if self.use_softmax:
if not self.include_background and self.alpha is not None:
self.alpha = None
warnings.warn("`include_background=False`, `alpha` ignored when using softmax.")
loss = softmax_focal_loss(input, target, self.gamma, self.alpha)
if isinstance(self.alpha, (float, int)):
alpha_arg = None
warnings.warn(
"`include_background=False`, scalar `alpha` ignored when using softmax.", stacklevel=2
)
loss = softmax_focal_loss(input, target, self.gamma, alpha_arg)
else:
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)
loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg)

num_of_classes = target.shape[1]
if self.class_weight is not None and num_of_classes != 1:
Expand Down Expand Up @@ -202,7 +215,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:


def softmax_focal_loss(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | None = None
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
Expand All @@ -214,8 +227,22 @@ def softmax_focal_loss(
loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target

if alpha is not None:
# (1-alpha) for the background class and alpha for the other classes
alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss)
if isinstance(alpha, torch.Tensor):
alpha_t = alpha.to(device=input.device, dtype=input.dtype)
else:
alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)

if alpha_t.ndim == 0: # scalar
alpha_val = alpha_t.item()
# (1-alpha) for the background class and alpha for the other classes
alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss)
else: # tensor (sequence)
if alpha_t.shape[0] != target.shape[1]:
raise ValueError(
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
)
alpha_fac = alpha_t

broadcast_dims = [-1] + [1] * len(target.shape[2:])
alpha_fac = alpha_fac.view(broadcast_dims)
loss = alpha_fac * loss
Expand All @@ -224,7 +251,7 @@ def softmax_focal_loss(


def sigmoid_focal_loss(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | None = None
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
Expand All @@ -247,8 +274,27 @@ def sigmoid_focal_loss(
loss = (invprobs * gamma).exp() * loss

if alpha is not None:
# alpha if t==1; (1-alpha) if t==0
alpha_factor = target * alpha + (1 - target) * (1 - alpha)
if isinstance(alpha, torch.Tensor):
alpha_t = alpha.to(device=input.device, dtype=input.dtype)
else:
alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)

if alpha_t.ndim == 0: # scalar
# alpha if t==1; (1-alpha) if t==0
alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t)
else: # tensor (sequence)
if alpha_t.shape[0] != target.shape[1]:
raise ValueError(
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
)
# Reshape alpha for broadcasting: (1, C, 1, 1...)
broadcast_dims = [-1] + [1] * len(target.shape[2:])
alpha_t = alpha_t.view(broadcast_dims)
# Apply per-class weight only to positive samples
# For positive samples (target==1): multiply by alpha[c]
# For negative samples (target==0): keep weight as 1.0
alpha_factor = torch.where(target == 1, alpha_t, torch.ones_like(alpha_t))

loss = alpha_factor * loss

return loss
46 changes: 44 additions & 2 deletions tests/losses/test_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@

from monai.losses import FocalLoss
from monai.networks import one_hot
from tests.test_utils import test_script_save
from tests.test_utils import TEST_DEVICES, test_script_save

TEST_CASES = []
for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]:
for case in TEST_DEVICES:
device = case[0]
input_data = {
"input": torch.tensor(
[[[[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]]]], device=device
Expand Down Expand Up @@ -77,6 +78,13 @@
TEST_CASES.append([{"to_onehot_y": True, "use_softmax": True}, input_data, 0.16276])
TEST_CASES.append([{"to_onehot_y": True, "alpha": 0.8, "use_softmax": True}, input_data, 0.08138])

TEST_ALPHA_BROADCASTING = []
for case in TEST_DEVICES:
device = case[0]
for include_background in [True, False]:
for use_softmax in [True, False]:
TEST_ALPHA_BROADCASTING.append([device, include_background, use_softmax])


class TestFocalLoss(unittest.TestCase):
@parameterized.expand(TEST_CASES)
Expand Down Expand Up @@ -374,6 +382,40 @@ def test_script(self):
test_input = torch.ones(2, 2, 8, 8)
test_script_save(loss, test_input, test_input)

@parameterized.expand(TEST_ALPHA_BROADCASTING)
def test_alpha_sequence_broadcasting(self, device, include_background, use_softmax):
"""
Test FocalLoss with alpha as a sequence for proper broadcasting.
"""
num_classes = 3
batch_size = 2
spatial_dims = (4, 4)

logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device)
target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device)

if include_background:
alpha_seq = [0.1, 0.5, 2.0]
else:
alpha_seq = [0.5, 2.0]

loss_func = FocalLoss(
to_onehot_y=True,
gamma=2.0,
alpha=alpha_seq,
include_background=include_background,
use_softmax=use_softmax,
reduction="mean",
)

result = loss_func(logits, target)

self.assertTrue(torch.is_tensor(result))
self.assertEqual(result.ndim, 0)
self.assertTrue(
result > 0, f"Loss should be positive. params: dev={device}, bg={include_background}, softmax={use_softmax}"
)


if __name__ == "__main__":
unittest.main()
Loading