@@ -669,26 +669,27 @@ array concatenate(
669669 int axis,
670670 StreamOrDevice s /* = {} */ ) {
671671 if (arrays.size () == 0 ) {
672- throw std::invalid_argument (" No arrays provided for concatenation" );
672+ throw std::invalid_argument (
673+ " [concatenate] No arrays provided for concatenation" );
673674 }
674675
675676 // Normalize the given axis
676677 auto ax = axis < 0 ? axis + arrays[0 ].ndim () : axis;
677678 if (ax < 0 || ax >= arrays[0 ].ndim ()) {
678679 std::ostringstream msg;
679- msg << " Invalid axis (" << axis << " ) passed to concatenate"
680+ msg << " [concatenate] Invalid axis (" << axis << " ) passed to concatenate"
680681 << " for array with shape " << arrays[0 ].shape () << " ." ;
681682 throw std::invalid_argument (msg.str ());
682683 }
683684
684685 auto throw_invalid_shapes = [&]() {
685686 std::ostringstream msg;
686- msg << " All the input array dimensions must match exactly except "
687- << " for the concatenation axis. However, the provided shapes are " ;
687+ msg << " [concatenate] All the input array dimensions must match exactly "
688+ << " except for the concatenation axis. However, the provided shapes are " ;
688689 for (auto & a : arrays) {
689690 msg << a.shape () << " , " ;
690691 }
691- msg << " and the concatenation axis is " << axis;
692+ msg << " and the concatenation axis is " << axis << " . " ;
692693 throw std::invalid_argument (msg.str ());
693694 };
694695
@@ -697,6 +698,13 @@ array concatenate(
697698 // Make the output shape and validate that all arrays have the same shape
698699 // except for the concatenation axis.
699700 for (auto & a : arrays) {
701+ if (a.ndim () != shape.size ()) {
702+ std::ostringstream msg;
703+ msg << " [concatenate] All the input arrays must have the same number of "
704+ << " dimensions. However, got arrays with dimensions " << shape.size ()
705+ << " and " << a.ndim () << " ." ;
706+ throw std::invalid_argument (msg.str ());
707+ }
700708 for (int i = 0 ; i < a.ndim (); i++) {
701709 if (i == ax) {
702710 continue ;
0 commit comments