|
66 | 66 | if _broadcast_shapes[Is] > BROADCAST_SPECIALIZATION_THRESHOLD |
67 | 67 | ## COV_EXCL_START |
68 | 68 | function broadcast_cartesian_static(dest, bc, Is) |
69 | | - i = thread_position_in_grid().x |
70 | | - stride = threads_per_grid().x |
71 | | - while 1 <= i <= length(dest) |
| 69 | + i = Int(thread_position_in_grid().x) |
| 70 | + stride = threads_per_grid().x |
| 71 | + while 1 <= i <= length(dest) |
72 | 72 | I = @inbounds Is[i] |
73 | 73 | @inbounds dest[I] = bc[I] |
74 | 74 | i += stride |
75 | | - end |
76 | | - return |
| 75 | + end |
| 76 | + return |
77 | 77 | end |
78 | 78 | ## COV_EXCL_STOP |
79 | 79 |
|
80 | 80 | Is = StaticCartesianIndices(Is) |
81 | 81 | kernel = @metal launch=false broadcast_cartesian_static(dest, bc, Is) |
82 | 82 | elements = cld(length(dest), 4) |
83 | 83 | threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup) |
84 | | - groups = cld(elements, threads) |
| 84 | + groups = cld(min(elements, typemax(UInt32).-threads), threads) |
85 | 85 | kernel(dest, bc, Is; threads, groups) |
86 | 86 | return dest |
87 | 87 | end |
|
91 | 91 | (isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear)) |
92 | 92 | ## COV_EXCL_START |
93 | 93 | function broadcast_linear(dest, bc) |
94 | | - i = thread_position_in_grid().x |
95 | | - stride = threads_per_grid().x |
96 | | - while 1 <= i <= length(dest) |
97 | | - @inbounds dest[i] = bc[i] |
98 | | - i += stride |
99 | | - end |
100 | | - return |
| 94 | + i = Int(thread_position_in_grid().x) |
| 95 | + stride = threads_per_grid().x |
| 96 | + while 1 <= i <= length(dest) |
| 97 | + @inbounds dest[i] = bc[i] |
| 98 | + i += stride |
| 99 | + end |
| 100 | + return |
101 | 101 | end |
102 | 102 | ## COV_EXCL_STOP |
103 | 103 |
|
104 | 104 | kernel = @metal launch=false broadcast_linear(dest, bc) |
105 | 105 | elements = cld(length(dest), 4) |
106 | 106 | threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup) |
107 | | - groups = cld(elements, threads) |
| 107 | + elements = size(dest) |
108 | 108 | elseif ndims(dest) == 2 |
109 | 109 | ## COV_EXCL_START |
110 | 110 | function broadcast_2d(dest, bc) |
111 | | - is = Tuple(thread_position_in_grid_2d()) |
112 | | - stride = threads_per_grid_2d() |
113 | | - while 1 <= is[1] <= size(dest, 1) && 1 <= is[2] <= size(dest, 2) |
114 | | - I = CartesianIndex(is) |
115 | | - @inbounds dest[I] = bc[I] |
116 | | - is = (is[1] + stride[1], is[2] + stride[2]) |
117 | | - end |
118 | | - return |
| 111 | + i = Int(thread_position_in_grid().x) |
| 112 | + y = Int(thread_position_in_grid().y) |
| 113 | + @inbounds stride1, stride2, _ = threads_per_grid() |
| 114 | + @inbounds dim1, dim2 = size(dest) |
| 115 | + while 1 <= i <= dim1 |
| 116 | + j = y |
| 117 | + while 1 <= j <= dim2 |
| 118 | + I = CartesianIndex(i, j) |
| 119 | + @inbounds dest[I] = bc[I] |
| 120 | + j += stride2 |
| 121 | + end |
| 122 | + i += stride1 |
| 123 | + end |
| 124 | + return |
119 | 125 | end |
120 | 126 | ## COV_EXCL_STOP |
121 | 127 |
|
122 | 128 | kernel = @metal launch=false broadcast_2d(dest, bc) |
123 | | - w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth) |
124 | | - h = min(size(dest, 2), kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ w) |
| 129 | + |
| 130 | + maxThreads = kernel.pipeline.maxTotalThreadsPerThreadgroup |
| 131 | + w = min(size(dest, 1), maxThreads) |
| 132 | + h = min(size(dest, 2), maxThreads ÷ w) |
125 | 133 | threads = (w, h) |
126 | | - groups = cld.(size(dest), threads) |
| 134 | + elements = size(dest) |
127 | 135 | elseif ndims(dest) == 3 |
128 | 136 | ## COV_EXCL_START |
129 | 137 | function broadcast_3d(dest, bc) |
130 | | - is = Tuple(thread_position_in_grid_3d()) |
131 | | - stride = threads_per_grid_3d() |
132 | | - while 1 <= is[1] <= size(dest, 1) && |
133 | | - 1 <= is[2] <= size(dest, 2) && |
134 | | - 1 <= is[3] <= size(dest, 3) |
135 | | - I = CartesianIndex(is) |
136 | | - @inbounds dest[I] = bc[I] |
137 | | - is = (is[1] + stride[1], is[2] + stride[2], is[3] + stride[3]) |
138 | | - end |
139 | | - return |
| 138 | + i = Int(thread_position_in_grid().x) |
| 139 | + y = Int(thread_position_in_grid().y) |
| 140 | + z = Int(thread_position_in_grid().z) |
| 141 | + @inbounds stride1, stride2, stride3 = threads_per_grid() |
| 142 | + @inbounds dim1, dim2, dim3 = size(dest) |
| 143 | + while 1 <= i <= dim1 |
| 144 | + j = y |
| 145 | + while 1 <= j <= dim2 |
| 146 | + k = z |
| 147 | + while 1 <= k <= dim3 |
| 148 | + I = CartesianIndex(i, j, k) |
| 149 | + @inbounds dest[I] = bc[I] |
| 150 | + k += stride3 |
| 151 | + end |
| 152 | + j += stride2 |
| 153 | + end |
| 154 | + i += stride1 |
| 155 | + end |
| 156 | + return |
140 | 157 | end |
141 | 158 | ## COV_EXCL_STOP |
142 | 159 |
|
143 | 160 | kernel = @metal launch=false broadcast_3d(dest, bc) |
144 | | - w = min(size(dest, 1), kernel.pipeline.threadExecutionWidth) |
145 | | - h = min(size(dest, 2), kernel.pipeline.threadExecutionWidth, |
146 | | - kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ w) |
147 | | - d = min(size(dest, 3), kernel.pipeline.maxTotalThreadsPerThreadgroup ÷ (w*h)) |
| 161 | + |
| 162 | + maxThreads = kernel.pipeline.maxTotalThreadsPerThreadgroup |
| 163 | + w = min(size(dest, 1), maxThreads) |
| 164 | + h = min(size(dest, 2), maxThreads ÷ w) |
| 165 | + d = min(size(dest, 3), maxThreads ÷ (w*h)) |
148 | 166 | threads = (w, h, d) |
149 | | - groups = cld.(size(dest), threads) |
| 167 | + elements = size(dest) |
150 | 168 | else |
151 | 169 | ## COV_EXCL_START |
152 | 170 | function broadcast_cartesian(dest, bc) |
153 | | - i = thread_position_in_grid().x |
154 | | - stride = threads_per_grid().x |
155 | | - while 1 <= i <= length(dest) |
| 171 | + i = Int(thread_position_in_grid().x) |
| 172 | + stride = threads_per_grid().x |
| 173 | + while 1 <= i <= length(dest) |
156 | 174 | I = @inbounds CartesianIndices(dest)[i] |
157 | 175 | @inbounds dest[I] = bc[I] |
158 | 176 | i += stride |
159 | | - end |
160 | | - return |
| 177 | + end |
| 178 | + return |
161 | 179 | end |
162 | 180 | ## COV_EXCL_STOP |
163 | 181 |
|
164 | 182 | kernel = @metal launch=false broadcast_cartesian(dest, bc) |
165 | 183 | elements = cld(length(dest), 4) |
166 | 184 | threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup) |
167 | | - groups = cld(elements, threads) |
168 | 185 | end |
| 186 | + |
| 187 | + groups = cld.(min.(elements, typemax(UInt32).-threads), threads) |
| 188 | + |
169 | 189 | kernel(dest, bc; threads, groups) |
170 | 190 |
|
171 | 191 | return dest |
|
0 commit comments