Skip to content

Commit 6a3acf2

Browse files
authored
[CUDA] Set bias as input when using bias epilogue (#2584)
1 parent d6977f2 commit 6a3acf2

File tree

4 files changed

+13
-8
lines changed

4 files changed

+13
-8
lines changed

mlx/backend/cuda/gemms/cublas_gemm.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,20 @@ void CublasGemm::set_out(
230230
batch_stride);
231231
}
232232

233-
void CublasGemm::set_bias(void* bias) {
233+
void CublasGemm::set_bias(cu::CommandEncoder& encoder, const array& bias) {
234+
encoder.set_input_array(bias);
234235
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
235236
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
236237
matmul_desc_,
237238
CUBLASLT_MATMUL_DESC_EPILOGUE,
238239
&epilogue,
239240
sizeof(epilogue)));
241+
auto* bias_ptr = bias.data<void>();
240242
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
241-
matmul_desc_, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
243+
matmul_desc_,
244+
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
245+
&bias_ptr,
246+
sizeof(bias_ptr)));
242247
}
243248

244249
void CublasGemm::run(

mlx/backend/cuda/gemms/cublas_gemm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class CublasGemm {
5555
int32_t batch_count,
5656
int64_t batch_stride);
5757

58-
void set_bias(void* bias);
58+
void set_bias(cu::CommandEncoder& encoder, const array& bias);
5959

6060
void run(
6161
cu::CommandEncoder& encoder,

mlx/backend/cuda/matmul.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void gemm_and_bias(
4141
array& out,
4242
const array& a,
4343
const array& b,
44-
void* bias = nullptr,
44+
const std::optional<array>& bias = std::nullopt,
4545
float alpha = 1.0f) {
4646
// Check and collapse batch dimensions
4747
auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b);
@@ -93,7 +93,7 @@ void gemm_and_bias(
9393
a_batch_strides.back(),
9494
b_batch_strides.back());
9595
if (bias) {
96-
gemm.set_bias(bias);
96+
gemm.set_bias(encoder, *bias);
9797
}
9898
gemm.run(
9999
encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides, alpha);
@@ -171,7 +171,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
171171
out,
172172
a,
173173
b,
174-
c.data<void>(),
174+
c,
175175
alpha_);
176176
return;
177177
}

python/tests/test_blas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,15 +702,15 @@ def test_addmm(self):
702702
b = mx.ones((5, 5))
703703
out = mx.addmm(a, b, a, beta=beta, alpha=alpha)
704704
expected = beta * a + alpha * (b @ a)
705-
self.assertTrue(mx.allclose(expected, out, atol=1e-5))
705+
self.assertTrue(mx.allclose(expected, out))
706706

707707
# Broadcast c
708708
a = mx.ones((5, 5))
709709
b = mx.ones((5, 5))
710710
c = mx.ones((1, 5))
711711
out = mx.addmm(c, a, b, beta=beta, alpha=alpha)
712712
expected = beta * c + alpha * (a @ b)
713-
self.assertTrue(mx.allclose(expected, out, atol=1e-5))
713+
self.assertTrue(mx.allclose(expected, out))
714714

715715
def test_addmm_grad(self):
716716
def make_ref_addmm(alpha, beta):

0 commit comments

Comments
 (0)