@@ -135,7 +135,7 @@ def _reduce_impl(input: "tripy.Tensor", kind: Reduce.Kind, dim: Union[int, Seque
135
135
@export .public_api (document_under = "operations/functions" )
136
136
@constraints .dtypes (
137
137
constraints = {"input" : "T1" , constraints .RETURN_VALUE : "T1" },
138
- variables = {"T1" : ["float32" , "int32" , "float16" , "bfloat16" ]},
138
+ variables = {"T1" : ["float32" , "int32" , "int64" , " float16" , "bfloat16" ]},
139
139
)
140
140
def sum (
141
141
input : "tripy.Tensor" , dim : Optional [Union [int , Sequence [int ]]] = None , keepdim : bool = False
@@ -232,7 +232,7 @@ def any(
232
232
@export .public_api (document_under = "operations/functions" )
233
233
@constraints .dtypes (
234
234
constraints = {"input" : "T1" , constraints .RETURN_VALUE : "T1" },
235
- variables = {"T1" : ["float32" , "int32" , "float16" , "bfloat16" ]},
235
+ variables = {"T1" : ["float32" , "int32" , "int64" , " float16" , "bfloat16" ]},
236
236
)
237
237
def max (
238
238
input : "tripy.Tensor" , dim : Optional [Union [int , Sequence [int ]]] = None , keepdim : bool = False
@@ -265,7 +265,7 @@ def max(
265
265
@export .public_api (document_under = "operations/functions" )
266
266
@constraints .dtypes (
267
267
constraints = {"input" : "T1" , constraints .RETURN_VALUE : "T1" },
268
- variables = {"T1" : ["float32" , "int32" , "float16" , "bfloat16" ]},
268
+ variables = {"T1" : ["float32" , "int32" , "int64" , " float16" , "bfloat16" ]},
269
269
)
270
270
def prod (
271
271
input : "tripy.Tensor" , dim : Optional [Union [int , Sequence [int ]]] = None , keepdim : bool = False
@@ -313,7 +313,7 @@ def mean_impl(tensor: "tripy.Tensor", dim: Union[int, Sequence] = None, keepdim:
313
313
@export .public_api (document_under = "operations/functions" )
314
314
@constraints .dtypes (
315
315
constraints = {"input" : "T1" , constraints .RETURN_VALUE : "T1" },
316
- variables = {"T1" : ["float32" , "int32" , "float16" , "bfloat16" ]},
316
+ variables = {"T1" : ["float32" , "int32" , "int64" , " float16" , "bfloat16" ]},
317
317
)
318
318
def mean (
319
319
input : "tripy.Tensor" , dim : Optional [Union [int , Sequence [int ]]] = None , keepdim : bool = False
@@ -413,7 +413,7 @@ def _arg_min_max_impl(tensor: "tripy.Tensor", kind: ArgMinMax.Kind, dim: Optiona
413
413
@export .public_api (document_under = "operations/functions" )
414
414
@constraints .dtypes (
415
415
constraints = {"input" : "T1" , constraints .RETURN_VALUE : "T2" },
416
- variables = {"T1" : ["float32" , "float16" , "bfloat16" , "int32" , "bool" , "int8" ], "T2" : ["int32" ]},
416
+ variables = {"T1" : ["float32" , "float16" , "bfloat16" , "int32" ], "T2" : ["int32" ]},
417
417
)
418
418
def argmax (input : "tripy.Tensor" , dim : Optional [int ] = None , keepdim : bool = False ) -> "tripy.Tensor" :
419
419
"""
@@ -445,7 +445,7 @@ def argmax(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = Fal
445
445
@export .public_api (document_under = "operations/functions" )
446
446
@constraints .dtypes (
447
447
constraints = {"input" : "T1" , constraints .RETURN_VALUE : "T2" },
448
- variables = {"T1" : ["float32" , "float16" , "bfloat16" , "int32" , "bool" , "int8" ], "T2" : ["int32" ]},
448
+ variables = {"T1" : ["float32" , "float16" , "bfloat16" , "int32" ], "T2" : ["int32" ]},
449
449
)
450
450
def argmin (input : "tripy.Tensor" , dim : Optional [int ] = None , keepdim : bool = False ) -> "tripy.Tensor" :
451
451
"""
0 commit comments