Skip to content

Commit 02de234

Browse files
authored
Activations LeakyReLU / PReLU / Softplus / Mish (#109)
* Leaky_relu / prelu / softplus / mish * added tests * updated bench * remove torch refs, add init to PReLU * added arvix reference to mish * added missing docs
1 parent f5df47e commit 02de234

File tree

8 files changed

+133
-31
lines changed

8 files changed

+133
-31
lines changed

benchmarks/python/comparative/bench_mlx.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,35 @@ def softmax_fused(axis, x):
9696
def relu(x):
9797
y = x
9898
for i in range(100):
99-
y = mx.maximum(y, 0)
99+
y = nn.relu(y)
100+
mx.eval(y)
101+
102+
103+
def leaky_relu(x: mx.array):
104+
y = x
105+
for i in range(100):
106+
y = nn.leaky_relu(y)
107+
mx.eval(y)
108+
109+
110+
def prelu(x: mx.array):
111+
y = x
112+
for i in range(100):
113+
y = nn.prelu(y, mx.ones(1))
114+
mx.eval(y)
115+
116+
117+
def softplus(x: mx.array):
118+
y = x
119+
for i in range(100):
120+
y = nn.softplus(y)
121+
mx.eval(y)
122+
123+
124+
def mish(x: mx.array):
125+
y = x
126+
for i in range(100):
127+
y = nn.mish(y)
100128
mx.eval(y)
101129

102130

@@ -334,24 +362,26 @@ def selu(x):
334362
elif args.benchmark == "relu":
335363
print(bench(relu, x))
336364

337-
elif args.benchmark == "leaky_relu":
338-
print(bench(leaky_relu, x))
339-
340365
elif args.benchmark == "elu":
341366
print(bench(elu, x))
342367

343368
elif args.benchmark == "relu6":
344369
print(bench(relu6, x))
345370

346-
elif args.benchmark == "softplus":
347-
print(bench(softplus, x))
348-
349371
elif args.benchmark == "celu":
350372
print(bench(celu, x))
351373

352374
elif args.benchmark == "log_sigmoid":
353375
print(bench(log_sigmoid, x))
354376

377+
elif args.benchmark == "leaky_relu":
378+
print(bench(leaky_relu, x))
379+
elif args.benchmark == "prelu":
380+
print(bench(prelu, x))
381+
elif args.benchmark == "softplus":
382+
print(bench(softplus, x))
383+
elif args.benchmark == "mish":
384+
print(bench(mish, x))
355385
elif args.benchmark == "scalar_mul":
356386
print(bench(scalar_mult, x))
357387

benchmarks/python/comparative/bench_torch.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,22 @@ def log_sigmoid(x):
163163
sync_if_needed(x)
164164

165165

166+
@torch.no_grad()
167+
def prelu(x: torch.Tensor) -> torch.Tensor:
168+
y = x
169+
for _ in range(100):
170+
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
171+
sync_if_needed(x)
172+
173+
174+
@torch.no_grad()
175+
def mish(x: torch.Tensor) -> torch.Tensor:
176+
y = x
177+
for _ in range(100):
178+
return torch.nn.functional.mish(y)
179+
sync_if_needed(x)
180+
181+
166182
@torch.no_grad()
167183
def scalar_mult(x):
168184
y = x
@@ -376,6 +392,10 @@ def selu(x):
376392
elif args.benchmark == "log_sigmoid":
377393
print(bench(log_sigmoid, x))
378394

395+
elif args.benchmark == "prelu":
396+
print(bench(prelu, x))
397+
elif args.benchmark == "mish":
398+
print(bench(mish, x))
379399
elif args.benchmark == "scalar_mul":
380400
print(bench(scalar_mult, x))
381401

benchmarks/python/comparative/compare.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ def predicate(x):
209209
compare_filtered("step --size 32x16x1024 --cpu")
210210
compare_filtered("selu --size 32x16x1024")
211211
compare_filtered("selu --size 32x16x1024 --cpu")
212+
# compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm
213+
compare_filtered("mish --size 32x16x1024 --cpu")
214+
compare_filtered("prelu --size 32x16x1024")
215+
compare_filtered("prelu --size 32x16x1024 --cpu")
216+
212217
compare_filtered("scalar_mul --size 32x16x1024")
213218
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
214219
compare_filtered("cross_entropy --size 256x1024")

docs/src/python/nn.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,12 @@ Neural Network Layers
146146

147147
Embedding
148148
ReLU
149+
PReLU
149150
GELU
150151
SiLU
151152
Step
152153
SELU
154+
Mish
153155
Linear
154156
Conv1d
155157
Conv2d
@@ -171,9 +173,11 @@ simple functions.
171173
gelu_approx
172174
gelu_fast_approx
173175
relu
176+
prelu
174177
silu
175178
step
176179
selu
180+
mish
177181

178182
Loss Functions
179183
--------------

python/mlx/nn/layers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
SELU,
88
LeakyReLU,
99
LogSigmoid,
10+
Mish,
11+
PReLU,
1012
ReLU,
1113
ReLU6,
1214
SiLU,
@@ -19,6 +21,8 @@
1921
gelu_fast_approx,
2022
leaky_relu,
2123
log_sigmoid,
24+
mish,
25+
prelu,
2226
relu,
2327
relu6,
2428
selu,

python/mlx/nn/layers/activations.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,33 @@ def selu(x):
176176
See also :func:`elu`.
177177
"""
178178
return elu(x, 1.67326) * 1.0507
179+
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
180+
r"""Applies the element-wise function:
181+
182+
.. math::
183+
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
184+
185+
Here :math:`a` is an array.
186+
"""
187+
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
188+
189+
190+
def mish(x: mx.array) -> mx.array:
191+
r"""Applies the Mish function, element-wise.
192+
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
193+
194+
Reference: https://arxiv.org/abs/1908.08681
195+
196+
.. math::
197+
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
198+
199+
"""
200+
return x * mx.tanh(softplus(x))
201+
202+
203+
@_make_activation_module(mish)
204+
class Mish(Module):
205+
pass
179206

180207

181208
@_make_activation_module(relu)
@@ -257,6 +284,15 @@ class LogSigmoid(Module):
257284
pass
258285

259286

287+
class PReLU(Module):
288+
def __init__(self, num_parameters=1, init=0.25):
289+
super().__init__()
290+
self.weight = mx.full([num_parameters], init)
291+
292+
def __call__(self, x: mx.array):
293+
return prelu(x, self.weight)
294+
295+
260296
class GELU(Module):
261297
r"""Applies the Gaussian Error Linear Units.
262298

python/tests/mlx_tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import os
44
import unittest
5+
from typing import Callable, List, Tuple
56

67
import mlx.core as mx
8+
import numpy as np
79

810

911
class MLXTestCase(unittest.TestCase):
@@ -16,3 +18,16 @@ def setUp(self):
1618

1719
def tearDown(self):
1820
mx.set_default_device(self.default)
21+
22+
def assertEqualArray(
23+
self,
24+
args: List[mx.array | float | int],
25+
mlx_func: Callable[..., mx.array],
26+
expected: mx.array,
27+
atol=1e-2,
28+
rtol=1e-2,
29+
):
30+
mx_res = mlx_func(*args)
31+
assert tuple(mx_res.shape) == tuple(expected.shape), "shape mismatch"
32+
assert mx_res.dtype == expected.dtype, "dtype mismatch"
33+
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)

python/tests/test_nn.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -449,31 +449,19 @@ def test_log_sigmoid(self):
449449
self.assertEqual(y.shape, [3])
450450
self.assertEqual(y.dtype, mx.float32)
451451

452-
def test_step_activation(self):
453-
x = mx.arange(-3, 4)
454-
expected = mx.array([0, 0, 0, 0, 0, 1, 1])
455-
y = nn.Step()(x)
456-
self.assertTrue(mx.array_equal(y, expected))
457-
458-
y = nn.Step(2)(x)
459-
expected = mx.array([0, 0, 0, 0, 0, 0, 1])
460-
self.assertTrue(mx.array_equal(y, expected))
461-
462-
def test_selu(self):
463-
x = mx.arange(-3, 4)
464-
expected = mx.array(
465-
[
466-
-1.670563817024231,
467-
-1.5201621055603027,
468-
-1.1113275289535522,
469-
0.0,
470-
1.0506999492645264,
471-
2.1013998985290527,
472-
3.152099847793579,
473-
]
452+
def test_prelu(self):
453+
self.assertEqualArray(
454+
[mx.array([1.0, -1.0, 0.0, 0.5])],
455+
nn.PReLU(),
456+
mx.array([1.0, -0.25, 0.0, 0.5]),
457+
)
458+
459+
def test_mish(self):
460+
self.assertEqualArray(
461+
[mx.array([1.0, -1.0, 0.0, 0.5])],
462+
nn.Mish(),
463+
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
474464
)
475-
y = nn.SELU()(x)
476-
self.assertTrue(mx.allclose(y, expected))
477465

478466

479467
if __name__ == "__main__":

0 commit comments

Comments
 (0)