@@ -185,16 +185,8 @@ def __init__(
185185class MaxPool1d (_Pool1d ):
186186 r"""Applies 1-dimensional max pooling.
187187
188- Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is
189- :math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given
190- by:
191-
192- .. math::
193- \text{out}(N_i, t, C_j) = \max_{m=0, \ldots, k - 1}
194- \text{input}(N_i, \text{stride} \times t + m, C_j),
195-
196- where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
197- \text{kernel\_size}}{\text{stride}}\right\rfloor + 1`.
188+ Spatially downsamples the input by taking the maximum of a sliding window
189+ of size ``kernel_size`` and sliding stride ``stride``.
198190
199191 Args:
200192 kernel_size (int or tuple(int)): The size of the pooling window kernel.
@@ -224,16 +216,8 @@ def __init__(
224216class AvgPool1d (_Pool1d ):
225217 r"""Applies 1-dimensional average pooling.
226218
227- Assuming an input of shape :math:`(N, L, C)` and ``kernel_size`` is
228- :math:`k`, the output is a tensor of shape :math:`(N, L_{out}, C)`, given
229- by:
230-
231- .. math::
232- \text{out}(N_i, t, C_j) = \frac{1}{k} \sum_{m=0, \ldots, k - 1}
233- \text{input}(N_i, \text{stride} \times t + m, C_j),
234-
235- where :math:`L_{out} = \left\lfloor \frac{L + 2 \times \text{padding} -
236- \text{kernel\_size}}{\text{stride}}\right\rfloor + 1`.
219+ Spatially downsamples the input by taking the average of a sliding window
220+ of size ``kernel_size`` and sliding stride ``stride``.
237221
238222 Args:
239223 kernel_size (int or tuple(int)): The size of the pooling window kernel.
@@ -263,26 +247,15 @@ def __init__(
263247class MaxPool2d (_Pool2d ):
264248 r"""Applies 2-dimensional max pooling.
265249
266- Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is
267- :math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out},
268- W_{out}, C)`, given by:
269-
270- .. math::
271- \begin{aligned}
272- \text{out}(N_i, h, w, C_j) = & \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
273- & \text{input}(N_i, \text{stride[0]} \times h + m,
274- \text{stride[1]} \times w + n, C_j),
275- \end{aligned}
250+ Spatially downsamples the input by taking the maximum of a sliding window
251+ of size ``kernel_size`` and sliding stride ``stride``.
276252
277- where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
278- :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
253+ The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:
279254
280- The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
281-
282- - a single ``int`` -- in which case the same value is used for both the
283- height and width axis;
284- - a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
285- used for the height axis, the second ``int`` for the width axis.
255+ * a single ``int`` -- in which case the same value is used for both the
256+ height and width axis.
257+ * a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
258+ used for the height axis, the second ``int`` for the width axis.
286259
287260 Args:
288261 kernel_size (int or tuple(int, int)): The size of the pooling window.
@@ -312,26 +285,15 @@ def __init__(
312285class AvgPool2d (_Pool2d ):
313286 r"""Applies 2-dimensional average pooling.
314287
315- Assuming an input of shape :math:`(N, H, W, C)` and ``kernel_size`` is
316- :math:`(k_H, k_W)`, the output is a tensor of shape :math:`(N, H_{out},
317- W_{out}, C)`, given by:
318-
319- .. math::
320- \begin{aligned}
321- \text{out}(N_i, h, w, C_j) = & \frac{1}{k_H k_W} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
322- & \text{input}(N_i, \text{stride[0]} \times h + m,
323- \text{stride[1]} \times w + n, C_j),
324- \end{aligned}
325-
326- where :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`,
327- :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`.
288+ Spatially downsamples the input by taking the average of a sliding window
289+ of size ``kernel_size`` and sliding stride ``stride``.
328290
329- The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
291+ The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:
330292
331- - a single ``int`` -- in which case the same value is used for both the
332- height and width axis;
333- - a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
334- used for the height axis, the second ``int`` for the width axis.
293+ * a single ``int`` -- in which case the same value is used for both the
294+ height and width axis.
295+ * a ``tuple`` of two ``int`` s -- in which case, the first ``int`` is
296+ used for the height axis, the second ``int`` for the width axis.
335297
336298 Args:
337299 kernel_size (int or tuple(int, int)): The size of the pooling window.
@@ -359,30 +321,18 @@ def __init__(
359321
360322
361323class MaxPool3d (_Pool3d ):
362- """
363- Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
364- :math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
365- H_{out}, W_{out}, C)`, given by:
366-
367- .. math::
368- \b egin{aligned}
369- \t ext{out}(N_i, d, h, w, C_j) = & \max_{l=0, \ldots, k_D-1} \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\
370- & \t ext{input}(N_i, \t ext{stride[0]} \t imes d + l,
371- \t ext{stride[1]} \t imes h + m,
372- \t ext{stride[2]} \t imes w + n, C_j),
373- \end{aligned}
374-
375- where :math:`D_{out} = \left\lfloor\f rac{D + 2 * \t ext{padding[0]} - \t ext{kernel\_size[0]}}{\t ext{stride[0]}}\r ight\r floor + 1`,
376- :math:`H_{out} = \left\lfloor\f rac{H + 2 * \t ext{padding[1]} - \t ext{kernel\_size[1]}}{\t ext{stride[1]}}\r ight\r floor + 1`,
377- :math:`W_{out} = \left\lfloor\f rac{W + 2 * \t ext{padding[2]} - \t ext{kernel\_size[2]}}{\t ext{stride[2]}}\r ight\r floor + 1`.
378-
379- The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
380-
381- - a single ``int`` -- in which case the same value is used for the depth,
382- height and width axis;
383- - a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
384- for the depth axis, the second ``int`` for the height axis, and the third
385- ``int`` for the width axis.
324+ r"""Applies 3-dimensional max pooling.
325+
326+ Spatially downsamples the input by taking the maximum of a sliding window
327+ of size ``kernel_size`` and sliding stride ``stride``.
328+
329+ The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:
330+
331+ * a single ``int`` -- in which case the same value is used for the depth,
332+ height, and width axis.
333+ * a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
334+ for the depth axis, the second ``int`` for the height axis, and the third
335+ ``int`` for the width axis.
386336
387337 Args:
388338 kernel_size (int or tuple(int, int, int)): The size of the pooling window.
@@ -410,40 +360,28 @@ def __init__(
410360
411361
412362class AvgPool3d (_Pool3d ):
413- """
414- Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is
415- :math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out},
416- H_{out}, W_{out}, C)`, given by:
417-
418- .. math::
419- \b egin{aligned}
420- \t ext{out}(N_i, d, h, w, C_j) = & \f rac{1}{k_D k_H k_W} \sum_{l=0, \ldots, k_D-1} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\
421- & \t ext{input}(N_i, \t ext{stride[0]} \t imes d + l,
422- \t ext{stride[1]} \t imes h + m,
423- \t ext{stride[2]} \t imes w + n, C_j),
424- \end{aligned}
425-
426- where :math:`D_{out} = \left\lfloor\f rac{D + 2 * \t ext{padding[0]} - \t ext{kernel\_size[0]}}{\t ext{stride[0]}}\r ight\r floor + 1`,
427- :math:`H_{out} = \left\lfloor\f rac{H + 2 * \t ext{padding[1]} - \t ext{kernel\_size[1]}}{\t ext{stride[1]}}\r ight\r floor + 1`,
428- :math:`W_{out} = \left\lfloor\f rac{W + 2 * \t ext{padding[2]} - \t ext{kernel\_size[2]}}{\t ext{stride[2]}}\r ight\r floor + 1`.
429-
430- The parameters ``kernel_size``, ``stride``, ``padding``, can either be:
431-
432- - a single ``int`` -- in which case the same value is used for the depth,
433- height and width axis;
434- - a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
435- for the depth axis, the second ``int`` for the height axis, and the third
436- ``int`` for the width axis.
437-
438- Args:
363+ r"""Applies 3-dimensional average pooling.
364+
365+ Spatially downsamples the input by taking the average of a sliding window
366+ of size ``kernel_size`` and sliding stride ``stride``.
367+
368+ The parameters ``kernel_size``, ``stride``, and ``padding`` can either be:
369+
370+ * a single ``int`` -- in which case the same value is used for the depth,
371+ height, and width axis.
372+ * a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used
373+ for the depth axis, the second ``int`` for the height axis, and the third
374+ ``int`` for the width axis.
375+
376+ Args:
439377 kernel_size (int or tuple(int, int, int)): The size of the pooling window.
440378 stride (int or tuple(int, int, int), optional): The stride of the pooling
441379 window. Default: ``kernel_size``.
442380 padding (int or tuple(int, int, int), optional): How much zero
443381 padding to apply to the input. The padding is applied on both sides
444382 of the depth, height and width axis. Default: ``0``.
445383
446- Examples:
384+ Examples:
447385 >>> import mlx.core as mx
448386 >>> import mlx.nn.layers as nn
449387 >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
0 commit comments