Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions paddle/phi/kernels/impl/abs_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ struct AbsGradCUDAFunctor {
};

template <>
struct AbsGradCUDAFunctor<phi::bfloat16> {
struct AbsGradCUDAFunctor<bfloat16> {
HOSTDEVICE inline AbsGradCUDAFunctor() {}

HOSTDEVICE inline phi::bfloat16 operator()(const phi::bfloat16 x,
const phi::bfloat16 dout) const {
phi::bfloat16 output;
if (x == phi::bfloat16(0)) {
output = static_cast<phi::bfloat16>(0);
HOSTDEVICE inline bfloat16 operator()(const bfloat16 x,
const bfloat16 dout) const {
bfloat16 output;
if (x == bfloat16(0)) {
output = static_cast<bfloat16>(0);
} else {
output = (dout) * (x / abs(x));
}
Expand All @@ -55,30 +55,30 @@ struct AbsGradCUDAFunctor<phi::bfloat16> {
};

template <>
struct AbsGradCUDAFunctor<phi::complex64> {
struct AbsGradCUDAFunctor<complex64> {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
HOSTDEVICE inline phi::complex64 operator()(const phi::complex64 x,
const float dout) const {
phi::complex64 output;
if (x == phi::complex64(0)) {
output = phi::complex64(0);
HOSTDEVICE inline complex64 operator()(const complex64 x,
const float dout) const {
complex64 output;
if (x == complex64(0)) {
output = complex64(0);
} else {
output = phi::complex64(dout) * (x / phi::complex64(abs(x)));
output = complex64(dout) * (x / complex64(abs(x)));
}
return output;
}
};

template <>
struct AbsGradCUDAFunctor<phi::complex128> {
struct AbsGradCUDAFunctor<complex128> {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
HOSTDEVICE inline phi::complex128 operator()(const phi::complex128 x,
const double dout) const {
phi::complex128 output;
if (x == phi::complex128(0)) {
output = phi::complex128(0);
HOSTDEVICE inline complex128 operator()(const complex128 x,
const double dout) const {
complex128 output;
if (x == complex128(0)) {
output = complex128(0);
} else {
output = phi::complex128(dout) * (x / phi::complex128(abs(x)));
output = complex128(dout) * (x / complex128(abs(x)));
}
return output;
}
Expand Down Expand Up @@ -110,7 +110,7 @@ void AbsGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx) {
auto numel = dout.numel();
auto* dout_data = dout.data<phi::dtype::Real<T>>();
auto* dout_data = dout.data<dtype::Real<T>>();
auto* x_data = x.data<T>();

dev_ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
Expand Down
43 changes: 21 additions & 22 deletions paddle/phi/kernels/impl/accuracy_check_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ struct AccuracyCheckFunctor<CPUContext, T> {
};

template <typename T>
struct AccuracyCheckFunctor<CPUContext, phi::dtype::complex<T>> {
struct AccuracyCheckFunctor<CPUContext, dtype::complex<T>> {
void operator()(const CPUContext& dev_ctx,
const DenseTensor& in,
const DenseTensor& other,
Expand All @@ -99,8 +99,8 @@ struct AccuracyCheckFunctor<CPUContext, phi::dtype::complex<T>> {
const double atol,
bool equal_nan,
DenseTensor* output) {
auto* in_a = in.data<phi::dtype::complex<T>>();
auto* in_b = other.data<phi::dtype::complex<T>>();
auto* in_a = in.data<dtype::complex<T>>();
auto* in_b = other.data<dtype::complex<T>>();
auto* out_data = dev_ctx.template Alloc<bool>(output);
auto num = in.numel();
// *out_data = true;
Expand All @@ -110,7 +110,7 @@ struct AccuracyCheckFunctor<CPUContext, phi::dtype::complex<T>> {
bool val = false;
int res_index = -1;
for (int i = 0; i < num; i++) {
const phi::dtype::complex<T> a = in_a[i], b = in_b[i];
const dtype::complex<T> a = in_a[i], b = in_b[i];
if (std::isnan(a) || std::isnan(b)) {
val = equal_nan && std::isnan(a) == std::isnan(b);
} else {
Expand Down Expand Up @@ -146,7 +146,7 @@ __global__ void AccuracyCheckCUDAKernel(const T* in_data,
bool* out_data) {
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
bool val;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
using MPType = typename dtype::MPTypeTrait<T>::Type;
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
const double a = static_cast<MPType>(in_data[i]);
const double b = static_cast<MPType>(other_data[i]);
Expand All @@ -166,19 +166,18 @@ __global__ void AccuracyCheckCUDAKernel(const T* in_data,
}
}
template <>
__global__ void AccuracyCheckCUDAKernel<phi::complex64>(
const phi::complex64* in_data,
const phi::complex64* other_data,
const double rtol,
const double atol,
bool equal_nan,
int64_t num,
bool* out_data) {
__global__ void AccuracyCheckCUDAKernel<complex64>(const complex64* in_data,
const complex64* other_data,
const double rtol,
const double atol,
bool equal_nan,
int64_t num,
bool* out_data) {
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
bool val;
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
const phi::complex64 a = in_data[i];
const phi::complex64 b = other_data[i];
const complex64 a = in_data[i];
const complex64 b = other_data[i];
if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b);
} else {
Expand All @@ -196,9 +195,9 @@ __global__ void AccuracyCheckCUDAKernel<phi::complex64>(
}

template <>
__global__ void AccuracyCheckCUDAKernel<phi::complex128>(
const phi::complex128* in_data,
const phi::complex128* other_data,
__global__ void AccuracyCheckCUDAKernel<complex128>(
const complex128* in_data,
const complex128* other_data,
const double rtol,
const double atol,
bool equal_nan,
Expand All @@ -207,8 +206,8 @@ __global__ void AccuracyCheckCUDAKernel<phi::complex128>(
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
bool val;
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
const phi::complex128 a = in_data[i];
const phi::complex128 b = other_data[i];
const complex128 a = in_data[i];
const complex128 b = other_data[i];
if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b);
} else {
Expand All @@ -226,8 +225,8 @@ __global__ void AccuracyCheckCUDAKernel<phi::complex128>(
}

template <typename T>
struct AccuracyCheckFunctor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& dev_ctx,
struct AccuracyCheckFunctor<GPUContext, T> {
void operator()(const GPUContext& dev_ctx,
const DenseTensor& in,
const DenseTensor& other,
const std::string& fn_name,
Expand Down
Loading
Loading