[Paddle Pilot] Align cross_entropy precision for soft-label, label-smoothing, and ND reduction under FLAGS_use_accuracy_compatible_kernel#78572
Open
zrr1999 wants to merge 9 commits intoPaddlePaddle:developfrom
Conversation
…opology for 4D hard-label mean path Remove the FLAGS_use_accuracy_compatible_kernel branch that routed the compatible path through a single-block butterfly reduction (GPUNLLLossForward2D_with_reduce_compatible). Both code paths now unconditionally use GPUNLLLossForward2D_with_reduce (multi-block per-sample, BlockReduceSum + gpuAtomicAdd), matching PyTorch NLLLoss2d accumulation order and eliminating the dominant accuracy divergence for rank > 2 inputs.
- Add kNumCUDAThreads4D=1024 constant matching PyTorch CUDA_NUM_THREADS - Add NLLLossWarpReduceSum/NLLLossBlockReduceSum helpers matching PyTorch BlockReduceSum two-level warp-shuffle reduction pattern - GPUNLLLossForward2D_with_reduce: use separate smem arrays, warp-shuffle block reduction, and launch with kNumCUDAThreads4D=1024 threads - blocks_per_sample now computed with 1024-thread GET_BLOCKS matching PyTorch - Remove thrust::plus dependency from nll_loss.h - 2D/1D NLL paths (no_reduce) unchanged
…unit-test compatibility
…LAG-aware validation
|
你的PR提交成功,感谢你对开源项目的贡献! |
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (86.66%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #78572 +/- ##
==========================================
Coverage ? 86.66%
==========================================
Files ? 1
Lines ? 105
Branches ? 0
==========================================
Hits ? 91
Misses ? 14
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
PR Category
Performance optimization
Description
Align
paddle.nn.functional.cross_entropyprecision with PyTorchtorch.nn.functional.cross_entropyunderFLAGS_use_accuracy_compatible_kernel=1, covering soft-label, label-smoothing, weighted reduction, and N-D (rank 3/4) NLL loss paths.All changes are gated behind
FLAGS_use_accuracy_compatible_kernel— when the flag is off (default), behavior is identical to upstream.Changes Summary
CUDA kernels (
paddle/phi/kernels/gpu/):cross_entropy_kernel.cu: Added compatible soft-label and label-smoothing paths that match PyTorch'slog_softmax → nll_lossdecomposition under FLAG=1cross_entropy_grad_kernel.cu: Corresponding backward pass for compatible pathsnll_loss.h/nll_loss_kernel.cu: Multi-block atomic-accumulation topology for 4D hard-label mean reduction, matching PyTorch's NLLLoss2d warp-reduction semanticssoftmax_gpudnn.h: Compatible log-softmax path for precision-critical cross-entropy computationPython routing (
python/paddle/nn/functional/loss.py):_cross_entropy_compatible_soft_label_loss(): New function for soft-label cross-entropy with PyTorch-aligned weighted-mean reduction_cross_entropy_compatible_label_smoothing_loss(): New function for hard-label + label_smoothing using NLL + smooth-loss decompositioncross_entropy(): Routing logic gates compatible paths behind FLAG; label-smoothing paths additionally gated to avoid unit-test interference when FLAG=0Tests (
test/legacy_test/test_cross_entropy_loss.py):Strict Validation Results (PaddleAPITest, atol=0, rtol=0)
Improvement: +11 strict-pass cases — all from soft-label, label-smoothing, and ND reduction paths now aligned.
Remaining Residuals (46 cases, all documented)
reduction="mean"+ignore_index(float32, rank 2/3)use_softmax=False+reduction="none"use_softmax=Falsemapping to PyTorchsoft_label=True+label_smoothing(float64)soft_label=True+ int64 label + weight (torch_error)matmulrejectsint64 @ float64; PaddleAPITest reference-side limitationNone of these residuals are regressions — they existed before this PR and are either test-harness issues, architectural CUDA reduction-order gaps, or reference-side limitations.
Unit Tests
test_cross_entropy_op.pytest_cross_entropy_loss.py