Skip to content

Commit a8b25ef

Browse files
authored
implement axes keyword argument in transpose (#735)
* implement axis keyword of transpose * fix keyword typo * add 2D transpose tests * update documentation * clean up expected values of test scripts * update missed test script
1 parent 11eefea commit a8b25ef

18 files changed

+512
-233
lines changed

code/ndarray.c

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,28 +1874,110 @@ mp_obj_t ndarray_unary_op(mp_unary_op_t op, mp_obj_t self_in) {
18741874
#endif /* NDARRAY_HAS_UNARY_OPS */
18751875

18761876
#if NDARRAY_HAS_TRANSPOSE
1877-
mp_obj_t ndarray_transpose(mp_obj_t self_in) {
1878-
#if ULAB_MAX_DIMS == 1
1879-
return self_in;
1880-
#endif
1881-
// TODO: check, what happens to the offset here, if we have a view
1877+
// We have to implement the T property separately, for the property can't take keyword arguments
1878+
1879+
#if ULAB_MAX_DIMS == 1
1880+
// isolating the one-dimensional case saves space, because the transpose is sort of meaningless
1881+
mp_obj_t ndarray_T(mp_obj_t self_in) {
1882+
return self_in;
1883+
}
1884+
#else
1885+
mp_obj_t ndarray_T(mp_obj_t self_in) {
1886+
// without argument, simply return a view with axes in reverse order
18821887
ndarray_obj_t *self = MP_OBJ_TO_PTR(self_in);
18831888
if(self->ndim == 1) {
18841889
return self_in;
18851890
}
18861891
size_t *shape = m_new(size_t, self->ndim);
18871892
int32_t *strides = m_new(int32_t, self->ndim);
1888-
for(uint8_t i=0; i < self->ndim; i++) {
1893+
for(uint8_t i = 0; i < self->ndim; i++) {
18891894
shape[ULAB_MAX_DIMS - 1 - i] = self->shape[ULAB_MAX_DIMS - self->ndim + i];
18901895
strides[ULAB_MAX_DIMS - 1 - i] = self->strides[ULAB_MAX_DIMS - self->ndim + i];
18911896
}
1892-
// TODO: I am not sure ndarray_new_view is OK here...
1893-
// should be deep copy...
18941897
ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0);
18951898
return MP_OBJ_FROM_PTR(ndarray);
18961899
}
1900+
#endif /* ULAB_MAX_DIMS == 1 */
1901+
1902+
MP_DEFINE_CONST_FUN_OBJ_1(ndarray_T_obj, ndarray_T);
1903+
1904+
# if ULAB_MAX_DIMS == 1
1905+
// again, nothing to do, if there is only one dimension, though, the arguments might still have to be parsed...
1906+
mp_obj_t ndarray_transpose(mp_obj_t self_in) {
1907+
return self_in;
1908+
}
18971909

18981910
MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose);
1911+
#else
1912+
mp_obj_t ndarray_transpose(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
1913+
static const mp_arg_t allowed_args[] = {
1914+
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
1915+
{ MP_QSTR_axes, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
1916+
};
1917+
1918+
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
1919+
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
1920+
1921+
ndarray_obj_t *self = MP_OBJ_TO_PTR(args[0].u_obj);
1922+
1923+
if(self->ndim == 1) {
1924+
return args[0].u_obj;
1925+
}
1926+
1927+
size_t *shape = m_new(size_t, self->ndim);
1928+
int32_t *strides = m_new(int32_t, self->ndim);
1929+
uint8_t *order = m_new(uint8_t, self->ndim);
1930+
1931+
mp_obj_t axes = args[1].u_obj;
1932+
1933+
if(axes == mp_const_none) {
1934+
// simply swap the order of the axes
1935+
for(uint8_t i = 0; i < self->ndim; i++) {
1936+
order[i] = self->ndim - 1 - i;
1937+
}
1938+
} else {
1939+
if(!mp_obj_is_type(axes, &mp_type_tuple)) {
1940+
mp_raise_TypeError(MP_ERROR_TEXT("keyword argument must be tuple of integers"));
1941+
}
1942+
// start with the straight array, and then swap only those specified in the argument
1943+
for(uint8_t i = 0; i < self->ndim; i++) {
1944+
order[i] = i;
1945+
}
1946+
1947+
mp_obj_tuple_t *axes_tuple = MP_OBJ_TO_PTR(axes);
1948+
1949+
if(axes_tuple->len > self->ndim) {
1950+
mp_raise_ValueError(MP_ERROR_TEXT("too many axes specified"));
1951+
}
1952+
1953+
for(uint8_t i = 0; i < axes_tuple->len; i++) {
1954+
int32_t ax = mp_obj_get_int(axes_tuple->items[i]);
1955+
if((ax >= self->ndim) || (ax < 0)) {
1956+
mp_raise_ValueError(MP_ERROR_TEXT("axis index out of bounds"));
1957+
} else {
1958+
order[i] = (uint8_t)ax;
1959+
// TODO: check that no two identical numbers appear in the tuple
1960+
for(uint8_t j = 0; j < i; j++) {
1961+
if(order[i] == order[j]) {
1962+
mp_raise_ValueError(MP_ERROR_TEXT("repeated indices"));
1963+
}
1964+
}
1965+
}
1966+
}
1967+
}
1968+
1969+
uint8_t axis_offset = ULAB_MAX_DIMS - self->ndim;
1970+
for(uint8_t i = 0; i < self->ndim; i++) {
1971+
shape[axis_offset + i] = self->shape[axis_offset + order[i]];
1972+
strides[axis_offset + i] = self->strides[axis_offset + order[i]];
1973+
}
1974+
1975+
ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0);
1976+
return MP_OBJ_FROM_PTR(ndarray);
1977+
}
1978+
1979+
MP_DEFINE_CONST_FUN_OBJ_KW(ndarray_transpose_obj, 1, ndarray_transpose);
1980+
#endif /* ULAB_MAX_DIMS == 1 */
18991981
#endif /* NDARRAY_HAS_TRANSPOSE */
19001982

19011983
#if ULAB_MAX_DIMS > 1

code/ndarray.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,16 @@ MP_DECLARE_CONST_FUN_OBJ_1(ndarray_tolist_obj);
265265
#endif
266266

267267
#if NDARRAY_HAS_TRANSPOSE
268+
mp_obj_t ndarray_T(mp_obj_t );
269+
MP_DECLARE_CONST_FUN_OBJ_1(ndarray_T_obj);
270+
#if ULAB_MAX_DIMS == 1
268271
mp_obj_t ndarray_transpose(mp_obj_t );
269272
MP_DECLARE_CONST_FUN_OBJ_1(ndarray_transpose_obj);
270-
#endif
273+
#else
274+
mp_obj_t ndarray_transpose(size_t , const mp_obj_t *, mp_map_t *);
275+
MP_DECLARE_CONST_FUN_OBJ_KW(ndarray_transpose_obj);
276+
#endif /* ULAB_MAX_DIMS == 1 */
277+
#endif /* NDARRAY_HAS_TRANSPOSE */
271278

272279
#if ULAB_NUMPY_HAS_NDINFO
273280
mp_obj_t ndarray_info(mp_obj_t );

code/ndarray_properties.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void ndarray_properties_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) {
6464
#endif
6565
#if NDARRAY_HAS_TRANSPOSE
6666
case MP_QSTR_T:
67-
dest[0] = ndarray_transpose(self_in);
67+
dest[0] = ndarray_T(self_in);
6868
break;
6969
#endif
7070
#if ULAB_SUPPORTS_COMPLEX

code/ulab.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "user/user.h"
3434
#include "utils/utils.h"
3535

36-
#define ULAB_VERSION 6.10.0
36+
#define ULAB_VERSION 6.11.0
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939

docs/manual/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
author = 'Zoltán Vörös'
2828

2929
# The full version, including alpha/beta/rc tags
30-
release = '6.9.0'
30+
release = '6.11.0'
3131

3232

3333
# -- General configuration ---------------------------------------------------

docs/manual/source/ulab-ndarray.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,9 @@ dimensions is larger than 1.
15641564
15651565
15661566
1567+
The method also accepts the ``axes`` keyword argument, if permutation of
1568+
the returned axes is required.
1569+
15671570
The transpose of the array can also be gotten through the ``T``
15681571
property:
15691572

docs/ulab-convert.ipynb

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
},
1515
{
1616
"cell_type": "code",
17-
"execution_count": 2,
17+
"execution_count": 1,
1818
"metadata": {
1919
"ExecuteTime": {
2020
"end_time": "2022-02-09T06:27:15.118699Z",
@@ -61,7 +61,7 @@
6161
"author = 'Zoltán Vörös'\n",
6262
"\n",
6363
"# The full version, including alpha/beta/rc tags\n",
64-
"release = '6.9.0'\n",
64+
"release = '6.11.0'\n",
6565
"\n",
6666
"\n",
6767
"# -- General configuration ---------------------------------------------------\n",
@@ -217,7 +217,7 @@
217217
},
218218
{
219219
"cell_type": "code",
220-
"execution_count": 3,
220+
"execution_count": 2,
221221
"metadata": {
222222
"ExecuteTime": {
223223
"end_time": "2022-02-09T06:27:21.647179Z",
@@ -258,7 +258,7 @@
258258
},
259259
{
260260
"cell_type": "code",
261-
"execution_count": 4,
261+
"execution_count": null,
262262
"metadata": {
263263
"ExecuteTime": {
264264
"end_time": "2022-02-09T06:27:42.024028Z",
@@ -270,34 +270,6 @@
270270
"name": "stderr",
271271
"output_type": "stream",
272272
"text": [
273-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
274-
" _, nbc = validator.normalize(nbc)\n",
275-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
276-
" _, nbc = validator.normalize(nbc)\n",
277-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
278-
" _, nbc = validator.normalize(nbc)\n",
279-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
280-
" _, nbc = validator.normalize(nbc)\n",
281-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
282-
" _, nbc = validator.normalize(nbc)\n",
283-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
284-
" _, nbc = validator.normalize(nbc)\n",
285-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
286-
" _, nbc = validator.normalize(nbc)\n",
287-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
288-
" _, nbc = validator.normalize(nbc)\n",
289-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
290-
" _, nbc = validator.normalize(nbc)\n",
291-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
292-
" _, nbc = validator.normalize(nbc)\n",
293-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
294-
" _, nbc = validator.normalize(nbc)\n",
295-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
296-
" _, nbc = validator.normalize(nbc)\n",
297-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
298-
" _, nbc = validator.normalize(nbc)\n",
299-
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
300-
" _, nbc = validator.normalize(nbc)\n",
301273
"/home/v923z/anaconda3/lib/python3.11/site-packages/nbconvert/exporters/exporter.py:349: MissingIDFieldWarning: Code cell is missing an id field, this will become a hard error in future nbformat versions. You may want to use `normalize()` on your notebooks before validations (available since nbformat 5.1.4). Previous versions of nbformat are fixing this issue transparently, and will stop doing so in the future.\n",
302274
" _, nbc = validator.normalize(nbc)\n"
303275
]

docs/ulab-ndarray.ipynb

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2237,6 +2237,13 @@
22372237
"print('shape of a:', a.shape)"
22382238
]
22392239
},
2240+
{
2241+
"cell_type": "markdown",
2242+
"metadata": {},
2243+
"source": [
2244+
"The method also accepts the `axes` keyword argument, if permutation of the returned axes is required."
2245+
]
2246+
},
22402247
{
22412248
"cell_type": "markdown",
22422249
"metadata": {},
@@ -3731,11 +3738,8 @@
37313738
}
37323739
],
37333740
"metadata": {
3734-
"interpreter": {
3735-
"hash": "ce9a02f9f7db620716422019cafa4bc1786ca85daa298b819f6da075e7993842"
3736-
},
37373741
"kernelspec": {
3738-
"display_name": "Python 3",
3742+
"display_name": "base",
37393743
"language": "python",
37403744
"name": "python3"
37413745
},
@@ -3749,7 +3753,7 @@
37493753
"name": "python",
37503754
"nbconvert_exporter": "python",
37513755
"pygments_lexer": "ipython3",
3752-
"version": "3.8.5"
3756+
"version": "3.11.7"
37533757
},
37543758
"toc": {
37553759
"base_numbering": 1,

0 commit comments

Comments
 (0)