Skip to content

Commit 1efee9d

Browse files
authored
Add types and order in kernel name (#831)
1 parent 43abc40 commit 1efee9d

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

mlx/backend/common/compiled.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright © 2023-2024 Apple Inc.
22

33
#include "mlx/backend/common/compiled.h"
4+
#include "mlx/graph_utils.h"
45
#include "mlx/primitives.h"
56
#include "mlx/utils.h"
67

@@ -81,13 +82,27 @@ std::string build_lib_name(
8182
const std::vector<array>& outputs,
8283
const std::vector<array>& tape,
8384
const std::unordered_set<uintptr_t>& constant_ids) {
85+
NodeNamer namer;
8486
std::ostringstream os;
8587
std::ostringstream constant_hasher;
8688

89+
// Fill the input names. This is not really necessary, I just like having A,
90+
// B, C, ... as the inputs.
91+
for (auto& x : inputs) {
92+
namer.get_name(x);
93+
}
94+
8795
// The primitives describing the tape. For unary and binary primitives this
8896
// must be enough to describe the full computation.
8997
for (auto& a : tape) {
98+
// name and type of output
99+
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
100+
// computation performed
90101
a.primitive().print(os);
102+
// name of inputs to the function
103+
for (auto& inp : a.inputs()) {
104+
os << namer.get_name(inp);
105+
}
91106
}
92107
os << "_";
93108

0 commit comments

Comments
 (0)