@@ -504,7 +504,20 @@ array squeeze(
504504 shape.push_back (a.shape (i));
505505 }
506506 }
507- return reshape (a, shape, s);
507+ return reshape (a, std::move (shape), s);
508+ }
509+
510+ array squeeze (const array& a, int axis, StreamOrDevice s /* = {} */ ) {
511+ int ax = axis < 0 ? axis + a.ndim () : axis;
512+ if (ax < 0 || ax >= a.ndim ()) {
513+ std::ostringstream msg;
514+ msg << " [squeeze] Invalid axis " << axis << " for array with " << a.ndim ()
515+ << " dimensions." ;
516+ throw std::invalid_argument (msg.str ());
517+ }
518+ auto shape = a.shape ();
519+ shape.erase (shape.begin () + ax);
520+ return reshape (a, std::move (shape), s);
508521}
509522
510523array squeeze (const array& a, StreamOrDevice s /* = {} */ ) {
@@ -657,10 +670,15 @@ array slice(
657670
658671array slice (
659672 const array& a,
660- const std::vector<int >& start,
661- const std::vector<int >& stop,
673+ std::vector<int > start,
674+ std::vector<int > stop,
662675 StreamOrDevice s /* = {} */ ) {
663- return slice (a, start, stop, std::vector<int >(a.ndim (), 1 ), to_stream (s));
676+ return slice (
677+ a,
678+ std::move (start),
679+ std::move (stop),
680+ std::vector<int >(a.ndim (), 1 ),
681+ to_stream (s));
664682}
665683
666684/* * Update a slice from the source array */
@@ -2715,13 +2733,43 @@ array take(
27152733 // Squeeze the axis we take over
27162734 std::vector<int > out_shape = out.shape ();
27172735 out_shape.erase (out_shape.begin () + indices.ndim () + axis);
2718- return reshape (out, out_shape, s);
2736+ return reshape (out, std::move ( out_shape) , s);
27192737}
27202738
27212739array take (const array& a, const array& indices, StreamOrDevice s /* = {} */ ) {
27222740 return take (reshape (a, {-1 }, s), indices, 0 , s);
27232741}
27242742
2743+ array take (const array& a, int index, int axis, StreamOrDevice s /* = {} */ ) {
2744+ // Check for valid axis
2745+ if (axis + static_cast <int >(a.ndim ()) < 0 ||
2746+ axis >= static_cast <int >(a.ndim ())) {
2747+ std::ostringstream msg;
2748+ msg << " [take] Received invalid axis " << axis << " for array with "
2749+ << a.ndim () << " dimensions." ;
2750+ throw std::invalid_argument (msg.str ());
2751+ }
2752+
2753+ // Check for valid take
2754+ if (a.size () == 0 ) {
2755+ throw std::invalid_argument (
2756+ " [take] Cannot do a non-empty take from an array with zero elements." );
2757+ }
2758+
2759+ // Handle negative axis
2760+ axis = axis < 0 ? a.ndim () + axis : axis;
2761+
2762+ std::vector<int > starts (a.ndim (), 0 );
2763+ std::vector<int > stops = a.shape ();
2764+ starts[axis] = index;
2765+ stops[axis] = index + 1 ;
2766+ return squeeze (slice (a, std::move (starts), std::move (stops), s), axis, s);
2767+ }
2768+
2769+ array take (const array& a, int index, StreamOrDevice s /* = {} */ ) {
2770+ return take (reshape (a, {-1 }, s), index, 0 , s);
2771+ }
2772+
27252773array take_along_axis (
27262774 const array& a,
27272775 const array& indices,
@@ -2764,7 +2812,7 @@ array take_along_axis(
27642812 // Squeeze out the slice shape
27652813 std::vector<int > out_shape (
27662814 out.shape ().begin (), out.shape ().begin () + a.ndim ());
2767- return reshape (out, out_shape, s);
2815+ return reshape (out, std::move ( out_shape) , s);
27682816}
27692817
27702818array put_along_axis (
0 commit comments