Skip to content

Commit 25f70d4

Browse files
authored
Fix divide types + floor divide (//) (#138)
* divide types * fix black + test
1 parent 02de234 commit 25f70d4

File tree

4 files changed

+42
-6
lines changed

4 files changed

+42
-6
lines changed

python/mlx/nn/layers/activations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def selu(x):
176176
See also :func:`elu`.
177177
"""
178178
return elu(x, 1.67326) * 1.0507
179+
180+
179181
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
180182
r"""Applies the element-wise function:
181183

python/src/array.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -623,25 +623,41 @@ void init_array(py::module_& m) {
623623
.def(
624624
"__truediv__",
625625
[](const array& a, const ScalarOrArray v) {
626-
return divide(a, to_array(v, float32));
626+
return divide(a, to_array(v, a.dtype()));
627627
},
628628
"other"_a)
629629
.def(
630630
"__div__",
631631
[](const array& a, const ScalarOrArray v) {
632-
return divide(a, to_array(v, float32));
632+
return divide(a, to_array(v, a.dtype()));
633+
},
634+
"other"_a)
635+
.def(
636+
"__floordiv__",
637+
[](const array& a, const ScalarOrArray v) {
638+
auto b = to_array(v, a.dtype());
639+
auto t = promote_types(a.dtype(), b.dtype());
640+
return astype(divide(a, b), t);
633641
},
634642
"other"_a)
635643
.def(
636644
"__rtruediv__",
637645
[](const array& a, const ScalarOrArray v) {
638-
return divide(to_array(v, float32), a);
646+
return divide(to_array(v, a.dtype()), a);
647+
},
648+
"other"_a)
649+
.def(
650+
"__rfloordiv__",
651+
[](const array& a, const ScalarOrArray v) {
652+
auto b = to_array(v, a.dtype());
653+
auto t = promote_types(a.dtype(), b.dtype());
654+
return astype(divide(b, a), t);
639655
},
640656
"other"_a)
641657
.def(
642658
"__rdiv__",
643659
[](const array& a, const ScalarOrArray v) {
644-
return divide(to_array(v, float32), a);
660+
return divide(to_array(v, a.dtype()), a);
645661
},
646662
"other"_a)
647663
.def(

python/tests/mlx_tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
import unittest
5-
from typing import Callable, List, Tuple
5+
from typing import Callable, List, Tuple, Union
66

77
import mlx.core as mx
88
import numpy as np
@@ -21,7 +21,7 @@ def tearDown(self):
2121

2222
def assertEqualArray(
2323
self,
24-
args: List[mx.array | float | int],
24+
args: List[Union[mx.array, float, int]],
2525
mlx_func: Callable[..., mx.array],
2626
expected: mx.array,
2727
atol=1e-2,

python/tests/test_ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,24 @@ def test_divide(self):
236236
self.assertEqual(z.dtype, mx.float32)
237237
self.assertEqual(z.item(), 0.5)
238238

239+
x = x.astype(mx.float16)
240+
z = x / 4.0
241+
self.assertEqual(z.dtype, mx.float16)
242+
243+
x = x.astype(mx.float16)
244+
z = 4.0 / x
245+
self.assertEqual(z.dtype, mx.float16)
246+
247+
x = mx.array(5)
248+
y = mx.array(2)
249+
z = x / y
250+
self.assertEqual(z.dtype, mx.float32)
251+
self.assertEqual(z.item(), 2.5)
252+
253+
z = x // y
254+
self.assertEqual(z.dtype, mx.int32)
255+
self.assertEqual(z.item(), 2)
256+
239257
def test_remainder(self):
240258
for dt in [mx.int32, mx.float32]:
241259
x = mx.array(2, dtype=dt)

0 commit comments

Comments
 (0)