Skip to content

Commit dd4f53d

Browse files
authored
use fp32 for testing, add more complex ops (#2322)
1 parent 3d5e17e commit dd4f53d

File tree

6 files changed

+68
-40
lines changed

6 files changed

+68
-40
lines changed

mlx/backend/cuda/device/unary_ops.cuh

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ struct ArcCos {
2727
__device__ T operator()(T x) {
2828
return acos(x);
2929
}
30+
31+
__device__ cuComplex operator()(cuComplex x);
3032
};
3133

3234
struct ArcCosh {
@@ -41,6 +43,8 @@ struct ArcSin {
4143
__device__ T operator()(T x) {
4244
return asin(x);
4345
}
46+
47+
__device__ cuComplex operator()(cuComplex x);
4448
};
4549

4650
struct ArcSinh {
@@ -55,6 +59,8 @@ struct ArcTan {
5559
__device__ T operator()(T x) {
5660
return atan(x);
5761
}
62+
63+
__device__ cuComplex operator()(cuComplex x);
5864
};
5965

6066
struct ArcTanh {
@@ -261,13 +267,6 @@ struct Round {
261267
}
262268
};
263269

264-
struct Rsqrt {
265-
template <typename T>
266-
__device__ T operator()(T x) {
267-
return rsqrt(x);
268-
}
269-
};
270-
271270
struct Sigmoid {
272271
template <typename T>
273272
__device__ T operator()(T x) {
@@ -333,6 +332,29 @@ struct Sqrt {
333332
__device__ T operator()(T x) {
334333
return sqrt(x);
335334
}
335+
336+
__device__ cuComplex operator()(cuComplex x) {
337+
auto xr = cuCrealf(x);
338+
auto xi = cuCimagf(x);
339+
if (xr == 0.0f && xi == 0.0f) {
340+
return {0.0f, 0.0f};
341+
}
342+
auto r = cuCrealf(Abs{}(x));
343+
auto a = sqrt((r + xr) / 2.0f);
344+
auto b_abs = sqrt((r - xr) / 2.0f);
345+
auto b = copysign(b_abs, xi);
346+
return {a, b};
347+
}
348+
};
349+
350+
struct Rsqrt {
351+
template <typename T>
352+
__device__ T operator()(T x) {
353+
return rsqrt(x);
354+
}
355+
__device__ cuComplex operator()(cuComplex x) {
356+
return 1.0f / Sqrt{}(x);
357+
}
336358
};
337359

338360
struct Tan {
@@ -365,4 +387,22 @@ struct Tanh {
365387
}
366388
};
367389

390+
__device__ cuComplex ArcCos::operator()(cuComplex x) {
391+
auto i = cuComplex{0.0, 1.0};
392+
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
393+
return {cuCimagf(y), -cuCrealf(y)};
394+
};
395+
396+
__device__ cuComplex ArcSin::operator()(cuComplex x) {
397+
auto i = cuComplex{0.0f, 1.0f};
398+
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
399+
return {cuCimagf(y), -cuCrealf(y)};
400+
};
401+
402+
__device__ cuComplex ArcTan::operator()(cuComplex x) {
403+
auto i = cuComplex{0.0f, 1.0f};
404+
auto ix = i * x;
405+
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
406+
};
407+
368408
} // namespace mlx::core::cu

mlx/backend/cuda/layer_norm.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,6 @@ void LayerNormVJP::eval_gpu(
342342
encoder.add_temporary(gw_temp);
343343
}
344344
}
345-
gw.set_data(allocator::malloc(gw.nbytes()));
346-
gb.set_data(allocator::malloc(gb.nbytes()));
347345

348346
// Finish with the gradient for b in case we had a b.
349347
if (gb.ndim() == 1 && gb.size() == axis_size) {

mlx/backend/cuda/rms_norm.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,6 @@ void RMSNormVJP::eval_gpu(
304304
encoder.add_temporary(gw_temp);
305305
}
306306
}
307-
gw.set_data(allocator::malloc(gw.nbytes()));
308307

309308
encoder.set_input_array(x);
310309
encoder.set_input_array(w);

mlx/backend/cuda/unary.cu

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,35 @@ namespace cu {
2020
template <typename Op, typename In, typename Out>
2121
constexpr bool supports_unary_op() {
2222
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
23-
std::is_same_v<Op, Sign>) {
23+
std::is_same_v<Op, Sign> || std::is_same_v<Op, Square>) {
2424
return std::is_same_v<In, Out>;
2525
}
26-
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcCosh> ||
27-
std::is_same_v<Op, ArcSin> || std::is_same_v<Op, ArcSinh> ||
28-
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
29-
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
30-
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Sigmoid> ||
31-
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
26+
if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||
27+
std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||
28+
std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||
29+
std::is_same_v<Op, Sigmoid>) {
3230
return std::is_same_v<In, Out> && is_floating_v<In>;
3331
}
34-
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
35-
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p>) {
36-
return std::is_same_v<In, Out> && is_inexact_v<In>;
37-
}
3832
if (std::is_same_v<Op, BitwiseInvert>) {
3933
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
4034
!std::is_same_v<In, bool>;
4135
}
42-
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor> ||
43-
std::is_same_v<Op, Square>) {
36+
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
4437
return std::is_same_v<In, Out> && !std::is_same_v<In, complex64_t>;
4538
}
4639
if (std::is_same_v<Op, Conjugate>) {
4740
return std::is_same_v<In, Out> && std::is_same_v<In, complex64_t>;
4841
}
49-
if (std::is_same_v<Op, Cos> || std::is_same_v<Op, Cosh> ||
50-
std::is_same_v<Op, Exp> || std::is_same_v<Op, Round> ||
51-
std::is_same_v<Op, Sin> || std::is_same_v<Op, Sinh> ||
52-
std::is_same_v<Op, Tan> || std::is_same_v<Op, Tanh>) {
53-
return std::is_same_v<In, Out> &&
54-
(is_floating_v<In> || std::is_same_v<In, complex64_t>);
42+
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
43+
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
44+
std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||
45+
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
46+
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||
47+
std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||
48+
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||
49+
std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||
50+
std::is_same_v<Op, Tanh>) {
51+
return std::is_same_v<In, Out> && is_inexact_v<In>;
5552
}
5653
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
5754
return std::is_same_v<In, complex64_t> && std::is_same_v<Out, float>;

python/tests/cuda_skip.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,15 @@
11
cuda_skip = {
2-
"TestArray.test_api",
3-
"TestBF16.test_arg_reduction_ops",
4-
"TestBlas.test_complex_gemm",
5-
"TestEinsum.test_ellipses",
6-
"TestEinsum.test_opt_einsum_test_cases",
72
"TestLoad.test_load_f8_e4m3",
8-
"TestLayers.test_group_norm",
9-
"TestLayers.test_pooling",
103
"TestLayers.test_quantized_embedding",
11-
"TestLayers.test_sin_pe",
12-
"TestLayers.test_upsample",
13-
"TestOps.test_complex_ops",
144
"TestOps.test_dynamic_slicing",
155
"TestReduce.test_dtypes",
16-
"TestUpsample.test_torch_upsample",
176
# Block masked matmul NYI
187
"TestBlas.test_block_masked_matmul",
198
# Gather matmul NYI
209
"TestBlas.test_gather_matmul",
2110
"TestBlas.test_gather_matmul_grad",
2211
# Scan NYI
12+
"TestArray.test_api",
2313
"TestAutograd.test_cumprod_grad",
2414
"TestOps.test_scans",
2515
"TestOps.test_logcumsumexp",

python/tests/mlx_tests.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Copyright © 2023 Apple Inc.
22

33
import os
4+
5+
# Use regular fp32 precision for tests
6+
os.environ["MLX_ENABLE_TF32"] = "0"
7+
48
import platform
59
import unittest
610
from typing import Any, Callable, List, Tuple, Union

0 commit comments

Comments
 (0)