-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgskfold.py
More file actions
190 lines (176 loc) · 8.3 KB
/
gskfold.py
File metadata and controls
190 lines (176 loc) · 8.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from sklearn.model_selection._split import _BaseKFold
from sklearn.utils import check_random_state
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import column_or_1d
from collections import defaultdict
import warnings
import numpy as np
class StratifiedGroupKFold(_BaseKFold):
"""Stratified K-Folds iterator variant with non-overlapping groups.
This cross-validation object is a variation of StratifiedKFold attempts to
return stratified folds with non-overlapping groups. The folds are made by
preserving the percentage of samples for each class.
The same group will not appear in two different folds (the number of
distinct groups has to be at least equal to the number of folds).
The difference between GroupKFold and StratifiedGroupKFold is that
the former attempts to create balanced folds such that the number of
distinct groups is approximately the same in each fold, whereas
StratifiedGroupKFold attempts to create folds which preserve the
percentage of samples for each class as much as possible given the
constraint of non-overlapping groups between splits.
Read more in the :ref:`User Guide <cross_validation>`.
Parameters
----------
n_splits : int, default=5
Number of folds. Must be at least 2.
shuffle : bool, default=False
Whether to shuffle each class's samples before splitting into batches.
Note that the samples within each split will not be shuffled.
This implementation can only shuffle groups that have approximately the
same y distribution, no global shuffle will be performed.
random_state : int or RandomState instance, default=None
When `shuffle` is True, `random_state` affects the ordering of the
indices, which controls the randomness of each fold for each class.
Otherwise, leave `random_state` as `None`.
Pass an int for reproducible output across multiple function calls.
See :term:`Glossary <random_state>`.
Examples
--------
>>> import numpy as np
>>> from sklearn.model_selection import StratifiedGroupKFold
>>> X = np.ones((17, 2))
>>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
>>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8])
>>> cv = StratifiedGroupKFold(n_splits=3)
>>> for train_idxs, test_idxs in cv.split(X, y, groups):
... print("TRAIN:", groups[train_idxs])
... print(" ", y[train_idxs])
... print(" TEST:", groups[test_idxs])
... print(" ", y[test_idxs])
TRAIN: [1 1 2 2 4 5 5 5 5 8 8]
[0 0 1 1 1 0 0 0 0 0 0]
TEST: [3 3 3 6 6 7]
[1 1 1 0 0 0]
TRAIN: [3 3 3 4 5 5 5 5 6 6 7]
[1 1 1 1 0 0 0 0 0 0 0]
TEST: [1 1 2 2 8 8]
[0 0 1 1 0 0]
TRAIN: [1 1 2 2 3 3 3 6 6 7 8 8]
[0 0 1 1 1 1 1 0 0 0 0 0]
TEST: [4 5 5 5 5]
[1 0 0 0 0]
Notes
-----
The implementation is designed to:
* Mimic the behavior of StratifiedKFold as much as possible for trivial
groups (e.g. when each group contains only one sample).
* Be invariant to class label: relabelling ``y = ["Happy", "Sad"]`` to
``y = [1, 0]`` should not change the indices generated.
* Stratify based on samples as much as possible while keeping
non-overlapping groups constraint. That means that in some cases when
there is a small number of groups containing a large number of samples
the stratification will not be possible and the behavior will be close
to GroupKFold.
See also
--------
StratifiedKFold: Takes class information into account to build folds which
retain class distributions (for binary or multiclass classification
tasks).
GroupKFold: K-fold iterator variant with non-overlapping groups.
"""
def __init__(self, n_splits=5, shuffle=False, random_state=None):
super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
def _iter_test_indices(self, X, y, groups):
# Implementation is based on this kaggle kernel:
# https://www.kaggle.com/jakubwasikowski/stratified-group-k-fold-cross-validation
# and is a subject to Apache 2.0 License. You may obtain a copy of the
# License at http://www.apache.org/licenses/LICENSE-2.0
# Changelist:
# - Refactored function to a class following scikit-learn KFold
# interface.
# - Added heuristic for assigning group to the least populated fold in
# cases when all other criteria are equal
# - Swtch from using python ``Counter`` to ``np.unique`` to get class
# distribution
# - Added scikit-learn checks for input: checking that target is binary
# or multiclass, checking passed random state, checking that number
# of splits is less than number of members in each class, checking
# that least populated class has more members than there are splits.
rng = check_random_state(self.random_state)
y = np.asarray(y)
type_of_target_y = type_of_target(y)
allowed_target_types = ("binary", "multiclass")
if type_of_target_y not in allowed_target_types:
raise ValueError(
"Supported target types are: {}. Got {!r} instead.".format(
allowed_target_types, type_of_target_y
)
)
y = column_or_1d(y)
_, y_inv, y_cnt = np.unique(y, return_inverse=True, return_counts=True)
if np.all(self.n_splits > y_cnt):
raise ValueError(
"n_splits=%d cannot be greater than the"
" number of members in each class." % (self.n_splits)
)
n_smallest_class = np.min(y_cnt)
if self.n_splits > n_smallest_class:
warnings.warn(
"The least populated class in y has only %d"
" members, which is less than n_splits=%d."
% (n_smallest_class, self.n_splits),
UserWarning,
)
n_classes = len(y_cnt)
_, groups_inv, groups_cnt = np.unique(
groups, return_inverse=True, return_counts=True
)
y_counts_per_group = np.zeros((len(groups_cnt), n_classes))
for class_idx, group_idx in zip(y_inv, groups_inv):
y_counts_per_group[group_idx, class_idx] += 1
y_counts_per_fold = np.zeros((self.n_splits, n_classes))
groups_per_fold = defaultdict(set)
if self.shuffle:
rng.shuffle(y_counts_per_group)
# Stable sort to keep shuffled order for groups with the same
# class distribution variance
sorted_groups_idx = np.argsort(
-np.std(y_counts_per_group, axis=1), kind="mergesort"
)
for group_idx in sorted_groups_idx:
group_y_counts = y_counts_per_group[group_idx]
best_fold = self._find_best_fold(
y_counts_per_fold=y_counts_per_fold,
y_cnt=y_cnt,
group_y_counts=group_y_counts,
)
y_counts_per_fold[best_fold] += group_y_counts
groups_per_fold[best_fold].add(group_idx)
for i in range(self.n_splits):
test_indices = [
idx
for idx, group_idx in enumerate(groups_inv)
if group_idx in groups_per_fold[i]
]
yield test_indices
def _find_best_fold(self, y_counts_per_fold, y_cnt, group_y_counts):
best_fold = None
min_eval = np.inf
min_samples_in_fold = np.inf
for i in range(self.n_splits):
y_counts_per_fold[i] += group_y_counts
# Summarise the distribution over classes in each proposed fold
std_per_class = np.std(y_counts_per_fold / y_cnt.reshape(1, -1), axis=0)
y_counts_per_fold[i] -= group_y_counts
fold_eval = np.mean(std_per_class)
samples_in_fold = np.sum(y_counts_per_fold[i])
is_current_fold_better = (
fold_eval < min_eval
or np.isclose(fold_eval, min_eval)
and samples_in_fold < min_samples_in_fold
)
if is_current_fold_better:
min_eval = fold_eval
min_samples_in_fold = samples_in_fold
best_fold = i
return best_fold