Skip to content

Commit d021c26

Browse files
authored
bugfix: stratification fix (#55)
1 parent f1ca3a8 commit d021c26

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

autoxai4omics/utils/ml/data_split.py

100755100644
Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""A sub module for splitting data."""
1415

1516
from numpy import ndarray
1617
from pandas.core.frame import DataFrame
@@ -29,10 +30,7 @@
2930

3031

3132
def split_data(x, y, config_dict):
32-
"""
33-
Split the data according to the config (i.e normal split or stratify by groups)
34-
"""
35-
33+
"""Split the data according to the config (i.e normal split or stratify by groups)."""
3634
omicLogger.debug("Splitting data...")
3735
# Split the data in train and test
3836
if config_dict["ml"]["stratify_by_groups"] == "Y":
@@ -65,7 +63,7 @@ def strat_split(
6563
test_size: float = 0.2,
6664
seed: int = 29292,
6765
) -> tuple[ndarray, ndarray, ndarray, ndarray]:
68-
"""split the data according to stratification
66+
"""Split the data according to stratification.
6967
7068
Parameters
7169
----------
@@ -165,16 +163,23 @@ def strat_split(
165163

166164
for train_idx, test_idx in gss.split(x, y, groups):
167165
if isinstance(x, DataFrame):
168-
x_train, x_test, y_train, y_test = (
166+
x_train, x_test = (
169167
x.iloc[train_idx, :],
170168
x.iloc[test_idx, :],
171-
y.iloc[train_idx, :],
172-
y.iloc[test_idx, :],
173169
)
174170
else:
175-
x_train, x_test, y_train, y_test = (
171+
x_train, x_test = (
176172
x[train_idx],
177173
x[test_idx],
174+
)
175+
176+
if isinstance(y, DataFrame):
177+
y_train, y_test = (
178+
y.iloc[train_idx, :],
179+
y.iloc[test_idx, :],
180+
)
181+
else:
182+
y_train, y_test = (
178183
y[train_idx],
179184
y[test_idx],
180185
)
@@ -226,7 +231,6 @@ def std_split(
226231
ValueError
227232
is raised if x_full and y_full dont have the same number of rows
228233
"""
229-
230234
if not isinstance(test_size, float):
231235
raise TypeError(f"test_size must be an float, recieved {type(test_size)}")
232236
elif test_size < 0 or test_size > 1:

0 commit comments

Comments
 (0)