|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +"""A sub module for splitting data.""" |
14 | 15 |
|
15 | 16 | from numpy import ndarray |
16 | 17 | from pandas.core.frame import DataFrame |
|
29 | 30 |
|
30 | 31 |
|
31 | 32 | def split_data(x, y, config_dict): |
32 | | - """ |
33 | | - Split the data according to the config (i.e normal split or stratify by groups) |
34 | | - """ |
35 | | - |
| 33 | + """Split the data according to the config (i.e normal split or stratify by groups).""" |
36 | 34 | omicLogger.debug("Splitting data...") |
37 | 35 | # Split the data in train and test |
38 | 36 | if config_dict["ml"]["stratify_by_groups"] == "Y": |
@@ -65,7 +63,7 @@ def strat_split( |
65 | 63 | test_size: float = 0.2, |
66 | 64 | seed: int = 29292, |
67 | 65 | ) -> tuple[ndarray, ndarray, ndarray, ndarray]: |
68 | | - """split the data according to stratification |
| 66 | + """Split the data according to stratification. |
69 | 67 |
|
70 | 68 | Parameters |
71 | 69 | ---------- |
@@ -165,16 +163,23 @@ def strat_split( |
165 | 163 |
|
166 | 164 | for train_idx, test_idx in gss.split(x, y, groups): |
167 | 165 | if isinstance(x, DataFrame): |
168 | | - x_train, x_test, y_train, y_test = ( |
| 166 | + x_train, x_test = ( |
169 | 167 | x.iloc[train_idx, :], |
170 | 168 | x.iloc[test_idx, :], |
171 | | - y.iloc[train_idx, :], |
172 | | - y.iloc[test_idx, :], |
173 | 169 | ) |
174 | 170 | else: |
175 | | - x_train, x_test, y_train, y_test = ( |
| 171 | + x_train, x_test = ( |
176 | 172 | x[train_idx], |
177 | 173 | x[test_idx], |
| 174 | + ) |
| 175 | + |
| 176 | + if isinstance(y, DataFrame): |
| 177 | + y_train, y_test = ( |
| 178 | + y.iloc[train_idx, :], |
| 179 | + y.iloc[test_idx, :], |
| 180 | + ) |
| 181 | + else: |
| 182 | + y_train, y_test = ( |
178 | 183 | y[train_idx], |
179 | 184 | y[test_idx], |
180 | 185 | ) |
@@ -226,7 +231,6 @@ def std_split( |
226 | 231 | ValueError |
227 | 232 | is raised if x_full and y_full dont have the same number of rows |
228 | 233 | """ |
229 | | - |
230 | 234 | if not isinstance(test_size, float): |
231 | 235 | raise TypeError(f"test_size must be an float, recieved {type(test_size)}") |
232 | 236 | elif test_size < 0 or test_size > 1: |
|
0 commit comments