File tree Expand file tree Collapse file tree 2 files changed +5
-3
lines changed
Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -1052,13 +1052,15 @@ void write_signature(
10521052 index++;
10531053 }
10541054 // Add metal attributes e.g. `threadgroup_index_in_grid`
1055+ index = 0 ;
10551056 for (const auto & [attr, dtype] : attrs) {
10561057 kernel_source << " " << dtype << " " << attr << " [[" << attr << " ]]" ;
10571058 if (index < attrs.size () - 1 ) {
10581059 kernel_source << " ," << std::endl;
10591060 } else {
10601061 kernel_source << " ) {" << std::endl;
10611062 }
1063+ index++;
10621064 }
10631065 kernel_source << source << std::endl;
10641066 kernel_source << " }" << std::endl;
Original file line number Diff line number Diff line change @@ -618,12 +618,12 @@ def test_custom_kernel_strides(self):
618618 uint elem = thread_position_in_grid.x;
619619 uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
620620 T tmp = inp[loc];
621- out[elem] = metal::exp(tmp);
621+ out[elem] = metal::exp(tmp) * threads_per_simdgroup ;
622622 """
623623 source_contig = """
624624 uint elem = thread_position_in_grid.x;
625625 T tmp = inp[elem];
626- out[elem] = metal::exp(tmp);
626+ out[elem] = metal::exp(tmp) * threads_per_simdgroup ;
627627 """
628628
629629 # non contiguous
@@ -644,7 +644,7 @@ def test_custom_kernel_strides(self):
644644 output_dtypes = {"out" : a .dtype },
645645 stream = mx .gpu ,
646646 )
647- self .assertTrue (mx .allclose (mx .exp (a ), outputs ["out" ]))
647+ self .assertTrue (mx .allclose (mx .exp (a ) * 32 , outputs ["out" ]))
648648
649649
650650if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments