Skip to content

[BUG] mx.conv3d produces silently wrong results when the unfolded matrix has M*K > 2^32 elements #3138

@belkakari

Description

@belkakari

Describe the bug
mx.conv3d produces silently wrong results when the unfolded matrix size M * K > 2^32, where M = B * T_out * H_out * W_out and K = kT * kH * kW * C_in. No error is raised — the output is just incorrect.

Conv3d always routes through explicit_gemm_conv_ND_gpu → naive_unfold_Nd, which materializes a full im2col buffer. In the unfold kernel (mlx/backend/metal/kernels/conv.metal line 33):
out += gid.z * filter_size + gid.y * (params->C);
gid.z is uint (32-bit) and filter_size is int (32-bit). When gid.z * filter_size > 2^32, the product wraps around, causing writes to wrong buffer offsets.

Conv2d is unaffected because it has implicit GEMM and Winograd paths that never materialize the unfolded matrix. Conv3d has no such optimized path — it always uses explicit GEMM.

To Reproduce

import mlx.core as mx
import numpy as np
import torch
import torch.nn.functional as F


def test_conv3d(T, H, W, C_in, C_out, kernel=(3, 3, 3), label=""):
    """Test mx.conv3d against PyTorch conv3d with identical inputs and weights."""
    kT, kH, kW = kernel
    padding = (kT - 1, kH // 2, kW // 2)  # causal temporal, symmetric spatial

    M = T * H * W
    K = kT * kH * kW * C_in
    product = M * K

    # Deterministic shared input and weights
    np.random.seed(42)
    x_np = np.random.randn(1, T, H, W, C_in).astype(np.float32) * 0.01
    w_np = np.random.randn(C_out, kT, kH, kW, C_in).astype(np.float32) * 0.01
    b_np = np.zeros(C_out, dtype=np.float32)

    # --- PyTorch (reference) ---
    # PyTorch layout: [B, C, T, H, W], weight: [O, I, kT, kH, kW]
    x_pt = torch.tensor(x_np.transpose(0, 4, 1, 2, 3))  # BTHWC -> BCTHW
    w_pt = torch.tensor(w_np.transpose(0, 4, 1, 2, 3))  # OkTkHkWI -> OITHW
    b_pt = torch.tensor(b_np)

    # Causal temporal padding (left only) + symmetric spatial
    x_pt_padded = F.pad(
        x_pt, (padding[2], padding[2], padding[1], padding[1], padding[0], 0)
    )
    y_pt = F.conv3d(x_pt_padded, w_pt, b_pt, stride=1, padding=0)
    y_pt_np = y_pt.numpy().transpose(0, 2, 3, 4, 1)  # BCTHW -> BTHWC

    # --- MLX ---
    x_padded_np = np.pad(
        x_np,
        [
            (0, 0),
            (padding[0], 0),
            (padding[1], padding[1]),
            (padding[2], padding[2]),
            (0, 0),
        ],
    )
    x_mlx = mx.array(x_padded_np)
    w_mlx = mx.array(w_np)
    b_mlx = mx.array(b_np)
    y_mlx = mx.conv3d(x_mlx, w_mlx, stride=(1, 1, 1), padding=0) + b_mlx
    mx.eval(y_mlx)
    y_mlx_np = np.array(y_mlx)

    # --- Compare ---
    max_diff = np.abs(y_pt_np - y_mlx_np).max()
    corr = np.corrcoef(y_pt_np.flatten(), y_mlx_np.flatten())[0, 1]

    overflow = product > 2**32
    status = "FAIL" if max_diff > 0.01 else "OK"
    print(
        f"  {label:45s}  M*K={product:>14,}  "
        f"{'> 2^32' if overflow else '< 2^32'}  "
        f"max_diff={max_diff:.6f}  corr={corr:.6f}  [{status}]"
    )
    return status == "OK"


def main():
    print("=" * 110)
    print("mx.conv3d bug: uint32 overflow in Metal unfold kernel for large tensors")
    print("=" * 110)
    try:
        print(f"MLX version: {mx.__version__}")
    except AttributeError:
        print("MLX version: unknown")
    print(f"MLX device: {mx.default_device()}")
    print(f"PyTorch version: {torch.__version__}")
    print()

    all_pass = True

    # ---- Test 1: Vary temporal dimension ----
    print("Test 1: Vary T (H=240, W=416, C_in=512, kernel=3x3x3)")
    print("  K = 3*3*3*512 = 13,824")
    print()
    for T in [1, 2, 3, 4]:
        ok = test_conv3d(T, 240, 416, 512, 256, label=f"T={T}")
        all_pass &= ok

    print()

    # ---- Test 2: Vary spatial size ----
    print("Test 2: Vary spatial (T=4, C_in=512, kernel=3x3x3)")
    print()
    for H, W in [(120, 208), (200, 350), (220, 380), (240, 416)]:
        ok = test_conv3d(4, H, W, 512, 256, label=f"{H}x{W}")
        all_pass &= ok

    print()

    # ---- Test 3: Vary input channels ----
    print("Test 3: Vary C_in (T=4, H=240, W=416, kernel=3x3x3)")
    print()
    for C_in, C_out in [(128, 64), (256, 128), (512, 256)]:
        ok = test_conv3d(4, 240, 416, C_in, C_out, label=f"C_in={C_in}")
        all_pass &= ok

    print()

    # ---- Test 4: Boundary around 2^32 ----
    print("Test 4: Exact boundary (M*K near 2^32 = 4,294,967,296)")
    print()
    for H, W in [(278, 278), (278, 280), (280, 280)]:
        M = 4 * H * W
        MK = M * 13824
        ok = test_conv3d(4, H, W, 512, 256, label=f"{H}x{W}, M*K={MK:,}")
        all_pass &= ok

    print()

    # ---- Test 5: 1x1x1 kernel (small K, no overflow) ----
    print("Test 5: Small kernel (1x1x1) — same tensor size, no overflow")
    print()
    ok = test_conv3d(
        4,
        240,
        416,
        512,
        256,
        kernel=(1, 1, 1),
        label="1x1x1 kernel, T=4, 240x416, C_in=512",
    )
    all_pass &= ok

    print()

    # ---- Summary ----
    print("=" * 110)
    if not all_pass:
        print("mx.conv3d produces wrong results when M*K > 2^32")
    else:
        print("All tests passed — bug may be fixed in this MLX version.")


if __name__ == "__main__":
    main()

Expected behavior
All tests should produce max_diff ≈ 0.0 and corr ≈ 1.0 matching PyTorch output, regardless of tensor size. The bug is silent — no error is raised, results are just wrong.
Sample output:

Test 1: Vary T (H=240, W=416, C_in=512, kernel=3x3x3)
  K = 3*3*3*512 = 13,824

  T=1                                            M*K= 1,380,188,160  < 2^32  max_diff=0.000000  corr=1.000000  [OK]
  T=2                                            M*K= 2,760,376,320  < 2^32  max_diff=0.000000  corr=1.000000  [OK]
  T=3                                            M*K= 4,140,564,480  < 2^32  max_diff=0.000000  corr=1.000000  [OK]
  T=4                                            M*K= 5,520,752,640  > 2^32  max_diff=0.074222  corr=0.637116  [FAIL]

Test 2: Vary spatial (T=4, C_in=512, kernel=3x3x3)

  120x208                                        M*K= 1,380,188,160  < 2^32  max_diff=0.000000  corr=1.000000  [OK]
  200x350                                        M*K= 3,870,720,000  < 2^32  max_diff=0.000000  corr=1.000000  [OK]
  220x380                                        M*K= 4,622,745,600  > 2^32  max_diff=0.080086  corr=0.888318  [FAIL]
  240x416                                        M*K= 5,520,752,640  > 2^32  max_diff=0.074222  corr=0.637116  [FAIL]

Test 3: Vary C_in (T=4, H=240, W=416, kernel=3x3x3)

  C_in=128                                       M*K= 1,380,188,160  < 2^32  max_diff=0.000000  corr=1.000000  [OK]
  C_in=256                                       M*K= 2,760,376,320  < 2^32  max_diff=0.000000  corr=1.000000  [OK]
  C_in=512                                       M*K= 5,520,752,640  > 2^32  max_diff=0.074222  corr=0.637116  [FAIL]

Test 4: Exact boundary (M*K near 2^32 = 4,294,967,296)

  278x278, M*K=4,273,496,064                     M*K= 4,273,496,064  < 2^32  max_diff=0.000000  corr=1.000000  [OK]
  278x280, M*K=4,304,240,640                     M*K= 4,304,240,640  > 2^32  max_diff=0.068097  corr=0.997114  [FAIL]
  280x280, M*K=4,335,206,400                     M*K= 4,335,206,400  > 2^32  max_diff=0.068373  corr=0.985941  [FAIL]

Test 5: Small kernel (1x1x1) — same tensor size, no overflow

  1x1x1 kernel, T=4, 240x416, C_in=512           M*K=   204,472,320  < 2^32  max_diff=0.000000  corr=1.000000  [OK]

==============================================================================================================
mx.conv3d produces wrong results when M*K > 2^32

Desktop:

  • OS Version: macOS 15.4
  • MLX Version: 0.30.3
  • Device: Apple M4 Max

Additional context
Root cause: mlx/backend/metal/kernels/conv.metal line 33 — gid.z * filter_size uses 32-bit arithmetic. The boundary is exactly at M * K = 2^32. Both fp16 and fp32 produce the same wrong results (not a precision issue).

Why only conv3d: conv_3D_gpu (line 1038 in mlx/backend/metal/conv.cpp) always calls explicit_gemm_conv_ND_gpu, which uses the naive_unfold_Nd kernel. Conv2d has implicit GEMM and Winograd paths that avoid materializing the unfolded matrix entirely.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions