Skip to content

Add focal loss support to SegmentationHead#509

Draft
Farbum wants to merge 6 commits intomasterfrom
hadriens/focal_loss
Draft

Add focal loss support to SegmentationHead#509
Farbum wants to merge 6 commits intomasterfrom
hadriens/focal_loss

Conversation

@Farbum
Copy link
Copy Markdown
Contributor

@Farbum Farbum commented Feb 13, 2026

  • Add focal_loss parameter using torchvision sigmoid_focal_loss
  • Add focal_loss_alpha and focal_loss_gamma parameters
  • Add loss_weights tuple to scale classification and dice losses

root and others added 4 commits January 27, 2026 05:30
- Add focal_loss parameter using torchvision sigmoid_focal_loss
- Add focal_loss_alpha and focal_loss_gamma parameters
- Add loss_weights tuple to scale classification and dice losses
@Farbum Farbum marked this pull request as ready for review February 18, 2026 17:32
Comment thread rslearn/train/tasks/segmentation.py Outdated
weights: list[float] | None = None,
class_weights: list[float] | None = None,
dice_loss: bool = False,
focal_loss: bool = False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we may add more and more losses and it will get confusing. So I think we need to define CrossEntropyLoss/FocalLoss/DiceLoss and then take a losses: list[SegmentationHeadLoss] or dict[str, SegmentationHeadLoss] here (SegmentationHeadLoss can be subclass of torch.nn.Module similar to FeatureExtractor, IntermediateComponent, etc.). It can default to cross entropy loss to preserve the existing behavior.

focal_loss: bool = False,
focal_loss_alpha: float = 0.25,
focal_loss_gamma: float = 2.0,
loss_weights: tuple[float, float] = (1.0, 1.0),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we return the losses as dict and don't combine them, I think it is better to handle these weights in RslearnLightningModule so it can be more general-purpose. It may be possible for this to replace the loss_weights in MultiTaskModel (so just handling the task name prefixing in MultiTaskModel and user can configure the loss weights across tasks via RslearnLightningModule, in addition to controlling different losses like focal loss vs dice loss weight).

@favyen2
Copy link
Copy Markdown
Collaborator

favyen2 commented Mar 2, 2026

@claude can you review this PR?

@claude
Copy link
Copy Markdown

claude Bot commented Mar 2, 2026

Claude finished @favyen2's task in 1m 40s —— View job


PR Review: Add focal loss support to SegmentationHead

Tasks:

  • Gather context and understand request
  • Read changed files and analyze diff
  • Post review feedback

Summary

The PR adds focal loss support and a loss_weights parameter to SegmentationHead. The implementation is functional and well-tested, but the existing review comments from @favyen2 raise important architectural concerns that should be addressed before merging.


Architectural Concerns (from existing review)

These are the most important issues for this PR:

1. Growing boolean flags approach (line 317)@favyen2's comment is spot-on. The current API:

SegmentationHead(
    dice_loss=True,
    focal_loss=True,
    focal_loss_alpha=0.25,
    focal_loss_gamma=2.0,
    loss_weights=(1.0, 1.0),
)

will keep growing as more losses are added. A composable loss class approach is cleaner:

SegmentationHead(losses=[FocalLoss(alpha=0.25, gamma=2.0), DiceLoss()])

Each loss class could be a subclass of torch.nn.Module and implement forward(logits, labels, mask) -> Tensor, similar to the existing FeatureExtractor/IntermediateComponent pattern in the codebase.

2. loss_weights placement (line 320)RslearnLightningModule already aggregates the loss dict with train_loss = sum(loss_dict.values()) (lightning_module.py:345). Moving weight configuration there would avoid coupling weighting logic to the head, and would make it general-purpose across all tasks (including multi-task setups).


Code-Level Issues

3. Silent class_weights ignored with focal_loss=True — When focal_loss=True, any configured class_weights are silently ignored. sigmoid_focal_loss has no equivalent weight parameter. This should either raise an error or the docstring should explicitly document this incompatibility. Fix this →

4. Sigmoid focal loss vs softmax semantics mismatchsigmoid_focal_loss (from torchvision/RetinaNet) treats each class as an independent binary prediction using sigmoid activation. The baseline cross-entropy uses softmax, which treats classes as mutually exclusive. For standard multi-class segmentation, this is semantically inconsistent — sigmoid focal loss was designed for multi-label/object-detection scenarios where classes are not mutually exclusive. A softmax-based focal loss would be more appropriate:

# Softmax focal loss variant
pt = torch.nn.functional.softmax(logits, dim=1)
labels_one_hot = F.one_hot(labels, num_classes).permute(0, 3, 1, 2).float()
ce = -torch.log(pt + 1e-8)
focal_weight = (1 - pt) ** gamma
per_pixel_loss = (focal_weight * ce * labels_one_hot).sum(dim=1)

If the sigmoid variant is intentional for a specific use case, this should be documented.

5. loss_weights[1] silently unused when dice_loss=False — If a user sets loss_weights=(1.0, 0.5) without dice_loss=True, the second value is silently ignored with no warning.


Minor Issues

6. Docstring title is wrongSegmentationHead.__init__ docstring says "Initialize a new SegmentationTask" (line 323) — should say SegmentationHead. Fix this →

7. Missing test: combined focal + dice loss — The test suite covers focal-only and dice-only, but doesn't test focal_loss=True, dice_loss=True together.

8. Missing test: class_weights with focal_loss=True — There's no test demonstrating that class_weights is ignored when focal_loss=True, which would expose the silent failure noted in issue #3.


Positive Notes

  • The use of einops.rearrange is consistent with the rest of the codebase (addresses @robmarkcole's feedback).
  • The test_losses_match_manual_computation test is thorough and correctly verifies numerical accuracy against reference computations.
  • Backward compatibility is preserved — all new parameters have sensible defaults.
  • The loss_weights default of (1.0, 1.0) means existing configs are unaffected.

Bottom line: The implementation is functionally correct, but the architectural concerns in the existing review (#3, #4 above, and @favyen2's comments) are worth addressing before merging, as they would affect how future loss types are added to the codebase.

@Farbum Farbum marked this pull request as draft March 26, 2026 16:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants