@@ -120,6 +120,7 @@ void copy_gpu_inplace(
120120 compute_encoder.set_input_array (donate_in ? out : in, 0 , inp_offset);
121121 compute_encoder.set_output_array (out, 1 , out_offset);
122122
123+ auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup ();
123124 if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
124125 std::vector<int64_t > strides_in{strides_in_.begin (), strides_in_.end ()};
125126 std::vector<int64_t > strides_out{strides_out_.begin (), strides_out_.end ()};
@@ -145,7 +146,6 @@ void copy_gpu_inplace(
145146 }
146147
147148 // NB assuming thread_group_size is a power of 2 larger than 32 x 32
148- NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup ();
149149 if (thread_group_size != 1024 ) {
150150 throw std::runtime_error (" [Metal::copy] Must use 1024 sized block" );
151151 }
@@ -155,13 +155,12 @@ void copy_gpu_inplace(
155155 compute_encoder.dispatchThreads (grid_dims, group_dims);
156156 } else {
157157 size_t nthreads = out.data_size ();
158- MTL::Size grid_dims = use_2d ? get_2d_grid_dims (out.shape (), out.strides ())
159- : MTL::Size (nthreads, 1 , 1 );
160- NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup ();
161158 if (thread_group_size > nthreads) {
162159 thread_group_size = nthreads;
163160 }
164161 MTL::Size group_dims = MTL::Size (thread_group_size, 1 , 1 );
162+ MTL::Size grid_dims = use_2d ? get_2d_grid_dims (out.shape (), out.strides ())
163+ : MTL::Size (nthreads, 1 , 1 );
165164 compute_encoder.dispatchThreads (grid_dims, group_dims);
166165 }
167166}
@@ -205,14 +204,14 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
205204 compute_encoder.set_input_array (val, 0 );
206205 compute_encoder.set_output_array (out, 1 );
207206
207+ auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup ();
208208 size_t nthreads = out.data_size ();
209- MTL::Size grid_dims = use_2d ? get_2d_grid_dims (out.shape (), out.strides ())
210- : MTL::Size (nthreads, 1 , 1 );
211- NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup ();
212209 if (thread_group_size > nthreads) {
213210 thread_group_size = nthreads;
214211 }
215212 MTL::Size group_dims = MTL::Size (thread_group_size, 1 , 1 );
213+ MTL::Size grid_dims = use_2d ? get_2d_grid_dims (out.shape (), out.strides ())
214+ : MTL::Size (nthreads, 1 , 1 );
216215 compute_encoder.dispatchThreads (grid_dims, group_dims);
217216}
218217
0 commit comments