Skip to content

Commit 06207bb

Browse files
committed
Ensure gridsize is less than 2^32
1 parent bf9de4a commit 06207bb

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/broadcast.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ end
8181
kernel = @metal launch=false broadcast_cartesian_static(dest, bc, Is)
8282
elements = cld(length(dest), 4)
8383
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
84-
groups = cld(elements, threads)
84+
groups = cld(min(elements, typemax(UInt32).-threads), threads)
8585
kernel(dest, bc, Is; threads, groups)
8686
return dest
8787
end
@@ -104,7 +104,7 @@ end
104104
kernel = @metal launch=false broadcast_linear(dest, bc)
105105
elements = cld(length(dest), 4)
106106
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
107-
groups = cld(elements, threads)
107+
elements = size(dest)
108108
elseif ndims(dest) == 2
109109
## COV_EXCL_START
110110
function broadcast_2d(dest, bc)
@@ -127,11 +127,11 @@ end
127127

128128
kernel = @metal launch=false broadcast_2d(dest, bc)
129129

130-
maxThreads = prevwarp(kernel.pipeline, kernel.pipeline.maxTotalThreadsPerThreadgroup - 1)
130+
maxThreads = kernel.pipeline.maxTotalThreadsPerThreadgroup
131131
w = min(size(dest, 1), maxThreads)
132132
h = min(size(dest, 2), maxThreads ÷ w)
133133
threads = (w, h)
134-
groups = cld.(size(dest), threads)
134+
elements = size(dest)
135135
elseif ndims(dest) == 3
136136
## COV_EXCL_START
137137
function broadcast_3d(dest, bc)
@@ -159,12 +159,12 @@ end
159159

160160
kernel = @metal launch=false broadcast_3d(dest, bc)
161161

162-
maxThreads = prevwarp(kernel.pipeline, kernel.pipeline.maxTotalThreadsPerThreadgroup - 1)
162+
maxThreads = kernel.pipeline.maxTotalThreadsPerThreadgroup
163163
w = min(size(dest, 1), maxThreads)
164164
h = min(size(dest, 2), maxThreads ÷ w)
165165
d = min(size(dest, 3), maxThreads ÷ (w*h))
166166
threads = (w, h, d)
167-
groups = cld.(size(dest), threads)
167+
elements = size(dest)
168168
else
169169
## COV_EXCL_START
170170
function broadcast_cartesian(dest, bc)
@@ -182,8 +182,10 @@ end
182182
kernel = @metal launch=false broadcast_cartesian(dest, bc)
183183
elements = cld(length(dest), 4)
184184
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
185-
groups = cld(elements, threads)
186185
end
186+
187+
groups = cld.(min.(elements, typemax(UInt32).-threads), threads)
188+
187189
kernel(dest, bc; threads, groups)
188190

189191
return dest

0 commit comments

Comments
 (0)