Skip to content

Commit 7af18c4

Browse files
authored
update cross_entropy_grad_kernel.cu (PaddlePaddle#78638)
1 parent a1fa887 commit 7af18c4

3 files changed

Lines changed: 5 additions & 10 deletions

File tree

paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,6 @@ void CrossEntropyWithSoftmaxGradGPUKernel(const GPUContext& dev_ctx,
169169
int ignore_index,
170170
int axis,
171171
DenseTensor* logits_grad) {
172-
PADDLE_ENFORCE_EQ(
173-
dev_ctx.GetPlace().GetType(),
174-
AllocationType::GPU,
175-
common::errors::Unavailable("softmax_with_cross_entropy operator's "
176-
"CUDA kernel only runs on GPU device."));
177172
const T* loss_grad_data = loss_grad.data<T>();
178173
DenseTensor* logit_grad = logits_grad;
179174

paddle/phi/kernels/impl/gammaincc_kernel_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ HOSTDEVICE T igam(const T a, const T x) {
5656

5757
template <typename T>
5858
HOSTDEVICE T igamc(const T a, const T x) {
59-
static T big = 4.503599627370496e15;
60-
static T biginv = 2.22044604925031308085e-16;
59+
static const T big = 4.503599627370496e15;
60+
static const T biginv = 2.22044604925031308085e-16;
6161

6262
if ((x <= T{0}) || (a <= T{0})) return (T{1.0});
6363

paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
namespace phi {
2121
template <typename T>
2222
HOSTDEVICE T digamma_positive_domain(T x) {
23-
static T c = T{8.5};
24-
static T euler_mascheroni = T{0.57721566490153286060};
23+
static const T c = T{8.5};
24+
static const T euler_mascheroni = T{0.57721566490153286060};
2525
T r;
2626
T value;
2727
T x2;
@@ -54,7 +54,7 @@ HOSTDEVICE T digamma_positive_domain(T x) {
5454

5555
template <typename T>
5656
HOSTDEVICE T digamma(T x) {
57-
static T pi = T{3.14159265358979323846};
57+
static const T pi = T{3.14159265358979323846};
5858

5959
if (x == T{0.0}) {
6060
T inf = std::numeric_limits<T>::infinity();

0 commit comments

Comments
 (0)