Skip to content

Commit 9c27b59

Browse files
authored
[API-Compat] scatter/gather API with overload spec (PaddlePaddle#75187)
1 parent f38d3cb commit 9c27b59

File tree

1 file changed

+65
-2
lines changed

1 file changed

+65
-2
lines changed

python/paddle/tensor/manipulation.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4124,7 +4124,7 @@ def _take_along_axis_wrapper(
41244124
dim: int,
41254125
index: Tensor,
41264126
out: Tensor | None = None,
4127-
):
4127+
) -> Tensor:
41284128
"""Wrapper for take_along_axis"""
41294129
res = paddle.take_along_axis(input, index, dim, broadcast=False)
41304130
if out is not None:
@@ -4193,6 +4193,25 @@ def _gather_wrapper(
41934193
return res
41944194

41954195

4196+
@overload
4197+
def gather(
4198+
x: Tensor,
4199+
index: Tensor,
4200+
axis: Tensor | int | None = None,
4201+
name: str | None = None,
4202+
out: Tensor | None = None,
4203+
) -> Tensor: ...
4204+
4205+
4206+
@overload
4207+
def gather(
4208+
input: Tensor,
4209+
dim: int,
4210+
index: Tensor,
4211+
out: Tensor | None = None,
4212+
) -> Tensor: ...
4213+
4214+
41964215
def gather(*args: Any, **kwargs: Any) -> Tensor:
41974216
"""
41984217
This function has two functionalities, depending on the parameters passed:
@@ -4442,6 +4461,27 @@ def _scatter_inplace_wrapper(
44424461
return _C_ops.scatter_(x, index, updates, overwrite)
44434462

44444463

4464+
@overload
4465+
def scatter_(
4466+
x: Tensor,
4467+
index: Tensor,
4468+
updates: Tensor,
4469+
overwrite: bool = True,
4470+
name: str | None = None,
4471+
) -> Tensor: ...
4472+
4473+
4474+
@overload
4475+
def scatter_(
4476+
input: Tensor,
4477+
dim: int,
4478+
index: Tensor,
4479+
src: Tensor | None = None,
4480+
reduce: str | None = None,
4481+
value: Tensor | None = None,
4482+
) -> Tensor: ...
4483+
4484+
44454485
@inplace_apis_in_dygraph_only
44464486
def scatter_(*args: Any, **kwargs: Any) -> Tensor:
44474487
"""
@@ -4513,7 +4553,7 @@ def _put_along_axis_wrapper(
45134553
reduce: str | None = None,
45144554
out: Tensor | None = None,
45154555
value: Tensor | None = None,
4516-
):
4556+
) -> Tensor:
45174557
"""A PyTorch Compatible wrapper for put_along_axis
45184558
This API is not directly available for users. One can only call this API via torch.Tensor.scatter or torch.scatter
45194559
"""
@@ -4538,6 +4578,29 @@ def _put_along_axis_wrapper(
45384578
return res
45394579

45404580

4581+
@overload
4582+
def scatter(
4583+
x: Tensor,
4584+
index: Tensor,
4585+
updates: Tensor,
4586+
overwrite: bool = True,
4587+
name: str | None = None,
4588+
out: Tensor | None = None,
4589+
) -> Tensor: ...
4590+
4591+
4592+
@overload
4593+
def scatter(
4594+
input: Tensor,
4595+
dim: int,
4596+
index: Tensor,
4597+
src: Tensor | None = None,
4598+
reduce: str | None = None,
4599+
out: Tensor | None = None,
4600+
value: Tensor | None = None,
4601+
) -> Tensor: ...
4602+
4603+
45414604
def scatter(*args: Any, **kwargs: Any) -> Tensor:
45424605
"""
45434606

0 commit comments

Comments
 (0)