Skip to content

Commit 6c15b18

Browse files
authored
Merge pull request #78 from itan1/add-leakyrelu
Add leakyrelu
2 parents 8e1ccba + 1f2ea0d commit 6c15b18

File tree

6 files changed

+27
-4
lines changed

6 files changed

+27
-4
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ Gemm
8787
GlobalAveragePool
8888
GlobalMaxPool
8989
LSTM
90+
LeakyRelu
9091
MatMul
9192
MaxPool
9293
Mul

src/deserialize/ops.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ constant(::Val{:value}, val) = val
5151
actfuns[:Relu] = params -> Flux.relu
5252
actfuns[:Sigmoid] = params -> Flux.σ
5353

54+
actfuns[:LeakyRelu] = function(params)
55+
α = get(params, :alpha, 0.01f0)
56+
return x -> Flux.leakyrelu(x, oftype(x, α))
57+
end
58+
rnnactfuns[:LeakyRelu] = (ind, params) -> actfuns[:LeakyRelu](Dict(:alpha => get(params, :activation_alpha, ntuple(i -> 0.01f0, ind))[ind]))
59+
5460
actfuns[:Elu] = function(params)
5561
α = get(params, :alpha, 1)
5662
return x -> Flux.elu(x, oftype(x, α))

src/serialize/serialize.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ function attribfun(fhshape, optype, pps::AbstractProbe...; attributes = ONNX.Att
365365
end
366366

367367
Flux.relu(pp::AbstractProbe) = attribfun(identity, "Relu", pp)
368+
Flux.leakyrelu(pp::AbstractProbe, α=0.01f0) = attribfun(identity, "LeakyRelu", pp; attributes = [ONNX.AttributeProto("alpha", α)])
368369
Flux.elu(pp::AbstractProbe, α=1f0) = attribfun(identity, "Elu", pp; attributes = [ONNX.AttributeProto("alpha", α)])
369370
Flux.selu(pp::AbstractProbe) = attribfun(identity, "Selu", pp)
370371
Flux.selu(pp::AbstractProbe, γ, α) = attribfun(identity, "Selu", pp; attributes = ONNX.AttributeProto.(["gamma", "alpha"], [γ, α]))

test/deserialize/Artifacts.toml

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ git-tree-sha1 = "8be97aa969ebdbe7599798d511a7790eba0697f2"
8383
git-tree-sha1 = "41eca620fb09f7d90ec9b875a80388566baadada"
8484

8585
[test_div]
86-
git-tree-sha1 = "430fd9135e60076904f717971bd174b0f16e1c54"
86+
git-tree-sha1 = "57dd66f7274aac0e2a462e49dadf5a551c4e5e80"
8787

8888
[test_dropout_default]
8989
git-tree-sha1 = "70fe420142b8d29b708578e4b6f2929e6907cb4c"
@@ -169,14 +169,23 @@ git-tree-sha1 = "377710458916cc790bb7eec00c8e3f0719680cf8"
169169
[test_globalmaxpool_precomputed]
170170
git-tree-sha1 = "6d72b58370176351d46937ca3df65ba2fd114f04"
171171

172+
[test_leakyrelu]
173+
git-tree-sha1 = "07afe319b71db2cb6bc295ff9409482721473817"
174+
175+
[test_leakyrelu_default]
176+
git-tree-sha1 = "2751dbd14e5feaf6c59798e84ae9e7d9700240b6"
177+
178+
[test_leakyrelu_example]
179+
git-tree-sha1 = "b7e814cb5b5d6d538db1d6af49c02b786cb0036e"
180+
172181
[test_lstm_defaults]
173182
git-tree-sha1 = "c8b0d06dc9733222906bb6471c39a9c41270d149"
174183

175184
[test_lstm_with_initial_bias]
176185
git-tree-sha1 = "19fe9305067a4225c6dd76264bb342cf966546ae"
177186

178187
[test_matmul_2d]
179-
git-tree-sha1 = "481de0ea5b1fb4692f10920215ce701df1b7ba09"
188+
git-tree-sha1 = "3008bfa77da6160c406f6dae414b19861fef9f13"
180189

181190
[test_maxpool_1d_default]
182191
git-tree-sha1 = "9b0a2b97518eb68122276b242313529582e4be95"
@@ -269,10 +278,10 @@ git-tree-sha1 = "44c37442a35def50d4e1230cdcff8a986899d18d"
269278
git-tree-sha1 = "dc4a6180985e796aca6997ae79137fee6f9c05e9"
270279

271280
[test_sigmoid]
272-
git-tree-sha1 = "2bb16571d0809d1e6216a1b6eb5b27d332fac4f0"
281+
git-tree-sha1 = "6f0e41cd8b1498f3c60b3d00b9558bf311217bf8"
273282

274283
[test_sigmoid_example]
275-
git-tree-sha1 = "710a8b85b3a7e9e301af21f8f1c0e7b087536ba5"
284+
git-tree-sha1 = "b8b836fd3cb97d2801777c03e98d6cd41bc17b2d"
276285

277286
[test_softmax_axis_0]
278287
git-tree-sha1 = "4f090b0b0f540b5f133176e4cf8a118c24db1886"

test/deserialize/deserialize.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ end
139139
(name="test_elu_default", ninputs=1, noutputs=1),
140140
(name="test_elu_example", ninputs=1, noutputs=1),
141141
(name="test_relu", ninputs=1, noutputs=1),
142+
(name="test_leakyrelu", ninputs=1, noutputs=1),
143+
(name="test_leakyrelu_default", ninputs=1, noutputs=1),
144+
(name="test_leakyrelu_example", ninputs=1, noutputs=1),
142145
(name="test_selu", ninputs=1, noutputs=1),
143146
(name="test_selu_default", ninputs=1, noutputs=1),
144147
(name="test_selu_example", ninputs=1, noutputs=1),

test/serialize/serialize.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242

4343
@testset "Paramfree op $(tc.op) attrs: $(pairs(tc.attr))" for tc in (
4444
(op=:Relu, attr = Dict(), fd=actfuns),
45+
(op=:LeakyRelu, attr = Dict(), fd=actfuns),
46+
(op=:LeakyRelu, attr = Dict(:alpha => 0.05f0), fd=actfuns),
4547
(op=:Elu, attr = Dict(), fd=actfuns),
4648
(op=:Elu, attr = Dict(:alpha => 0.5f0), fd=actfuns),
4749
(op=:Selu, attr = Dict(), fd=actfuns),
@@ -154,6 +156,7 @@
154156

155157
@testset "Layer with activation function $actfun" for actfun in (
156158
relu,
159+
leakyrelu,
157160
elu,
158161
selu,
159162
tanh,

0 commit comments

Comments
 (0)