@@ -44,11 +44,6 @@ namespace
4444 get_value_type_t <T>
4545 >;
4646 }
47- }
48-
49- namespace pyAMReX
50- {
51- using namespace amrex ;
5247
5348 /* * CPU: __array_interface__ v3
5449 *
@@ -58,12 +53,14 @@ namespace pyAMReX
5853 class T ,
5954 int NRows,
6055 int NCols,
61- Order ORDER = Order::F,
56+ amrex:: Order ORDER = amrex:: Order::F,
6257 int StartIndex = 0
6358 >
6459 py::dict
65- array_interface (SmallMatrix<T, NRows, NCols, ORDER, StartIndex> const & m)
60+ array_interface (amrex:: SmallMatrix<T, NRows, NCols, ORDER, StartIndex> const & m)
6661 {
62+ using namespace amrex ;
63+
6764 auto d = py::dict ();
6865 // provide C index order for shape and strides
6966 auto shape = m.ordering == Order::F ? py::make_tuple (
@@ -110,27 +107,26 @@ namespace pyAMReX
110107 return d;
111108 }
112109
113- template <
114- class T ,
115- int NRows,
116- int NCols,
117- Order ORDER = Order::F,
118- int StartIndex = 0
119- >
120- void make_SmallMatrix (py::module &m, std::string typestr)
110+ template <class SM >
111+ py::class_<SM>
112+ make_SmallMatrix_or_Vector (py::module &m, std::string typestr)
121113 {
122114 using namespace amrex ;
123115
116+ using T = typename SM::value_type;
124117 using T_no_cv = std::remove_cv_t <T>;
118+ static constexpr int row_size = SM::row_size;
119+ static constexpr int column_size = SM::column_size;
120+ static constexpr Order ordering = SM::ordering;
121+ static constexpr int starting_index = SM::starting_index;
125122
126123 // dispatch simpler via: py::format_descriptor<T>::format() naming
127124 // but note the _const suffix that might be needed
128125 auto const sm_name = std::string (" SmallMatrix_" )
129- .append (std::to_string (NRows )).append (" x" ).append (std::to_string (NCols ))
130- .append (" _" ).append (ORDER == Order::F ? " F" : " C" )
131- .append (" _SI" ).append (std::to_string (StartIndex ))
126+ .append (std::to_string (row_size )).append (" x" ).append (std::to_string (column_size ))
127+ .append (" _" ).append (ordering == Order::F ? " F" : " C" )
128+ .append (" _SI" ).append (std::to_string (starting_index ))
132129 .append (" _" ).append (typestr);
133- using SM = SmallMatrix<T, NRows, NCols, ORDER, StartIndex>;
134130 py::class_< SM > py_sm (m, sm_name.c_str ());
135131 py_sm
136132 .def (" __repr__" ,
@@ -177,7 +173,7 @@ namespace pyAMReX
177173 py::format_descriptor<T_no_cv>::format () +
178174 " ' and received '" + buf.format + " '!" );
179175
180- // TODO: check that strides are either exact or None in buf
176+ // TODO: check that strides are either exact or None in buf (e.g., F or C contiguous)
181177 // TODO: transpose if SM order is not C?
182178
183179 auto sm = std::make_unique< SM >();
@@ -197,7 +193,7 @@ namespace pyAMReX
197193 // CPU: __array_interface__ v3
198194 // https://numpy.org/doc/stable/reference/arrays.interface.html
199195 .def_property_readonly (" __array_interface__" , [](SM const & sm) {
200- return pyAMReX:: array_interface (sm);
196+ return array_interface (sm);
201197 })
202198
203199 // CPU: __array_function__ interface (TODO)
@@ -210,8 +206,9 @@ namespace pyAMReX
210206
211207 // Nvidia GPUs: __cuda_array_interface__ v3
212208 // https://numba.readthedocs.io/en/latest/cuda/cuda_array_interface.html
213- .def_property_readonly (" __cuda_array_interface__" , [](SM const & sm) {
214- auto d = pyAMReX::array_interface (sm);
209+ .def_property_readonly (" __cuda_array_interface__" , [](SM const & sm)
210+ {
211+ auto d = array_interface (sm);
215212
216213 // data:
217214 // Because the user of the interface may or may not be in the same context, the most common case is to use cuPointerGetAttribute with CU_POINTER_ATTRIBUTE_DEVICE_POINTER in the CUDA driver API (or the equivalent CUDA Runtime API) to retrieve a device pointer that is usable in the currently active context.
@@ -239,6 +236,21 @@ namespace pyAMReX
239236 // https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
240237 // https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol
241238
239+ ;
240+
241+ return py_sm;
242+ }
243+
244+ template <class SM >
245+ void add_matrix_methods (py::class_<SM> & py_sm)
246+ {
247+ using T = typename SM::value_type;
248+ using T_no_cv = std::remove_cv_t <T>;
249+ static constexpr int row_size = SM::row_size;
250+ static constexpr int column_size = SM::column_size;
251+ static constexpr int starting_index = SM::starting_index;
252+
253+ py_sm
242254 .def (" dot" , &SM::dot)
243255 .def (" prod" , &SM::product) // NumPy name
244256 .def (" set_val" , &SM::setVal)
@@ -254,66 +266,102 @@ namespace pyAMReX
254266
255267 // getter
256268 .def (" __getitem__" , [](SM & sm, std::array<int , 2 > const & key){
257- if (key[0 ] < SM:: starting_index || key[0 ] >= SM:: row_size + SM:: starting_index ||
258- key[1 ] < SM:: starting_index || key[1 ] >= SM:: column_size + SM:: starting_index)
269+ if (key[0 ] < starting_index || key[0 ] >= row_size + starting_index ||
270+ key[1 ] < starting_index || key[1 ] >= column_size + starting_index)
259271 throw std::runtime_error (
260272 " Index out of bounds: [" +
261273 std::to_string (key[0 ]) + " , " +
262274 std::to_string (key[1 ]) + " ]" );
263275 return sm (key[0 ], key[1 ]);
264276 })
265277 ;
278+
266279 // setter
267280 if constexpr (is_not_const<T>())
268281 {
269282 py_sm
270- .def (" __setitem__" , [](SM & sm, std::array<int , 2 > const & key, T const value){
283+ .def (" __setitem__" , [](SM & sm, std::array<int , 2 > const & key, T_no_cv const value){
271284 if (key[0 ] < SM::starting_index || key[0 ] >= SM::row_size + SM::starting_index ||
272285 key[1 ] < SM::starting_index || key[1 ] >= SM::column_size + SM::starting_index)
286+ {
273287 throw std::runtime_error (
274288 " Index out of bounds: [" +
275289 std::to_string (key[0 ]) + " , " +
276290 std::to_string (key[1 ]) + " ]" );
291+ }
277292 sm (key[0 ], key[1 ]) = value;
278293 })
279294 ;
280295 }
281296
282297 // square matrix
283- if constexpr (NRows == NCols )
298+ if constexpr (row_size == column_size )
284299 {
285300 py_sm
286- .def_static (" identity" , [](){ return SM::Identity (); })
287- .def (" trace" , &SM:: trace)
288- .def (" transpose_in_place" , &SM:: transposeInPlace)
301+ .def_static (" identity" , []() { return SM::Identity (); })
302+ .def (" trace" , [](SM & sm){ return sm. trace (); } )
303+ .def (" transpose_in_place" , [](SM & sm){ return sm. transposeInPlace (); } )
289304 ;
290305 }
306+ }
291307
292- // vector
293- if constexpr (NRows == 1 || NCols == 1 )
294- {
295- py_sm
296- .def (" __getitem__" , [](SM & sm, int key){
297- if (key < SM::starting_index || key >= SM::column_size * SM::row_size + SM::starting_index)
298- throw std::runtime_error (" Index out of bounds: " + std::to_string (key));
299- return sm (key);
300- })
301- .def (" __setitem__" , [](SM & sm, int key, T const value){
302- if (key < SM::starting_index || key >= SM::column_size * SM::row_size + SM::starting_index)
303- throw std::runtime_error (" Index out of bounds: " + std::to_string (key));
304- sm (key) = value;
305- })
306- ;
307- } else {
308- using SV = SmallMatrix<T, NRows, 1 , Order::F, StartIndex>;
309- using SRV = SmallMatrix<T, 1 , NCols, Order::F, StartIndex>;
308+ template <class T_SV >
309+ void add_get_set_Vector (py::class_<T_SV> &py_v)
310+ {
311+ using self = T_SV;
312+ using T = typename T_SV::value_type;
313+ using T_no_cv = std::remove_cv_t <T>;
310314
311- // operators for matrix-matrix & matrix-vector
312- py_sm
313- .def (py::self * py::self)
314- .def (py::self * SV ())
315- .def (SRV () * py::self)
316- ;
317- }
315+ py_v
316+ .def (" __getitem__" , [](self & sm, int key){
317+ if (key < self::starting_index || key >= self::column_size * self::row_size + self::starting_index)
318+ throw std::runtime_error (" Index out of bounds: " + std::to_string (key));
319+ return sm (key);
320+ })
321+ .def (" __setitem__" , [](self & sm, int key, T_no_cv const value){
322+ if (key < self::starting_index || key >= self::column_size * self::row_size + self::starting_index)
323+ throw std::runtime_error (" Index out of bounds: " + std::to_string (key));
324+ sm (key) = value;
325+ })
326+ ;
327+ }
328+ }
329+
330+ namespace pyAMReX
331+ {
332+ template <
333+ class T ,
334+ int NRows,
335+ int NCols,
336+ amrex::Order ORDER = amrex::Order::F,
337+ int StartIndex = 0
338+ >
339+ void make_SmallMatrix (py::module &m, std::string typestr)
340+ {
341+ using namespace amrex ;
342+
343+ using SM = SmallMatrix<T, NRows, NCols, ORDER, StartIndex>;
344+ using SV = SmallMatrix<T, NRows, 1 , Order::F, StartIndex>;
345+ using SRV = SmallMatrix<T, 1 , NCols, Order::F, StartIndex>;
346+
347+ py::class_<SM> py_sm = make_SmallMatrix_or_Vector<SM>(m, typestr);
348+ py::class_<SV> py_sv = make_SmallMatrix_or_Vector<SV>(m, typestr);
349+ py::class_<SRV> py_srv = make_SmallMatrix_or_Vector<SRV>(m, typestr);
350+
351+ // methods, getter, setter
352+ add_matrix_methods (py_sm);
353+ add_matrix_methods (py_sv);
354+ add_matrix_methods (py_srv);
355+
356+ // vector setter/getter
357+ add_get_set_Vector (py_sv);
358+ add_get_set_Vector (py_srv);
359+
360+ // operators for matrix-matrix & matrix-vector
361+ py_sm
362+ .def (py::self * py::self)
363+ .def (py::self * SV ())
364+ .def (SRV () * py::self)
365+ ;
318366 }
319367}
0 commit comments