@@ -274,7 +274,7 @@ static mp_obj_t numerical_sum_mean_std_iterable(mp_obj_t oin, uint8_t optype, si
274274 }
275275}
276276
277- static mp_obj_t numerical_sum_mean_std_ndarray (ndarray_obj_t * ndarray , mp_obj_t axis , uint8_t optype , size_t ddof ) {
277+ static mp_obj_t numerical_sum_mean_std_ndarray (ndarray_obj_t * ndarray , mp_obj_t axis , mp_obj_t keepdims , uint8_t optype , size_t ddof ) {
278278 COMPLEX_DTYPE_NOT_IMPLEMENTED (ndarray -> dtype )
279279 uint8_t * array = (uint8_t * )ndarray -> array ;
280280 shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
@@ -372,15 +372,15 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
372372 mp_float_t norm = (mp_float_t )_shape_strides .shape [0 ];
373373 // re-wind the array here
374374 farray = (mp_float_t * )results -> array ;
375- for (size_t i = 0 ; i < results -> len ; i ++ ) {
375+ for (size_t i = 0 ; i < results -> len ; i ++ ) {
376376 * farray ++ *= norm ;
377377 }
378378 }
379379 } else {
380380 bool isStd = optype == NUMERICAL_STD ? 1 : 0 ;
381381 results = ndarray_new_dense_ndarray (_shape_strides .ndim , _shape_strides .shape , NDARRAY_FLOAT );
382382 farray = (mp_float_t * )results -> array ;
383- // we can return the 0 array here, if the degrees of freedom is larger than the length of the axis
383+ // we can return the 0 array here, if the degrees of freedom are larger than the length of the axis
384384 if ((optype == NUMERICAL_STD ) && (_shape_strides .shape [0 ] <= ddof )) {
385385 return MP_OBJ_FROM_PTR (results );
386386 }
@@ -397,11 +397,9 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
397397 RUN_MEAN_STD (mp_float_t , array , farray , _shape_strides , div , isStd );
398398 }
399399 }
400- if (results -> ndim == 0 ) { // return a scalar here
401- return mp_binary_get_val_array (results -> dtype , results -> array , 0 );
402- }
403- return MP_OBJ_FROM_PTR (results );
400+ return ulab_tools_restore_dims (ndarray , results , keepdims , _shape_strides );
404401 }
402+ // we should never get to this point
405403 return mp_const_none ;
406404}
407405#endif
@@ -441,7 +439,7 @@ static mp_obj_t numerical_argmin_argmax_iterable(mp_obj_t oin, uint8_t optype) {
441439 }
442440}
443441
444- static mp_obj_t numerical_argmin_argmax_ndarray (ndarray_obj_t * ndarray , mp_obj_t axis , uint8_t optype ) {
442+ static mp_obj_t numerical_argmin_argmax_ndarray (ndarray_obj_t * ndarray , mp_obj_t keepdims , mp_obj_t axis , uint8_t optype ) {
445443 // TODO: treat the flattened array
446444 if (ndarray -> len == 0 ) {
447445 mp_raise_ValueError (MP_ERROR_TEXT ("attempt to get (arg)min/(arg)max of empty sequence" ));
@@ -521,7 +519,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
521519 int32_t * strides = m_new0 (int32_t , ULAB_MAX_DIMS );
522520
523521 numerical_reduce_axes (ndarray , ax , shape , strides );
524- uint8_t index = ULAB_MAX_DIMS - ndarray -> ndim + ax ;
522+ shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
523+
524+ uint8_t index = _shape_strides .axis ;
525525
526526 ndarray_obj_t * results = NULL ;
527527
@@ -550,8 +550,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
550550 if (results -> len == 1 ) {
551551 return mp_binary_get_val_array (results -> dtype , results -> array , 0 );
552552 }
553- return MP_OBJ_FROM_PTR ( results );
553+ return ulab_tools_restore_dims ( ndarray , results , keepdims , _shape_strides );
554554 }
555+ // we should never get to this point
555556 return mp_const_none ;
556557}
557558#endif
@@ -560,13 +561,16 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
560561 static const mp_arg_t allowed_args [] = {
561562 { MP_QSTR_ , MP_ARG_REQUIRED | MP_ARG_OBJ , { .u_rom_obj = MP_ROM_NONE } } ,
562563 { MP_QSTR_axis , MP_ARG_OBJ , { .u_rom_obj = MP_ROM_NONE } },
564+ { MP_QSTR_keepdims , MP_ARG_OBJ , { .u_rom_obj = MP_ROM_FALSE } },
563565 };
564566
565567 mp_arg_val_t args [MP_ARRAY_SIZE (allowed_args )];
566568 mp_arg_parse_all (n_args , pos_args , kw_args , MP_ARRAY_SIZE (allowed_args ), allowed_args , args );
567569
568570 mp_obj_t oin = args [0 ].u_obj ;
569571 mp_obj_t axis = args [1 ].u_obj ;
572+ mp_obj_t keepdims = args [2 ].u_obj ;
573+
570574 if ((axis != mp_const_none ) && (!mp_obj_is_int (axis ))) {
571575 mp_raise_TypeError (MP_ERROR_TEXT ("axis must be None, or an integer" ));
572576 }
@@ -598,11 +602,11 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
598602 case NUMERICAL_ARGMIN :
599603 case NUMERICAL_ARGMAX :
600604 COMPLEX_DTYPE_NOT_IMPLEMENTED (ndarray -> dtype )
601- return numerical_argmin_argmax_ndarray (ndarray , axis , optype );
605+ return numerical_argmin_argmax_ndarray (ndarray , keepdims , axis , optype );
602606 case NUMERICAL_SUM :
603607 case NUMERICAL_MEAN :
604608 COMPLEX_DTYPE_NOT_IMPLEMENTED (ndarray -> dtype )
605- return numerical_sum_mean_std_ndarray (ndarray , axis , optype , 0 );
609+ return numerical_sum_mean_std_ndarray (ndarray , axis , keepdims , optype , 0 );
606610 default :
607611 mp_raise_NotImplementedError (MP_ERROR_TEXT ("operation is not implemented on ndarrays" ));
608612 }
@@ -1385,6 +1389,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
13851389 { MP_QSTR_ , MP_ARG_REQUIRED | MP_ARG_OBJ , {.u_rom_obj = MP_ROM_NONE } } ,
13861390 { MP_QSTR_axis , MP_ARG_OBJ , {.u_rom_obj = MP_ROM_NONE } },
13871391 { MP_QSTR_ddof , MP_ARG_KW_ONLY | MP_ARG_INT , {.u_int = 0 } },
1392+ { MP_QSTR_keepdims , MP_ARG_OBJ , { .u_rom_obj = MP_ROM_FALSE } },
13881393 };
13891394
13901395 mp_arg_val_t args [MP_ARRAY_SIZE (allowed_args )];
@@ -1393,6 +1398,8 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
13931398 mp_obj_t oin = args [0 ].u_obj ;
13941399 mp_obj_t axis = args [1 ].u_obj ;
13951400 size_t ddof = args [2 ].u_int ;
1401+ mp_obj_t keepdims = args [2 ].u_obj ;
1402+
13961403 if ((axis != mp_const_none ) && (mp_obj_get_int (axis ) != 0 ) && (mp_obj_get_int (axis ) != 1 )) {
13971404 // this seems to pass with False, and True...
13981405 mp_raise_ValueError (MP_ERROR_TEXT ("axis must be None, or an integer" ));
@@ -1401,7 +1408,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
14011408 return numerical_sum_mean_std_iterable (oin , NUMERICAL_STD , ddof );
14021409 } else if (mp_obj_is_type (oin , & ulab_ndarray_type )) {
14031410 ndarray_obj_t * ndarray = MP_OBJ_TO_PTR (oin );
1404- return numerical_sum_mean_std_ndarray (ndarray , axis , NUMERICAL_STD , ddof );
1411+ return numerical_sum_mean_std_ndarray (ndarray , axis , keepdims , NUMERICAL_STD , ddof );
14051412 } else {
14061413 mp_raise_TypeError (MP_ERROR_TEXT ("input must be tuple, list, range, or ndarray" ));
14071414 }
0 commit comments