Skip to content

Conversation

@embedding-shapes
Copy link

Summary

Adds ignore_index support to cross-entropy and NLL loss functions via new *_with_ignore functions, enabling proper handling of padding tokens in sequence modeling.

Motivation

Sequence modeling (NLP, time series) requires excluding padding tokens from loss computation. Currently users must manually mask losses, which is error-prone and less efficient.

Changes

  • Add loss::nll_with_ignore(inp, target, ignore_index) function
  • Add loss::cross_entropy_with_ignore(inp, target, ignore_index) function
  • Implement safe target sanitization to prevent OOB errors
  • Add 4 tests including safety and PyTorch parity checks
  • Fully backward compatible (new functions, existing ones unchanged)

PyTorch Compatibility

Matches PyTorch functionality with two documented divergences:

  1. All-ignored case: Returns 0.0 (PyTorch returns nan) for better ergonomics
  2. Type: Uses u32 (PyTorch uses i32) to match Candle's target dtype

Open Questions

  • Is adding two new functions better for you rather than changing the signature of existing functions? I'm personally fine with both, opted for the approach that reduces API churn for now, but OK to change it if you want to.
  • Are there too many comments? I opted for adding "too much" over "too little" right now, let me know if it's too verbose and you want to adjust. I'm still a beginner ML enthusiast, so the comment were more for helping me during the implementation, we might want to remove some of them if deemed "too basic" maybe.

Matches more or less PyTorch's functionality with minor changes to the
behaviour to fit better with Candle.

- Added 4 basic tests to assert correctness
- All existing tests pass
- Added new functions rather than changing existing functions to avoid
  forcing callers to refactor, existing `nll` and `cross_entropy`
  unchanged

Alternatives I thought of but rejected in favor of the chosen design:

- Adding `ignore_index: Option<u32>` parameter: Would force callers to
  update, breaking change.
- Creating some sort of builder pattern: Over-engineered for just a
  single optional parameter

Divergencies from PyTorch:

- PyTorch lets ignore_index be -100 (i32), this implementation does u32 (must
  be non-negative). Users need to specify actual padding token IDs
  rather than sentinel values. However, out-of-range values (like 999 with vocab size 100) still
  works correctly due to target sanitazion
- PyTorch would return `nan` when all targets are ignored, our
  implementation returns `0.0` for safety and ergonomics instead. Feels
  like it'd be easier to debug (NaN would be an error signal, would be
  confusing maybe).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant