Skip to content

Commit c7edafb

Browse files
gboduljakawni
andauthored
implemented InstanceNorm (#244)
* implemented instancenorm * implemented vector_norm in cpp added linalg to mlx * implemented vector_norm python binding * renamed vector_norm to norm, implemented norm without provided ord * completed the implementation of the norm * added tests * removed unused import in linalg.cpp * updated python bindings * added some tests for python bindings * handling inf, -inf as numpy does, more extensive tests of compatibility with numpy * added better docs and examples * refactored mlx.linalg.norm bindings * reused existing util for implementation of linalg.norm * more tests * fixed a bug with no ord and axis provided * removed unused imports * some style and API consistency updates to linalg norm * remove unused includes * fix python tests * fixed a bug with frobenius norm of a complex-valued matrix * complex for vector too * addressed PR review comments * fixed import order in __init__ * expected values in instancenorm tests are simple lists * minor return expression style change * added InstanceNorm to docs * doc string nits * added myself to individual contributors --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent dff4a38 commit c7edafb

File tree

5 files changed

+287
-1
lines changed

5 files changed

+287
-1
lines changed

ACKNOWLEDGMENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ MLX was developed with contributions from the following individuals:
1111
- Juarez Bochi: Fixed bug in cross attention.
1212
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
1313
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot` and safetensor support
14+
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
1415

1516
# Third-Party Software
1617

docs/src/python/nn/layers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Layers
2626
LayerNorm
2727
RMSNorm
2828
GroupNorm
29+
InstanceNorm
2930
Dropout
3031
Dropout2d
3132
Dropout3d

python/mlx/nn/layers/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@
4646
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
4747
from mlx.nn.layers.embedding import Embedding
4848
from mlx.nn.layers.linear import Bilinear, Identity, Linear
49-
from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm
49+
from mlx.nn.layers.normalization import (
50+
BatchNorm,
51+
GroupNorm,
52+
InstanceNorm,
53+
LayerNorm,
54+
RMSNorm,
55+
)
5056
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
5157
from mlx.nn.layers.quantized import QuantizedLinear
5258
from mlx.nn.layers.transformer import (

python/mlx/nn/layers/normalization.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,66 @@
66
from mlx.nn.layers.base import Module
77

88

9+
class InstanceNorm(Module):
10+
r"""Applies instance normalization [1] on the inputs.
11+
12+
Computes
13+
14+
.. math::
15+
16+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,
17+
18+
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
19+
parameters initialized at 1 and 0 respectively. Both are of size :attr:`dims`,
20+
if :attr:`affine` is ``True``.
21+
22+
Args:
23+
dims (int): The number of features of the input.
24+
eps (float): A value added to the denominator for numerical stability. Default: ``1e-5``.
25+
affine (bool): Default: ``False``.
26+
27+
Shape:
28+
- Input: :math:`(..., C)` where :math:`C` is equal to :attr:`dims`.
29+
- Output: Same shape as the input.
30+
31+
Examples:
32+
>>> import mlx.core as mx
33+
>>> import mlx.nn as nn
34+
>>> x = mx.random.normal((8, 4, 4, 16))
35+
>>> inorm = nn.InstanceNorm(dims=16)
36+
>>> output = inorm(x)
37+
38+
References:
39+
[1]: https://arxiv.org/abs/1607.08022
40+
"""
41+
42+
def __init__(
43+
self,
44+
dims: int,
45+
eps: float = 1e-5,
46+
affine: bool = False,
47+
):
48+
super().__init__()
49+
if affine:
50+
self.weight = mx.ones((dims,))
51+
self.bias = mx.zeros((dims,))
52+
self.dims = dims
53+
self.eps = eps
54+
55+
def _extra_repr(self):
56+
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
57+
58+
def __call__(self, x: mx.array) -> mx.array:
59+
reduction_axes = tuple(range(1, x.ndim - 1))
60+
# Compute stats
61+
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
62+
var = mx.var(x, axis=reduction_axes, keepdims=True)
63+
# Normalize
64+
x = (x - mean) * mx.rsqrt(var + self.eps)
65+
# Scale and shift if necessary
66+
return (self.weight * x + self.bias) if "weight" in self else x
67+
68+
969
class LayerNorm(Module):
1070
r"""Applies layer normalization [1] on the inputs.
1171

python/tests/test_nn.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,224 @@ def test_group_norm(self):
172172
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
173173
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
174174

175+
def test_instance_norm(self):
176+
# Test InstanceNorm1d
177+
x = mx.array(
178+
[
179+
[
180+
[-0.0119524, 1.1263, 2.02223],
181+
[-0.500331, 0.517899, -1.21143],
182+
[1.12958, -0.21413, -2.48738],
183+
[1.39955, 0.891329, 1.63289],
184+
],
185+
[
186+
[0.241417, -0.619157, -0.77484],
187+
[-1.42512, 0.970817, -1.31352],
188+
[2.739, -1.2506, 1.56844],
189+
[-1.23175, 0.32756, 1.13969],
190+
],
191+
]
192+
)
193+
inorm = nn.InstanceNorm(dims=3)
194+
y = inorm(x)
195+
expected_y = [
196+
[
197+
[-0.657082, 1.07593, 1.0712],
198+
[-1.27879, -0.123074, -0.632505],
199+
[0.796101, -1.56572, -1.30476],
200+
[1.13978, 0.612862, 0.866067],
201+
],
202+
[
203+
[0.0964426, -0.557906, -0.759885],
204+
[-0.904772, 1.30444, -1.20013],
205+
[1.59693, -1.29752, 1.15521],
206+
[-0.7886, 0.550987, 0.804807],
207+
],
208+
]
209+
self.assertTrue(x.shape == y.shape)
210+
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
211+
# Test InstanceNorm2d
212+
x = mx.array(
213+
[
214+
[
215+
[
216+
[-0.458824, 0.483254, -0.58611],
217+
[-0.447996, -0.176577, -0.622545],
218+
[0.0486988, -0.0611224, 1.8845],
219+
],
220+
[
221+
[1.13049, 0.345315, -0.926389],
222+
[0.301795, 0.99207, -0.184927],
223+
[-2.23876, -0.758631, -1.12639],
224+
],
225+
[
226+
[0.0986325, -1.82973, -0.241765],
227+
[-1.25257, 0.154442, -0.556204],
228+
[-0.329399, -0.319107, 0.830584],
229+
],
230+
],
231+
[
232+
[
233+
[1.04407, 0.073752, 0.407081],
234+
[0.0800776, 1.2513, 1.20627],
235+
[0.782321, -0.444367, 0.563132],
236+
],
237+
[
238+
[0.671423, -1.21689, -1.88979],
239+
[-0.110299, -1.42248, 1.17838],
240+
[0.159905, 0.516452, -0.539121],
241+
],
242+
[
243+
[0.810252, 1.50456, 1.08659],
244+
[0.182597, 0.0576239, 0.973883],
245+
[-0.0621687, 0.184253, 0.784216],
246+
],
247+
],
248+
]
249+
)
250+
inorm = nn.InstanceNorm(dims=3)
251+
y = inorm(x)
252+
expected_y = [
253+
[
254+
[
255+
[-0.120422, 0.801503, -0.463983],
256+
[-0.108465, -0.0608611, -0.504602],
257+
[0.440008, 0.090032, 2.29032],
258+
],
259+
[
260+
[1.63457, 0.621224, -0.843335],
261+
[0.719488, 1.4665, -0.0167344],
262+
[-2.08591, -0.821575, -1.0663],
263+
],
264+
[
265+
[0.495147, -2.22145, -0.0800989],
266+
[-0.996913, 0.371763, -0.430643],
267+
[0.022495, -0.24714, 1.11538],
268+
],
269+
],
270+
[
271+
[
272+
[1.5975, 0.0190292, -0.0123306],
273+
[-0.776381, 1.28291, 0.817237],
274+
[0.952927, -0.537076, 0.149652],
275+
],
276+
[
277+
[0.679836, -1.36624, -2.39651],
278+
[-1.24519, -1.5869, 0.788287],
279+
[-0.579802, 0.494186, -0.994499],
280+
],
281+
[
282+
[1.02171, 1.55474, 0.693008],
283+
[-0.523922, 0.00171862, 0.576016],
284+
[-1.12667, 0.137632, 0.37914],
285+
],
286+
],
287+
]
288+
self.assertTrue(x.shape == y.shape)
289+
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
290+
# # Test InstanceNorm3d
291+
x = mx.array(
292+
[
293+
[
294+
[
295+
[[0.777621, 0.528145, -1.56133], [-2.1722, 0.128192, 0.153862]],
296+
[
297+
[-1.41317, 0.476288, -1.20411],
298+
[0.284446, -0.649858, 0.152112],
299+
],
300+
],
301+
[
302+
[[0.11, -0.12431, 1.18768], [-0.837743, 1.93502, 0.00236324]],
303+
[
304+
[-2.40205, -1.25873, -2.04243],
305+
[0.336682, -0.261986, 1.54289],
306+
],
307+
],
308+
[
309+
[
310+
[0.789185, -1.63747, 0.67917],
311+
[-1.42998, -1.73247, -0.402572],
312+
],
313+
[
314+
[-0.459489, -2.15559, -0.249959],
315+
[0.0298199, 0.10275, -0.821897],
316+
],
317+
],
318+
],
319+
[
320+
[
321+
[
322+
[-2.12354, 0.643973, 0.72391],
323+
[0.317797, -0.682916, 0.016364],
324+
],
325+
[
326+
[-0.146628, -0.987925, 0.573199],
327+
[0.0329215, 1.54086, 0.213092],
328+
],
329+
],
330+
[
331+
[
332+
[-1.55784, 0.71179, -0.0678402],
333+
[2.41031, -0.290786, 0.00449439],
334+
],
335+
[
336+
[0.226341, 0.057712, -1.58342],
337+
[0.265387, -0.742304, 1.28133],
338+
],
339+
],
340+
[
341+
[
342+
[0.990317, -0.399875, -0.357647],
343+
[0.475161, -1.10479, -1.07389],
344+
],
345+
[
346+
[-1.37804, 1.40097, 0.141618],
347+
[-0.501041, 0.0723374, -0.386141],
348+
],
349+
],
350+
],
351+
]
352+
)
353+
inorm = nn.InstanceNorm(dims=3)
354+
y = inorm(x)
355+
expected_y = [
356+
[
357+
[
358+
[[1.23593, 0.821849, -1.30944], [-1.54739, 0.462867, 0.357126]],
359+
[[-0.831204, 0.775304, -0.962338], [0.770588, -0.23548, 0.355425]],
360+
],
361+
[
362+
[[0.605988, 0.236231, 1.36163], [-0.288258, 2.0846, 0.209922]],
363+
[[-1.76427, -0.78198, -1.77689], [0.819875, 0.112659, 1.70677]],
364+
],
365+
[
366+
[[1.24684, -1.12192, 0.867539], [-0.847068, -1.20719, -0.183531]],
367+
[
368+
[0.0686449, -1.58697, -0.0352458],
369+
[0.530334, 0.440032, -0.590967],
370+
],
371+
],
372+
],
373+
[
374+
[
375+
[[-1.75315, 0.733967, 1.04349], [0.343736, -0.822472, 0.080661]],
376+
[[-0.0551618, -1.18025, 0.838402], [0.0990544, 1.78602, 0.348368]],
377+
],
378+
[
379+
[[-1.26726, 0.813517, -0.033924], [2.14101, -0.362504, 0.0645089]],
380+
[[0.265184, 0.0462839, -2.09632], [0.298721, -0.892134, 1.80203]],
381+
],
382+
[
383+
[[0.921369, -0.490465, -0.428293], [0.478897, -1.31732, -1.40296]],
384+
[[-1.11283, 1.62192, 0.251107], [-0.35957, 0.0634394, -0.467067]],
385+
],
386+
],
387+
]
388+
self.assertTrue(x.shape == y.shape)
389+
self.assertTrue(np.allclose(y, expected_y, atol=1e-5))
390+
# Test repr
391+
self.assertTrue(str(inorm) == "InstanceNorm(3, eps=1e-05, affine=False)")
392+
175393
def test_batch_norm(self):
176394
mx.random.seed(42)
177395
x = mx.random.normal((5, 4), dtype=mx.float32)

0 commit comments

Comments
 (0)