Skip to content

Commit bf9de4a

Browse files
committed
Fix broadcasting with large arrays
1 parent 6d1af96 commit bf9de4a

File tree

1 file changed

+59
-41
lines changed

1 file changed

+59
-41
lines changed

src/broadcast.jl

Lines changed: 59 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,14 @@ end
6666
if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
6767
## COV_EXCL_START
6868
function broadcast_cartesian_static(dest, bc, Is)
69-
i = thread_position_in_grid().x
70-
stride = threads_per_grid().x
71-
while 1 <= i <= length(dest)
69+
i = Int(thread_position_in_grid().x)
70+
stride = threads_per_grid().x
71+
while 1 <= i <= length(dest)
7272
I = @inbounds Is[i]
7373
@inbounds dest[I] = bc[I]
7474
i += stride
75-
end
76-
return
75+
end
76+
return
7777
end
7878
## COV_EXCL_STOP
7979

@@ -91,13 +91,13 @@ end
9191
(isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
9292
## COV_EXCL_START
9393
function broadcast_linear(dest, bc)
94-
i = thread_position_in_grid().x
95-
stride = threads_per_grid().x
96-
while 1 <= i <= length(dest)
97-
@inbounds dest[i] = bc[i]
98-
i += stride
99-
end
100-
return
94+
i = Int(thread_position_in_grid().x)
95+
stride = threads_per_grid().x
96+
while 1 <= i <= length(dest)
97+
@inbounds dest[i] = bc[i]
98+
i += stride
99+
end
100+
return
101101
end
102102
## COV_EXCL_STOP
103103

@@ -108,56 +108,74 @@ end
108108
elseif ndims(dest) == 2
109109
## COV_EXCL_START
110110
function broadcast_2d(dest, bc)
111-
is = Tuple(thread_position_in_grid_2d())
112-
stride = threads_per_grid_2d()
113-
while 1 <= is[1] <= size(dest, 1) && 1 <= is[2] <= size(dest, 2)
114-
I = CartesianIndex(is)
115-
@inbounds dest[I] = bc[I]
116-
is = (is[1] + stride[1], is[2] + stride[2])
117-
end
118-
return
111+
i = Int(thread_position_in_grid().x)
112+
y = Int(thread_position_in_grid().y)
113+
@inbounds stride1, stride2, _ = threads_per_grid()
114+
@inbounds dim1, dim2 = size(dest)
115+
while 1 <= i <= dim1
116+
j = y
117+
while 1 <= j <= dim2
118+
I = CartesianIndex(i, j)
119+
@inbounds dest[I] = bc[I]
120+
j += stride2
121+
end
122+
i += stride1
123+
end
124+
return
119125
end
120126
## COV_EXCL_STOP
121127

122128
kernel = @metal launch=false broadcast_2d(dest, bc)
123-
w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth)
124-
h = min(size(dest, 2), kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ w)
129+
130+
maxThreads = prevwarp(kernel.pipeline, kernel.pipeline.maxTotalThreadsPerThreadgroup - 1)
131+
w = min(size(dest, 1), maxThreads)
132+
h = min(size(dest, 2), maxThreads ÷ w)
125133
threads = (w, h)
126134
groups = cld.(size(dest), threads)
127135
elseif ndims(dest) == 3
128136
## COV_EXCL_START
129137
function broadcast_3d(dest, bc)
130-
is = Tuple(thread_position_in_grid_3d())
131-
stride = threads_per_grid_3d()
132-
while 1 <= is[1] <= size(dest, 1) &&
133-
1 <= is[2] <= size(dest, 2) &&
134-
1 <= is[3] <= size(dest, 3)
135-
I = CartesianIndex(is)
136-
@inbounds dest[I] = bc[I]
137-
is = (is[1] + stride[1], is[2] + stride[2], is[3] + stride[3])
138-
end
139-
return
138+
i = Int(thread_position_in_grid().x)
139+
y = Int(thread_position_in_grid().y)
140+
z = Int(thread_position_in_grid().z)
141+
@inbounds stride1, stride2, stride3 = threads_per_grid()
142+
@inbounds dim1, dim2, dim3 = size(dest)
143+
while 1 <= i <= dim1
144+
j = y
145+
while 1 <= j <= dim2
146+
k = z
147+
while 1 <= k <= dim3
148+
I = CartesianIndex(i, j, k)
149+
@inbounds dest[I] = bc[I]
150+
k += stride3
151+
end
152+
j += stride2
153+
end
154+
i += stride1
155+
end
156+
return
140157
end
141158
## COV_EXCL_STOP
142159

143160
kernel = @metal launch=false broadcast_3d(dest, bc)
144-
w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth)
145-
h = min(size(dest, 2), kernel.pipeline.threadExecutionWidth,
146-
kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ w)
147-
d = min(size(dest, 3), kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ (w*h))
161+
162+
maxThreads = prevwarp(kernel.pipeline, kernel.pipeline.maxTotalThreadsPerThreadgroup - 1)
163+
w = min(size(dest, 1), maxThreads)
164+
h = min(size(dest, 2), maxThreads ÷ w)
165+
d = min(size(dest, 3), maxThreads ÷ (w*h))
148166
threads = (w, h, d)
149167
groups = cld.(size(dest), threads)
150168
else
151169
## COV_EXCL_START
152170
function broadcast_cartesian(dest, bc)
153-
i = thread_position_in_grid().x
154-
stride = threads_per_grid().x
155-
while 1 <= i <= length(dest)
171+
i = Int(thread_position_in_grid().x)
172+
stride = threads_per_grid().x
173+
while 1 <= i <= length(dest)
156174
I = @inbounds CartesianIndices(dest)[i]
157175
@inbounds dest[I] = bc[I]
158176
i += stride
159-
end
160-
return
177+
end
178+
return
161179
end
162180
## COV_EXCL_STOP
163181

0 commit comments

Comments
 (0)