Skip to content

Commit dad1b00

Browse files
authored
fix (#1523)
1 parent 430ffef commit dad1b00

File tree

2 files changed

+3
-11
lines changed

2 files changed

+3
-11
lines changed

mlx/backend/metal/device.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,11 @@ CommandEncoder::~CommandEncoder() {
130130
enc_->release();
131131
}
132132

133-
void CommandEncoder::set_array(
133+
void CommandEncoder::set_input_array(
134134
const array& a,
135135
int idx,
136136
int64_t offset /* = 0 */) {
137+
all_inputs_.insert(a.buffer().ptr());
137138
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
138139
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
139140
// Insert a barrier
@@ -149,20 +150,12 @@ void CommandEncoder::set_array(
149150
enc_->setBuffer(a_buf, base_offset, idx);
150151
}
151152

152-
void CommandEncoder::set_input_array(
153-
const array& a,
154-
int idx,
155-
int64_t offset /* = 0 */) {
156-
all_inputs_.insert(a.buffer().ptr());
157-
set_array(a, idx, offset);
158-
}
159-
160153
void CommandEncoder::set_output_array(
161154
array& a,
162155
int idx,
163156
int64_t offset /* = 0 */) {
164157
// Add barriers before adding the output to the output set
165-
set_array(a, idx, offset);
158+
set_input_array(a, idx, offset);
166159
all_outputs_.insert(a.buffer().ptr());
167160
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
168161
if (concurrent_) {

mlx/backend/metal/device.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ struct CommandEncoder {
8383
};
8484

8585
private:
86-
void set_array(const array& a, int idx, int64_t offset);
8786
MTL::ComputeCommandEncoder* enc_;
8887
bool concurrent_{false};
8988
std::unordered_set<MTL::Resource*> outputs_;

0 commit comments

Comments
 (0)