@@ -93,10 +93,10 @@ inline void build_kernel(
9393 // a third grid dimension
9494 os << " size_t index = pos.x + grid.x * size_t(pos.y);\n " ;
9595 } else if (work_per_thread > 1 ) {
96- os << " constexpr int N = " << std::to_string (work_per_thread) << " ;\n "
96+ os << " constexpr int N_ = " << std::to_string (work_per_thread) << " ;\n "
9797 << " int xshape = output_shape["
9898 << (dynamic_dims ? " ndim - 1" : std::to_string (ndim - 1 )) << " ];\n "
99- << " size_t index = N * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n " ;
99+ << " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n " ;
100100 } else {
101101 os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n " ;
102102 }
@@ -141,11 +141,11 @@ inline void build_kernel(
141141 << " in_strides + " << offset << " );\n " ;
142142 } else if (!dynamic_dims) {
143143 int offset = i * ndim;
144- os << " size_t index_" << xname << " = N * pos.x * in_strides["
144+ os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
145145 << offset + ndim - 1 << " ]"
146146 << " + pos.y * in_strides[" << offset + ndim - 2 << " ];\n " ;
147147 } else {
148- os << " size_t index_" << xname << " = N * pos.x * in_strides[ndim * "
148+ os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
149149 << i << " + ndim - 1]"
150150 << " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n " ;
151151 }
@@ -172,7 +172,7 @@ inline void build_kernel(
172172
173173 // Open per-thread loop
174174 if (work_per_thread > 1 ) {
175- os << " for (int i = 0; i < N && (int(N * pos.x) + i) < xshape; ++i) {\n " ;
175+ os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n " ;
176176 }
177177
178178 // Read non-contiguous inputs into tmps
@@ -434,10 +434,15 @@ void Compiled::eval_gpu(
434434 int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1 ;
435435 dim0 = (dim0 + work_per_thread - 1 ) / work_per_thread;
436436 NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup ();
437- if (thread_group_size != 1024 ) {
438- throw std::runtime_error (" [Metal::binary] Must use 1024 sized block" );
437+ int pow2;
438+ if (thread_group_size == 1024 ) {
439+ pow2 = 10 ;
440+ } else if (thread_group_size > 512 ) {
441+ pow2 = 9 ;
442+ } else {
443+ throw std::runtime_error (" [Metal::compiled] Must use > 512 sized block" );
439444 }
440- auto group_dims = get_block_dims (dim0, dim1, rest);
445+ auto group_dims = get_block_dims (dim0, dim1, rest, pow2 );
441446 MTL::Size grid_dims = MTL::Size (dim0, dim1, rest);
442447 compute_encoder.dispatchThreads (grid_dims, group_dims);
443448 }
0 commit comments