Skip to content

Commit 7a65022

Browse files
Apply backend.result_type to cumprod, cumsum, nonzero, power, take, take_along_axis, tensordot, tile, trace, transpose, tril, triu, vdot, vstack, where (#18831)
* Apply `backend.result_type` to `cumprod`, `cumsum`, `take`, `take_along_axis` and `tensordot` * Apply `backend.result_type` to `tile`, `transpose`, `tril`, `triu`, `true_divide`, `vdot` and `vstack` * Apply `backend.result_type` to `nonzero`, `power` and `where` * Apply `backend.result_type` to `trace` * Remove useless uint32 condition in torch's trace
1 parent 866b745 commit 7a65022

File tree

6 files changed

+664
-63
lines changed

6 files changed

+664
-63
lines changed

keras/backend/jax/numpy.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,11 @@ def tile(x, repeats):
731731

732732

733733
def trace(x, offset=0, axis1=0, axis2=1):
734-
return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2)
734+
x = convert_to_tensor(x)
735+
dtype = None
736+
if standardize_dtype(x.dtype) == "bool":
737+
dtype = "int32"
738+
return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
735739

736740

737741
def tri(N, M=None, k=0, dtype=None):
@@ -766,9 +770,7 @@ def divide(x1, x2):
766770

767771

768772
def true_divide(x1, x2):
769-
x1 = convert_to_tensor(x1)
770-
x2 = convert_to_tensor(x2)
771-
return jnp.true_divide(x1, x2)
773+
return divide(x1, x2)
772774

773775

774776
def power(x1, x2):

keras/backend/numpy/numpy.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -366,12 +366,18 @@ def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
366366

367367
def cumprod(x, axis=None, dtype=None):
368368
axis = tuple(axis) if isinstance(axis, list) else axis
369-
return np.cumprod(x, axis=axis, dtype=dtype or x.dtype)
369+
dtype = dtypes.result_type(dtype or x.dtype)
370+
if dtype == "bool":
371+
dtype = "int32"
372+
return np.cumprod(x, axis=axis, dtype=dtype)
370373

371374

372375
def cumsum(x, axis=None, dtype=None):
373376
axis = tuple(axis) if isinstance(axis, list) else axis
374-
return np.cumsum(x, axis=axis, dtype=dtype or x.dtype)
377+
dtype = dtypes.result_type(dtype or x.dtype)
378+
if dtype == "bool":
379+
dtype = "int32"
380+
return np.cumsum(x, axis=axis, dtype=dtype)
375381

376382

377383
def diag(x, k=0):
@@ -470,12 +476,12 @@ def greater_equal(x1, x2):
470476

471477

472478
def hstack(xs):
473-
xs = tree.map_structure(convert_to_tensor, xs)
474-
dtypes_to_resolve = []
475-
for x in xs:
476-
dtypes_to_resolve.append(x.dtype)
477-
dtype = dtypes.result_type(*dtypes_to_resolve)
478-
xs = tree.map_structure(lambda x: x.astype(dtype), xs)
479+
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
480+
if len(dtype_set) > 1:
481+
dtype = dtypes.result_type(*dtype_set)
482+
xs = tree.map_structure(
483+
lambda x: convert_to_tensor(x).astype(dtype), xs
484+
)
479485
return np.hstack(xs)
480486

481487

@@ -680,7 +686,7 @@ def ndim(x):
680686

681687

682688
def nonzero(x):
683-
return np.nonzero(x)
689+
return tuple(indices.astype("int32") for indices in np.nonzero(x))
684690

685691

686692
def not_equal(x1, x2):
@@ -864,6 +870,11 @@ def tanh(x):
864870

865871
def tensordot(x1, x2, axes=2):
866872
axes = tuple(axes) if isinstance(axes, list) else axes
873+
x1 = convert_to_tensor(x1)
874+
x2 = convert_to_tensor(x2)
875+
dtype = dtypes.result_type(x1.dtype, x2.dtype)
876+
x1 = x1.astype(dtype)
877+
x2 = x2.astype(dtype)
867878
return np.tensordot(x1, x2, axes=axes)
868879

869880

@@ -878,7 +889,15 @@ def tile(x, repeats):
878889
def trace(x, offset=0, axis1=0, axis2=1):
879890
axis1 = tuple(axis1) if isinstance(axis1, list) else axis1
880891
axis2 = tuple(axis2) if isinstance(axis2, list) else axis2
881-
return np.trace(x, offset=offset, axis1=axis1, axis2=axis2)
892+
x = convert_to_tensor(x)
893+
dtype = standardize_dtype(x.dtype)
894+
if dtype == "int64":
895+
dtype = "int64"
896+
elif dtype == "uint32":
897+
dtype = "uint32"
898+
else:
899+
dtype = dtypes.result_type(dtype, "int32")
900+
return np.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
882901

883902

884903
def tri(N, M=None, k=0, dtype=None):
@@ -895,15 +914,36 @@ def triu(x, k=0):
895914

896915

897916
def vdot(x1, x2):
917+
x1 = convert_to_tensor(x1)
918+
x2 = convert_to_tensor(x2)
919+
dtype = dtypes.result_type(x1.dtype, x2.dtype)
920+
x1 = x1.astype(dtype)
921+
x2 = x2.astype(dtype)
898922
return np.vdot(x1, x2)
899923

900924

901925
def vstack(xs):
926+
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
927+
if len(dtype_set) > 1:
928+
dtype = dtypes.result_type(*dtype_set)
929+
xs = tree.map_structure(
930+
lambda x: convert_to_tensor(x).astype(dtype), xs
931+
)
902932
return np.vstack(xs)
903933

904934

905935
def where(condition, x1, x2):
906936
if x1 is not None and x2 is not None:
937+
if not isinstance(x1, (int, float)):
938+
x1 = convert_to_tensor(x1)
939+
if not isinstance(x2, (int, float)):
940+
x2 = convert_to_tensor(x2)
941+
dtype = dtypes.result_type(
942+
getattr(x1, "dtype", type(x1)),
943+
getattr(x2, "dtype", type(x2)),
944+
)
945+
x1 = convert_to_tensor(x1, dtype)
946+
x2 = convert_to_tensor(x2, dtype)
907947
return np.where(condition, x1, x2)
908948
else:
909949
return np.where(condition)
@@ -925,10 +965,20 @@ def divide(x1, x2):
925965

926966

927967
def true_divide(x1, x2):
928-
return np.true_divide(x1, x2)
968+
return divide(x1, x2)
929969

930970

931971
def power(x1, x2):
972+
if not isinstance(x1, (int, float)):
973+
x1 = convert_to_tensor(x1)
974+
if not isinstance(x2, (int, float)):
975+
x2 = convert_to_tensor(x2)
976+
dtype = dtypes.result_type(
977+
getattr(x1, "dtype", type(x1)),
978+
getattr(x2, "dtype", type(x2)),
979+
)
980+
x1 = convert_to_tensor(x1, dtype)
981+
x2 = convert_to_tensor(x2, dtype)
932982
return np.power(x1, x2)
933983

934984

keras/backend/tensorflow/numpy.py

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -590,11 +590,17 @@ def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
590590

591591

592592
def cumprod(x, axis=None, dtype=None):
593-
return tfnp.cumprod(x, axis=axis, dtype=dtype or x.dtype)
593+
dtype = dtypes.result_type(dtype or x.dtype)
594+
if dtype == "bool":
595+
dtype = "int32"
596+
return tfnp.cumprod(x, axis=axis, dtype=dtype)
594597

595598

596599
def cumsum(x, axis=None, dtype=None):
597-
return tfnp.cumsum(x, axis=axis, dtype=dtype or x.dtype)
600+
dtype = dtypes.result_type(dtype or x.dtype)
601+
if dtype == "bool":
602+
dtype = "int32"
603+
return tfnp.cumsum(x, axis=axis, dtype=dtype)
598604

599605

600606
def diag(x, k=0):
@@ -743,12 +749,10 @@ def greater_equal(x1, x2):
743749

744750

745751
def hstack(xs):
746-
xs = tf.nest.map_structure(convert_to_tensor, xs)
747-
dtypes_to_resolve = []
748-
for x in xs:
749-
dtypes_to_resolve.append(x.dtype)
750-
dtype = dtypes.result_type(*dtypes_to_resolve)
751-
xs = tf.nest.map_structure(lambda x: tf.cast(x, dtype), xs)
752+
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
753+
if len(dtype_set) > 1:
754+
dtype = dtypes.result_type(*dtype_set)
755+
xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
752756
return tfnp.hstack(xs)
753757

754758

@@ -1031,7 +1035,9 @@ def ndim(x):
10311035

10321036

10331037
def nonzero(x):
1034-
return tfnp.nonzero(x)
1038+
return tf.nest.map_structure(
1039+
lambda indices: tf.cast(indices, "int32"), tfnp.nonzero(x)
1040+
)
10351041

10361042

10371043
def not_equal(x1, x2):
@@ -1292,12 +1298,10 @@ def split(x, indices_or_sections, axis=0):
12921298

12931299

12941300
def stack(x, axis=0):
1295-
x = tf.nest.map_structure(convert_to_tensor, x)
1296-
dtypes_to_resolve = []
1297-
for a in x:
1298-
dtypes_to_resolve.append(a.dtype)
1299-
dtype = dtypes.result_type(*dtypes_to_resolve)
1300-
x = tf.nest.map_structure(lambda a: tf.cast(a, dtype), x)
1301+
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
1302+
if len(dtype_set) > 1:
1303+
dtype = dtypes.result_type(*dtype_set)
1304+
x = tf.nest.map_structure(lambda a: convert_to_tensor(a, dtype), x)
13011305
return tfnp.stack(x, axis=axis)
13021306

13031307

@@ -1366,7 +1370,14 @@ def tanh(x):
13661370

13671371

13681372
def tensordot(x1, x2, axes=2):
1369-
return tfnp.tensordot(x1, x2, axes=axes)
1373+
x1 = convert_to_tensor(x1)
1374+
x2 = convert_to_tensor(x2)
1375+
result_dtype = dtypes.result_type(x1.dtype, x2.dtype)
1376+
# TODO: tfnp.tensordot only supports float types
1377+
compute_dtype = dtypes.result_type(result_dtype, float)
1378+
x1 = tf.cast(x1, compute_dtype)
1379+
x2 = tf.cast(x2, compute_dtype)
1380+
return tf.cast(tfnp.tensordot(x1, x2, axes=axes), dtype=result_dtype)
13701381

13711382

13721383
@sparse.elementwise_unary
@@ -1394,7 +1405,15 @@ def tile(x, repeats):
13941405

13951406

13961407
def trace(x, offset=0, axis1=0, axis2=1):
1397-
return tfnp.trace(x, offset=offset, axis1=axis1, axis2=axis2)
1408+
x = convert_to_tensor(x)
1409+
dtype = standardize_dtype(x.dtype)
1410+
if dtype == "int64":
1411+
dtype = "int64"
1412+
elif dtype == "uint32":
1413+
dtype = "uint32"
1414+
else:
1415+
dtype = dtypes.result_type(dtype, "int32")
1416+
return tfnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
13981417

13991418

14001419
def tri(N, M=None, k=0, dtype=None):
@@ -1403,22 +1422,54 @@ def tri(N, M=None, k=0, dtype=None):
14031422

14041423

14051424
def tril(x, k=0):
1425+
x = convert_to_tensor(x)
1426+
# TODO: tfnp.tril doesn't support bool
1427+
if standardize_dtype(x.dtype) == "bool":
1428+
x = tf.cast(x, "uint8")
1429+
return tf.cast(tfnp.tril(x, k=k), "bool")
14061430
return tfnp.tril(x, k=k)
14071431

14081432

14091433
def triu(x, k=0):
1434+
x = convert_to_tensor(x)
1435+
# TODO: tfnp.triu doesn't support bool
1436+
if standardize_dtype(x.dtype) == "bool":
1437+
x = tf.cast(x, "uint8")
1438+
return tf.cast(tfnp.tril(x, k=k), "bool")
14101439
return tfnp.triu(x, k=k)
14111440

14121441

14131442
def vdot(x1, x2):
1414-
return tfnp.vdot(x1, x2)
1443+
x1 = convert_to_tensor(x1)
1444+
x2 = convert_to_tensor(x2)
1445+
result_dtype = dtypes.result_type(x1.dtype, x2.dtype)
1446+
# TODO: tfnp.vdot only supports float types
1447+
compute_dtype = dtypes.result_type(result_dtype, float)
1448+
x1 = tf.cast(x1, compute_dtype)
1449+
x2 = tf.cast(x2, compute_dtype)
1450+
return tf.cast(tfnp.vdot(x1, x2), result_dtype)
14151451

14161452

14171453
def vstack(xs):
1454+
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
1455+
if len(dtype_set) > 1:
1456+
dtype = dtypes.result_type(*dtype_set)
1457+
xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
14181458
return tfnp.vstack(xs)
14191459

14201460

14211461
def where(condition, x1, x2):
1462+
if x1 is not None and x2 is not None:
1463+
if not isinstance(x1, (int, float)):
1464+
x1 = convert_to_tensor(x1)
1465+
if not isinstance(x2, (int, float)):
1466+
x2 = convert_to_tensor(x2)
1467+
dtype = dtypes.result_type(
1468+
getattr(x1, "dtype", type(x1)),
1469+
getattr(x2, "dtype", type(x2)),
1470+
)
1471+
x1 = convert_to_tensor(x1, dtype)
1472+
x2 = convert_to_tensor(x2, dtype)
14221473
return tfnp.where(condition, x1, x2)
14231474

14241475

@@ -1440,10 +1491,25 @@ def divide(x1, x2):
14401491

14411492
@sparse.elementwise_division
14421493
def true_divide(x1, x2):
1443-
return tfnp.true_divide(x1, x2)
1494+
return divide(x1, x2)
14441495

14451496

14461497
def power(x1, x2):
1498+
if not isinstance(x1, (int, float)):
1499+
x1 = convert_to_tensor(x1)
1500+
if not isinstance(x2, (int, float)):
1501+
x2 = convert_to_tensor(x2)
1502+
dtype = dtypes.result_type(
1503+
getattr(x1, "dtype", type(x1)),
1504+
getattr(x2, "dtype", type(x2)),
1505+
)
1506+
# TODO: tfnp.power doesn't support uint* types
1507+
if "uint" in dtype:
1508+
x1 = convert_to_tensor(x1, "int32")
1509+
x2 = convert_to_tensor(x2, "int32")
1510+
return tf.cast(tfnp.power(x1, x2), dtype)
1511+
x1 = convert_to_tensor(x1, dtype)
1512+
x2 = convert_to_tensor(x2, dtype)
14471513
return tfnp.power(x1, x2)
14481514

14491515

0 commit comments

Comments
 (0)