Skip to content

Commit d0b6cb0

Browse files
authored
More primitives for compiling with shapeless (#1653)
* more shapeless and more Shape * more shape * fix * fix
1 parent 95c4a2e commit d0b6cb0

File tree

5 files changed

+160
-81
lines changed

5 files changed

+160
-81
lines changed

mlx/compile.cpp

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// Copyright © 2023-2024 Apple Inc.
2-
32
#include <cstdlib>
43
#include <map>
54
#include <unordered_map>
@@ -8,6 +7,7 @@
87
#include "mlx/allocator.h"
98
#include "mlx/compile.h"
109
#include "mlx/compile_impl.h"
10+
#include "mlx/fast_primitives.h"
1111
#include "mlx/primitives.h"
1212
#include "mlx/transforms.h"
1313
#include "mlx/transforms_impl.h"
@@ -73,11 +73,18 @@ bool is_fusable(const Primitive& p) {
7373
}
7474

7575
bool allows_shapeless(const Primitive& p) {
76-
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
77-
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
78-
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
79-
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
80-
typeid(p) == typeid(Select) || typeid(p) == typeid(NumberOfElements);
76+
return typeid(p) == typeid(Arange) || typeid(p) == typeid(Compiled) ||
77+
is_unary(p) || is_binary(p) || is_noop(p) || is_reduction(p) ||
78+
typeid(p) == typeid(Softmax) || typeid(p) == typeid(Sort) ||
79+
typeid(p) == typeid(ArgSort) || typeid(p) == typeid(ArgPartition) ||
80+
typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) ||
81+
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
82+
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
83+
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
84+
typeid(p) == typeid(fast::AffineQuantize) ||
85+
typeid(p) == typeid(fast::LayerNorm) ||
86+
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||
87+
typeid(p) == typeid(fast::ScaledDotProductAttention);
8188
}
8289

8390
Compiled::Compiled(
@@ -93,23 +100,23 @@ Compiled::Compiled(
93100
constant_ids_(std::move(constant_ids)) {}
94101

95102
std::vector<array> Compiled::vjp(
96-
const std::vector<array>& primals,
97-
const std::vector<array>& cotangents,
98-
const std::vector<int>& argnums,
99-
const std::vector<array>& outputs) {
103+
const std::vector<array>&,
104+
const std::vector<array>&,
105+
const std::vector<int>&,
106+
const std::vector<array>&) {
100107
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
101108
}
102109

103110
std::vector<array> Compiled::jvp(
104-
const std::vector<array>& primals,
105-
const std::vector<array>& tangents,
106-
const std::vector<int>& argnums) {
111+
const std::vector<array>&,
112+
const std::vector<array>&,
113+
const std::vector<int>&) {
107114
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
108115
}
109116

110117
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
111-
const std::vector<array>& inputs,
112-
const std::vector<int>& axes) {
118+
const std::vector<array>&,
119+
const std::vector<int>&) {
113120
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
114121
}
115122

@@ -134,21 +141,20 @@ void Compiled::print(std::ostream& os) {
134141
}
135142
}
136143

137-
std::vector<std::vector<int>> Compiled::output_shapes(
138-
const std::vector<array>& inputs) {
144+
std::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) {
139145
size_t nd = 0;
140146
for (auto& in : inputs) {
141147
nd = std::max(nd, in.ndim());
142148
}
143-
std::vector<int> out_shape(nd, 0);
149+
Shape out_shape(nd, 0);
144150
for (auto& in : inputs) {
145151
auto dd = nd - in.ndim();
146152
for (auto i = dd; i < nd; ++i) {
147153
out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]);
148154
}
149155
}
150156
// All outputs have the same shape
151-
return std::vector<std::vector<int>>(outputs_.size(), out_shape);
157+
return std::vector<Shape>(outputs_.size(), out_shape);
152158
}
153159

154160
namespace detail {
@@ -553,14 +559,12 @@ void compile_fuse(
553559
// - Collect inputs to the new compiled primitive
554560
// - Add fusable primitives to a tape in the correct order
555561

556-
std::function<void(
557-
const array&, int, const Stream&, const std::vector<int>&)>
558-
recurse;
562+
std::function<void(const array&, int, const Stream&, const Shape&)> recurse;
559563
std::unordered_set<uintptr_t> cache;
560564
recurse = [&](const array& a,
561565
int depth,
562566
const Stream& s,
563-
const std::vector<int>& shape) {
567+
const Shape& shape) {
564568
if (cache.find(a.id()) != cache.end()) {
565569
return;
566570
}
@@ -667,7 +671,7 @@ void compile_fuse(
667671
}
668672
old_outputs.push_back(arr);
669673

670-
std::vector<std::vector<int>> shapes;
674+
std::vector<Shape> shapes;
671675
std::vector<Dtype> types;
672676
for (auto& o : old_outputs) {
673677
if (o.shape() != old_outputs.back().shape()) {
@@ -771,7 +775,7 @@ std::vector<array> compile_replace(
771775
for (auto& o : trace_out) {
772776
types.push_back(o.dtype());
773777
}
774-
std::vector<std::vector<int>> shapes;
778+
std::vector<Shape> shapes;
775779
if (shapeless) {
776780
shapes = a.primitive().output_shapes(real_inputs);
777781
} else {

mlx/fast.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,31 @@ array affine_dequantize(
915915
return fallback({w, scales, biases})[0];
916916
}
917917

918+
bool AffineQuantize::is_equivalent(const Primitive& other) const {
919+
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
920+
return (
921+
p_other.group_size_ == group_size_ && p_other.bits_ == bits_ &&
922+
p_other.dequantize_ == dequantize_);
923+
}
924+
925+
std::vector<Shape> AffineQuantize::output_shapes(
926+
const std::vector<array>& inputs) {
927+
auto& w = inputs[0];
928+
if (dequantize_) {
929+
auto out_size = w.shape(-1) * 32 / bits_;
930+
auto out_shape = w.shape();
931+
out_shape.back() = out_size;
932+
return {std::move(out_shape)};
933+
} else {
934+
auto wq_shape = w.shape();
935+
wq_shape.back() = w.shape(-1) * bits_ / 32;
936+
auto sshape = w.shape();
937+
sshape.back() = w.shape(-1) / group_size_;
938+
auto bshape = sshape;
939+
return {std::move(wq_shape), std::move(sshape), std::move(bshape)};
940+
}
941+
}
942+
918943
std::string write_signature(
919944
std::string func_name,
920945
const std::string& header,

mlx/fast_primitives.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class RMSNorm : public Custom {
5858

5959
DEFINE_PRINT(RMSNorm)
6060
bool is_equivalent(const Primitive& other) const override;
61+
DEFINE_INPUT_OUTPUT_SHAPE()
6162

6263
private:
6364
std::function<std::vector<array>(std::vector<array>)> fallback_;
@@ -110,6 +111,7 @@ class LayerNorm : public Custom {
110111

111112
DEFINE_PRINT(LayerNorm)
112113
bool is_equivalent(const Primitive& other) const override;
114+
DEFINE_INPUT_OUTPUT_SHAPE()
113115

114116
private:
115117
std::function<std::vector<array>(std::vector<array>)> fallback_;
@@ -173,6 +175,7 @@ class RoPE : public Custom {
173175

174176
DEFINE_PRINT(RoPE)
175177
bool is_equivalent(const Primitive& other) const override;
178+
DEFINE_INPUT_OUTPUT_SHAPE()
176179

177180
private:
178181
std::function<std::vector<array>(std::vector<array>)> fallback_;
@@ -207,6 +210,7 @@ class ScaledDotProductAttention : public Custom {
207210
bool is_equivalent(const Primitive& other) const override;
208211

209212
DEFINE_PRINT(ScaledDotProductAttention);
213+
DEFINE_INPUT_OUTPUT_SHAPE()
210214

211215
private:
212216
std::function<std::vector<array>(std::vector<array>)> fallback_;
@@ -235,6 +239,9 @@ class AffineQuantize : public Custom {
235239

236240
DEFINE_PRINT(AffineQuantize);
237241

242+
bool is_equivalent(const Primitive& other) const override;
243+
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
244+
238245
private:
239246
std::function<std::vector<array>(std::vector<array>)> fallback_;
240247
int group_size_;

mlx/primitives.cpp

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,11 @@ bool Arange::is_equivalent(const Primitive& other) const {
267267
step_ == a_other.step_);
268268
}
269269

270+
std::vector<Shape> Arange::output_shapes(const std::vector<array>&) {
271+
auto real_size = std::ceil((stop_ - start_) / step_);
272+
return {{std::max(static_cast<int>(real_size), 0)}};
273+
}
274+
270275
std::vector<array> ArcCos::vjp(
271276
const std::vector<array>& primals,
272277
const std::vector<array>& cotangents,
@@ -534,11 +539,10 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
534539
return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};
535540
}
536541

537-
std::vector<std::vector<int>> ArgReduce::output_shapes(
538-
const std::vector<array>& inputs) {
542+
std::vector<Shape> ArgReduce::output_shapes(const std::vector<array>& inputs) {
539543
auto out_shape = inputs[0].shape();
540544
out_shape[axis_] = 1;
541-
return {out_shape};
545+
return {std::move(out_shape)};
542546
}
543547

544548
bool ArgSort::is_equivalent(const Primitive& other) const {
@@ -787,6 +791,23 @@ std::pair<std::vector<array>, std::vector<int>> Eigh::vmap(
787791
return {outputs, std::vector<int>(outputs.size(), ax)};
788792
}
789793

794+
std::vector<Shape> Eigh::output_shapes(const std::vector<array>& inputs) {
795+
auto shape = inputs[0].shape();
796+
shape.pop_back(); // Remove last dimension for eigenvalues
797+
if (compute_eigenvectors_) {
798+
return {
799+
std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors
800+
} else {
801+
return {std::move(shape)}; // Only eigenvalues
802+
}
803+
}
804+
805+
bool Eigh::is_equivalent(const Primitive& other) const {
806+
auto& e_other = static_cast<const Eigh&>(other);
807+
return uplo_ == e_other.uplo_ &&
808+
compute_eigenvectors_ == e_other.compute_eigenvectors_;
809+
}
810+
790811
std::vector<array> Concatenate::vjp(
791812
const std::vector<array>& primals,
792813
const std::vector<array>& cotangents,
@@ -881,6 +902,15 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
881902
return axis_ == c_other.axis_;
882903
}
883904

905+
std::vector<Shape> Concatenate::output_shapes(
906+
const std::vector<array>& inputs) {
907+
auto shape = inputs[0].shape();
908+
for (int i = 1; i < inputs.size(); ++i) {
909+
shape[axis_] += inputs[i].shape(axis_);
910+
}
911+
return {std::move(shape)};
912+
}
913+
884914
std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
885915
const std::vector<array>& inputs,
886916
const std::vector<int>& axes) {
@@ -1811,6 +1841,15 @@ bool Gather::is_equivalent(const Primitive& other) const {
18111841
return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_;
18121842
}
18131843

1844+
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
1845+
Shape out_shape;
1846+
if (inputs.size() > 1) {
1847+
out_shape = inputs[0].shape();
1848+
}
1849+
out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end());
1850+
return {std::move(out_shape)};
1851+
}
1852+
18141853
std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
18151854
const std::vector<array>& inputs,
18161855
const std::vector<int>& axes) {
@@ -2184,6 +2223,12 @@ std::pair<std::vector<array>, std::vector<int>> Matmul::vmap(
21842223
return {{matmul(a, b, stream())}, {0}};
21852224
}
21862225

2226+
std::vector<Shape> Matmul::output_shapes(const std::vector<array>& inputs) {
2227+
auto out_shape = inputs[0].shape();
2228+
out_shape.back() = inputs[1].shape(-1);
2229+
return {std::move(out_shape)};
2230+
}
2231+
21872232
std::vector<array> Maximum::vjp(
21882233
const std::vector<array>& primals,
21892234
const std::vector<array>& cotangents,
@@ -2608,6 +2653,15 @@ bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
26082653
transpose_ == qm_other.transpose_;
26092654
}
26102655

2656+
std::vector<Shape> QuantizedMatmul::output_shapes(
2657+
const std::vector<array>& inputs) {
2658+
auto& w = inputs[1];
2659+
int w_outer_dims = (transpose_) ? w.shape(-2) : w.shape(-1) * 32 / bits_;
2660+
auto out_shape = inputs[0].shape();
2661+
out_shape.back() = w_outer_dims;
2662+
return {std::move(out_shape)};
2663+
}
2664+
26112665
std::pair<std::vector<array>, std::vector<int>> GatherQMM::vmap(
26122666
const std::vector<array>& inputs,
26132667
const std::vector<int>& axes) {
@@ -2937,13 +2991,12 @@ bool Reduce::is_equivalent(const Primitive& other) const {
29372991
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
29382992
}
29392993

2940-
std::vector<std::vector<int>> Reduce::output_shapes(
2941-
const std::vector<array>& inputs) {
2942-
std::vector<int> out_shape = inputs[0].shape();
2994+
std::vector<Shape> Reduce::output_shapes(const std::vector<array>& inputs) {
2995+
auto out_shape = inputs[0].shape();
29432996
for (auto i : axes_) {
29442997
out_shape[i] = 1;
29452998
}
2946-
return {out_shape};
2999+
return {std::move(out_shape)};
29473000
}
29483001

29493002
std::vector<array> Round::vjp(
@@ -4209,6 +4262,15 @@ bool Transpose::is_equivalent(const Primitive& other) const {
42094262
return axes_ == t_other.axes_;
42104263
}
42114264

4265+
std::vector<Shape> Transpose::output_shapes(const std::vector<array>& inputs) {
4266+
auto& in = inputs[0];
4267+
Shape shape(in.ndim(), 0);
4268+
for (int i = 0; i < axes_.size(); ++i) {
4269+
shape[i] = in.shape()[axes_[i]];
4270+
}
4271+
return {std::move(shape)};
4272+
}
4273+
42124274
std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
42134275
const std::vector<array>& inputs,
42144276
const std::vector<int>& axes) {

0 commit comments

Comments
 (0)