@@ -17,8 +17,7 @@ namespace {
1717
1818std::pair<std::vector<int >, std::vector<int >> compute_reduce_shape (
1919 const std::vector<int >& axes,
20- const std::vector<int >& shape,
21- bool keepdims) {
20+ const std::vector<int >& shape) {
2221 std::set<int > axes_set;
2322 auto ndim = shape.size ();
2423 for (auto ax : axes) {
@@ -38,7 +37,7 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
3837 for (int i = 0 ; i < ndim; ++i) {
3938 if (axes_set.count (i) == 0 ) {
4039 out_shape.push_back (shape[i]);
41- } else if (keepdims) {
40+ } else {
4241 out_shape.push_back (1 );
4342 }
4443 }
@@ -1217,13 +1216,16 @@ array all(
12171216 if (axes.empty ()) {
12181217 return astype (a, bool_, s);
12191218 }
1220- auto [out_shape, sorted_axes] =
1221- compute_reduce_shape (axes, a.shape (), keepdims);
1222- return array (
1219+ auto [out_shape, sorted_axes] = compute_reduce_shape (axes, a.shape ());
1220+ auto out = array (
12231221 out_shape,
12241222 bool_,
12251223 std::make_unique<Reduce>(to_stream (s), Reduce::And, sorted_axes),
12261224 {a});
1225+ if (!keepdims) {
1226+ out = squeeze (out, sorted_axes, s);
1227+ }
1228+ return out;
12271229}
12281230
12291231array all (
@@ -1248,13 +1250,16 @@ array any(
12481250 if (axes.empty ()) {
12491251 return astype (a, bool_, s);
12501252 }
1251- auto [out_shape, sorted_axes] =
1252- compute_reduce_shape (axes, a.shape (), keepdims);
1253- return array (
1253+ auto [out_shape, sorted_axes] = compute_reduce_shape (axes, a.shape ());
1254+ auto out = array (
12541255 out_shape,
12551256 bool_,
12561257 std::make_unique<Reduce>(to_stream (s), Reduce::Or, sorted_axes),
12571258 {a});
1259+ if (!keepdims) {
1260+ out = squeeze (out, sorted_axes, s);
1261+ }
1262+ return out;
12581263}
12591264
12601265array any (
@@ -1279,14 +1284,17 @@ array sum(
12791284 if (axes.empty ()) {
12801285 return a;
12811286 }
1282- auto [out_shape, sorted_axes] =
1283- compute_reduce_shape (axes, a.shape (), keepdims);
1287+ auto [out_shape, sorted_axes] = compute_reduce_shape (axes, a.shape ());
12841288 auto out_type = a.dtype () == bool_ ? int32 : a.dtype ();
1285- return array (
1289+ auto out = array (
12861290 out_shape,
12871291 out_type,
12881292 std::make_unique<Reduce>(to_stream (s), Reduce::Sum, sorted_axes),
12891293 {a});
1294+ if (!keepdims) {
1295+ out = squeeze (out, sorted_axes, s);
1296+ }
1297+ return out;
12901298}
12911299
12921300array sum (
@@ -1374,13 +1382,16 @@ array prod(
13741382 if (axes.empty ()) {
13751383 return a;
13761384 }
1377- auto [out_shape, sorted_axes] =
1378- compute_reduce_shape (axes, a.shape (), keepdims);
1379- return array (
1385+ auto [out_shape, sorted_axes] = compute_reduce_shape (axes, a.shape ());
1386+ auto out = array (
13801387 out_shape,
13811388 a.dtype (),
13821389 std::make_unique<Reduce>(to_stream (s), Reduce::Prod, sorted_axes),
13831390 {a});
1391+ if (!keepdims) {
1392+ out = squeeze (out, sorted_axes, s);
1393+ }
1394+ return out;
13841395}
13851396
13861397array prod (
@@ -1408,13 +1419,16 @@ array max(
14081419 if (axes.empty ()) {
14091420 return a;
14101421 }
1411- auto [out_shape, sorted_axes] =
1412- compute_reduce_shape (axes, a.shape (), keepdims);
1413- return array (
1422+ auto [out_shape, sorted_axes] = compute_reduce_shape (axes, a.shape ());
1423+ auto out = array (
14141424 out_shape,
14151425 a.dtype (),
14161426 std::make_unique<Reduce>(to_stream (s), Reduce::Max, sorted_axes),
14171427 {a});
1428+ if (!keepdims) {
1429+ out = squeeze (out, sorted_axes, s);
1430+ }
1431+ return out;
14181432}
14191433
14201434array max (
@@ -1442,13 +1456,16 @@ array min(
14421456 if (axes.empty ()) {
14431457 return a;
14441458 }
1445- auto [out_shape, sorted_axes] =
1446- compute_reduce_shape (axes, a.shape (), keepdims);
1447- return array (
1459+ auto [out_shape, sorted_axes] = compute_reduce_shape (axes, a.shape ());
1460+ auto out = array (
14481461 out_shape,
14491462 a.dtype (),
14501463 std::make_unique<Reduce>(to_stream (s), Reduce::Min, sorted_axes),
14511464 {a});
1465+ if (!keepdims) {
1466+ out = squeeze (out, sorted_axes, s);
1467+ }
1468+ return out;
14521469}
14531470
14541471array min (
@@ -1477,14 +1494,17 @@ array argmin(
14771494 throw std::invalid_argument (
14781495 " [argmin] Cannot argmin reduce zero size array." );
14791496 }
1480- auto [out_shape, sorted_axes] =
1481- compute_reduce_shape ({axis}, a.shape (), keepdims);
1482- return array (
1497+ auto [out_shape, sorted_axes] = compute_reduce_shape ({axis}, a.shape ());
1498+ auto out = array (
14831499 out_shape,
14841500 uint32,
14851501 std::make_unique<ArgReduce>(
14861502 to_stream (s), ArgReduce::ArgMin, sorted_axes[0 ]),
14871503 {a});
1504+ if (!keepdims) {
1505+ out = squeeze (out, sorted_axes, s);
1506+ }
1507+ return out;
14881508}
14891509
14901510array argmax (const array& a, bool keepdims, StreamOrDevice s /* = {} */ ) {
@@ -1505,14 +1525,17 @@ array argmax(
15051525 throw std::invalid_argument (
15061526 " [argmax] Cannot argmax reduce zero size array." );
15071527 }
1508- auto [out_shape, sorted_axes] =
1509- compute_reduce_shape ({axis}, a.shape (), keepdims);
1510- return array (
1528+ auto [out_shape, sorted_axes] = compute_reduce_shape ({axis}, a.shape ());
1529+ auto out = array (
15111530 out_shape,
15121531 uint32,
15131532 std::make_unique<ArgReduce>(
15141533 to_stream (s), ArgReduce::ArgMax, sorted_axes[0 ]),
15151534 {a});
1535+ if (!keepdims) {
1536+ out = squeeze (out, sorted_axes, s);
1537+ }
1538+ return out;
15161539}
15171540
15181541/* * Returns a sorted copy of the flattened array. */
0 commit comments