Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@
"all",
"sum",
"mean",
"nansum",
# logical
"logical_and",
"logical_or",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2450,6 +2450,22 @@ bool NanmedianOpInferSymbolicShape(
return true;
}

bool NansumOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
bool keepdim = GetBoolAttr(op, "keepdim");

std::vector<int64_t> axis;
const auto attributes = op->attributes();
if (attributes.find("axis") != attributes.end()) {
axis = op->attribute<paddle::dialect::IntArrayAttribute>("axis")
.data()
.GetData();
}
bool reduce_all = axis.size() == 0;

return details::ReduceInferDim(op, infer_context, axis, keepdim, reduce_all);
}

bool NormOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
auto x_shape_or_data =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool2dWithIndex)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool3dWithIndex)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Multinomial)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nanmedian)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nansum)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Numel)
Expand Down
76 changes: 76 additions & 0 deletions paddle/phi/kernels/cpu/reduce_nansum_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/reduce_nansum_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/impl/reduce_grad.h"

namespace phi {

template <typename T, typename Context>
void NansumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
if (x_grad && x_grad->numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
return;
}

// Step 1: broadcast out_grad to x_grad shape (same as sum_grad)
ReduceGradKernel<Context, T, funcs::SumGradFunctor, true>(dev_ctx,
x,
paddle::none,
out_grad,
dims.GetData(),
keep_dim,
reduce_all,
x_grad);

// Step 2: zero out gradient where x is NaN
const T* x_data = x.data<T>();
T* x_grad_data = x_grad->data<T>();
int64_t numel = x.numel();
for (int64_t i = 0; i < numel; ++i) {
if (x_data[i] != x_data[i]) {
x_grad_data[i] = static_cast<T>(0);
}
}
}

} // namespace phi

PD_REGISTER_KERNEL(nansum_grad,
CPU,
ALL_LAYOUT,
phi::NansumGradKernel,
bool,
float,
double,
phi::float16,
phi::bfloat16,
int16_t,
int,
int64_t,
phi::complex64,
phi::complex128) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
81 changes: 81 additions & 0 deletions paddle/phi/kernels/cpu/reduce_nansum_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/reduce_nansum_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {

template <typename T, typename Context>
void NansumKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
DataType out_dtype,
bool keep_dim,
DenseTensor* out) {
if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) {
out_dtype = out->dtype();
}

if (x.numel() == 0) {
dev_ctx.template Alloc<T>(out);
if (out_dtype == DataType::INT64) {
Full<int64_t, Context>(dev_ctx, out->dims(), 0, out);
} else {
Full<T, Context>(dev_ctx, out->dims(), 0, out);
}
return;
}

// Replace NaN with 0
DenseTensor cleaned_x;
cleaned_x.Resize(x.dims());
dev_ctx.template Alloc<T>(&cleaned_x);
const T* x_data = x.data<T>();
T* clean_data = cleaned_x.data<T>();
int64_t numel = x.numel();
for (int64_t i = 0; i < numel; ++i) {
clean_data[i] = (x_data[i] != x_data[i]) ? static_cast<T>(0) : x_data[i];
}

// Delegate to SumRawKernel
bool reduce_all = recompute_reduce_all(x, dims);
SumRawKernel<T, Context>(
dev_ctx, cleaned_x, dims, keep_dim, reduce_all, out_dtype, out);
}

} // namespace phi

PD_REGISTER_KERNEL(nansum,
CPU,
ALL_LAYOUT,
phi::NansumKernel,
bool,
float,
double,
phi::float16,
phi::bfloat16,
int16_t,
int8_t,
uint8_t,
int,
int64_t,
phi::complex64,
phi::complex128) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
83 changes: 83 additions & 0 deletions paddle/phi/kernels/gpu/reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
// limitations under the License.

#include "paddle/phi/kernels/reduce_kernel.h"
#include "paddle/phi/kernels/reduce_nansum_grad_kernel.h"

#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/gpu/reduce_amin_amax_common.h"
#include "paddle/phi/kernels/reduce_amin_grad_kernel.h"
#include "paddle/phi/kernels/reduce_max_grad_kernel.h"
Expand Down Expand Up @@ -224,6 +227,67 @@ void ReduceKernel(const Context& dev_ctx,
#endif
}

template <typename T>
struct NanMaskFunctor {
const T* x_data;
T* x_grad_data;

NanMaskFunctor(const T* x_data, T* x_grad_data)
: x_data(x_data), x_grad_data(x_grad_data) {}

HOSTDEVICE void operator()(size_t idx) const {
// NaN != NaN for floating-point; always false for integral types
if (x_data[idx] != x_data[idx]) {
x_grad_data[idx] = static_cast<T>(0);
}
}
};

template <typename T, typename Context>
void NansumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
if (x_grad && x_grad->numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
return;
}

// Step 1: broadcast out_grad to x_grad shape (same as sum_grad)
int dim_size = x.dims().size();
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all);

auto update_dims = vectorize(x.dims());
for (auto i : reduce_dims) {
update_dims[i] = 1;
}

DenseTensor new_out_grad(out_grad.dtype());
new_out_grad.ShareDataWith(out_grad);
new_out_grad.Resize(make_ddim(update_dims));

dev_ctx.Alloc(x_grad, x.dtype());
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
phi::ReduceGrad<kps::IdentityFunctor<T, MPType>>(
dev_ctx,
&new_out_grad,
x_grad,
x.dtype(),
kps::IdentityFunctor<T, MPType>());

// Step 2: zero out gradient where x is NaN
const T* x_data = x.data<T>();
T* x_grad_data = x_grad->data<T>();
int64_t numel = x.numel();
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
for_range(NanMaskFunctor<T>(x_data, x_grad_data));
}

} // namespace phi

#if NCCL_VERSION_CODE >= 21000
Expand Down Expand Up @@ -330,3 +394,22 @@ PD_REGISTER_KERNEL(sum_grad,
phi::complex128) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}

PD_REGISTER_KERNEL(nansum_grad,
GPU,
ALL_LAYOUT,
phi::NansumGradKernel,
bool,
float,
double,
phi::float16,
phi::bfloat16,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
phi::complex64,
phi::complex128) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
55 changes: 55 additions & 0 deletions paddle/phi/kernels/kps/reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "paddle/phi/kernels/reduce_max_kernel.h"
#include "paddle/phi/kernels/reduce_mean_kernel.h"
#include "paddle/phi/kernels/reduce_min_kernel.h"
#include "paddle/phi/kernels/reduce_nansum_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/phi/kernels/funcs/eigen/common.h"
Expand Down Expand Up @@ -272,6 +273,37 @@ void SumRawKernel(const Context& dev_ctx,
dev_ctx, x, reduce_all, dims.GetData(), out_dtype, out);
#endif
}

template <typename T, typename Context>
void NansumKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
DataType out_dtype,
bool keep_dim,
DenseTensor* out) {
if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) {
out_dtype = out->dtype();
}

if (x.numel() == 0) {
dev_ctx.template Alloc<T>(out);
if (out_dtype == DataType::INT64) {
Full<int64_t, Context>(dev_ctx, out->dims(), 0, out);
} else {
Full<T, Context>(dev_ctx, out->dims(), 0, out);
}
return;
}

bool reduce_all = recompute_reduce_all(x, dims);
#ifdef PADDLE_WITH_XPU_KP
phi::Reduce<T, kps::AddFunctor, kps::NanToZeroFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
#else
phi::Reduce<T, kps::NansumOps>(
dev_ctx, x, reduce_all, dims.GetData(), out_dtype, out);
#endif
}
} // namespace phi

#ifdef PADDLE_WITH_XPU_KP
Expand All @@ -296,6 +328,10 @@ PD_REGISTER_KERNEL(min_raw, KPS, ALL_LAYOUT, phi::MinRawKernel, float) {}
PD_REGISTER_KERNEL(sum_raw, KPS, ALL_LAYOUT, phi::SumRawKernel, float) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}

PD_REGISTER_KERNEL(nansum, KPS, ALL_LAYOUT, phi::NansumKernel, float) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
#else
using float16 = phi::float16;
using bfloat16 = phi::bfloat16;
Expand Down Expand Up @@ -406,6 +442,25 @@ PD_REGISTER_KERNEL(sum_raw,
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}

PD_REGISTER_KERNEL(nansum,
KPS,
ALL_LAYOUT,
phi::NansumKernel,
bool,
float,
double,
float16,
bfloat16,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
complex64,
complex128) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}

PD_REGISTER_KERNEL(prod,
KPS,
ALL_LAYOUT,
Expand Down
Loading
Loading