@@ -267,6 +267,11 @@ bool Arange::is_equivalent(const Primitive& other) const {
267267 step_ == a_other.step_ );
268268}
269269
270+ std::vector<Shape> Arange::output_shapes (const std::vector<array>&) {
271+ auto real_size = std::ceil ((stop_ - start_) / step_);
272+ return {{std::max (static_cast <int >(real_size), 0 )}};
273+ }
274+
270275std::vector<array> ArcCos::vjp (
271276 const std::vector<array>& primals,
272277 const std::vector<array>& cotangents,
@@ -534,11 +539,10 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
534539 return {{argsort (inputs[0 ], axis_ + axis_left, stream ())}, axes};
535540}
536541
537- std::vector<std::vector<int >> ArgReduce::output_shapes (
538- const std::vector<array>& inputs) {
542+ std::vector<Shape> ArgReduce::output_shapes (const std::vector<array>& inputs) {
539543 auto out_shape = inputs[0 ].shape ();
540544 out_shape[axis_] = 1 ;
541- return {out_shape};
545+ return {std::move ( out_shape) };
542546}
543547
544548bool ArgSort::is_equivalent (const Primitive& other) const {
@@ -787,6 +791,23 @@ std::pair<std::vector<array>, std::vector<int>> Eigh::vmap(
787791 return {outputs, std::vector<int >(outputs.size (), ax)};
788792}
789793
794+ std::vector<Shape> Eigh::output_shapes (const std::vector<array>& inputs) {
795+ auto shape = inputs[0 ].shape ();
796+ shape.pop_back (); // Remove last dimension for eigenvalues
797+ if (compute_eigenvectors_) {
798+ return {
799+ std::move (shape), inputs[0 ].shape ()}; // Eigenvalues and eigenvectors
800+ } else {
801+ return {std::move (shape)}; // Only eigenvalues
802+ }
803+ }
804+
805+ bool Eigh::is_equivalent (const Primitive& other) const {
806+ auto & e_other = static_cast <const Eigh&>(other);
807+ return uplo_ == e_other.uplo_ &&
808+ compute_eigenvectors_ == e_other.compute_eigenvectors_ ;
809+ }
810+
790811std::vector<array> Concatenate::vjp (
791812 const std::vector<array>& primals,
792813 const std::vector<array>& cotangents,
@@ -881,6 +902,15 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
881902 return axis_ == c_other.axis_ ;
882903}
883904
905+ std::vector<Shape> Concatenate::output_shapes (
906+ const std::vector<array>& inputs) {
907+ auto shape = inputs[0 ].shape ();
908+ for (int i = 1 ; i < inputs.size (); ++i) {
909+ shape[axis_] += inputs[i].shape (axis_);
910+ }
911+ return {std::move (shape)};
912+ }
913+
884914std::pair<std::vector<array>, std::vector<int >> Conjugate::vmap (
885915 const std::vector<array>& inputs,
886916 const std::vector<int >& axes) {
@@ -1811,6 +1841,15 @@ bool Gather::is_equivalent(const Primitive& other) const {
18111841 return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_ ;
18121842}
18131843
1844+ std::vector<Shape> Gather::output_shapes (const std::vector<array>& inputs) {
1845+ Shape out_shape;
1846+ if (inputs.size () > 1 ) {
1847+ out_shape = inputs[0 ].shape ();
1848+ }
1849+ out_shape.insert (out_shape.end (), slice_sizes_.begin (), slice_sizes_.end ());
1850+ return {std::move (out_shape)};
1851+ }
1852+
18141853std::pair<std::vector<array>, std::vector<int >> Greater::vmap (
18151854 const std::vector<array>& inputs,
18161855 const std::vector<int >& axes) {
@@ -2184,6 +2223,12 @@ std::pair<std::vector<array>, std::vector<int>> Matmul::vmap(
21842223 return {{matmul (a, b, stream ())}, {0 }};
21852224}
21862225
2226+ std::vector<Shape> Matmul::output_shapes (const std::vector<array>& inputs) {
2227+ auto out_shape = inputs[0 ].shape ();
2228+ out_shape.back () = inputs[1 ].shape (-1 );
2229+ return {std::move (out_shape)};
2230+ }
2231+
21872232std::vector<array> Maximum::vjp (
21882233 const std::vector<array>& primals,
21892234 const std::vector<array>& cotangents,
@@ -2608,6 +2653,15 @@ bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
26082653 transpose_ == qm_other.transpose_ ;
26092654}
26102655
2656+ std::vector<Shape> QuantizedMatmul::output_shapes (
2657+ const std::vector<array>& inputs) {
2658+ auto & w = inputs[1 ];
2659+ int w_outer_dims = (transpose_) ? w.shape (-2 ) : w.shape (-1 ) * 32 / bits_;
2660+ auto out_shape = inputs[0 ].shape ();
2661+ out_shape.back () = w_outer_dims;
2662+ return {std::move (out_shape)};
2663+ }
2664+
26112665std::pair<std::vector<array>, std::vector<int >> GatherQMM::vmap (
26122666 const std::vector<array>& inputs,
26132667 const std::vector<int >& axes) {
@@ -2937,13 +2991,12 @@ bool Reduce::is_equivalent(const Primitive& other) const {
29372991 return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_ ;
29382992}
29392993
2940- std::vector<std::vector<int >> Reduce::output_shapes (
2941- const std::vector<array>& inputs) {
2942- std::vector<int > out_shape = inputs[0 ].shape ();
2994+ std::vector<Shape> Reduce::output_shapes (const std::vector<array>& inputs) {
2995+ auto out_shape = inputs[0 ].shape ();
29432996 for (auto i : axes_) {
29442997 out_shape[i] = 1 ;
29452998 }
2946- return {out_shape};
2999+ return {std::move ( out_shape) };
29473000}
29483001
29493002std::vector<array> Round::vjp (
@@ -4209,6 +4262,15 @@ bool Transpose::is_equivalent(const Primitive& other) const {
42094262 return axes_ == t_other.axes_ ;
42104263}
42114264
4265+ std::vector<Shape> Transpose::output_shapes (const std::vector<array>& inputs) {
4266+ auto & in = inputs[0 ];
4267+ Shape shape (in.ndim (), 0 );
4268+ for (int i = 0 ; i < axes_.size (); ++i) {
4269+ shape[i] = in.shape ()[axes_[i]];
4270+ }
4271+ return {std::move (shape)};
4272+ }
4273+
42124274std::pair<std::vector<array>, std::vector<int >> NumberOfElements::vmap (
42134275 const std::vector<array>& inputs,
42144276 const std::vector<int >& axes) {
0 commit comments