Skip to content

Commit 55af4a3

Browse files
committed
fix: improve frontend test for torch.tensor repeat
1 parent 4b6e739 commit 55af4a3

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py

+15-18
Original file line numberDiff line numberDiff line change
@@ -335,26 +335,16 @@ def _masked_fill_helper(draw):
335335
@st.composite
336336
def _repeat_helper(draw):
337337
shape = draw(
338-
helpers.get_shape(
339-
min_num_dims=1, max_num_dims=5, min_dim_size=2, max_dim_size=10
340-
)
341-
)
342-
343-
input_dtype, x = draw(
344-
helpers.dtype_and_values(
345-
available_dtypes=helpers.get_dtypes("valid"),
346-
shape=shape,
347-
)
338+
st.shared(helpers.get_shape(min_dim_size=0, max_num_dims=8), key="value_shape")
348339
)
349-
350340
repeats = draw(
351341
st.lists(
352-
st.integers(min_value=1, max_value=5),
342+
st.integers(min_value=0, max_value=5),
353343
min_size=len(shape),
354-
max_size=5,
344+
max_size=8,
355345
)
356346
)
357-
return input_dtype, x, repeats
347+
return repeats
358348

359349

360350
@st.composite
@@ -11330,11 +11320,18 @@ def test_torch_remainder_(
1133011320
class_tree=CLASS_TREE,
1133111321
init_tree="torch.tensor",
1133211322
method_name="repeat",
11333-
dtype_x_repeats=_repeat_helper(),
11323+
dtype_and_x=helpers.dtype_and_values(
11324+
available_dtypes=helpers.get_dtypes("valid"),
11325+
shape=st.shared(
11326+
helpers.get_shape(min_dim_size=0, max_num_dims=8), key="value_shape"
11327+
),
11328+
),
11329+
repeats=_repeat_helper(),
1133411330
unpack_repeats=st.booleans(),
1133511331
)
1133611332
def test_torch_repeat(
11337-
dtype_x_repeats,
11333+
dtype_and_x,
11334+
repeats,
1133811335
unpack_repeats,
1133911336
frontend_method_data,
1134011337
init_flags,
@@ -11343,8 +11340,8 @@ def test_torch_repeat(
1134311340
on_device,
1134411341
backend_fw,
1134511342
):
11346-
input_dtype, x, repeats = dtype_x_repeats
11347-
if unpack_repeats:
11343+
input_dtype, x = dtype_and_x
11344+
if unpack_repeats and len(repeats) > 0:
1134811345
method_flags.num_positional_args = len(repeats)
1134911346
method_kwargs = {f"x{i}": x_ for i, x_ in enumerate(repeats)}
1135011347
else:

0 commit comments

Comments
 (0)