diff --git a/paddle/phi/kernels/impl/isfinite_kernel_impl.h b/paddle/phi/kernels/impl/isfinite_kernel_impl.h index 6d0172808ebfe8..c0cec1d97fe836 100644 --- a/paddle/phi/kernels/impl/isfinite_kernel_impl.h +++ b/paddle/phi/kernels/impl/isfinite_kernel_impl.h @@ -301,7 +301,23 @@ __global__ void IsfiniteCUDAKernel( const T* in_data, IndexType num, bool* out_data, - typename std::enable_if::value>::type* = 0) { + typename std::enable_if::value && + !std::is_same::value && + !std::is_same::value>::type* = 0) { + IndexType idx = threadIdx.x + blockIdx.x * blockDim.x; + for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) { + const T& a = in_data[i]; + out_data[i] = isfinite(a); + } +} + +template +__global__ void IsfiniteCUDAKernel( + const T* in_data, + IndexType num, + bool* out_data, + typename std::enable_if::value || + std::is_same::value>::type* = 0) { IndexType idx = threadIdx.x + blockIdx.x * blockDim.x; for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) { const T& a = in_data[i]; @@ -340,7 +356,23 @@ __global__ void IsnanCUDAKernel( const T* in_data, IndexType num, bool* out_data, - typename std::enable_if::value>::type* = 0) { + typename std::enable_if::value && + !std::is_same::value && + !std::is_same::value>::type* = 0) { + IndexType idx = threadIdx.x + blockIdx.x * blockDim.x; + for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) { + const T& a = in_data[i]; + out_data[i] = isnan(a); + } +} + +template +__global__ void IsnanCUDAKernel( + const T* in_data, + IndexType num, + bool* out_data, + typename std::enable_if::value || + std::is_same::value>::type* = 0) { IndexType idx = threadIdx.x + blockIdx.x * blockDim.x; for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) { const T& a = in_data[i]; @@ -379,7 +411,23 @@ __global__ void IsinfCUDAKernel( const T* in_data, IndexType num, bool* out_data, - typename std::enable_if::value>::type* = 0) { + typename std::enable_if::value && + !std::is_same::value && + !std::is_same::value>::type* = 0) { + IndexType idx = threadIdx.x + blockIdx.x * blockDim.x; + for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) { + const T& a = in_data[i]; + out_data[i] = isinf(a); + } +} + +template +__global__ void IsinfCUDAKernel( + const T* in_data, + IndexType num, + bool* out_data, + typename std::enable_if::value || + std::is_same::value>::type* = 0) { IndexType idx = threadIdx.x + blockIdx.x * blockDim.x; for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) { const T& a = in_data[i]; @@ -477,9 +525,9 @@ struct IsinfFunctor { #endif template -PADDLE_API void IsfiniteKernel(const Context& dev_ctx, - const DenseTensor& x, - DenseTensor* out) { +void IsfiniteKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { if (out && out->numel() == 0) { dev_ctx.template Alloc(out); return;