-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Description
Describe the bug
mx.conv3d produces silently wrong results when the unfolded matrix size M * K > 2^32, where M = B * T_out * H_out * W_out and K = kT * kH * kW * C_in. No error is raised — the output is just incorrect.
Conv3d always routes through explicit_gemm_conv_ND_gpu → naive_unfold_Nd, which materializes a full im2col buffer. In the unfold kernel (mlx/backend/metal/kernels/conv.metal line 33):
out += gid.z * filter_size + gid.y * (params->C);
gid.z is uint (32-bit) and filter_size is int (32-bit). When gid.z * filter_size > 2^32, the product wraps around, causing writes to wrong buffer offsets.
Conv2d is unaffected because it has implicit GEMM and Winograd paths that never materialize the unfolded matrix. Conv3d has no such optimized path — it always uses explicit GEMM.
To Reproduce
import mlx.core as mx
import numpy as np
import torch
import torch.nn.functional as F
def test_conv3d(T, H, W, C_in, C_out, kernel=(3, 3, 3), label=""):
"""Test mx.conv3d against PyTorch conv3d with identical inputs and weights."""
kT, kH, kW = kernel
padding = (kT - 1, kH // 2, kW // 2) # causal temporal, symmetric spatial
M = T * H * W
K = kT * kH * kW * C_in
product = M * K
# Deterministic shared input and weights
np.random.seed(42)
x_np = np.random.randn(1, T, H, W, C_in).astype(np.float32) * 0.01
w_np = np.random.randn(C_out, kT, kH, kW, C_in).astype(np.float32) * 0.01
b_np = np.zeros(C_out, dtype=np.float32)
# --- PyTorch (reference) ---
# PyTorch layout: [B, C, T, H, W], weight: [O, I, kT, kH, kW]
x_pt = torch.tensor(x_np.transpose(0, 4, 1, 2, 3)) # BTHWC -> BCTHW
w_pt = torch.tensor(w_np.transpose(0, 4, 1, 2, 3)) # OkTkHkWI -> OITHW
b_pt = torch.tensor(b_np)
# Causal temporal padding (left only) + symmetric spatial
x_pt_padded = F.pad(
x_pt, (padding[2], padding[2], padding[1], padding[1], padding[0], 0)
)
y_pt = F.conv3d(x_pt_padded, w_pt, b_pt, stride=1, padding=0)
y_pt_np = y_pt.numpy().transpose(0, 2, 3, 4, 1) # BCTHW -> BTHWC
# --- MLX ---
x_padded_np = np.pad(
x_np,
[
(0, 0),
(padding[0], 0),
(padding[1], padding[1]),
(padding[2], padding[2]),
(0, 0),
],
)
x_mlx = mx.array(x_padded_np)
w_mlx = mx.array(w_np)
b_mlx = mx.array(b_np)
y_mlx = mx.conv3d(x_mlx, w_mlx, stride=(1, 1, 1), padding=0) + b_mlx
mx.eval(y_mlx)
y_mlx_np = np.array(y_mlx)
# --- Compare ---
max_diff = np.abs(y_pt_np - y_mlx_np).max()
corr = np.corrcoef(y_pt_np.flatten(), y_mlx_np.flatten())[0, 1]
overflow = product > 2**32
status = "FAIL" if max_diff > 0.01 else "OK"
print(
f" {label:45s} M*K={product:>14,} "
f"{'> 2^32' if overflow else '< 2^32'} "
f"max_diff={max_diff:.6f} corr={corr:.6f} [{status}]"
)
return status == "OK"
def main():
print("=" * 110)
print("mx.conv3d bug: uint32 overflow in Metal unfold kernel for large tensors")
print("=" * 110)
try:
print(f"MLX version: {mx.__version__}")
except AttributeError:
print("MLX version: unknown")
print(f"MLX device: {mx.default_device()}")
print(f"PyTorch version: {torch.__version__}")
print()
all_pass = True
# ---- Test 1: Vary temporal dimension ----
print("Test 1: Vary T (H=240, W=416, C_in=512, kernel=3x3x3)")
print(" K = 3*3*3*512 = 13,824")
print()
for T in [1, 2, 3, 4]:
ok = test_conv3d(T, 240, 416, 512, 256, label=f"T={T}")
all_pass &= ok
print()
# ---- Test 2: Vary spatial size ----
print("Test 2: Vary spatial (T=4, C_in=512, kernel=3x3x3)")
print()
for H, W in [(120, 208), (200, 350), (220, 380), (240, 416)]:
ok = test_conv3d(4, H, W, 512, 256, label=f"{H}x{W}")
all_pass &= ok
print()
# ---- Test 3: Vary input channels ----
print("Test 3: Vary C_in (T=4, H=240, W=416, kernel=3x3x3)")
print()
for C_in, C_out in [(128, 64), (256, 128), (512, 256)]:
ok = test_conv3d(4, 240, 416, C_in, C_out, label=f"C_in={C_in}")
all_pass &= ok
print()
# ---- Test 4: Boundary around 2^32 ----
print("Test 4: Exact boundary (M*K near 2^32 = 4,294,967,296)")
print()
for H, W in [(278, 278), (278, 280), (280, 280)]:
M = 4 * H * W
MK = M * 13824
ok = test_conv3d(4, H, W, 512, 256, label=f"{H}x{W}, M*K={MK:,}")
all_pass &= ok
print()
# ---- Test 5: 1x1x1 kernel (small K, no overflow) ----
print("Test 5: Small kernel (1x1x1) — same tensor size, no overflow")
print()
ok = test_conv3d(
4,
240,
416,
512,
256,
kernel=(1, 1, 1),
label="1x1x1 kernel, T=4, 240x416, C_in=512",
)
all_pass &= ok
print()
# ---- Summary ----
print("=" * 110)
if not all_pass:
print("mx.conv3d produces wrong results when M*K > 2^32")
else:
print("All tests passed — bug may be fixed in this MLX version.")
if __name__ == "__main__":
main()Expected behavior
All tests should produce max_diff ≈ 0.0 and corr ≈ 1.0 matching PyTorch output, regardless of tensor size. The bug is silent — no error is raised, results are just wrong.
Sample output:
Test 1: Vary T (H=240, W=416, C_in=512, kernel=3x3x3)
K = 3*3*3*512 = 13,824
T=1 M*K= 1,380,188,160 < 2^32 max_diff=0.000000 corr=1.000000 [OK]
T=2 M*K= 2,760,376,320 < 2^32 max_diff=0.000000 corr=1.000000 [OK]
T=3 M*K= 4,140,564,480 < 2^32 max_diff=0.000000 corr=1.000000 [OK]
T=4 M*K= 5,520,752,640 > 2^32 max_diff=0.074222 corr=0.637116 [FAIL]
Test 2: Vary spatial (T=4, C_in=512, kernel=3x3x3)
120x208 M*K= 1,380,188,160 < 2^32 max_diff=0.000000 corr=1.000000 [OK]
200x350 M*K= 3,870,720,000 < 2^32 max_diff=0.000000 corr=1.000000 [OK]
220x380 M*K= 4,622,745,600 > 2^32 max_diff=0.080086 corr=0.888318 [FAIL]
240x416 M*K= 5,520,752,640 > 2^32 max_diff=0.074222 corr=0.637116 [FAIL]
Test 3: Vary C_in (T=4, H=240, W=416, kernel=3x3x3)
C_in=128 M*K= 1,380,188,160 < 2^32 max_diff=0.000000 corr=1.000000 [OK]
C_in=256 M*K= 2,760,376,320 < 2^32 max_diff=0.000000 corr=1.000000 [OK]
C_in=512 M*K= 5,520,752,640 > 2^32 max_diff=0.074222 corr=0.637116 [FAIL]
Test 4: Exact boundary (M*K near 2^32 = 4,294,967,296)
278x278, M*K=4,273,496,064 M*K= 4,273,496,064 < 2^32 max_diff=0.000000 corr=1.000000 [OK]
278x280, M*K=4,304,240,640 M*K= 4,304,240,640 > 2^32 max_diff=0.068097 corr=0.997114 [FAIL]
280x280, M*K=4,335,206,400 M*K= 4,335,206,400 > 2^32 max_diff=0.068373 corr=0.985941 [FAIL]
Test 5: Small kernel (1x1x1) — same tensor size, no overflow
1x1x1 kernel, T=4, 240x416, C_in=512 M*K= 204,472,320 < 2^32 max_diff=0.000000 corr=1.000000 [OK]
==============================================================================================================
mx.conv3d produces wrong results when M*K > 2^32
Desktop:
- OS Version: macOS 15.4
- MLX Version: 0.30.3
- Device: Apple M4 Max
Additional context
Root cause: mlx/backend/metal/kernels/conv.metal line 33 — gid.z * filter_size uses 32-bit arithmetic. The boundary is exactly at M * K = 2^32. Both fp16 and fp32 produce the same wrong results (not a precision issue).
Why only conv3d: conv_3D_gpu (line 1038 in mlx/backend/metal/conv.cpp) always calls explicit_gemm_conv_ND_gpu, which uses the naive_unfold_Nd kernel. Conv2d has implicit GEMM and Winograd paths that avoid materializing the unfolded matrix entirely.