@@ -4124,7 +4124,7 @@ def _take_along_axis_wrapper(
41244124 dim : int ,
41254125 index : Tensor ,
41264126 out : Tensor | None = None ,
4127- ):
4127+ ) -> Tensor :
41284128 """Wrapper for take_along_axis"""
41294129 res = paddle .take_along_axis (input , index , dim , broadcast = False )
41304130 if out is not None :
@@ -4193,6 +4193,25 @@ def _gather_wrapper(
41934193 return res
41944194
41954195
4196+ @overload
4197+ def gather (
4198+ x : Tensor ,
4199+ index : Tensor ,
4200+ axis : Tensor | int | None = None ,
4201+ name : str | None = None ,
4202+ out : Tensor | None = None ,
4203+ ) -> Tensor : ...
4204+
4205+
4206+ @overload
4207+ def gather (
4208+ input : Tensor ,
4209+ dim : int ,
4210+ index : Tensor ,
4211+ out : Tensor | None = None ,
4212+ ) -> Tensor : ...
4213+
4214+
41964215def gather (* args : Any , ** kwargs : Any ) -> Tensor :
41974216 """
41984217 This function has two functionalities, depending on the parameters passed:
@@ -4442,6 +4461,27 @@ def _scatter_inplace_wrapper(
44424461 return _C_ops .scatter_ (x , index , updates , overwrite )
44434462
44444463
4464+ @overload
4465+ def scatter_ (
4466+ x : Tensor ,
4467+ index : Tensor ,
4468+ updates : Tensor ,
4469+ overwrite : bool = True ,
4470+ name : str | None = None ,
4471+ ) -> Tensor : ...
4472+
4473+
4474+ @overload
4475+ def scatter_ (
4476+ input : Tensor ,
4477+ dim : int ,
4478+ index : Tensor ,
4479+ src : Tensor | None = None ,
4480+ reduce : str | None = None ,
4481+ value : Tensor | None = None ,
4482+ ) -> Tensor : ...
4483+
4484+
44454485@inplace_apis_in_dygraph_only
44464486def scatter_ (* args : Any , ** kwargs : Any ) -> Tensor :
44474487 """
@@ -4513,7 +4553,7 @@ def _put_along_axis_wrapper(
45134553 reduce : str | None = None ,
45144554 out : Tensor | None = None ,
45154555 value : Tensor | None = None ,
4516- ):
4556+ ) -> Tensor :
45174557 """A PyTorch Compatible wrapper for put_along_axis
45184558 This API is not directly available for users. One can only call this API via torch.Tensor.scatter or torch.scatter
45194559 """
@@ -4538,6 +4578,29 @@ def _put_along_axis_wrapper(
45384578 return res
45394579
45404580
4581+ @overload
4582+ def scatter (
4583+ x : Tensor ,
4584+ index : Tensor ,
4585+ updates : Tensor ,
4586+ overwrite : bool = True ,
4587+ name : str | None = None ,
4588+ out : Tensor | None = None ,
4589+ ) -> Tensor : ...
4590+
4591+
4592+ @overload
4593+ def scatter (
4594+ input : Tensor ,
4595+ dim : int ,
4596+ index : Tensor ,
4597+ src : Tensor | None = None ,
4598+ reduce : str | None = None ,
4599+ out : Tensor | None = None ,
4600+ value : Tensor | None = None ,
4601+ ) -> Tensor : ...
4602+
4603+
45414604def scatter (* args : Any , ** kwargs : Any ) -> Tensor :
45424605 """
45434606
0 commit comments