Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _determine_partition_id_to_indices_if_needed(
return

# Generate information needed for Dirichlet partitioning
self._unique_classes = self.dataset.unique(self._partition_by)
self._unique_classes = sorted(self.dataset.unique(self._partition_by))
assert self._unique_classes is not None
# This is needed only if self._self_balancing is True (the default option)
self._avg_num_of_samples_per_partition = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,47 @@ def test__determine_partition_id_to_indices(self) -> None:
and len(partitioner._partition_id_to_indices) == num_partitions
)

def test__determine_partition_id_to_indices_if_needed_consistency(
self,
) -> None:
"""Test that indices are consistently assigned to partition IDs.

Partition distributions should be consistent regardless of the ordering of
examples in the dataset. This is important to ensure that partitions from train
and test partitioners have the same distributions.
"""
num_partitions = 3
data = {
"features": list(range(100)),
"labels": [i % 3 for i in range(100)],
}

dataset1 = Dataset.from_dict(data)
partitioner1 = DirichletPartitioner(num_partitions, "labels", 0.5, 10)
partitioner1.dataset = dataset1
partitioner1.load_partition(0)

data_reversed = data.copy()
data_reversed["features"].reverse()
data_reversed["labels"].reverse()
dataset2 = Dataset.from_dict(data_reversed)
partitioner2 = DirichletPartitioner(num_partitions, "labels", 0.5, 10)
partitioner2.dataset = dataset2
partitioner2.load_partition(0)

classes = partitioner1.dataset.unique("labels")
for i in range(num_partitions):
partition1 = partitioner1.load_partition(i)
partition2 = partitioner2.load_partition(i)

targets1 = np.array(partition1["labels"])
targets2 = np.array(partition2["labels"])

for k in classes:
self.assertCountEqual(
np.nonzero(targets1 == k)[0], np.nonzero(targets2 == k)[0]
)


class TestDirichletPartitionerFailure(unittest.TestCase):
"""Test DirichletPartitioner failures (exceptions) by incorrect usage."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _create_int_partition_id_to_natural_id(self) -> None:

Natural ids come from the column specified in `partition_by`.
"""
unique_natural_ids = self.dataset.unique(self._partition_by)
unique_natural_ids = sorted(self.dataset.unique(self._partition_by))
self._partition_id_to_natural_id = dict(
zip(range(len(unique_natural_ids)), unique_natural_ids)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,34 @@ def test_cannot_set_partition_id_to_natural_id(self) -> None:
with self.assertRaises(AttributeError):
partitioner.partition_id_to_natural_id = {0: "0"}

def test_consistent_partition_ids(self) -> None:
"""Test that the partition IDs assigned to the natural IDs are consistent."""
train_data = {
"features": [1, 2, 3],
"labels": [0, 0, 1],
"clients": ["a", "b", "a"],
}
test_data = {
"features": [4, 5, 6],
"labels": [1, 0, 0],
"clients": ["b", "a", "a"],
}
train_dataset = Dataset.from_dict(train_data)
test_dataset = Dataset.from_dict(test_data)

# Create partitioners
train_partitioner = NaturalIdPartitioner(partition_by="clients")
test_partitioner = NaturalIdPartitioner(partition_by="clients")
train_partitioner.dataset = train_dataset
test_partitioner.dataset = test_dataset
train_partitioner.load_partition(0)
test_partitioner.load_partition(0)

self.assertEqual(
train_partitioner.partition_id_to_natural_id,
test_partitioner.partition_id_to_natural_id,
)


if __name__ == "__main__":
unittest.main()