Skip to content

Commit 47179ef

Browse files
authored
[Precision Depth Alignment]support p_norm_grad (#78477)
1 parent 8a61375 commit 47179ef

File tree

18 files changed

+476
-27
lines changed

18 files changed

+476
-27
lines changed

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2806,8 +2806,8 @@ struct SimpleOpTypeSetTeller : public Teller {
28062806
}
28072807
bool asvector = PADDLE_GET_CONST(bool, desc.GetAttr("asvector"));
28082808
int axis = PADDLE_GET_CONST(int, desc.GetAttr("axis"));
2809-
float porder = PADDLE_GET_CONST(float, desc.GetAttr("porder"));
2810-
if (asvector || porder != 2.0f || axis != -1) {
2809+
double porder = PADDLE_GET_CONST(double, desc.GetAttr("porder"));
2810+
if (asvector || porder != 2.0 || axis != -1) {
28112811
VLOG(3) << op_type
28122812
<< " op only support asvector=False, porder=2, axis = -1.";
28132813
return false;

paddle/fluid/pir/serialize_deserialize/patch/4.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ op_patches:
121121
object : using_weighted_combine
122122
type : pir::BoolAttribute
123123
data : "false"
124+
- op_name : pd_op.p_norm
125+
actions :
126+
- action : modify_attr
127+
object : porder
128+
type : pir::DoubleAttribute
129+
data : 2
124130
- op_name : pd_op.layer_norm
125131
actions:
126132
- action : modify_attr

paddle/phi/infermeta/unary.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3781,7 +3781,7 @@ void PixelUnshuffleInferMeta(const MetaTensor& x,
37813781
}
37823782

37833783
void PNormInferMeta(const MetaTensor& x,
3784-
float porder,
3784+
double porder,
37853785
int axis,
37863786
float epsilon,
37873787
bool keepdim,

paddle/phi/infermeta/unary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ PADDLE_API void PixelUnshuffleInferMeta(const MetaTensor& x,
591591
MetaTensor* out);
592592

593593
PADDLE_API void PNormInferMeta(const MetaTensor& x,
594-
float porder,
594+
double porder,
595595
int axis,
596596
float epsilon,
597597
bool keepdim,

paddle/phi/kernels/cpu/p_norm_grad_kernel.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ void PNormGradKernel(const Context& dev_ctx,
4444
const DenseTensor& x,
4545
const DenseTensor& out,
4646
const DenseTensor& out_grad,
47-
float porder,
47+
double porder,
4848
int axis,
4949
float epsilon,
5050
bool keepdim UNUSED,
@@ -88,8 +88,8 @@ void PNormGradKernel(const Context& dev_ctx,
8888
xr.sign() * norm_dy.broadcast(bcast);
8989
} else {
9090
dx.device(*place) =
91-
(xr.abs()).pow(porder - 1.0f) /
92-
((norm.broadcast(bcast)).pow(porder - 1.0f) + xr.constant(eps));
91+
(xr.abs()).pow(porder - 1.0) /
92+
((norm.broadcast(bcast)).pow(porder - 1.0) + xr.constant(eps));
9393
dx.device(*place) = dx * norm_dy.broadcast(bcast) * xr.sign();
9494
}
9595
}

paddle/phi/kernels/cpu/p_norm_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ inline void GetDims(
4444
template <typename T, typename Context>
4545
void PNormKernel(const Context& dev_ctx,
4646
const DenseTensor& x,
47-
float porder,
47+
double porder,
4848
int axis,
4949
float epsilon UNUSED,
5050
bool keepdim UNUSED,
@@ -90,7 +90,7 @@ void PNormKernel(const Context& dev_ctx,
9090
} else if (porder == -INFINITY) {
9191
norm.device(*place) = xr.abs().minimum(rdim);
9292
} else {
93-
norm.device(*place) = xr.abs().pow(porder).sum(rdim).pow(1.0f / porder);
93+
norm.device(*place) = xr.abs().pow(porder).sum(rdim).pow(1.0 / porder);
9494
}
9595
}
9696
} // namespace phi

0 commit comments

Comments
 (0)