6666 if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD
6767 # # COV_EXCL_START
6868 function broadcast_cartesian_static(dest, bc, Is)
69- i = thread_position_in_grid(). x
70- stride = threads_per_grid(). x
71- while 1 <= i <= length(dest)
69+ i = Int( thread_position_in_grid(). x)
70+ stride = threads_per_grid(). x
71+ while 1 <= i <= length(dest)
7272 I = @inbounds Is[i]
7373 @inbounds dest[I] = bc[I]
7474 i += stride
75- end
76- return
75+ end
76+ return
7777 end
7878 # # COV_EXCL_STOP
7979
9191 (isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear))
9292 # # COV_EXCL_START
9393 function broadcast_linear(dest, bc)
94- i = thread_position_in_grid(). x
95- stride = threads_per_grid(). x
96- while 1 <= i <= length(dest)
97- @inbounds dest[i] = bc[i]
98- i += stride
99- end
100- return
94+ i = Int( thread_position_in_grid(). x)
95+ stride = threads_per_grid(). x
96+ while 1 <= i <= length(dest)
97+ @inbounds dest[i] = bc[i]
98+ i += stride
99+ end
100+ return
101101 end
102102 # # COV_EXCL_STOP
103103
@@ -108,56 +108,74 @@ end
108108 elseif ndims(dest) == 2
109109 # # COV_EXCL_START
110110 function broadcast_2d(dest, bc)
111- is = Tuple(thread_position_in_grid_2d())
112- stride = threads_per_grid_2d()
113- while 1 <= is[1 ] <= size(dest, 1 ) && 1 <= is[2 ] <= size(dest, 2 )
114- I = CartesianIndex(is)
115- @inbounds dest[I] = bc[I]
116- is = (is[1 ] + stride[1 ], is[2 ] + stride[2 ])
117- end
118- return
111+ i = Int(thread_position_in_grid(). x)
112+ y = Int(thread_position_in_grid(). y)
113+ @inbounds stride1, stride2, _ = threads_per_grid()
114+ @inbounds dim1, dim2 = size(dest)
115+ while 1 <= i <= dim1
116+ j = y
117+ while 1 <= j <= dim2
118+ I = CartesianIndex(i, j)
119+ @inbounds dest[I] = bc[I]
120+ j += stride2
121+ end
122+ i += stride1
123+ end
124+ return
119125 end
120126 # # COV_EXCL_STOP
121127
122128 kernel = @metal launch= false broadcast_2d(dest, bc)
123- w = min(size(dest, 1 ), kernel. pipeline. threadExecutionWidth)
124- h = min(size(dest, 2 ), kernel. pipeline. maxTotalThreadsPerThreadgroup ÷ w)
129+
130+ maxThreads = prevwarp(kernel. pipeline, kernel. pipeline. maxTotalThreadsPerThreadgroup - 1 )
131+ w = min(size(dest, 1 ), maxThreads)
132+ h = min(size(dest, 2 ), maxThreads ÷ w)
125133 threads = (w, h)
126134 groups = cld.(size(dest), threads)
127135 elseif ndims(dest) == 3
128136 # # COV_EXCL_START
129137 function broadcast_3d(dest, bc)
130- is = Tuple(thread_position_in_grid_3d())
131- stride = threads_per_grid_3d()
132- while 1 <= is[1 ] <= size(dest, 1 ) &&
133- 1 <= is[2 ] <= size(dest, 2 ) &&
134- 1 <= is[3 ] <= size(dest, 3 )
135- I = CartesianIndex(is)
136- @inbounds dest[I] = bc[I]
137- is = (is[1 ] + stride[1 ], is[2 ] + stride[2 ], is[3 ] + stride[3 ])
138- end
139- return
138+ i = Int(thread_position_in_grid(). x)
139+ y = Int(thread_position_in_grid(). y)
140+ z = Int(thread_position_in_grid(). z)
141+ @inbounds stride1, stride2, stride3 = threads_per_grid()
142+ @inbounds dim1, dim2, dim3 = size(dest)
143+ while 1 <= i <= dim1
144+ j = y
145+ while 1 <= j <= dim2
146+ k = z
147+ while 1 <= k <= dim3
148+ I = CartesianIndex(i, j, k)
149+ @inbounds dest[I] = bc[I]
150+ k += stride3
151+ end
152+ j += stride2
153+ end
154+ i += stride1
155+ end
156+ return
140157 end
141158 # # COV_EXCL_STOP
142159
143160 kernel = @metal launch= false broadcast_3d(dest, bc)
144- w = min(size(dest, 1 ), kernel. pipeline. threadExecutionWidth)
145- h = min(size(dest, 2 ), kernel. pipeline. threadExecutionWidth,
146- kernel. pipeline. maxTotalThreadsPerThreadgroup ÷ w)
147- d = min(size(dest, 3 ), kernel. pipeline. maxTotalThreadsPerThreadgroup ÷ (w* h))
161+
162+ maxThreads = prevwarp(kernel. pipeline, kernel. pipeline. maxTotalThreadsPerThreadgroup - 1 )
163+ w = min(size(dest, 1 ), maxThreads)
164+ h = min(size(dest, 2 ), maxThreads ÷ w)
165+ d = min(size(dest, 3 ), maxThreads ÷ (w* h))
148166 threads = (w, h, d)
149167 groups = cld.(size(dest), threads)
150168 else
151169 # # COV_EXCL_START
152170 function broadcast_cartesian(dest, bc)
153- i = thread_position_in_grid(). x
154- stride = threads_per_grid(). x
155- while 1 <= i <= length(dest)
171+ i = Int( thread_position_in_grid(). x)
172+ stride = threads_per_grid(). x
173+ while 1 <= i <= length(dest)
156174 I = @inbounds CartesianIndices(dest)[i]
157175 @inbounds dest[I] = bc[I]
158176 i += stride
159- end
160- return
177+ end
178+ return
161179 end
162180 # # COV_EXCL_STOP
163181
0 commit comments