Skip to content

Commit df035ec

Browse files
authored
Fix lpnormpool
Implemented the correct LpNorm Pooling and backprop steps
1 parent 1468582 commit df035ec

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/impl/pooling_direct.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ for name in (:max, :mean, :lpnorm)
9696
m += x[input_kw, input_kh, input_kd, c, batch_idx]
9797
elseif $(name == :lpnorm)
9898
# y = (∑ᵢ xᵢ^p)^(1 / p), here to calculate ∑ᵢ xᵢ^p
99-
m += x[input_kw, input_kh, input_kd, c, batch_idx]^p
99+
m += abs(x[input_kw, input_kh, input_kd, c, batch_idx])^p
100100
else
101101
error("Unimplemented codegen path")
102102
end
@@ -151,7 +151,7 @@ for name in (:max, :mean, :lpnorm)
151151
elseif $(name == :mean)
152152
m += x[input_kw, input_kh, input_kd, c, batch_idx]
153153
elseif $(name == :lpnorm)
154-
m += x[input_kw, input_kh, input_kd, c, batch_idx]^p
154+
m += abs(x[input_kw, input_kh, input_kd, c, batch_idx])^p
155155
else
156156
error("Unimplemented codegen path")
157157
end
@@ -264,7 +264,8 @@ for name in (:max, :mean, :lpnorm)
264264
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha
265265
elseif $(name == :lpnorm)
266266
# y = (∑ᵢ xᵢ^p)^(1 / p), ∂y/∂xᵢ = xᵢ^(p-1) × y^(1-p)
267-
grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p)
267+
xv = x[input_kw, input_kh, input_kd, c, batch_idx]
268+
grad = abs(xv)^(p-1) * y_idx^(1-p) * sign(xv)
268269
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad
269270
else
270271
error("Unimplemented codegen path")
@@ -327,7 +328,8 @@ for name in (:max, :mean, :lpnorm)
327328
elseif $(name == :mean)
328329
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha #+ _beta * dx[x_idxs...]
329330
elseif $(name == :lpnorm)
330-
grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p)
331+
xv = x[input_kw, input_kh, input_kd, c, batch_idx]
332+
grad = abs(xv)^(p-1) * y_idx^(1-p) * sign(xv)
331333
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad
332334
else
333335
error("Unimplemented codegen path")

0 commit comments

Comments
 (0)