Skip to content

Commit 92d7cb7

Browse files
authored
Fix compile (#1501)
* fix compile * fix space
1 parent 50d8bed commit 92d7cb7

File tree

5 files changed

+36
-14
lines changed

5 files changed

+36
-14
lines changed

mlx/backend/common/compiled_cpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ void* compile(
103103
source_file.close();
104104

105105
std::ostringstream build_command;
106-
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared "
107-
<< source_file_path << " -o " << shared_lib_path;
106+
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
107+
<< source_file_path << "' -o '" << shared_lib_path << "'";
108108
std::string build_command_str = build_command.str();
109109
auto return_code = system(build_command_str.c_str());
110110
if (return_code) {

mlx/backend/metal/compiled.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ inline void build_kernel(
9393
// a third grid dimension
9494
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
9595
} else if (work_per_thread > 1) {
96-
os << " constexpr int N = " << std::to_string(work_per_thread) << ";\n"
96+
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
9797
<< " int xshape = output_shape["
9898
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
99-
<< " size_t index = N * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
99+
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
100100
} else {
101101
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
102102
}
@@ -141,11 +141,11 @@ inline void build_kernel(
141141
<< "in_strides + " << offset << ");\n";
142142
} else if (!dynamic_dims) {
143143
int offset = i * ndim;
144-
os << " size_t index_" << xname << " = N * pos.x * in_strides["
144+
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
145145
<< offset + ndim - 1 << "]"
146146
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
147147
} else {
148-
os << " size_t index_" << xname << " = N * pos.x * in_strides[ndim * "
148+
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
149149
<< i << " + ndim - 1]"
150150
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
151151
}
@@ -172,7 +172,7 @@ inline void build_kernel(
172172

173173
// Open per-thread loop
174174
if (work_per_thread > 1) {
175-
os << " for (int i = 0; i < N && (int(N * pos.x) + i) < xshape; ++i) {\n";
175+
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
176176
}
177177

178178
// Read non-contiguous inputs into tmps
@@ -434,10 +434,15 @@ void Compiled::eval_gpu(
434434
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
435435
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
436436
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
437-
if (thread_group_size != 1024) {
438-
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
437+
int pow2;
438+
if (thread_group_size == 1024) {
439+
pow2 = 10;
440+
} else if (thread_group_size > 512) {
441+
pow2 = 9;
442+
} else {
443+
throw std::runtime_error("[Metal::compiled] Must use > 512 sized block");
439444
}
440-
auto group_dims = get_block_dims(dim0, dim1, rest);
445+
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
441446
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
442447
compute_encoder.dispatchThreads(grid_dims, group_dims);
443448
}

mlx/backend/metal/utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ std::string type_to_name(const array& a) {
5252
return tname;
5353
}
5454

55-
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
55+
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
5656
int pows[3] = {0, 0, 0};
5757
int sum = 0;
5858
while (true) {
@@ -76,7 +76,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
7676
pows[2]++;
7777
sum++;
7878
}
79-
if (sum == presum || sum == 10) {
79+
if (sum == presum || sum == pow2) {
8080
break;
8181
}
8282
}

mlx/backend/metal/utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ std::string type_to_name(const array& a);
3030
// Compute the thread block dimensions which fit the given
3131
// input dimensions.
3232
// - The thread block dimensions will be powers of two
33-
// - The thread block size will be less than 1024
34-
MTL::Size get_block_dims(int dim0, int dim1, int dim2);
33+
// - The thread block size will be less than 2^pow2
34+
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
3535

3636
// Computes a 2D grid where each element is < UINT_MAX
3737
// Assumes:

python/tests/test_compile.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,23 @@ def fn(a, b):
772772
print((out - expected).abs().max())
773773
self.assertTrue(mx.allclose(out, expected))
774774

775+
def test_compile_many_inputs(self):
776+
inputs = [mx.ones((2, 2, 2, 2)) for _ in range(20)]
777+
inputs[0] = inputs[0].T
778+
779+
@mx.compile
780+
def fun(*inputs):
781+
x = inputs[0]
782+
for y in inputs[1:10]:
783+
x = x + y
784+
a = inputs[10]
785+
for b in inputs[11:]:
786+
a = a + b
787+
return x + a
788+
789+
out = fun(*inputs)
790+
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
791+
775792

776793
if __name__ == "__main__":
777794
unittest.main()

0 commit comments

Comments
 (0)