Skip to content

Commit 0e585b4

Browse files
nicolovawni
andauthored
Add docstring for scatter (#1189)
* Add docstring for scatter * docs nits --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent 0163a8e commit 0e585b4

File tree

1 file changed

+98
-1
lines changed

1 file changed

+98
-1
lines changed

mlx/ops.h

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,104 @@ array take_along_axis(
927927
int axis,
928928
StreamOrDevice s = {});
929929

930-
/** Scatter updates to given linear indices */
930+
/** Scatter updates to the given indices.
931+
*
932+
* The parameters ``indices`` and ``axes`` determine the locations of ``a``
933+
* that are updated with the values in ``updates``. Assuming 1-d ``indices``
934+
* for simplicity, ``indices[i]`` are the indices on axis ``axes[i]`` to which
935+
* the values in ``updates`` will be applied. Note each array in
936+
* ``indices`` is assigned to a corresponding axis and hence ``indices.size() ==
937+
* axes.size()``. If an index/axis pair is not provided then indices along that
938+
* axis are assumed to be zero.
939+
*
940+
* Note the rank of ``updates`` must be equal to the sum of the rank of the
941+
* broadcasted ``indices`` and the rank of ``a``. In other words, assuming the
942+
* arrays in ``indices`` have the same shape, ``updates.ndim() ==
943+
* indices[0].ndim() + a.ndim()``. The leading dimensions of ``updates``
944+
* correspond to the indices, and the remaining ``a.ndim()`` dimensions are the
945+
* values that will be applied to the given location in ``a``.
946+
*
947+
* For example:
948+
*
949+
* @code
950+
* auto in = zeros({4, 4}, float32);
951+
* auto indices = array({2});
952+
* auto updates = reshape(arange(1, 3, float32), {1, 1, 2});
953+
* std::vector<int> axes{0};
954+
*
955+
* auto out = scatter(in, {indices}, updates, axes);
956+
* @endcode
957+
*
958+
* will produce:
959+
*
960+
* @code
961+
* array([[0, 0, 0, 0],
962+
* [0, 0, 0, 0],
963+
* [1, 2, 0, 0],
964+
* [0, 0, 0, 0]], dtype=float32)
965+
* @endcode
966+
*
967+
* This scatters the two-element row vector ``[1, 2]`` starting at the ``(2,
968+
* 0)`` position of ``a``.
969+
*
970+
* Adding another element to ``indices`` will scatter into another location of
971+
* ``a``. We also have to add an another update for the new index:
972+
*
973+
* @code
974+
* auto in = zeros({4, 4}, float32);
975+
* auto indices = array({2, 0});
976+
* auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
977+
* std::vector<int> axes{0};
978+
*
979+
* auto out = scatter(in, {indices}, updates, axes):
980+
* @endcode
981+
*
982+
* will produce:
983+
*
984+
* @code
985+
* array([[3, 4, 0, 0],
986+
* [0, 0, 0, 0],
987+
* [1, 2, 0, 0],
988+
* [0, 0, 0, 0]], dtype=float32)
989+
* @endcode
990+
*
991+
* To control the scatter location on an additional axis, add another index
992+
* array to ``indices`` and another axis to ``axes``:
993+
*
994+
* @code
995+
* auto in = zeros({4, 4}, float32);
996+
* auto indices = std::vector{array({2, 0}), array({1, 2})};
997+
* auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
998+
* std::vector<int> axes{0, 1};
999+
*
1000+
* auto out = scatter(in, indices, updates, axes);
1001+
* @endcode
1002+
*
1003+
* will produce:
1004+
*
1005+
* @code
1006+
* array([[0, 0, 3, 4],
1007+
* [0, 0, 0, 0],
1008+
* [0, 1, 2, 0],
1009+
* [0, 0, 0, 0]], dtype=float32)
1010+
* @endcode
1011+
*
1012+
* Items in indices are broadcasted together. This means:
1013+
*
1014+
* @code
1015+
* auto indices = std::vector{array({2, 0}), array({1})};
1016+
* @endcode
1017+
*
1018+
* is equivalent to:
1019+
*
1020+
* @code
1021+
* auto indices = std::vector{array({2, 0}), array({1, 1})};
1022+
* @endcode
1023+
*
1024+
* Note, ``scatter`` does not perform bounds checking on the indices and
1025+
* updates. Out-of-bounds accesses on ``a`` are undefined and typically result
1026+
* in unintended or invalid memory writes.
1027+
*/
9311028
array scatter(
9321029
const array& a,
9331030
const std::vector<array>& indices,

0 commit comments

Comments
 (0)