Skip to content

Commit f7b0829

Browse files
Fix broadcasting on huge arrays (#728)
* Fix broadcasting with large arrays * Ensure gridsize is less than 2^32 * Bump version
1 parent 3ef42d9 commit f7b0829

File tree

2 files changed

+67
-47
lines changed

2 files changed

+67
-47
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Metal"
22
uuid = "dde4c033-4e86-420c-a63e-0dd931031962"
3-
version = "1.9.1"
3+
version = "1.9.2"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/broadcast.jl

Lines changed: 66 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,22 @@ end
6666
if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
6767
## COV_EXCL_START
6868
function broadcast_cartesian_static(dest, bc, Is)
69-
i = thread_position_in_grid().x
70-
stride = threads_per_grid().x
71-
while 1 <= i <= length(dest)
69+
i = Int(thread_position_in_grid().x)
70+
stride = threads_per_grid().x
71+
while 1 <= i <= length(dest)
7272
I = @inbounds Is[i]
7373
@inbounds dest[I] = bc[I]
7474
i += stride
75-
end
76-
return
75+
end
76+
return
7777
end
7878
## COV_EXCL_STOP
7979

8080
Is = StaticCartesianIndices(Is)
8181
kernel = @metal launch=false broadcast_cartesian_static(dest, bc, Is)
8282
elements = cld(length(dest), 4)
8383
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
84-
groups = cld(elements, threads)
84+
groups = cld(min(elements, typemax(UInt32).-threads), threads)
8585
kernel(dest, bc, Is; threads, groups)
8686
return dest
8787
end
@@ -91,81 +91,101 @@ end
9191
(isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
9292
## COV_EXCL_START
9393
function broadcast_linear(dest, bc)
94-
i = thread_position_in_grid().x
95-
stride = threads_per_grid().x
96-
while 1 <= i <= length(dest)
97-
@inbounds dest[i] = bc[i]
98-
i += stride
99-
end
100-
return
94+
i = Int(thread_position_in_grid().x)
95+
stride = threads_per_grid().x
96+
while 1 <= i <= length(dest)
97+
@inbounds dest[i] = bc[i]
98+
i += stride
99+
end
100+
return
101101
end
102102
## COV_EXCL_STOP
103103

104104
kernel = @metal launch=false broadcast_linear(dest, bc)
105105
elements = cld(length(dest), 4)
106106
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
107-
groups = cld(elements, threads)
107+
elements = size(dest)
108108
elseif ndims(dest) == 2
109109
## COV_EXCL_START
110110
function broadcast_2d(dest, bc)
111-
is = Tuple(thread_position_in_grid_2d())
112-
stride = threads_per_grid_2d()
113-
while 1 <= is[1] <= size(dest, 1) && 1 <= is[2] <= size(dest, 2)
114-
I = CartesianIndex(is)
115-
@inbounds dest[I] = bc[I]
116-
is = (is[1] + stride[1], is[2] + stride[2])
117-
end
118-
return
111+
i = Int(thread_position_in_grid().x)
112+
y = Int(thread_position_in_grid().y)
113+
@inbounds stride1, stride2, _ = threads_per_grid()
114+
@inbounds dim1, dim2 = size(dest)
115+
while 1 <= i <= dim1
116+
j = y
117+
while 1 <= j <= dim2
118+
I = CartesianIndex(i, j)
119+
@inbounds dest[I] = bc[I]
120+
j += stride2
121+
end
122+
i += stride1
123+
end
124+
return
119125
end
120126
## COV_EXCL_STOP
121127

122128
kernel = @metal launch=false broadcast_2d(dest, bc)
123-
w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth)
124-
h = min(size(dest, 2), kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ w)
129+
130+
maxThreads = kernel.pipeline.maxTotalThreadsPerThreadgroup
131+
w = min(size(dest, 1), maxThreads)
132+
h = min(size(dest, 2), maxThreads ÷ w)
125133
threads = (w, h)
126-
groups = cld.(size(dest), threads)
134+
elements = size(dest)
127135
elseif ndims(dest) == 3
128136
## COV_EXCL_START
129137
function broadcast_3d(dest, bc)
130-
is = Tuple(thread_position_in_grid_3d())
131-
stride = threads_per_grid_3d()
132-
while 1 <= is[1] <= size(dest, 1) &&
133-
1 <= is[2] <= size(dest, 2) &&
134-
1 <= is[3] <= size(dest, 3)
135-
I = CartesianIndex(is)
136-
@inbounds dest[I] = bc[I]
137-
is = (is[1] + stride[1], is[2] + stride[2], is[3] + stride[3])
138-
end
139-
return
138+
i = Int(thread_position_in_grid().x)
139+
y = Int(thread_position_in_grid().y)
140+
z = Int(thread_position_in_grid().z)
141+
@inbounds stride1, stride2, stride3 = threads_per_grid()
142+
@inbounds dim1, dim2, dim3 = size(dest)
143+
while 1 <= i <= dim1
144+
j = y
145+
while 1 <= j <= dim2
146+
k = z
147+
while 1 <= k <= dim3
148+
I = CartesianIndex(i, j, k)
149+
@inbounds dest[I] = bc[I]
150+
k += stride3
151+
end
152+
j += stride2
153+
end
154+
i += stride1
155+
end
156+
return
140157
end
141158
## COV_EXCL_STOP
142159

143160
kernel = @metal launch=false broadcast_3d(dest, bc)
144-
w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth)
145-
h = min(size(dest, 2), kernel.pipeline.threadExecutionWidth,
146-
kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ w)
147-
d = min(size(dest, 3), kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ (w*h))
161+
162+
maxThreads = kernel.pipeline.maxTotalThreadsPerThreadgroup
163+
w = min(size(dest, 1), maxThreads)
164+
h = min(size(dest, 2), maxThreads ÷ w)
165+
d = min(size(dest, 3), maxThreads ÷ (w*h))
148166
threads = (w, h, d)
149-
groups = cld.(size(dest), threads)
167+
elements = size(dest)
150168
else
151169
## COV_EXCL_START
152170
function broadcast_cartesian(dest, bc)
153-
i = thread_position_in_grid().x
154-
stride = threads_per_grid().x
155-
while 1 <= i <= length(dest)
171+
i = Int(thread_position_in_grid().x)
172+
stride = threads_per_grid().x
173+
while 1 <= i <= length(dest)
156174
I = @inbounds CartesianIndices(dest)[i]
157175
@inbounds dest[I] = bc[I]
158176
i += stride
159-
end
160-
return
177+
end
178+
return
161179
end
162180
## COV_EXCL_STOP
163181

164182
kernel = @metal launch=false broadcast_cartesian(dest, bc)
165183
elements = cld(length(dest), 4)
166184
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
167-
groups = cld(elements, threads)
168185
end
186+
187+
groups = cld.(min.(elements, typemax(UInt32).-threads), threads)
188+
169189
kernel(dest, bc; threads, groups)
170190

171191
return dest

0 commit comments

Comments
 (0)