Skip to content

Commit 44c82f0

Browse files
committed
Refactor
1 parent b23d94b commit 44c82f0

File tree

2 files changed

+104
-67
lines changed

2 files changed

+104
-67
lines changed

src/Base/SmallMatrix.H

Lines changed: 103 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ namespace
4444
get_value_type_t<T>
4545
>;
4646
}
47-
}
48-
49-
namespace pyAMReX
50-
{
51-
using namespace amrex;
5247

5348
/** CPU: __array_interface__ v3
5449
*
@@ -58,12 +53,14 @@ namespace pyAMReX
5853
class T,
5954
int NRows,
6055
int NCols,
61-
Order ORDER = Order::F,
56+
amrex::Order ORDER = amrex::Order::F,
6257
int StartIndex = 0
6358
>
6459
py::dict
65-
array_interface (SmallMatrix<T, NRows, NCols, ORDER, StartIndex> const & m)
60+
array_interface (amrex::SmallMatrix<T, NRows, NCols, ORDER, StartIndex> const & m)
6661
{
62+
using namespace amrex;
63+
6764
auto d = py::dict();
6865
// provide C index order for shape and strides
6966
auto shape = m.ordering == Order::F ? py::make_tuple(
@@ -110,27 +107,26 @@ namespace pyAMReX
110107
return d;
111108
}
112109

113-
template<
114-
class T,
115-
int NRows,
116-
int NCols,
117-
Order ORDER = Order::F,
118-
int StartIndex = 0
119-
>
120-
void make_SmallMatrix(py::module &m, std::string typestr)
110+
template<class SM>
111+
py::class_<SM>
112+
make_SmallMatrix_or_Vector (py::module &m, std::string typestr)
121113
{
122114
using namespace amrex;
123115

116+
using T = typename SM::value_type;
124117
using T_no_cv = std::remove_cv_t<T>;
118+
static constexpr int row_size = SM::row_size;
119+
static constexpr int column_size = SM::column_size;
120+
static constexpr Order ordering = SM::ordering;
121+
static constexpr int starting_index = SM::starting_index;
125122

126123
// dispatch simpler via: py::format_descriptor<T>::format() naming
127124
// but note the _const suffix that might be needed
128125
auto const sm_name = std::string("SmallMatrix_")
129-
.append(std::to_string(NRows)).append("x").append(std::to_string(NCols))
130-
.append("_").append(ORDER == Order::F ? "F" : "C")
131-
.append("_SI").append(std::to_string(StartIndex))
126+
.append(std::to_string(row_size)).append("x").append(std::to_string(column_size))
127+
.append("_").append(ordering == Order::F ? "F" : "C")
128+
.append("_SI").append(std::to_string(starting_index))
132129
.append("_").append(typestr);
133-
using SM = SmallMatrix<T, NRows, NCols, ORDER, StartIndex>;
134130
py::class_< SM > py_sm(m, sm_name.c_str());
135131
py_sm
136132
.def("__repr__",
@@ -177,7 +173,7 @@ namespace pyAMReX
177173
py::format_descriptor<T_no_cv>::format() +
178174
"' and received '" + buf.format + "'!");
179175

180-
// TODO: check that strides are either exact or None in buf
176+
// TODO: check that strides are either exact or None in buf (e.g., F or C contiguous)
181177
// TODO: transpose if SM order is not C?
182178

183179
auto sm = std::make_unique< SM >();
@@ -197,7 +193,7 @@ namespace pyAMReX
197193
// CPU: __array_interface__ v3
198194
// https://numpy.org/doc/stable/reference/arrays.interface.html
199195
.def_property_readonly("__array_interface__", [](SM const & sm) {
200-
return pyAMReX::array_interface(sm);
196+
return array_interface(sm);
201197
})
202198

203199
// CPU: __array_function__ interface (TODO)
@@ -210,8 +206,9 @@ namespace pyAMReX
210206

211207
// Nvidia GPUs: __cuda_array_interface__ v3
212208
// https://numba.readthedocs.io/en/latest/cuda/cuda_array_interface.html
213-
.def_property_readonly("__cuda_array_interface__", [](SM const & sm) {
214-
auto d = pyAMReX::array_interface(sm);
209+
.def_property_readonly("__cuda_array_interface__", [](SM const & sm)
210+
{
211+
auto d = array_interface(sm);
215212

216213
// data:
217214
// Because the user of the interface may or may not be in the same context, the most common case is to use cuPointerGetAttribute with CU_POINTER_ATTRIBUTE_DEVICE_POINTER in the CUDA driver API (or the equivalent CUDA Runtime API) to retrieve a device pointer that is usable in the currently active context.
@@ -239,6 +236,21 @@ namespace pyAMReX
239236
// https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
240237
// https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol
241238

239+
;
240+
241+
return py_sm;
242+
}
243+
244+
template<class SM>
245+
void add_matrix_methods (py::class_<SM> & py_sm)
246+
{
247+
using T = typename SM::value_type;
248+
using T_no_cv = std::remove_cv_t<T>;
249+
static constexpr int row_size = SM::row_size;
250+
static constexpr int column_size = SM::column_size;
251+
static constexpr int starting_index = SM::starting_index;
252+
253+
py_sm
242254
.def("dot", &SM::dot)
243255
.def("prod", &SM::product) // NumPy name
244256
.def("set_val", &SM::setVal)
@@ -254,66 +266,102 @@ namespace pyAMReX
254266

255267
// getter
256268
.def("__getitem__", [](SM & sm, std::array<int, 2> const & key){
257-
if (key[0] < SM::starting_index || key[0] >= SM::row_size + SM::starting_index ||
258-
key[1] < SM::starting_index || key[1] >= SM::column_size + SM::starting_index)
269+
if (key[0] < starting_index || key[0] >= row_size + starting_index ||
270+
key[1] < starting_index || key[1] >= column_size + starting_index)
259271
throw std::runtime_error(
260272
"Index out of bounds: [" +
261273
std::to_string(key[0]) + ", " +
262274
std::to_string(key[1]) + "]");
263275
return sm(key[0], key[1]);
264276
})
265277
;
278+
266279
// setter
267280
if constexpr (is_not_const<T>())
268281
{
269282
py_sm
270-
.def("__setitem__", [](SM & sm, std::array<int, 2> const & key, T const value){
283+
.def("__setitem__", [](SM & sm, std::array<int, 2> const & key, T_no_cv const value){
271284
if (key[0] < SM::starting_index || key[0] >= SM::row_size + SM::starting_index ||
272285
key[1] < SM::starting_index || key[1] >= SM::column_size + SM::starting_index)
286+
{
273287
throw std::runtime_error(
274288
"Index out of bounds: [" +
275289
std::to_string(key[0]) + ", " +
276290
std::to_string(key[1]) + "]");
291+
}
277292
sm(key[0], key[1]) = value;
278293
})
279294
;
280295
}
281296

282297
// square matrix
283-
if constexpr (NRows == NCols)
298+
if constexpr (row_size == column_size)
284299
{
285300
py_sm
286-
.def_static("identity", [](){ return SM::Identity(); })
287-
.def("trace", &SM::trace)
288-
.def("transpose_in_place", &SM::transposeInPlace)
301+
.def_static("identity", []() { return SM::Identity(); })
302+
.def("trace", [](SM & sm){ return sm.trace(); })
303+
.def("transpose_in_place", [](SM & sm){ return sm.transposeInPlace(); })
289304
;
290305
}
306+
}
291307

292-
// vector
293-
if constexpr (NRows == 1 || NCols == 1)
294-
{
295-
py_sm
296-
.def("__getitem__", [](SM & sm, int key){
297-
if (key < SM::starting_index || key >= SM::column_size * SM::row_size + SM::starting_index)
298-
throw std::runtime_error("Index out of bounds: " + std::to_string(key));
299-
return sm(key);
300-
})
301-
.def("__setitem__", [](SM & sm, int key, T const value){
302-
if (key < SM::starting_index || key >= SM::column_size * SM::row_size + SM::starting_index)
303-
throw std::runtime_error("Index out of bounds: " + std::to_string(key));
304-
sm(key) = value;
305-
})
306-
;
307-
} else {
308-
using SV = SmallMatrix<T, NRows, 1, Order::F, StartIndex>;
309-
using SRV = SmallMatrix<T, 1, NCols, Order::F, StartIndex>;
308+
template<class T_SV>
309+
void add_get_set_Vector (py::class_<T_SV> &py_v)
310+
{
311+
using self = T_SV;
312+
using T = typename T_SV::value_type;
313+
using T_no_cv = std::remove_cv_t<T>;
310314

311-
// operators for matrix-matrix & matrix-vector
312-
py_sm
313-
.def(py::self * py::self)
314-
.def(py::self * SV())
315-
.def(SRV() * py::self)
316-
;
317-
}
315+
py_v
316+
.def("__getitem__", [](self & sm, int key){
317+
if (key < self::starting_index || key >= self::column_size * self::row_size + self::starting_index)
318+
throw std::runtime_error("Index out of bounds: " + std::to_string(key));
319+
return sm(key);
320+
})
321+
.def("__setitem__", [](self & sm, int key, T_no_cv const value){
322+
if (key < self::starting_index || key >= self::column_size * self::row_size + self::starting_index)
323+
throw std::runtime_error("Index out of bounds: " + std::to_string(key));
324+
sm(key) = value;
325+
})
326+
;
327+
}
328+
}
329+
330+
namespace pyAMReX
331+
{
332+
template<
333+
class T,
334+
int NRows,
335+
int NCols,
336+
amrex::Order ORDER = amrex::Order::F,
337+
int StartIndex = 0
338+
>
339+
void make_SmallMatrix (py::module &m, std::string typestr)
340+
{
341+
using namespace amrex;
342+
343+
using SM = SmallMatrix<T, NRows, NCols, ORDER, StartIndex>;
344+
using SV = SmallMatrix<T, NRows, 1, Order::F, StartIndex>;
345+
using SRV = SmallMatrix<T, 1, NCols, Order::F, StartIndex>;
346+
347+
py::class_<SM> py_sm = make_SmallMatrix_or_Vector<SM>(m, typestr);
348+
py::class_<SV> py_sv = make_SmallMatrix_or_Vector<SV>(m, typestr);
349+
py::class_<SRV> py_srv = make_SmallMatrix_or_Vector<SRV>(m, typestr);
350+
351+
// methods, getter, setter
352+
add_matrix_methods(py_sm);
353+
add_matrix_methods(py_sv);
354+
add_matrix_methods(py_srv);
355+
356+
// vector setter/getter
357+
add_get_set_Vector(py_sv);
358+
add_get_set_Vector(py_srv);
359+
360+
// operators for matrix-matrix & matrix-vector
361+
py_sm
362+
.def(py::self * py::self)
363+
.def(py::self * SV())
364+
.def(SRV() * py::self)
365+
;
318366
}
319367
}

src/Base/SmallMatrix.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,9 @@ void init_SmallMatrix (py::module &m)
1616
{
1717
constexpr int NRows = 6;
1818
constexpr int NCols = 6;
19-
constexpr Order ORDER = Order::F;
19+
constexpr amrex::Order ORDER = amrex::Order::F;
2020
constexpr int StartIndex = 1;
2121

22-
// Vector
23-
make_SmallMatrix< float, NRows, 1, ORDER, StartIndex >(m, "float");
24-
make_SmallMatrix< double, NRows, 1, ORDER, StartIndex >(m, "double");
25-
make_SmallMatrix< long double, NRows, 1, ORDER, StartIndex >(m, "longdouble");
26-
27-
// RowVector
28-
make_SmallMatrix< float, 1, NCols, ORDER, StartIndex >(m, "float");
29-
make_SmallMatrix< double, 1, NCols, ORDER, StartIndex >(m, "double");
30-
make_SmallMatrix< long double, 1, NCols, ORDER, StartIndex >(m, "longdouble");
31-
32-
// Matrix
3322
make_SmallMatrix< float, NRows, NCols, ORDER, StartIndex >(m, "float");
3423
make_SmallMatrix< double, NRows, NCols, ORDER, StartIndex >(m, "double");
3524
make_SmallMatrix< long double, NRows, NCols, ORDER, StartIndex >(m, "longdouble");

0 commit comments

Comments
 (0)