Skip to content

Commit da8deb2

Browse files
barronalexAlex Barron
andauthored
fix bug with multiple attributes (#1348)
Co-authored-by: Alex Barron <[email protected]>
1 parent 98b6ce3 commit da8deb2

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

mlx/fast.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,13 +1052,15 @@ void write_signature(
10521052
index++;
10531053
}
10541054
// Add metal attributes e.g. `threadgroup_index_in_grid`
1055+
index = 0;
10551056
for (const auto& [attr, dtype] : attrs) {
10561057
kernel_source << " " << dtype << " " << attr << " [[" << attr << "]]";
10571058
if (index < attrs.size() - 1) {
10581059
kernel_source << "," << std::endl;
10591060
} else {
10601061
kernel_source << ") {" << std::endl;
10611062
}
1063+
index++;
10621064
}
10631065
kernel_source << source << std::endl;
10641066
kernel_source << "}" << std::endl;

python/tests/test_fast.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -618,12 +618,12 @@ def test_custom_kernel_strides(self):
618618
uint elem = thread_position_in_grid.x;
619619
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
620620
T tmp = inp[loc];
621-
out[elem] = metal::exp(tmp);
621+
out[elem] = metal::exp(tmp) * threads_per_simdgroup;
622622
"""
623623
source_contig = """
624624
uint elem = thread_position_in_grid.x;
625625
T tmp = inp[elem];
626-
out[elem] = metal::exp(tmp);
626+
out[elem] = metal::exp(tmp) * threads_per_simdgroup;
627627
"""
628628

629629
# non contiguous
@@ -644,7 +644,7 @@ def test_custom_kernel_strides(self):
644644
output_dtypes={"out": a.dtype},
645645
stream=mx.gpu,
646646
)
647-
self.assertTrue(mx.allclose(mx.exp(a), outputs["out"]))
647+
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
648648

649649

650650
if __name__ == "__main__":

0 commit comments

Comments
 (0)