Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions paddle/phi/kernels/impl/isfinite_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,23 @@ __global__ void IsfiniteCUDAKernel(
const T* in_data,
IndexType num,
bool* out_data,
typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
typename std::enable_if<std::is_floating_point<T>::value &&
!std::is_same<T, phi::bfloat16>::value &&
!std::is_same<T, phi::float16>::value>::type* = 0) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
const T& a = in_data[i];
out_data[i] = isfinite(a);
}
}

template <typename T, typename IndexType>
__global__ void IsfiniteCUDAKernel(
const T* in_data,
IndexType num,
bool* out_data,
typename std::enable_if<std::is_same<T, phi::bfloat16>::value ||
std::is_same<T, phi::float16>::value>::type* = 0) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
const T& a = in_data[i];
Expand Down Expand Up @@ -340,7 +356,23 @@ __global__ void IsnanCUDAKernel(
const T* in_data,
IndexType num,
bool* out_data,
typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
typename std::enable_if<std::is_floating_point<T>::value &&
!std::is_same<T, phi::bfloat16>::value &&
!std::is_same<T, phi::float16>::value>::type* = 0) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
const T& a = in_data[i];
out_data[i] = isnan(a);
}
}

template <typename T, typename IndexType>
__global__ void IsnanCUDAKernel(
const T* in_data,
IndexType num,
bool* out_data,
typename std::enable_if<std::is_same<T, phi::bfloat16>::value ||
std::is_same<T, phi::float16>::value>::type* = 0) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
const T& a = in_data[i];
Expand Down Expand Up @@ -379,7 +411,23 @@ __global__ void IsinfCUDAKernel(
const T* in_data,
IndexType num,
bool* out_data,
typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
typename std::enable_if<std::is_floating_point<T>::value &&
!std::is_same<T, phi::bfloat16>::value &&
!std::is_same<T, phi::float16>::value>::type* = 0) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
const T& a = in_data[i];
out_data[i] = isinf(a);
}
}

template <typename T, typename IndexType>
__global__ void IsinfCUDAKernel(
const T* in_data,
IndexType num,
bool* out_data,
typename std::enable_if<std::is_same<T, phi::bfloat16>::value ||
std::is_same<T, phi::float16>::value>::type* = 0) {
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
const T& a = in_data[i];
Expand Down Expand Up @@ -477,9 +525,9 @@ struct IsinfFunctor<phi::GPUContext, T> {
#endif

template <typename T, typename Context>
PADDLE_API void IsfiniteKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
void IsfiniteKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
if (out && out->numel() == 0) {
dev_ctx.template Alloc<bool>(out);
return;
Expand Down
Loading