-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tensorflow
: add a few TensorFlow functions
#13364
base: main
Are you sure you want to change the base?
Changes from 7 commits
f80f9c2
06f5932
89d1bcb
4e678f2
3dd5cbf
43304b7
5f233ec
61dc7a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ from collections.abc import Callable, Generator, Iterable, Iterator, Sequence | |
from contextlib import contextmanager | ||
from enum import Enum | ||
from types import TracebackType | ||
from typing import Any, Generic, NoReturn, TypeVar, overload | ||
from typing import Any, Generic, Literal, NoReturn, TypeVar, overload | ||
from typing_extensions import ParamSpec, Self | ||
|
||
from google.protobuf.message import Message | ||
|
@@ -20,7 +20,16 @@ from tensorflow import ( | |
math as math, | ||
types as types, | ||
) | ||
from tensorflow._aliases import AnyArray, DTypeLike, ShapeLike, Slice, TensorCompatible | ||
from tensorflow._aliases import ( | ||
AnyArray, | ||
DTypeLike, | ||
IntArray, | ||
ScalarTensorCompatible, | ||
ShapeLike, | ||
Slice, | ||
TensorCompatible, | ||
UIntTensorCompatible, | ||
) | ||
from tensorflow.autodiff import GradientTape as GradientTape | ||
from tensorflow.core.protobuf import struct_pb2 | ||
from tensorflow.dtypes import * | ||
|
@@ -56,6 +65,7 @@ from tensorflow.math import ( | |
reduce_min as reduce_min, | ||
reduce_prod as reduce_prod, | ||
reduce_sum as reduce_sum, | ||
round as round, | ||
sigmoid as sigmoid, | ||
sign as sign, | ||
sin as sin, | ||
|
@@ -403,4 +413,22 @@ def ones_like( | |
input: RaggedTensor, dtype: DTypeLike | None = None, name: str | None = None, layout: Layout | None = None | ||
) -> RaggedTensor: ... | ||
def reshape(tensor: TensorCompatible, shape: ShapeLike | Tensor, name: str | None = None) -> Tensor: ... | ||
def pad( | ||
tensor: TensorCompatible, | ||
paddings: Tensor | IntArray | Iterable[Iterable[int]], | ||
mode: Literal["CONSTANT", "constant", "REFLECT", "reflect", "SYMMETRIC", "symmectric"] = "CONSTANT", | ||
constant_values: ScalarTensorCompatible = 0, | ||
name: str | None = None, | ||
) -> Tensor: ... | ||
def shape(input: TensorCompatible, out_type: DTypeLike | None = None, name: str | None = None) -> Tensor: ... | ||
hoel-bagard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def where( | ||
condition: TensorCompatible, x: TensorCompatible | None = None, y: TensorCompatible | None = None, name: str | None = None | ||
) -> Tensor: ... | ||
def gather_nd( | ||
params: TensorCompatible, | ||
indices: UIntTensorCompatible, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be
https://www.tensorflow.org/api_docs/python/tf/gather_nd#args There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
batch_dims: UIntTensorCompatible = 0, | ||
name: str | None = None, | ||
bad_indices_policy: Literal["", "DEFAULT", "ERROR", "IGNORE"] = "", | ||
) -> Tensor: ... | ||
def __getattr__(name: str) -> Incomplete: ... |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from tensorflow import Tensor | ||
from tensorflow._aliases import DTypeLike, TensorCompatible | ||
|
||
def hamming_window( | ||
window_length: TensorCompatible, periodic: bool | TensorCompatible = True, dtype: DTypeLike = ..., name: str | None = None | ||
) -> Tensor: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are the last two necessary? The docs just specify
A Tensor of type int32.
https://www.tensorflow.org/api_docs/python/tf/pad#args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation is more restrictive than what is actually accepted by the function. Removing the last two will likely cause a lot of false positives.
I opened three of the example links, and all 3 of them use lists instead of Tensors.