@@ -2084,24 +2084,51 @@ MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose);
20842084#if NDARRAY_HAS_RESHAPE
20852085mp_obj_t ndarray_reshape_core (mp_obj_t oin , mp_obj_t _shape , bool inplace ) {
20862086 ndarray_obj_t * source = MP_OBJ_TO_PTR (oin );
2087- if (!mp_obj_is_type (_shape , & mp_type_tuple )) {
2088- mp_raise_TypeError (translate ("shape must be a tuple" ));
2087+ if (!mp_obj_is_type (_shape , & mp_type_tuple ) && !mp_obj_is_int (_shape )) {
2088+ mp_raise_TypeError (translate ("shape must be integer or tuple of integers" ));
2089+ }
2090+
2091+ mp_obj_tuple_t * shape ;
2092+
2093+ if (mp_obj_is_int (_shape )) {
2094+ mp_obj_t * items = m_new (mp_obj_t , 1 );
2095+ items [0 ] = _shape ;
2096+ shape = mp_obj_new_tuple (1 , items );
2097+ } else {
2098+ shape = MP_OBJ_TO_PTR (_shape );
20892099 }
20902100
2091- mp_obj_tuple_t * shape = MP_OBJ_TO_PTR (_shape );
20922101 if (shape -> len > ULAB_MAX_DIMS ) {
20932102 mp_raise_ValueError (translate ("maximum number of dimensions is " MP_STRINGIFY (ULAB_MAX_DIMS )));
20942103 }
2095- size_t * new_shape = m_new0 (size_t , ULAB_MAX_DIMS );
20962104
20972105 size_t new_length = 1 ;
2098- for (uint8_t i = 0 ; i < shape -> len ; i ++ ) {
2099- new_shape [ULAB_MAX_DIMS - i - 1 ] = mp_obj_get_int (shape -> items [shape -> len - i - 1 ]);
2100- new_length *= new_shape [ULAB_MAX_DIMS - i - 1 ];
2106+ size_t * new_shape = m_new0 (size_t , ULAB_MAX_DIMS );
2107+ uint8_t unknown_dim = 0 ;
2108+ uint8_t unknown_index = 0 ;
2109+
2110+ for (uint8_t i = 0 ; i < shape -> len ; i ++ ) {
2111+ int32_t ax_len = mp_obj_get_int (shape -> items [shape -> len - i - 1 ]);
2112+ if (ax_len >= 0 ) {
2113+ new_shape [ULAB_MAX_DIMS - i - 1 ] = (size_t )ax_len ;
2114+ new_length *= new_shape [ULAB_MAX_DIMS - i - 1 ];
2115+ } else {
2116+ unknown_dim ++ ;
2117+ unknown_index = ULAB_MAX_DIMS - i - 1 ;
2118+ }
2119+ }
2120+
2121+ if (unknown_dim > 1 ) {
2122+ mp_raise_ValueError (translate ("can only specify one unknown dimension" ));
2123+ } else if (unknown_dim == 1 ) {
2124+ new_shape [unknown_index ] = source -> len / new_length ;
2125+ new_length = source -> len ;
21012126 }
2127+
21022128 if (source -> len != new_length ) {
2103- mp_raise_ValueError (translate ("input and output shapes are not compatible " ));
2129+ mp_raise_ValueError (translate ("cannot reshape array " ));
21042130 }
2131+
21052132 ndarray_obj_t * ndarray ;
21062133 if (ndarray_is_dense (source )) {
21072134 int32_t * new_strides = strides_from_shape (new_shape , source -> dtype );
@@ -2118,7 +2145,11 @@ mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) {
21182145 if (inplace ) {
21192146 mp_raise_ValueError (translate ("cannot assign new shape" ));
21202147 }
2121- ndarray = ndarray_new_ndarray_from_tuple (shape , source -> dtype );
2148+ if (mp_obj_is_type (_shape , & mp_type_tuple )) {
2149+ ndarray = ndarray_new_ndarray_from_tuple (shape , source -> dtype );
2150+ } else {
2151+ ndarray = ndarray_new_linear_array (source -> len , source -> dtype );
2152+ }
21222153 ndarray_copy_array (source , ndarray , 0 );
21232154 }
21242155 return MP_OBJ_FROM_PTR (ndarray );
0 commit comments