Skip to content

Commit a92994f

Browse files
fix: torch.dsplit frontend (#28904)
1 parent b20eb45 commit a92994f

File tree

2 files changed

+6
-10
lines changed

2 files changed

+6
-10
lines changed

ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,11 @@ def diagonal_scatter(input, src, offset=0, dim1=0, dim2=1):
102102

103103
@to_ivy_arrays_and_back
104104
def dsplit(input, indices_or_sections, /):
105-
if isinstance(indices_or_sections, (list, tuple, ivy.Array)):
106-
indices_or_sections = (
107-
ivy.diff(indices_or_sections, prepend=[0], append=[input.shape[2]])
108-
.astype(ivy.int8)
109-
.to_list()
105+
if input.ndim < 3:
106+
raise ValueError(
107+
f"dsplit requires a tensor with at least 3 dimensions, but got a tensor with {input.ndim}"
110108
)
111-
return tuple(ivy.dsplit(input, indices_or_sections))
109+
return tensor_split(input, indices_or_sections, dim=2)
112110

113111

114112
@to_ivy_arrays_and_back

ivy_tests/test_ivy/test_frontends/test_torch/test_indexing_slicing_joining_mutating_ops.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# global
2-
import random
3-
4-
from hypothesis import strategies as st
2+
from hypothesis import assume, strategies as st
53
import math
6-
4+
import random
75

86
# local
97
import ivy

0 commit comments

Comments
 (0)