Skip to content

Commit 9558a3b

Browse files
committed
fix: test_torch_diag_embed
1 parent 78135d4 commit 9558a3b

File tree

2 files changed

+45
-47
lines changed

2 files changed

+45
-47
lines changed

ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
import scipy.linalg
55
import numpy as np
6-
from hypothesis import strategies as st, assume, settings, HealthCheck
6+
from hypothesis import assume, strategies as st
77

88
# local
99
import ivy
@@ -450,49 +450,6 @@ def test_torch_det(
450450
)
451451

452452

453-
@handle_frontend_test(
454-
fn_tree="torch.diag_embed",
455-
dtype_and_values=helpers.dtype_and_values(
456-
available_dtypes=helpers.get_dtypes("float"),
457-
shape=st.shared(helpers.get_shape(min_num_dims=1, max_num_dims=2), key="shape"),
458-
),
459-
dims_and_offsets=helpers.dims_and_offset(
460-
shape=st.shared(helpers.get_shape(min_num_dims=1, max_num_dims=2), key="shape"),
461-
ensure_dim_unique=True,
462-
),
463-
)
464-
@settings(suppress_health_check=list(HealthCheck))
465-
def test_torch_diag_embed(
466-
*,
467-
dtype_and_values,
468-
dims_and_offsets,
469-
test_flags,
470-
on_device,
471-
fn_tree,
472-
frontend,
473-
backend_fw,
474-
):
475-
input_dtype, value = dtype_and_values
476-
dim1, dim2, offset = dims_and_offsets
477-
num_of_dims = len(np.shape(value[0]))
478-
if dim1 < 0:
479-
assume(dim1 + num_of_dims != dim2)
480-
if dim2 < 0:
481-
assume(dim1 != dim2 + num_of_dims)
482-
helpers.test_frontend_function(
483-
input_dtypes=input_dtype,
484-
backend_to_test=backend_fw,
485-
test_flags=test_flags,
486-
frontend=frontend,
487-
fn_tree=fn_tree,
488-
on_device=on_device,
489-
input=value[0],
490-
offset=offset,
491-
dim1=dim1,
492-
dim2=dim2,
493-
)
494-
495-
496453
# eig
497454
# TODO: Test for all valid dtypes once ivy.eig supports complex data types
498455
@handle_frontend_test(

ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# global
2+
from hypothesis import assume, strategies as st, settings, HealthCheck
3+
import hypothesis.extra.numpy as nph
24
import math
3-
45
import numpy as np
5-
from hypothesis import assume, strategies as st
6-
import hypothesis.extra.numpy as nph
76

87
# local
98
import ivy
@@ -902,6 +901,48 @@ def test_torch_diag(
902901
)
903902

904903

904+
@handle_frontend_test(
905+
fn_tree="torch.diag_embed",
906+
dtype_and_values=helpers.dtype_and_values(
907+
available_dtypes=helpers.get_dtypes("float"),
908+
shape=st.shared(helpers.get_shape(min_num_dims=1, max_num_dims=2), key="shape"),
909+
),
910+
dims_and_offsets=helpers.dims_and_offset(
911+
shape=st.shared(helpers.get_shape(min_num_dims=1, max_num_dims=2), key="shape"),
912+
ensure_dim_unique=True,
913+
),
914+
)
915+
@settings(suppress_health_check=list(HealthCheck))
916+
def test_torch_diag_embed(
917+
*,
918+
dtype_and_values,
919+
dims_and_offsets,
920+
test_flags,
921+
on_device,
922+
fn_tree,
923+
frontend,
924+
backend_fw,
925+
):
926+
input_dtype, value = dtype_and_values
927+
dim1, dim2, offset = dims_and_offsets
928+
num_of_dims = len(np.shape(value[0])) + 1
929+
norm_dim1 = dim1 if dim1 >= 0 else dim1 + num_of_dims
930+
norm_dim2 = dim2 if dim2 >= 0 else dim2 + num_of_dims
931+
assume(norm_dim1 != norm_dim2)
932+
helpers.test_frontend_function(
933+
input_dtypes=input_dtype,
934+
backend_to_test=backend_fw,
935+
test_flags=test_flags,
936+
frontend=frontend,
937+
fn_tree=fn_tree,
938+
on_device=on_device,
939+
input=value[0],
940+
offset=offset,
941+
dim1=dim1,
942+
dim2=dim2,
943+
)
944+
945+
905946
# diagflat
906947
@handle_frontend_test(
907948
fn_tree="torch.diagflat",

0 commit comments

Comments
 (0)