Skip to content

Commit 45adec1

Browse files
authored
Add contiguous_copy_gpu util for copying array (#2379)
1 parent 31fc530 commit 45adec1

20 files changed

+40
-67
lines changed

mlx/backend/cuda/layer_norm.cu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,7 @@ void LayerNorm::eval_gpu(
237237
}
238238
return x;
239239
} else {
240-
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
241-
copy_gpu(x, x_copy, CopyType::General, s);
240+
array x_copy = contiguous_copy_gpu(x, s);
242241
out.copy_shared_buffer(x_copy);
243242
return x_copy;
244243
}
@@ -295,9 +294,7 @@ void LayerNormVJP::eval_gpu(
295294
return x;
296295
}
297296
copied = true;
298-
array x_copy(x.shape(), x.dtype(), nullptr, {});
299-
copy_gpu(x, x_copy, CopyType::General, s);
300-
return x_copy;
297+
return contiguous_copy_gpu(x, s);
301298
};
302299
bool donate_x = inputs[0].is_donatable();
303300
bool donate_g = inputs[3].is_donatable();

mlx/backend/cuda/logsumexp.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
108108
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
109109
return x;
110110
} else {
111-
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
112-
copy_gpu(x, x_copy, CopyType::General, s);
111+
array x_copy = contiguous_copy_gpu(x, s);
113112
encoder.add_temporary(x_copy);
114113
return x_copy;
115114
}

mlx/backend/cuda/matmul.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,7 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
297297
} else if (stx == 1 && sty == arr.shape(-2)) {
298298
return std::make_tuple(true, sty, arr);
299299
} else {
300-
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
301-
copy_gpu(arr, arr_copy, CopyType::General, s);
300+
array arr_copy = contiguous_copy_gpu(arr, s);
302301
enc.add_temporary(arr_copy);
303302
return std::make_tuple(false, arr.shape(-1), arr_copy);
304303
}

mlx/backend/cuda/quantized.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,7 @@ inline array ensure_row_contiguous(
247247
cu::CommandEncoder& enc,
248248
const Stream& s) {
249249
if (!x.flags().row_contiguous) {
250-
array x_copy(x.shape(), x.dtype(), nullptr, {});
251-
copy_gpu(x, x_copy, CopyType::General, s);
250+
array x_copy = contiguous_copy_gpu(x, s);
252251
enc.add_temporary(x_copy);
253252
return x_copy;
254253
} else {

mlx/backend/cuda/reduce.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
4747
}
4848
}
4949
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
50-
array in_copy(in.shape(), in.dtype(), nullptr, {});
51-
copy_gpu(in, in_copy, CopyType::General, s);
50+
array in_copy = contiguous_copy_gpu(in, s);
5251
encoder.add_temporary(in_copy);
5352
in = in_copy;
5453
plan = get_reduction_plan(in, axes_);

mlx/backend/cuda/rms_norm.cu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,7 @@ void RMSNorm::eval_gpu(
206206
}
207207
return x;
208208
} else {
209-
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
210-
copy_gpu(x, x_copy, CopyType::General, s);
209+
array x_copy = contiguous_copy_gpu(x, s);
211210
out.copy_shared_buffer(x_copy);
212211
return x_copy;
213212
}
@@ -259,9 +258,7 @@ void RMSNormVJP::eval_gpu(
259258
return x;
260259
}
261260
copied = true;
262-
array x_copy(x.shape(), x.dtype(), nullptr, {});
263-
copy_gpu(x, x_copy, CopyType::General, s);
264-
return x_copy;
261+
return contiguous_copy_gpu(x, s);
265262
};
266263
bool donate_x = inputs[0].is_donatable();
267264
bool donate_g = inputs[2].is_donatable();

mlx/backend/cuda/scan.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
379379
in.flags());
380380
}
381381
} else {
382-
array arr_copy(in.shape(), in.dtype(), nullptr, {});
383-
copy_gpu(in, arr_copy, CopyType::General, s);
384-
in = std::move(arr_copy);
382+
in = contiguous_copy_gpu(in, s);
385383
out.copy_shared_buffer(in);
386384
}
387385

mlx/backend/cuda/softmax.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
125125
}
126126
return x;
127127
} else {
128-
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
129-
copy_gpu(x, x_copy, CopyType::General, s);
128+
array x_copy = contiguous_copy_gpu(x, s);
130129
out.copy_shared_buffer(x_copy);
131130
return x_copy;
132131
}

mlx/backend/cuda/sort.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
7272
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
7373
if (!is_segmented_sort) {
7474
array trans = swapaxes_in_eval(in, axis, last_dim);
75-
in = array(trans.shape(), trans.dtype(), nullptr, {});
76-
copy_gpu(trans, in, CopyType::General, s);
75+
in = contiguous_copy_gpu(trans, s);
7776
encoder.add_temporary(in);
7877
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
7978
encoder.add_temporary(out);

mlx/backend/gpu/copy.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,10 @@ void copy_gpu_inplace(
4646
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
4747
}
4848

49+
array contiguous_copy_gpu(const array& arr, const Stream& s) {
50+
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
51+
copy_gpu(arr, arr_copy, CopyType::General, s);
52+
return arr_copy;
53+
}
54+
4955
} // namespace mlx::core

0 commit comments

Comments
 (0)