-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy pathvertical_even_partitioner.py
More file actions
220 lines (195 loc) · 8.45 KB
/
vertical_even_partitioner.py
File metadata and controls
220 lines (195 loc) · 8.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""VerticalEvenPartitioner class."""
# flake8: noqa: E501
# pylint: disable=C0301, R0902, R0913
from typing import Literal, Optional, Union
import numpy as np
import datasets
from flwr_datasets.partitioner.partitioner import Partitioner
from flwr_datasets.partitioner.vertical_partitioner_utils import (
_add_active_party_columns,
_init_optional_str_or_list_str,
_list_split,
)
class VerticalEvenPartitioner(Partitioner):
"""Partitioner that splits features (columns) evenly into vertical partitions.
Enables selection of "active party" column(s) and placement into
a specific partition or creation of a new partition just for it.
Also enables droping columns and sharing specified columns across
all partitions.
Parameters
----------
num_partitions : int
Number of partitions to create.
active_party_column : Optional[Union[str, list[str]]]
Column(s) (typically representing labels) associated with the
"active party" (which can be the server).
active_party_columns_mode : Union[Literal[["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int]
Determines how to assign the active party columns:
- `"add_to_first"`: Append active party columns to the first partition.
- `"add_to_last"`: Append active party columns to the last partition.
- `"create_as_first"`: Create a new partition at the start containing only these columns.
- `"create_as_last"`: Create a new partition at the end containing only these columns.
- `"add_to_all"`: Append active party columns to all partitions.
- int: Append active party columns to the specified partition index.
drop_columns : Optional[Union[str, list[str]]]
Columns to remove entirely from the dataset before partitioning.
shared_columns : Optional[Union[str, list[str]]]
Columns to duplicate into every partition after initial partitioning.
shuffle : bool
Whether to shuffle the order of columns before partitioning.
seed : Optional[int]
Random seed for shuffling columns. Has no effect if `shuffle=False`.
Examples
--------
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.partitioner import VerticalEvenPartitioner
>>>
>>> partitioner = VerticalEvenPartitioner(
... num_partitions=3,
... active_party_columns="income",
... active_party_columns_mode="add_to_last",
... shuffle=True,
... seed=42
... )
>>> fds = FederatedDataset(
... dataset="scikit-learn/adult-census-income",
... partitioners={"train": partitioner}
... )
>>> partitions = [fds.load_partition(i) for i in range(fds.partitioners["train"].num_partitions)]
>>> print([partition.column_names for partition in partitions])
"""
def __init__( # pylint: disable=R0917
self,
num_partitions: int,
active_party_columns: Optional[Union[str, list[str]]] = None,
active_party_columns_mode: Union[
Literal[
"add_to_first",
"add_to_last",
"create_as_first",
"create_as_last",
"add_to_all",
],
int,
] = "add_to_last",
drop_columns: Optional[Union[str, list[str]]] = None,
shared_columns: Optional[Union[str, list[str]]] = None,
shuffle: bool = True,
seed: Optional[int] = 42,
) -> None:
super().__init__()
self._num_partitions = num_partitions
self._active_party_columns = _init_optional_str_or_list_str(
active_party_columns
)
self._active_party_columns_mode = active_party_columns_mode
self._drop_columns = _init_optional_str_or_list_str(drop_columns)
self._shared_columns = _init_optional_str_or_list_str(shared_columns)
self._shuffle = shuffle
self._seed = seed
self._rng = np.random.default_rng(seed=self._seed)
self._partition_columns: Optional[list[list[str]]] = None
self._partitions_determined = False
self._validate_parameters_in_init()
def _determine_partitions_if_needed(self) -> None:
if self._partitions_determined:
return
if self.dataset is None:
raise ValueError("No dataset is set for this partitioner.")
all_columns = list(self.dataset.column_names)
self._validate_parameters_while_partitioning(
all_columns, self._shared_columns, self._active_party_columns
)
columns = [column for column in all_columns if column not in self._drop_columns]
columns = [column for column in columns if column not in self._shared_columns]
columns = [
column for column in columns if column not in self._active_party_columns
]
if self._shuffle:
self._rng.shuffle(columns)
partition_columns = _list_split(columns, self._num_partitions)
partition_columns = _add_active_party_columns(
self._active_party_columns,
self._active_party_columns_mode,
partition_columns,
)
# Add shared columns to all partitions
for partition in partition_columns:
for column in self._shared_columns:
partition.append(column)
self._partition_columns = partition_columns
self._partitions_determined = True
def load_partition(self, partition_id: int) -> datasets.Dataset:
"""Load a partition based on the partition index.
Parameters
----------
partition_id : int
The index that corresponds to the requested partition.
Returns
-------
dataset_partition : Dataset
Single partition of a dataset.
"""
self._determine_partitions_if_needed()
assert self._partition_columns is not None
if partition_id < 0 or partition_id >= len(self._partition_columns):
raise ValueError(f"Invalid partition_id {partition_id}.")
columns = self._partition_columns[partition_id]
return self.dataset.select_columns(columns)
@property
def num_partitions(self) -> int:
"""Number of partitions."""
self._determine_partitions_if_needed()
assert self._partition_columns is not None
return len(self._partition_columns)
def _validate_parameters_in_init(self) -> None:
if self._num_partitions < 1:
raise ValueError("`column_distribution` as int must be >= 1.")
valid_modes = {
"add_to_first",
"add_to_last",
"create_as_first",
"create_as_last",
"add_to_all",
}
if not (
isinstance(self._active_party_columns_mode, int)
or self._active_party_columns_mode in valid_modes
):
raise ValueError(
"`active_party_column_mode` must be an int or one of "
"'add_to_first', 'add_to_last', 'create_as_first', 'create_as_last', "
"'add_to_all'."
)
def _validate_parameters_while_partitioning(
self,
all_columns: list[str],
shared_columns: list[str],
active_party_column: Union[str, list[str]],
) -> None:
if isinstance(active_party_column, str):
active_party_column = [active_party_column]
# Shared columns existance check
for column in shared_columns:
if column not in all_columns:
raise ValueError(f"Shared column '{column}' not found in the dataset.")
# Active party columns existence check
for column in active_party_column:
if column not in all_columns:
raise ValueError(
f"Active party column '{column}' not found in the dataset."
)