Skip to content

Commit 881f09b

Browse files
authored
Allow querying the allocator for the buffer size (#1404)
1 parent 8b30acd commit 881f09b

File tree

10 files changed

+42
-13
lines changed

10 files changed

+42
-13
lines changed

mlx/allocator.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,22 @@ void free(Buffer buffer) {
2323
}
2424

2525
Buffer CommonAllocator::malloc(size_t size, bool) {
26-
return Buffer{std::malloc(size)};
26+
void* ptr = std::malloc(size + sizeof(size_t));
27+
if (ptr != nullptr) {
28+
*static_cast<size_t*>(ptr) = size;
29+
}
30+
return Buffer{ptr};
2731
}
2832

2933
void CommonAllocator::free(Buffer buffer) {
30-
std::free(buffer.raw_ptr());
34+
std::free(buffer.ptr());
35+
}
36+
37+
size_t CommonAllocator::size(Buffer buffer) const {
38+
if (buffer.ptr() == nullptr) {
39+
return 0;
40+
}
41+
return *static_cast<size_t*>(buffer.ptr());
3142
}
3243

3344
Buffer malloc_or_wait(size_t size) {

mlx/allocator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class Allocator {
4141
public:
4242
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
4343
virtual void free(Buffer buffer) = 0;
44+
virtual size_t size(Buffer buffer) const = 0;
4445

4546
Allocator() = default;
4647
Allocator(const Allocator& other) = delete;
@@ -57,6 +58,7 @@ class CommonAllocator : public Allocator {
5758
public:
5859
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
5960
virtual void free(Buffer buffer) override;
61+
virtual size_t size(Buffer buffer) const override;
6062

6163
private:
6264
CommonAllocator() = default;

mlx/array.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,10 @@ class array {
324324
return array_desc_->data->buffer;
325325
}
326326

327+
size_t buffer_size() const {
328+
return allocator::allocator().size(buffer());
329+
}
330+
327331
// Return a copy of the shared pointer
328332
// to the array::Data struct
329333
std::shared_ptr<Data> data_shared_ptr() const {

mlx/backend/common/binary.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,15 @@ void set_binary_op_output_data(
4343
array& out,
4444
BinaryOpType bopt,
4545
bool donate_with_move = false) {
46+
bool b_donatable = is_donatable(b, out);
47+
bool a_donatable = is_donatable(a, out);
4648
switch (bopt) {
4749
case BinaryOpType::ScalarScalar:
4850
out.set_data(
4951
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
5052
break;
5153
case BinaryOpType::ScalarVector:
52-
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
54+
if (b_donatable) {
5355
if (donate_with_move) {
5456
out.move_shared_buffer(b);
5557
} else {
@@ -64,7 +66,7 @@ void set_binary_op_output_data(
6466
}
6567
break;
6668
case BinaryOpType::VectorScalar:
67-
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
69+
if (a_donatable) {
6870
if (donate_with_move) {
6971
out.move_shared_buffer(a);
7072
} else {
@@ -79,13 +81,13 @@ void set_binary_op_output_data(
7981
}
8082
break;
8183
case BinaryOpType::VectorVector:
82-
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
84+
if (a_donatable) {
8385
if (donate_with_move) {
8486
out.move_shared_buffer(a);
8587
} else {
8688
out.copy_shared_buffer(a);
8789
}
88-
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
90+
} else if (b_donatable) {
8991
if (donate_with_move) {
9092
out.move_shared_buffer(b);
9193
} else {
@@ -100,16 +102,14 @@ void set_binary_op_output_data(
100102
}
101103
break;
102104
case BinaryOpType::General:
103-
if (a.is_donatable() && a.flags().row_contiguous &&
104-
a.itemsize() == out.itemsize() && a.size() == out.size()) {
105+
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
105106
if (donate_with_move) {
106107
out.move_shared_buffer(a);
107108
} else {
108109
out.copy_shared_buffer(a);
109110
}
110111
} else if (
111-
b.is_donatable() && b.flags().row_contiguous &&
112-
b.itemsize() == out.itemsize() && b.size() == out.size()) {
112+
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
113113
if (donate_with_move) {
114114
out.move_shared_buffer(b);
115115
} else {

mlx/backend/common/ternary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void set_ternary_op_output_data(
4141
TernaryOpType topt,
4242
bool donate_with_move = false) {
4343
auto maybe_donate = [&out, donate_with_move](const array& x) {
44-
if (x.is_donatable() && x.itemsize() == out.itemsize()) {
44+
if (is_donatable(x, out)) {
4545
if (donate_with_move) {
4646
out.move_shared_buffer(x);
4747
} else {

mlx/backend/common/unary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace mlx::core {
1212
namespace {
1313

1414
void set_unary_output_data(const array& in, array& out) {
15-
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
15+
if (is_donatable(in, out)) {
1616
out.copy_shared_buffer(in);
1717
} else {
1818
auto size = in.data_size();

mlx/backend/common/utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,11 @@ inline auto check_contiguity(
155155
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
156156
}
157157

158+
inline bool is_donatable(const array& in, const array& out) {
159+
constexpr size_t donation_extra = 16384;
160+
161+
return in.is_donatable() && in.itemsize() == out.itemsize() &&
162+
in.buffer_size() <= out.nbytes() + donation_extra;
163+
}
164+
158165
} // namespace mlx::core

mlx/backend/metal/allocator.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ void MetalAllocator::free(Buffer buffer) {
241241
}
242242
}
243243

244+
size_t MetalAllocator::size(Buffer buffer) const {
245+
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
246+
}
247+
244248
MetalAllocator& allocator() {
245249
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
246250
// not be called on exit and all the buffers will be leaked. This is necessary

mlx/backend/metal/allocator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class MetalAllocator : public allocator::Allocator {
5656
public:
5757
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
5858
virtual void free(Buffer buffer) override;
59+
virtual size_t size(Buffer buffer) const override;
5960
size_t get_active_memory() {
6061
return active_memory_;
6162
};

mlx/backend/no_metal/allocator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Allocator& allocator() {
1010
}
1111

1212
void* Buffer::raw_ptr() {
13-
return ptr_;
13+
return static_cast<size_t*>(ptr_) + 1;
1414
}
1515

1616
} // namespace mlx::core::allocator

0 commit comments

Comments
 (0)