@@ -335,26 +335,16 @@ def _masked_fill_helper(draw):
335
335
@st .composite
336
336
def _repeat_helper (draw ):
337
337
shape = draw (
338
- helpers .get_shape (
339
- min_num_dims = 1 , max_num_dims = 5 , min_dim_size = 2 , max_dim_size = 10
340
- )
341
- )
342
-
343
- input_dtype , x = draw (
344
- helpers .dtype_and_values (
345
- available_dtypes = helpers .get_dtypes ("valid" ),
346
- shape = shape ,
347
- )
338
+ st .shared (helpers .get_shape (min_dim_size = 0 , max_num_dims = 8 ), key = "value_shape" )
348
339
)
349
-
350
340
repeats = draw (
351
341
st .lists (
352
- st .integers (min_value = 1 , max_value = 5 ),
342
+ st .integers (min_value = 0 , max_value = 5 ),
353
343
min_size = len (shape ),
354
- max_size = 5 ,
344
+ max_size = 8 ,
355
345
)
356
346
)
357
- return input_dtype , x , repeats
347
+ return repeats
358
348
359
349
360
350
@st .composite
@@ -11330,11 +11320,18 @@ def test_torch_remainder_(
11330
11320
class_tree = CLASS_TREE ,
11331
11321
init_tree = "torch.tensor" ,
11332
11322
method_name = "repeat" ,
11333
- dtype_x_repeats = _repeat_helper (),
11323
+ dtype_and_x = helpers .dtype_and_values (
11324
+ available_dtypes = helpers .get_dtypes ("valid" ),
11325
+ shape = st .shared (
11326
+ helpers .get_shape (min_dim_size = 0 , max_num_dims = 8 ), key = "value_shape"
11327
+ ),
11328
+ ),
11329
+ repeats = _repeat_helper (),
11334
11330
unpack_repeats = st .booleans (),
11335
11331
)
11336
11332
def test_torch_repeat (
11337
- dtype_x_repeats ,
11333
+ dtype_and_x ,
11334
+ repeats ,
11338
11335
unpack_repeats ,
11339
11336
frontend_method_data ,
11340
11337
init_flags ,
@@ -11343,8 +11340,8 @@ def test_torch_repeat(
11343
11340
on_device ,
11344
11341
backend_fw ,
11345
11342
):
11346
- input_dtype , x , repeats = dtype_x_repeats
11347
- if unpack_repeats :
11343
+ input_dtype , x = dtype_and_x
11344
+ if unpack_repeats and len ( repeats ) > 0 :
11348
11345
method_flags .num_positional_args = len (repeats )
11349
11346
method_kwargs = {f"x{ i } " : x_ for i , x_ in enumerate (repeats )}
11350
11347
else :
0 commit comments