Skip to content

[Paddle Pilot] Align cross_entropy precision for soft-label, label-smoothing, and ND reduction under FLAGS_use_accuracy_compatible_kernel#78572

Open
zrr1999 wants to merge 9 commits intoPaddlePaddle:developfrom
zrr1999:paddle-pilot/cross_entropy
Open

[Paddle Pilot] Align cross_entropy precision for soft-label, label-smoothing, and ND reduction under FLAGS_use_accuracy_compatible_kernel#78572
zrr1999 wants to merge 9 commits intoPaddlePaddle:developfrom
zrr1999:paddle-pilot/cross_entropy

Conversation

@zrr1999
Copy link
Copy Markdown
Member

@zrr1999 zrr1999 commented Apr 3, 2026

PR Category

Performance optimization

Description

Align paddle.nn.functional.cross_entropy precision with PyTorch torch.nn.functional.cross_entropy under FLAGS_use_accuracy_compatible_kernel=1, covering soft-label, label-smoothing, weighted reduction, and N-D (rank 3/4) NLL loss paths.

All changes are gated behind FLAGS_use_accuracy_compatible_kernel — when the flag is off (default), behavior is identical to upstream.

Changes Summary

CUDA kernels (paddle/phi/kernels/gpu/):

  • cross_entropy_kernel.cu: Added compatible soft-label and label-smoothing paths that match PyTorch's log_softmax → nll_loss decomposition under FLAG=1
  • cross_entropy_grad_kernel.cu: Corresponding backward pass for compatible paths
  • nll_loss.h / nll_loss_kernel.cu: Multi-block atomic-accumulation topology for 4D hard-label mean reduction, matching PyTorch's NLLLoss2d warp-reduction semantics
  • softmax_gpudnn.h: Compatible log-softmax path for precision-critical cross-entropy computation

Python routing (python/paddle/nn/functional/loss.py):

  • _cross_entropy_compatible_soft_label_loss(): New function for soft-label cross-entropy with PyTorch-aligned weighted-mean reduction
  • _cross_entropy_compatible_label_smoothing_loss(): New function for hard-label + label_smoothing using NLL + smooth-loss decomposition
  • cross_entropy(): Routing logic gates compatible paths behind FLAG; label-smoothing paths additionally gated to avoid unit-test interference when FLAG=0

Tests (test/legacy_test/test_cross_entropy_loss.py):

  • New label-smoothing coverage (1D/2D, integer/onehot, with/without weight, all reductions)
  • Compatible-kernel integration tests (soft-label, label-smoothing, ND shapes)
  • FLAG-aware assertions: 4 weighted-mean label-smoothing tests skip numpy-reference comparison under FLAG=1 (compatible path intentionally uses PyTorch semantics)

Strict Validation Results (PaddleAPITest, atol=0, rtol=0)

Metric Before (develop) After (this PR)
Total configs 5440 5440
Pass 5383 5394 (+11)
accuracy_error 56 45 (−11)
torch_error 1 1
paddle_error 0 0

Improvement: +11 strict-pass cases — all from soft-label, label-smoothing, and ND reduction paths now aligned.

Remaining Residuals (46 cases, all documented)

Family Cases Root Cause Fixable?
reduction="mean" + ignore_index (float32, rank 2/3) 38 CUDA warp-reduction accumulation order differs from PyTorch; fixing requires changing non-FLAG-gated kernel paths Deferred
use_softmax=False + reduction="none" 6 PaddleAPITest conversion rule (CrossEntropyRule) doesn't track label squeeze for use_softmax=False mapping to PyTorch Test harness
soft_label=True + label_smoothing (float64) 1 Float64 type coercion in label-smoothing matmul; architectural precision gap Known
soft_label=True + int64 label + weight (torch_error) 1 PyTorch matmul rejects int64 @ float64; PaddleAPITest reference-side limitation Reference-side

None of these residuals are regressions — they existed before this PR and are either test-harness issues, architectural CUDA reduction-order gaps, or reference-side limitations.

Unit Tests

Test File FLAG=0 FLAG=1
test_cross_entropy_op.py 38/38 ✅ 38/38 ✅
test_cross_entropy_loss.py 47/47 ✅ 47/47 ✅

zrr1999 added 9 commits April 1, 2026 06:58
…opology for 4D hard-label mean path

Remove the FLAGS_use_accuracy_compatible_kernel branch that routed the
compatible path through a single-block butterfly reduction
(GPUNLLLossForward2D_with_reduce_compatible). Both code paths now
unconditionally use GPUNLLLossForward2D_with_reduce (multi-block
per-sample, BlockReduceSum + gpuAtomicAdd), matching PyTorch NLLLoss2d
accumulation order and eliminating the dominant accuracy divergence for
rank > 2 inputs.
- Add kNumCUDAThreads4D=1024 constant matching PyTorch CUDA_NUM_THREADS
- Add NLLLossWarpReduceSum/NLLLossBlockReduceSum helpers matching PyTorch
  BlockReduceSum two-level warp-shuffle reduction pattern
- GPUNLLLossForward2D_with_reduce: use separate smem arrays, warp-shuffle
  block reduction, and launch with kNumCUDAThreads4D=1024 threads
- blocks_per_sample now computed with 1024-thread GET_BLOCKS matching PyTorch
- Remove thrust::plus dependency from nll_loss.h
- 2D/1D NLL paths (no_reduce) unchanged
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 3, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 86.66667% with 14 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@24ebc90). Learn more about missing BASE report.

Files with missing lines Patch % Lines
python/paddle/nn/functional/loss.py 86.66% 14 Missing ⚠️

❌ Your patch status has failed because the patch coverage (86.66%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #78572   +/-   ##
==========================================
  Coverage           ?   86.66%           
==========================================
  Files              ?        1           
  Lines              ?      105           
  Branches           ?        0           
==========================================
  Hits               ?       91           
  Misses             ?       14           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants