Skip to content

Commit 7cbb4ae

Browse files
authored
Doc fix (#1615)
1 parent 02bec0b commit 7cbb4ae

File tree

2 files changed

+47
-107
lines changed

2 files changed

+47
-107
lines changed

docs/src/python/nn/layers.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Layers
1212
ALiBi
1313
AvgPool1d
1414
AvgPool2d
15+
AvgPool3d
1516
BatchNorm
1617
CELU
1718
Conv1d
@@ -41,6 +42,7 @@ Layers
4142
LSTM
4243
MaxPool1d
4344
MaxPool2d
45+
MaxPool3d
4446
Mish
4547
MultiHeadAttention
4648
PReLU

python/mlx/nn/layers/pooling.py

Lines changed: 45 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,8 @@ def __init__(
185185
class MaxPool1d(_Pool1d):
186186
r"""Applies 1-dimensional max pooling.
187187
188-
Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is
189-
:math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given
190-
by:
191-
192-
.. math::
193-
\text{out}(N_i, t, C_j) = \max_{m=0, \ldots, k - 1}
194-
\text{input}(N_i, \text{stride} \times t + m, C_j),
195-
196-
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
197-
\text{kernel\_size}}{\text{stride}}\right\rfloor + 1`.
188+
Spatially downsamples the input by taking the maximum of a sliding window
189+
of size ``kernel_size`` and sliding stride ``stride``.
198190
199191
Args:
200192
kernel_size (int or tuple(int)): The size of the pooling window kernel.
@@ -224,16 +216,8 @@ def __init__(
224216
class AvgPool1d(_Pool1d):
225217
r"""Applies 1-dimensional average pooling.
226218
227-
Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is
228-
:math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given
229-
by:
230-
231-
.. math::
232-
\text{out}(N_i, t, C_j) = \frac{1}{k} \sum_{m=0, \ldots, k - 1}
233-
\text{input}(N_i, \text{stride} \times t + m, C_j),
234-
235-
where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
236-
\text{kernel\_size}}{\text{stride}}\right\rfloor + 1`.
219+
Spatially downsamples the input by taking the average of a sliding window
220+
of size ``kernel_size`` and sliding stride ``stride``.
237221
238222
Args:
239223
kernel_size (int or tuple(int)): The size of the pooling window kernel.
@@ -263,26 +247,15 @@ def __init__(
263247
class MaxPool2d(_Pool2d):
264248
r"""Applies 2-dimensional max pooling.
265249
266-
Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is
267-
:math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out},
268-
W_{out}, C)`, given by:
269-
270-
.. math::
271-
\begin{aligned}
272-
\text{out}(N_i, h, w, C_j) = & \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
273-
& \text{input}(N_i, \text{stride[0]} \times h + m,
274-
\text{stride[1]} \times w + n, C_j),
275-
\end{aligned}
250+
Spatially downsamples the input by taking the maximum of a sliding window
251+
of size ``kernel_size`` and sliding stride ``stride``.
276252
277-
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
278-
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
253+
The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:
279254
280-
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
281-
282-
- a single ``int`` -- in which case the same value is used for both the
283-
height and width axis;
284-
- a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
285-
used for the height axis, the second ``int`` for the width axis.
255+
* a single ``int`` -- in which case the same value is used for both the
256+
height and width axis.
257+
* a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
258+
used for the height axis, the second ``int`` for the width axis.
286259
287260
Args:
288261
kernel_size (int or tuple(int, int)): The size of the pooling window.
@@ -312,26 +285,15 @@ def __init__(
312285
class AvgPool2d(_Pool2d):
313286
r"""Applies 2-dimensional average pooling.
314287
315-
Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is
316-
:math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out},
317-
W_{out}, C)`, given by:
318-
319-
.. math::
320-
\begin{aligned}
321-
\text{out}(N_i, h, w, C_j) = & \frac{1}{k_H k_W} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
322-
& \text{input}(N_i, \text{stride[0]} \times h + m,
323-
\text{stride[1]} \times w + n, C_j),
324-
\end{aligned}
325-
326-
where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
327-
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
288+
Spatially downsamples the input by taking the average of a sliding window
289+
of size ``kernel_size`` and sliding stride ``stride``.
328290
329-
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
291+
The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:
330292
331-
- a single ``int`` -- in which case the same value is used for both the
332-
height and width axis;
333-
- a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
334-
used for the height axis, the second ``int`` for the width axis.
293+
* a single ``int`` -- in which case the same value is used for both the
294+
height and width axis.
295+
* a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
296+
used for the height axis, the second ``int`` for the width axis.
335297
336298
Args:
337299
kernel_size (int or tuple(int, int)): The size of the pooling window.
@@ -359,30 +321,18 @@ def __init__(
359321

360322

361323
class MaxPool3d(_Pool3d):
362-
"""
363-
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
364-
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
365-
H_{out}, W_{out}, C)`, given by:
366-
367-
.. math::
368-
\begin{aligned}
369-
\text{out}(N_i, d, h, w, C_j) = & \max_{l=0, \ldots, k_D-1} \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
370-
& \text{input}(N_i, \text{stride[0]} \times d + l,
371-
\text{stride[1]} \times h + m,
372-
\text{stride[2]} \times w + n, C_j),
373-
\end{aligned}
374-
375-
where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
376-
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
377-
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.
378-
379-
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
380-
381-
- a single ``int`` -- in which case the same value is used for the depth,
382-
height and width axis;
383-
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
384-
for the depth axis, the second ``int`` for the height axis, and the third
385-
``int`` for the width axis.
324+
r"""Applies 3-dimensional max pooling.
325+
326+
Spatially downsamples the input by taking the maximum of a sliding window
327+
of size ``kernel_size`` and sliding stride ``stride``.
328+
329+
The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:
330+
331+
* a single ``int`` -- in which case the same value is used for the depth,
332+
height, and width axis.
333+
* a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
334+
for the depth axis, the second ``int`` for the height axis, and the third
335+
``int`` for the width axis.
386336
387337
Args:
388338
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
@@ -410,40 +360,28 @@ def __init__(
410360

411361

412362
class AvgPool3d(_Pool3d):
413-
"""
414-
Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
415-
:math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
416-
H_{out}, W_{out}, C)`, given by:
417-
418-
.. math::
419-
\begin{aligned}
420-
\text{out}(N_i, d, h, w, C_j) = & \frac{1}{k_D k_H k_W} \sum_{l=0, \ldots, k_D-1} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
421-
& \text{input}(N_i, \text{stride[0]} \times d + l,
422-
\text{stride[1]} \times h + m,
423-
\text{stride[2]} \times w + n, C_j),
424-
\end{aligned}
425-
426-
where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
427-
:math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`,
428-
:math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`.
429-
430-
The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
431-
432-
- a single ``int`` -- in which case the same value is used for the depth,
433-
height and width axis;
434-
- a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
435-
for the depth axis, the second ``int`` for the height axis, and the third
436-
``int`` for the width axis.
437-
438-
Args:
363+
r"""Applies 3-dimensional average pooling.
364+
365+
Spatially downsamples the input by taking the average of a sliding window
366+
of size ``kernel_size`` and sliding stride ``stride``.
367+
368+
The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:
369+
370+
* a single ``int`` -- in which case the same value is used for the depth,
371+
height, and width axis.
372+
* a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
373+
for the depth axis, the second ``int`` for the height axis, and the third
374+
``int`` for the width axis.
375+
376+
Args:
439377
kernel_size (int or tuple(int, int, int)): The size of the pooling window.
440378
stride (int or tuple(int, int, int), optional): The stride of the pooling
441379
window. Default: ``kernel_size``.
442380
padding (int or tuple(int, int, int), optional): How much zero
443381
padding to apply to the input. The padding is applied on both sides
444382
of the depth, height and width axis. Default: ``0``.
445383
446-
Examples:
384+
Examples:
447385
>>> import mlx.core as mx
448386
>>> import mlx.nn.layers as nn
449387
>>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))

0 commit comments

Comments
 (0)