@@ -927,7 +927,104 @@ array take_along_axis(
927927 int axis,
928928 StreamOrDevice s = {});
929929
930- /* * Scatter updates to given linear indices */
930+ /* * Scatter updates to the given indices.
931+ *
932+ * The parameters ``indices`` and ``axes`` determine the locations of ``a``
933+ * that are updated with the values in ``updates``. Assuming 1-d ``indices``
934+ * for simplicity, ``indices[i]`` are the indices on axis ``axes[i]`` to which
935+ * the values in ``updates`` will be applied. Note each array in
936+ * ``indices`` is assigned to a corresponding axis and hence ``indices.size() ==
937+ * axes.size()``. If an index/axis pair is not provided then indices along that
938+ * axis are assumed to be zero.
939+ *
940+ * Note the rank of ``updates`` must be equal to the sum of the rank of the
941+ * broadcasted ``indices`` and the rank of ``a``. In other words, assuming the
942+ * arrays in ``indices`` have the same shape, ``updates.ndim() ==
943+ * indices[0].ndim() + a.ndim()``. The leading dimensions of ``updates``
944+ * correspond to the indices, and the remaining ``a.ndim()`` dimensions are the
945+ * values that will be applied to the given location in ``a``.
946+ *
947+ * For example:
948+ *
949+ * @code
950+ * auto in = zeros({4, 4}, float32);
951+ * auto indices = array({2});
952+ * auto updates = reshape(arange(1, 3, float32), {1, 1, 2});
953+ * std::vector<int> axes{0};
954+ *
955+ * auto out = scatter(in, {indices}, updates, axes);
956+ * @endcode
957+ *
958+ * will produce:
959+ *
960+ * @code
961+ * array([[0, 0, 0, 0],
962+ * [0, 0, 0, 0],
963+ * [1, 2, 0, 0],
964+ * [0, 0, 0, 0]], dtype=float32)
965+ * @endcode
966+ *
967+ * This scatters the two-element row vector ``[1, 2]`` starting at the ``(2,
968+ * 0)`` position of ``a``.
969+ *
970+ * Adding another element to ``indices`` will scatter into another location of
971+ * ``a``. We also have to add an another update for the new index:
972+ *
973+ * @code
974+ * auto in = zeros({4, 4}, float32);
975+ * auto indices = array({2, 0});
976+ * auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
977+ * std::vector<int> axes{0};
978+ *
979+ * auto out = scatter(in, {indices}, updates, axes):
980+ * @endcode
981+ *
982+ * will produce:
983+ *
984+ * @code
985+ * array([[3, 4, 0, 0],
986+ * [0, 0, 0, 0],
987+ * [1, 2, 0, 0],
988+ * [0, 0, 0, 0]], dtype=float32)
989+ * @endcode
990+ *
991+ * To control the scatter location on an additional axis, add another index
992+ * array to ``indices`` and another axis to ``axes``:
993+ *
994+ * @code
995+ * auto in = zeros({4, 4}, float32);
996+ * auto indices = std::vector{array({2, 0}), array({1, 2})};
997+ * auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
998+ * std::vector<int> axes{0, 1};
999+ *
1000+ * auto out = scatter(in, indices, updates, axes);
1001+ * @endcode
1002+ *
1003+ * will produce:
1004+ *
1005+ * @code
1006+ * array([[0, 0, 3, 4],
1007+ * [0, 0, 0, 0],
1008+ * [0, 1, 2, 0],
1009+ * [0, 0, 0, 0]], dtype=float32)
1010+ * @endcode
1011+ *
1012+ * Items in indices are broadcasted together. This means:
1013+ *
1014+ * @code
1015+ * auto indices = std::vector{array({2, 0}), array({1})};
1016+ * @endcode
1017+ *
1018+ * is equivalent to:
1019+ *
1020+ * @code
1021+ * auto indices = std::vector{array({2, 0}), array({1, 1})};
1022+ * @endcode
1023+ *
1024+ * Note, ``scatter`` does not perform bounds checking on the indices and
1025+ * updates. Out-of-bounds accesses on ``a`` are undefined and typically result
1026+ * in unintended or invalid memory writes.
1027+ */
9311028array scatter (
9321029 const array& a,
9331030 const std::vector<array>& indices,
0 commit comments