Skip to content

Commit 975e265

Browse files
feat: Add numpy constants (#428)
* add numpy constants * feat: add unittests * add newaxis * add test for newaxis transformation * refactor
1 parent c92a134 commit 975e265

File tree

4 files changed

+86
-0
lines changed

4 files changed

+86
-0
lines changed

python/src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pybind11_add_module(
1212
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
1313
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
1414
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
15+
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
1516
)
1617

1718
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)

python/src/constants.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// init_constants.cpp
2+
3+
#include <pybind11/pybind11.h>
4+
#include <limits>
5+
6+
namespace py = pybind11;
7+
8+
void init_constants(py::module_& m) {
9+
m.attr("Inf") = std::numeric_limits<double>::infinity();
10+
m.attr("Infinity") = std::numeric_limits<double>::infinity();
11+
m.attr("NAN") = NAN;
12+
m.attr("NINF") = -std::numeric_limits<double>::infinity();
13+
m.attr("NZERO") = -0.0;
14+
m.attr("NaN") = NAN;
15+
m.attr("PINF") = std::numeric_limits<double>::infinity();
16+
m.attr("PZERO") = 0.0;
17+
m.attr("e") = 2.71828182845904523536028747135266249775724709369995;
18+
m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421;
19+
m.attr("inf") = std::numeric_limits<double>::infinity();
20+
m.attr("infty") = std::numeric_limits<double>::infinity();
21+
m.attr("nan") = NAN;
22+
m.attr("newaxis") = pybind11::none();
23+
m.attr("pi") = 3.1415926535897932384626433;
24+
}

python/src/mlx.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ void init_transforms(py::module_&);
1616
void init_random(py::module_&);
1717
void init_fft(py::module_&);
1818
void init_linalg(py::module_&);
19+
void init_constants(py::module_&);
1920

2021
PYBIND11_MODULE(core, m) {
2122
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
@@ -31,5 +32,6 @@ PYBIND11_MODULE(core, m) {
3132
init_random(m);
3233
init_fft(m);
3334
init_linalg(m);
35+
init_constants(m);
3436
m.attr("__version__") = TOSTRING(_VERSION_);
3537
}

python/tests/test_constants.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright © 2023 Apple Inc.
2+
3+
import unittest
4+
5+
import mlx.core as mx
6+
import mlx_tests
7+
import numpy as np
8+
9+
10+
class TestConstants(mlx_tests.MLXTestCase):
11+
def test_constants_values(self):
12+
# Check if mlx constants match expected values
13+
self.assertAlmostEqual(mx.Inf, float("inf"))
14+
self.assertAlmostEqual(mx.Infinity, float("inf"))
15+
self.assertTrue(np.isnan(mx.NAN))
16+
self.assertAlmostEqual(mx.NINF, float("-inf"))
17+
self.assertEqual(mx.NZERO, -0.0)
18+
self.assertTrue(np.isnan(mx.NaN))
19+
self.assertAlmostEqual(mx.PINF, float("inf"))
20+
self.assertEqual(mx.PZERO, 0.0)
21+
self.assertAlmostEqual(
22+
mx.e, 2.71828182845904523536028747135266249775724709369995
23+
)
24+
self.assertAlmostEqual(
25+
mx.euler_gamma, 0.5772156649015328606065120900824024310421
26+
)
27+
self.assertAlmostEqual(mx.inf, float("inf"))
28+
self.assertAlmostEqual(mx.infty, float("inf"))
29+
self.assertTrue(np.isnan(mx.nan))
30+
self.assertIsNone(mx.newaxis)
31+
self.assertAlmostEqual(mx.pi, 3.1415926535897932384626433)
32+
33+
def test_constants_availability(self):
34+
# Check if mlx constants are available
35+
self.assertTrue(hasattr(mx, "Inf"))
36+
self.assertTrue(hasattr(mx, "Infinity"))
37+
self.assertTrue(hasattr(mx, "NAN"))
38+
self.assertTrue(hasattr(mx, "NINF"))
39+
self.assertTrue(hasattr(mx, "NaN"))
40+
self.assertTrue(hasattr(mx, "PINF"))
41+
self.assertTrue(hasattr(mx, "NZERO"))
42+
self.assertTrue(hasattr(mx, "PZERO"))
43+
self.assertTrue(hasattr(mx, "e"))
44+
self.assertTrue(hasattr(mx, "euler_gamma"))
45+
self.assertTrue(hasattr(mx, "inf"))
46+
self.assertTrue(hasattr(mx, "infty"))
47+
self.assertTrue(hasattr(mx, "nan"))
48+
self.assertTrue(hasattr(mx, "newaxis"))
49+
self.assertTrue(hasattr(mx, "pi"))
50+
51+
def test_newaxis_for_reshaping_arrays(self):
52+
arr_1d = mx.array([1, 2, 3, 4, 5])
53+
arr_2d_column = arr_1d[:, mx.newaxis]
54+
expected_result = mx.array([[1], [2], [3], [4], [5]])
55+
self.assertTrue(mx.array_equal(arr_2d_column, expected_result))
56+
57+
58+
if __name__ == "__main__":
59+
unittest.main()

0 commit comments

Comments
 (0)