Skip to content

Commit 5c1fa64

Browse files
authored
Custom transforms (#1246)
1 parent a3c2873 commit 5c1fa64

File tree

16 files changed

+734
-39
lines changed

16 files changed

+734
-39
lines changed

docs/src/python/transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Transforms
1010

1111
eval
1212
compile
13+
custom_function
1314
disable_compile
1415
enable_compile
1516
grad

mlx/array.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ bool in_tracing() {
1717
return detail::InTracing::in_tracing();
1818
}
1919

20+
bool retain_graph() {
21+
return detail::RetainGraph::retain_graph();
22+
}
23+
2024
} // namespace
2125

2226
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
@@ -102,7 +106,7 @@ void array::eval() {
102106
}
103107

104108
bool array::is_tracer() const {
105-
return array_desc_->is_tracer && in_tracing();
109+
return array_desc_->is_tracer && in_tracing() || retain_graph();
106110
}
107111

108112
void array::set_data(allocator::Buffer buffer, deleter_t d) {

mlx/backend/accelerate/primitives.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ DEFAULT(Ceil)
3636
DEFAULT(Concatenate)
3737
DEFAULT(Conjugate)
3838
DEFAULT(Copy)
39-
DEFAULT_MULTI(CustomVJP)
39+
DEFAULT_MULTI(CustomTransforms)
4040
DEFAULT_MULTI(Depends)
4141
DEFAULT_MULTI(DivMod)
4242
DEFAULT(NumberOfElements)

mlx/backend/common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
6666
out.copy_shared_buffer(inputs[0]);
6767
}
6868

69-
void CustomVJP::eval(
69+
void CustomTransforms::eval(
7070
const std::vector<array>& inputs,
7171
std::vector<array>& outputs) {
7272
assert(inputs.size() > outputs.size());

mlx/backend/common/default_primitives.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ DEFAULT(Convolution)
5252
DEFAULT(Copy)
5353
DEFAULT(Cos)
5454
DEFAULT(Cosh)
55-
DEFAULT_MULTI(CustomVJP)
55+
DEFAULT_MULTI(CustomTransforms)
5656
DEFAULT_MULTI(Depends)
5757
DEFAULT(Divide)
5858
DEFAULT(NumberOfElements)

mlx/backend/metal/primitives.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
171171
eval(inputs, out);
172172
}
173173

174-
void CustomVJP::eval_gpu(
174+
void CustomTransforms::eval_gpu(
175175
const std::vector<array>& inputs,
176176
std::vector<array>& outputs) {
177177
eval(inputs, outputs);

mlx/backend/no_cpu/primitives.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ NO_CPU(Convolution)
4242
NO_CPU(Copy)
4343
NO_CPU(Cos)
4444
NO_CPU(Cosh)
45-
NO_CPU_MULTI(CustomVJP)
45+
NO_CPU_MULTI(CustomTransforms)
4646
NO_CPU_MULTI(Depends)
4747
NO_CPU(Divide)
4848
NO_CPU_MULTI(DivMod)

mlx/backend/no_metal/primitives.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ NO_GPU(Convolution)
4343
NO_GPU(Copy)
4444
NO_GPU(Cos)
4545
NO_GPU(Cosh)
46-
NO_GPU_MULTI(CustomVJP)
46+
NO_GPU_MULTI(CustomTransforms)
4747
NO_GPU_MULTI(Depends)
4848
NO_GPU(Divide)
4949
NO_GPU_MULTI(DivMod)

mlx/fast.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ std::vector<array> Custom::vjp(
1818
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
1919
std::vector<array> vjp_outs;
2020
for (int i = 0, j = 0; i < vjps.size(); ++i) {
21-
if (i < argnums.size() && i == argnums[j]) {
21+
if (j < argnums.size() && i == argnums[j]) {
2222
vjp_outs.push_back(vjps[i]);
2323
j++;
2424
}
@@ -30,15 +30,16 @@ std::vector<array> Custom::jvp(
3030
const std::vector<array>& primals,
3131
const std::vector<array>& tangents,
3232
const std::vector<int>& argnums) {
33-
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
34-
std::vector<array> jvp_outs;
35-
for (int i = 0, j = 0; i < jvps.size(); ++i) {
36-
if (i < argnums.size() && i == argnums[j]) {
37-
jvp_outs.push_back(jvps[i]);
38-
j++;
33+
std::vector<array> all_tangents;
34+
for (int i = 0, j = 0; i < primals.size(); i++) {
35+
if (j < argnums.size() && i == argnums[j]) {
36+
all_tangents.emplace_back(tangents[j++]);
37+
} else {
38+
all_tangents.emplace_back(zeros_like(primals[i]));
3939
}
4040
}
41-
return jvp_outs;
41+
auto [_, jvps] = mlx::core::jvp(fallback_, primals, all_tangents);
42+
return jvps;
4243
}
4344

4445
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(

mlx/primitives.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,17 +1113,21 @@ std::pair<std::vector<array>, std::vector<int>> Cosh::vmap(
11131113
return {{cosh(inputs[0], stream())}, axes};
11141114
}
11151115

1116-
std::vector<array> CustomVJP::vjp(
1116+
std::vector<array> CustomTransforms::vjp(
11171117
const std::vector<array>& primals,
11181118
const std::vector<array>& cotangents,
11191119
const std::vector<int>& argnums,
11201120
const std::vector<array>& outputs) {
1121-
std::vector<array> inputs(primals.begin(), primals.end() - outputs.size());
1121+
// Extract the inputs to the VJP function
1122+
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);
1123+
1124+
// Compute all the vjps
11221125
auto all_vjps = vjp_fun_(inputs, cotangents, outputs);
11231126
for (const auto& cot : cotangents) {
11241127
all_vjps.emplace_back(cot);
11251128
}
11261129

1130+
// Select the vjps requested
11271131
std::vector<array> vjps;
11281132
vjps.reserve(argnums.size());
11291133
for (auto arg : argnums) {
@@ -1133,6 +1137,26 @@ std::vector<array> CustomVJP::vjp(
11331137
return vjps;
11341138
}
11351139

1140+
std::vector<array> CustomTransforms::jvp(
1141+
const std::vector<array>& primals,
1142+
const std::vector<array>& tangents,
1143+
const std::vector<int>& argnums) {
1144+
// Extract the inputs to the JVP function
1145+
std::vector<array> inputs(primals.begin(), primals.end() - num_outputs_);
1146+
1147+
// Compute the jvps
1148+
return jvp_fun_(inputs, tangents, argnums);
1149+
}
1150+
1151+
std::pair<std::vector<array>, std::vector<int>> CustomTransforms::vmap(
1152+
const std::vector<array>& inputs_,
1153+
const std::vector<int>& axes_) {
1154+
// Extract the inputs to the vmap function
1155+
std::vector<array> inputs(inputs_.begin(), inputs_.end() - num_outputs_);
1156+
std::vector<int> axes(axes_.begin(), axes_.end() - num_outputs_);
1157+
return vmap_fun_(inputs, axes);
1158+
}
1159+
11361160
std::vector<array> Depends::vjp(
11371161
const std::vector<array>& primals,
11381162
const std::vector<array>& cotangents,

0 commit comments

Comments
 (0)