Skip to content

Commit 9dd72cd

Browse files
authored
fix gumbel (#1495)
1 parent 343aa46 commit 9dd72cd

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

python/src/random.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,10 @@ void init_random(nb::module_& parent_module) {
352352
},
353353
"shape"_a = std::vector<int>{},
354354
"dtype"_a.none() = float32,
355-
"stream"_a = nb::none(),
356355
"key"_a = nb::none(),
356+
"stream"_a = nb::none(),
357357
nb::sig(
358-
"def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, stream: Optional[array] = None, key: Union[None, Stream, Device] = None) -> array"),
358+
"def gumbel(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, key: Union[None, Stream, Device] = None, stream: Optional[array] = None) -> array"),
359359
R"pbdoc(
360360
Sample from the standard Gumbel distribution.
361361
@@ -364,11 +364,14 @@ void init_random(nb::module_& parent_module) {
364364
365365
Args:
366366
shape (list(int)): The shape of the output.
367+
dtype (Dtype, optional): The data type of the output.
368+
Default: ``float32``.
367369
key (array, optional): A PRNG key. Default: ``None``.
368370
369371
Returns:
370-
array: The :class:`array` with shape ``shape`` and
371-
distributed according to the Gumbel distribution
372+
array:
373+
The :class:`array` with shape ``shape`` and distributed according
374+
to the Gumbel distribution.
372375
)pbdoc");
373376
m.def(
374377
"categorical",

0 commit comments

Comments
 (0)