Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlx/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ if(MLX_METAL_JIT)
make_jit_source(softmax)
make_jit_source(scan)
make_jit_source(sort)
make_jit_source(radix_select)
make_jit_source(
reduce kernels/reduction/reduce_all.h kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h kernels/reduction/reduce_init.h)
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/metal/jit/includes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const char* scan();
const char* scatter_axis();
const char* softmax();
const char* sort();
const char* radix_select();
const char* reduce();

const char* gemm();
Expand Down
38 changes: 38 additions & 0 deletions mlx/backend/metal/jit_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,44 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
return d.get_kernel(kernel_name, lib);
}

MTL::ComputePipelineState* get_radix_select_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source << metal::utils() << metal::radix_select();
for (bool is_arg_partition : {true, false}) {
std::string bool_string = is_arg_partition ? "true" : "false";
std::string func_string = is_arg_partition ? "carg_" : "c_";
kernel_source << get_template_definition(
func_string + lib_name,
"radix_select_partition",
in_type,
out_type,
bool_string,
bn,
tn);
kernel_source << get_template_definition(
"n" + func_string + lib_name,
"radix_select_partition_nc",
in_type,
out_type,
bool_string,
bn,
tn);
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}

MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
Expand Down
8 changes: 8 additions & 0 deletions mlx/backend/metal/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
int bn,
int tn);

MTL::ComputePipelineState* get_radix_select_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
int bn,
int tn);

MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/metal/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ if(NOT MLX_METAL_JIT)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)
build_kernel(sort sort.h)
build_kernel(radix_select radix_select.h)
build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h)
build_kernel(steel/conv/kernels/steel_conv ${STEEL_HEADERS})
Expand Down
Loading