Skip to content

Commit 776c3d2

Browse files
jagrit06awni
andauthored
Convolution update (#651)
* Init steel conv and update Conv primitive * Update slow CPU implementation to support flipping and input dilation winograd conv routing Co-authored-by: Awni Hannun <[email protected]>
1 parent f5f18b7 commit 776c3d2

File tree

27 files changed

+2831
-907
lines changed

27 files changed

+2831
-907
lines changed

benchmarks/python/conv_bench.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import argparse
2+
import math
3+
import os
4+
import subprocess
5+
import time
6+
7+
import mlx.core as mx
8+
import numpy as np
9+
import torch
10+
11+
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
12+
device_name = device_name.decode("utf-8").strip("\n")
13+
14+
N_warmup = 10
15+
N_iter_bench = 100
16+
N_iter_func = 5
17+
18+
19+
def bench(f, a, b):
20+
for i in range(N_warmup):
21+
f(a, b)
22+
torch.mps.synchronize()
23+
24+
s = time.perf_counter_ns()
25+
for i in range(N_iter_bench):
26+
f(a, b)
27+
e = time.perf_counter_ns()
28+
return (e - s) * 1e-9
29+
30+
31+
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
32+
def mx_conv_2D(a, b):
33+
ys = []
34+
for i in range(N_iter_func):
35+
y = mx.conv2d(a, b, stride=strides, padding=padding)
36+
ys.append(y)
37+
mx.eval(ys)
38+
return ys
39+
40+
return mx_conv_2D
41+
42+
43+
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
44+
@torch.no_grad()
45+
def pt_conv_2D(a, b):
46+
ys = []
47+
for i in range(N_iter_func):
48+
y = torch.conv2d(a, b, stride=strides, padding=padding)
49+
ys.append(y)
50+
torch.mps.synchronize()
51+
return ys
52+
53+
return pt_conv_2D
54+
55+
56+
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
57+
58+
scale = 1.0 / math.sqrt(kH * kH * C)
59+
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
60+
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
61+
62+
a_mx = mx.array(a_np)
63+
b_mx = mx.array(b_np)
64+
65+
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
66+
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
67+
68+
torch.mps.synchronize()
69+
70+
f_mx = make_mx_conv_2D(strides, padding)
71+
f_pt = make_pt_conv_2D(strides, padding)
72+
73+
time_torch = bench(f_pt, a_pt, b_pt)
74+
time_mlx = bench(f_mx, a_mx, b_mx)
75+
76+
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
77+
out_pt = torch.conv2d(
78+
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
79+
)
80+
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
81+
out_pt = out_pt.numpy(force=True)
82+
83+
atol = 2e-5 if np_dtype == np.float32 else 1e-4
84+
85+
if not np.allclose(out_pt, out_mx, atol=atol):
86+
print(
87+
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
88+
)
89+
90+
return time_mlx, time_torch
91+
92+
93+
if __name__ == "__main__":
94+
parser = argparse.ArgumentParser(description="Run conv benchmarks")
95+
96+
dtypes = ("float32",)
97+
shapes = (
98+
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
99+
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
100+
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
101+
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
102+
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
103+
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
104+
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
105+
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
106+
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
107+
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
108+
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
109+
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
110+
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
111+
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
112+
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
113+
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
114+
)
115+
116+
for dtype in dtypes:
117+
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
118+
for N, H, W, C, kH, kW, O, strides, padding in shapes:
119+
np_dtype = getattr(np, dtype)
120+
time_mlx, time_torch = bench_shape(
121+
N, H, W, C, kH, kW, O, strides, padding, np_dtype
122+
)
123+
diff = time_torch / time_mlx - 1.0
124+
125+
print(
126+
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
127+
)
128+
if time_mlx >= 2.0 * time_torch:
129+
print("ATTENTION ^^^^^^^")

docs/src/python/ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Operations
3535
convolve
3636
conv1d
3737
conv2d
38+
conv_general
3839
cos
3940
cosh
4041
dequantize

0 commit comments

Comments
 (0)