Conversation
Farbum
commented
Feb 13, 2026
- Add focal_loss parameter using torchvision sigmoid_focal_loss
- Add focal_loss_alpha and focal_loss_gamma parameters
- Add loss_weights tuple to scale classification and dice losses
- Add focal_loss parameter using torchvision sigmoid_focal_loss - Add focal_loss_alpha and focal_loss_gamma parameters - Add loss_weights tuple to scale classification and dice losses
| weights: list[float] | None = None, | ||
| class_weights: list[float] | None = None, | ||
| dice_loss: bool = False, | ||
| focal_loss: bool = False, |
There was a problem hiding this comment.
Looks like we may add more and more losses and it will get confusing. So I think we need to define CrossEntropyLoss/FocalLoss/DiceLoss and then take a losses: list[SegmentationHeadLoss] or dict[str, SegmentationHeadLoss] here (SegmentationHeadLoss can be subclass of torch.nn.Module similar to FeatureExtractor, IntermediateComponent, etc.). It can default to cross entropy loss to preserve the existing behavior.
| focal_loss: bool = False, | ||
| focal_loss_alpha: float = 0.25, | ||
| focal_loss_gamma: float = 2.0, | ||
| loss_weights: tuple[float, float] = (1.0, 1.0), |
There was a problem hiding this comment.
Since we return the losses as dict and don't combine them, I think it is better to handle these weights in RslearnLightningModule so it can be more general-purpose. It may be possible for this to replace the loss_weights in MultiTaskModel (so just handling the task name prefixing in MultiTaskModel and user can configure the loss weights across tasks via RslearnLightningModule, in addition to controlling different losses like focal loss vs dice loss weight).
|
@claude can you review this PR? |
|
Claude finished @favyen2's task in 1m 40s —— View job PR Review: Add focal loss support to SegmentationHeadTasks:
SummaryThe PR adds focal loss support and a Architectural Concerns (from existing review)These are the most important issues for this PR: 1. Growing boolean flags approach (line 317) — @favyen2's comment is spot-on. The current API: SegmentationHead(
dice_loss=True,
focal_loss=True,
focal_loss_alpha=0.25,
focal_loss_gamma=2.0,
loss_weights=(1.0, 1.0),
)will keep growing as more losses are added. A composable loss class approach is cleaner: SegmentationHead(losses=[FocalLoss(alpha=0.25, gamma=2.0), DiceLoss()])Each loss class could be a subclass of 2. Code-Level Issues3. Silent class_weights ignored with focal_loss=True — When 4. Sigmoid focal loss vs softmax semantics mismatch — # Softmax focal loss variant
pt = torch.nn.functional.softmax(logits, dim=1)
labels_one_hot = F.one_hot(labels, num_classes).permute(0, 3, 1, 2).float()
ce = -torch.log(pt + 1e-8)
focal_weight = (1 - pt) ** gamma
per_pixel_loss = (focal_weight * ce * labels_one_hot).sum(dim=1)If the sigmoid variant is intentional for a specific use case, this should be documented. 5. Minor Issues6. Docstring title is wrong — 7. Missing test: combined focal + dice loss — The test suite covers focal-only and dice-only, but doesn't test 8. Missing test: Positive Notes
Bottom line: The implementation is functionally correct, but the architectural concerns in the existing review (#3, #4 above, and @favyen2's comments) are worth addressing before merging, as they would affect how future loss types are added to the codebase. |