Skip to content

Commit 60156fd

Browse files
chinasaurcopybara-github
authored andcommitted
Externalize training example sampling utilities for BrainState.
PiperOrigin-RevId: 862056907
1 parent 94eebf8 commit 60156fd

File tree

2 files changed

+342
-0
lines changed

2 files changed

+342
-0
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# coding=utf-8
2+
# Copyright 2026 The Google Research Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Utils for selecting consistent dataset splits across experiments."""
16+
17+
from collections.abc import Sequence
18+
import dataclasses
19+
from typing import Self
20+
import numpy as np
21+
22+
23+
@dataclasses.dataclass
24+
class DatasetMultiSplit:
25+
sample_id_splits: list[np.ndarray]
26+
label_splits: list[np.ndarray]
27+
28+
29+
@dataclasses.dataclass
30+
class DatasetSplit:
31+
"""Represents split of dataset into train/valid/test for ML experiments."""
32+
train_ids: np.ndarray
33+
valid_ids: np.ndarray
34+
test_ids: np.ndarray
35+
train_labels: np.ndarray
36+
valid_labels: np.ndarray
37+
test_labels: np.ndarray
38+
39+
def upsampled(self, upsample_factor: int, dataset_len: int) -> Self:
40+
train_ids, valid_ids, test_ids = [], [], []
41+
train_labels, valid_labels, test_labels = [], [], []
42+
for i in range(upsample_factor):
43+
train_ids.append(self.train_ids + i * dataset_len)
44+
valid_ids.append(self.valid_ids + i * dataset_len)
45+
test_ids.append(self.test_ids + i * dataset_len)
46+
train_labels.append(self.train_labels)
47+
valid_labels.append(self.valid_labels)
48+
test_labels.append(self.test_labels)
49+
return DatasetSplit(
50+
np.concatenate(train_ids),
51+
np.concatenate(valid_ids),
52+
np.concatenate(test_ids),
53+
np.concatenate(train_labels),
54+
np.concatenate(valid_labels),
55+
np.concatenate(test_labels),
56+
)
57+
58+
59+
def concatenate_splits(
60+
splits: Sequence[DatasetSplit], dataset_lengths: Sequence[int]
61+
) -> DatasetSplit:
62+
"""Concatenate DatasetSplits by incrementing by previous dataset length."""
63+
train_ids, valid_ids, test_ids = [], [], []
64+
train_labels, valid_labels, test_labels = [], [], []
65+
increment = 0
66+
for split, l in zip(splits, dataset_lengths):
67+
train_ids.append(split.train_ids + increment)
68+
valid_ids.append(split.valid_ids + increment)
69+
test_ids.append(split.test_ids + increment)
70+
train_labels.append(split.train_labels)
71+
valid_labels.append(split.valid_labels)
72+
test_labels.append(split.test_labels)
73+
increment += l
74+
return DatasetSplit(
75+
np.concatenate(train_ids),
76+
np.concatenate(valid_ids),
77+
np.concatenate(test_ids),
78+
np.concatenate(train_labels),
79+
np.concatenate(valid_labels),
80+
np.concatenate(test_labels),
81+
)
82+
83+
84+
def split_indices_by_labels(
85+
labels: Sequence[int], ratios: Sequence[float],
86+
rng: np.random.RandomState) -> list[np.ndarray]:
87+
"""Low-level function to generate arbitrary splits balanced by labels.
88+
89+
Args:
90+
labels: The data labels to balance splits by.
91+
ratios: The ratios of the splits. A final implicit split will be included,
92+
so e.g. passing ratios=[0.8, 0.1] will result in an 80/10/10 percent
93+
split. (If ratios adds up to >=1 then the trailing splits will be empty.)
94+
rng: A np.random.RandomState to use for splitting.
95+
96+
Returns:
97+
The indices into labels for each split (total len(ratios) + 1). This can be
98+
used to index into e.g. example IDs as well.
99+
"""
100+
split_indices = []
101+
for label in np.unique(labels):
102+
label_indices = np.flatnonzero(labels == label)
103+
rng.shuffle(label_indices)
104+
# Splits are rounded this way for backward compatibility.
105+
n = len(label_indices)
106+
splits = np.cumsum([int(ratio * n) for ratio in ratios])
107+
split_indices.append(np.split(label_indices, splits))
108+
109+
return [np.concat(si) for si in zip(*split_indices)] # Reshape.
110+
111+
112+
def split_dataset_by_ratios(
113+
sample_ids: Sequence[int], seed: int, ratios: Sequence[float],
114+
labels: Sequence[int] | None = None,
115+
) -> DatasetMultiSplit:
116+
"""Splits dataset and labels by given ratios, balanced by labels.
117+
118+
Args:
119+
sample_ids: IDs to identify examples, e.g. cell ids
120+
seed: random seed
121+
ratios: The ratios of the splits. A final implicit split will be included,
122+
so e.g. passing ratios=[0.8, 0.1] will result in an 80/10/10 percent
123+
split. (If ratios adds up to >=1 then the trailing splits will be empty.)
124+
labels: A label array of the same length as sample_ids. When passed, the
125+
samples for each label are distributed among the splits according to their
126+
ratios.
127+
128+
Returns:
129+
DatasetMultiSplit
130+
"""
131+
if len(np.unique(sample_ids)) != len(sample_ids):
132+
raise ValueError("Found repeated sample ids")
133+
134+
if labels is not None:
135+
if len(labels) != len(sample_ids):
136+
raise ValueError("labels must be of the same length as sample_ids")
137+
labels = np.array(labels, dtype=int)
138+
else:
139+
labels = np.zeros(len(sample_ids), dtype=int)
140+
141+
# Sort by cell id to make samples reproducible even if the samples are passed
142+
# in a different order
143+
sample_ids = np.array(sample_ids, dtype=int)
144+
sample_id_sorting = np.argsort(sample_ids)
145+
sample_ids = sample_ids[sample_id_sorting]
146+
labels = labels[sample_id_sorting]
147+
rng = np.random.RandomState(seed)
148+
split_indices = split_indices_by_labels(labels, ratios, rng)
149+
150+
sample_id_splits = [sample_ids[s] for s in split_indices]
151+
label_splits = [labels[s] for s in split_indices]
152+
return DatasetMultiSplit(sample_id_splits, label_splits)
153+
154+
155+
def split_dataset(
156+
sample_ids: Sequence[int], seed: int, train_ratio: float,
157+
valid_ratio: float = 0, labels: Sequence[int] | None = None,
158+
) -> DatasetSplit:
159+
"""Splits dataset into train / valid / test splits.
160+
161+
Args:
162+
sample_ids: IDs to identify examples, e.g. cell ids
163+
seed: random seed
164+
train_ratio: ratio of training examples to sample (0-1)
165+
valid_ratio: ratio of validation examples to sample (0-1)
166+
labels: Optional label array of the same length as sample_ids. When passed,
167+
the samples for each label are distributed among the splits according to
168+
their ratios.
169+
170+
Returns:
171+
DatasetSplit
172+
"""
173+
if train_ratio + valid_ratio > 1:
174+
raise ValueError(
175+
"train_ratio and valid_ratio must be <= 1: "
176+
f"{train_ratio}, {valid_ratio}"
177+
)
178+
ratios = train_ratio, valid_ratio
179+
split = split_dataset_by_ratios(sample_ids, seed, ratios, labels)
180+
return DatasetSplit(
181+
train_ids=split.sample_id_splits[0],
182+
valid_ids=split.sample_id_splits[1],
183+
test_ids=split.sample_id_splits[2],
184+
train_labels=split.label_splits[0],
185+
valid_labels=split.label_splits[1],
186+
test_labels=split.label_splits[2],
187+
)
188+
189+
190+
def cross_validation_split_dataset(
191+
sample_ids: Sequence[int], seed: int, num_splits: int,
192+
labels: Sequence[int] | None = None) -> DatasetMultiSplit:
193+
"""Splits dataset into num_splits, optionally balanced by labels.
194+
195+
Args:
196+
sample_ids: IDs to identify examples, e.g. cell ids
197+
seed: random seed
198+
num_splits: The number of splits to produce; typically 5- or 10-fold.
199+
labels: Optional label array of the same length as sample_ids. When passed,
200+
the samples for each label are distributed among the splits according to
201+
their ratios.
202+
203+
Returns:
204+
DatasetMultiSplit
205+
"""
206+
ratios = [1.0 / num_splits] * (num_splits - 1)
207+
return split_dataset_by_ratios(sample_ids, seed, ratios, labels)
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# coding=utf-8
2+
# Copyright 2026 The Google Research Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Tests for sampling module."""
16+
17+
from connectomics.brainstate import sampling
18+
import numpy as np
19+
from google3.testing.pybase import googletest
20+
21+
22+
class SamplingTest(googletest.TestCase):
23+
24+
def test_split_indices_by_labels(self):
25+
labels = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
26+
ratios = [0.8]
27+
rng = np.random.RandomState(22222)
28+
splits = sampling.split_indices_by_labels(labels, ratios, rng)
29+
np.testing.assert_array_equal(splits[0], [4, 0, 3, 2, 9, 6, 8, 5])
30+
np.testing.assert_array_equal(splits[1], [1, 7])
31+
32+
def test_empty_split(self):
33+
labels = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
34+
ratios = [0.8, 0.0]
35+
rng = np.random.RandomState(22222)
36+
splits = sampling.split_indices_by_labels(labels, ratios, rng)
37+
np.testing.assert_array_equal(splits[0], [4, 0, 3, 2, 9, 6, 8, 5])
38+
np.testing.assert_array_equal(splits[1], [])
39+
np.testing.assert_array_equal(splits[2], [1, 7])
40+
41+
def test_split_dataset(self):
42+
sample_ids = range(10)
43+
seed = 22222
44+
train_ratio = 0.7
45+
valid_ratio = 0.1 # Test 0.2 implicit.
46+
split = sampling.split_dataset(sample_ids, seed, train_ratio, valid_ratio)
47+
np.testing.assert_array_equal(split.train_ids, [3, 5, 9, 4, 6, 7, 0])
48+
np.testing.assert_array_equal(split.valid_ids, [8])
49+
np.testing.assert_array_equal(split.test_ids, [2, 1])
50+
np.testing.assert_array_equal(split.train_labels, [0, 0, 0, 0, 0, 0, 0])
51+
np.testing.assert_array_equal(split.valid_labels, [0])
52+
np.testing.assert_array_equal(split.test_labels, [0, 0])
53+
54+
# Results should be balanced by labels.
55+
labels = [1, 1, 1, 2, 2, 2, 2, 2, 2, 2]
56+
split = sampling.split_dataset(
57+
sample_ids, seed, train_ratio, valid_ratio, labels)
58+
np.testing.assert_array_equal(split.train_ids, [2, 0, 7, 4, 6, 8])
59+
np.testing.assert_array_equal(split.valid_ids, [])
60+
np.testing.assert_array_equal(split.test_ids, [1, 9, 3, 5])
61+
np.testing.assert_array_equal(split.train_labels, [1, 1, 2, 2, 2, 2])
62+
np.testing.assert_array_equal(split.valid_labels, [])
63+
np.testing.assert_array_equal(split.test_labels, [1, 2, 2, 2])
64+
65+
def test_upsample(self):
66+
sample_ids = range(10)
67+
labels = [1, 1, 1, 2, 2, 2, 2, 2, 2, 2]
68+
seed = 22222
69+
train_ratio = 0.7
70+
valid_ratio = 0.1 # Test 0.2 implicit.
71+
split = sampling.split_dataset(
72+
sample_ids, seed, train_ratio, valid_ratio, labels
73+
)
74+
upsampled = split.upsampled(upsample_factor=2, dataset_len=10)
75+
np.testing.assert_array_equal(
76+
upsampled.train_ids, [2, 0, 7, 4, 6, 8, 12, 10, 17, 14, 16, 18]
77+
)
78+
np.testing.assert_array_equal(upsampled.valid_ids, [])
79+
np.testing.assert_array_equal(
80+
upsampled.test_ids, [1, 9, 3, 5, 11, 19, 13, 15]
81+
)
82+
np.testing.assert_array_equal(
83+
upsampled.train_labels, [1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2]
84+
)
85+
np.testing.assert_array_equal(upsampled.valid_labels, [])
86+
np.testing.assert_array_equal(
87+
upsampled.test_labels, [1, 2, 2, 2, 1, 2, 2, 2]
88+
)
89+
90+
def test_concatenate_splits(self):
91+
sample_ids = range(10)
92+
seed = 22222
93+
train_ratio = 0.7
94+
valid_ratio = 0.1 # Test 0.2 implicit.
95+
split = sampling.split_dataset(sample_ids, seed, train_ratio, valid_ratio)
96+
97+
labels = [1, 1, 1, 2, 2, 2, 2, 2, 2, 2]
98+
split2 = sampling.split_dataset(
99+
sample_ids, seed, train_ratio, valid_ratio, labels
100+
)
101+
102+
dataset_lengths = [10, 10]
103+
concat = sampling.concatenate_splits([split, split2], dataset_lengths)
104+
np.testing.assert_array_equal(
105+
concat.train_ids, [3, 5, 9, 4, 6, 7, 0, 12, 10, 17, 14, 16, 18]
106+
)
107+
np.testing.assert_array_equal(concat.valid_ids, [8])
108+
np.testing.assert_array_equal(concat.test_ids, [2, 1, 11, 19, 13, 15])
109+
np.testing.assert_array_equal(
110+
concat.train_labels, [0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2]
111+
)
112+
np.testing.assert_array_equal(concat.valid_labels, [0])
113+
np.testing.assert_array_equal(concat.test_labels, [0, 0, 1, 2, 2, 2])
114+
115+
def test_cross_validation_split_dataset(self):
116+
sample_ids = range(10)
117+
seed = 22222
118+
num_splits = 5
119+
splits = sampling.cross_validation_split_dataset(
120+
sample_ids, seed, num_splits).sample_id_splits
121+
np.testing.assert_array_equal(splits[0], [3, 5])
122+
np.testing.assert_array_equal(splits[1], [9, 4])
123+
np.testing.assert_array_equal(splits[2], [6, 7])
124+
np.testing.assert_array_equal(splits[3], [0, 8])
125+
np.testing.assert_array_equal(splits[4], [2, 1])
126+
127+
num_splits = 2
128+
splits = sampling.cross_validation_split_dataset(
129+
sample_ids, seed, num_splits).sample_id_splits
130+
np.testing.assert_array_equal(splits[0], [3, 5, 9, 4, 6])
131+
np.testing.assert_array_equal(splits[1], [7, 0, 8, 2, 1])
132+
133+
134+
if __name__ == "__main__":
135+
googletest.main()

0 commit comments

Comments
 (0)