8181 kernel = @metal launch= false broadcast_cartesian_static(dest, bc, Is)
8282 elements = cld(length(dest), 4 )
8383 threads = min(elements, kernel. pipeline. maxTotalThreadsPerThreadgroup)
84- groups = cld(elements, threads)
84+ groups = cld(min( elements, typemax(UInt32) . - threads) , threads)
8585 kernel(dest, bc, Is; threads, groups)
8686 return dest
8787 end
104104 kernel = @metal launch= false broadcast_linear(dest, bc)
105105 elements = cld(length(dest), 4 )
106106 threads = min(elements, kernel. pipeline. maxTotalThreadsPerThreadgroup)
107- groups = cld(elements, threads )
107+ elements = size(dest )
108108 elseif ndims(dest) == 2
109109 # # COV_EXCL_START
110110 function broadcast_2d(dest, bc)
@@ -127,11 +127,11 @@ end
127127
128128 kernel = @metal launch= false broadcast_2d(dest, bc)
129129
130- maxThreads = prevwarp( kernel. pipeline, kernel . pipeline . maxTotalThreadsPerThreadgroup - 1 )
130+ maxThreads = kernel. pipeline. maxTotalThreadsPerThreadgroup
131131 w = min(size(dest, 1 ), maxThreads)
132132 h = min(size(dest, 2 ), maxThreads ÷ w)
133133 threads = (w, h)
134- groups = cld.( size(dest), threads )
134+ elements = size(dest)
135135 elseif ndims(dest) == 3
136136 # # COV_EXCL_START
137137 function broadcast_3d(dest, bc)
@@ -159,12 +159,12 @@ end
159159
160160 kernel = @metal launch= false broadcast_3d(dest, bc)
161161
162- maxThreads = prevwarp( kernel. pipeline, kernel . pipeline . maxTotalThreadsPerThreadgroup - 1 )
162+ maxThreads = kernel. pipeline. maxTotalThreadsPerThreadgroup
163163 w = min(size(dest, 1 ), maxThreads)
164164 h = min(size(dest, 2 ), maxThreads ÷ w)
165165 d = min(size(dest, 3 ), maxThreads ÷ (w* h))
166166 threads = (w, h, d)
167- groups = cld.( size(dest), threads )
167+ elements = size(dest)
168168 else
169169 # # COV_EXCL_START
170170 function broadcast_cartesian(dest, bc)
182182 kernel = @metal launch= false broadcast_cartesian(dest, bc)
183183 elements = cld(length(dest), 4 )
184184 threads = min(elements, kernel. pipeline. maxTotalThreadsPerThreadgroup)
185- groups = cld(elements, threads)
186185 end
186+
187+ groups = cld.(min.(elements, typemax(UInt32). - threads), threads)
188+
187189 kernel(dest, bc; threads, groups)
188190
189191 return dest
0 commit comments