|
| 1 | +# Copyright 2025 Flower Labs GmbH. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Continuous partitioner class that works with Hugging Face Datasets.""" |
| 16 | + |
| 17 | + |
| 18 | +# pylint: disable=R0913, R0917 |
| 19 | +from typing import Optional |
| 20 | + |
| 21 | +import numpy as np |
| 22 | + |
| 23 | +from datasets import Dataset |
| 24 | +from flwr_datasets.partitioner.partitioner import Partitioner |
| 25 | + |
| 26 | + |
| 27 | +class ContinuousPartitioner( |
| 28 | + Partitioner |
| 29 | +): # pylint: disable=too-many-instance-attributes |
| 30 | + r"""Partitioner based on a real-valued dataset property with adjustable strictness. |
| 31 | +
|
| 32 | + This partitioner enables non-IID partitioning by sorting the dataset according to a |
| 33 | + continuous (i.e., real-valued, not categorical) property and introducing controlled noise |
| 34 | + to adjust the level of heterogeneity. |
| 35 | +
|
| 36 | + To interpolate between IID and non-IID partitioning, a `strictness` parameter |
| 37 | + (𝜎 ∈ [0, 1]) blends a standardized property vector (z ∈ ℝⁿ) with Gaussian noise |
| 38 | + (ε ~ 𝒩(0, I)), producing blended scores: |
| 39 | +
|
| 40 | +
|
| 41 | + .. math:: |
| 42 | +
|
| 43 | + b = \sigma \cdot z + (1 - \sigma) \cdot ε |
| 44 | +
|
| 45 | +
|
| 46 | + Samples are then sorted by `b` to assign them to partitions. When `strictness` is 0, |
| 47 | + partitioning is purely random (IID), while a value of 1 strictly follows the property ranking |
| 48 | + (strongly non-IID). |
| 49 | +
|
| 50 | + Parameters |
| 51 | + ---------- |
| 52 | + num_partitions : int |
| 53 | + Number of partitions to create. |
| 54 | + partition_by : str |
| 55 | + Name of the continuous feature to partition the dataset on. |
| 56 | + strictness : float |
| 57 | + Controls how strongly the feature influences partitioning (0 = iid, 1 = non-iid). |
| 58 | + shuffle : bool |
| 59 | + Whether to shuffle the indices within each partition (default: True). |
| 60 | + seed : Optional[int] |
| 61 | + Random seed for reproducibility. |
| 62 | +
|
| 63 | + Examples |
| 64 | + -------- |
| 65 | + >>> from datasets import Dataset |
| 66 | + >>> import numpy as np |
| 67 | + >>> import pandas as pd |
| 68 | + >>> from flwr_datasets.partitioner import ContinuousPartitioner |
| 69 | + >>> import matplotlib.pyplot as plt |
| 70 | + >>> |
| 71 | + >>> # Create synthetic data |
| 72 | + >>> df = pd.DataFrame({ |
| 73 | + >>> "continuous": np.linspace(0, 10, 10_000), |
| 74 | + >>> "category": np.random.choice([0, 1, 2, 3], size=10_000) |
| 75 | + >>> }) |
| 76 | + >>> hf_dataset = Dataset.from_pandas(df) |
| 77 | + >>> |
| 78 | + >>> # Partition dataset |
| 79 | + >>> partitioner = ContinuousPartitioner( |
| 80 | + >>> num_partitions=5, |
| 81 | + >>> partition_by="continuous", |
| 82 | + >>> strictness=0.7, |
| 83 | + >>> shuffle=True |
| 84 | + >>> ) |
| 85 | + >>> partitioner.dataset = hf_dataset |
| 86 | + >>> |
| 87 | + >>> # Plot partitions |
| 88 | + >>> plt.figure(figsize=(10, 6)) |
| 89 | + >>> for i in range(5): |
| 90 | + >>> plt.hist( |
| 91 | + >>> partitioner.load_partition(i)["continuous"], |
| 92 | + >>> bins=64, |
| 93 | + >>> alpha=0.5, |
| 94 | + >>> label=f"Partition {i}" |
| 95 | + >>> ) |
| 96 | + >>> plt.legend() |
| 97 | + >>> plt.xlabel("Continuous Value") |
| 98 | + >>> plt.ylabel("Frequency") |
| 99 | + >>> plt.title("Partition distributions") |
| 100 | + >>> plt.grid(True) |
| 101 | + >>> plt.show() |
| 102 | + """ |
| 103 | + |
| 104 | + def __init__( |
| 105 | + self, |
| 106 | + num_partitions: int, |
| 107 | + partition_by: str, |
| 108 | + strictness: float, |
| 109 | + shuffle: bool = True, |
| 110 | + seed: Optional[int] = 42, |
| 111 | + ) -> None: |
| 112 | + super().__init__() |
| 113 | + if not 0 <= strictness <= 1: |
| 114 | + raise ValueError("`strictness` must be between 0 and 1") |
| 115 | + if num_partitions <= 0: |
| 116 | + raise ValueError("`num_partitions` must be greater than 0") |
| 117 | + |
| 118 | + self._num_partitions = num_partitions |
| 119 | + self._partition_by = partition_by |
| 120 | + self._strictness = strictness |
| 121 | + self._shuffle = shuffle |
| 122 | + self._seed = seed |
| 123 | + self._rng = np.random.default_rng(seed) |
| 124 | + |
| 125 | + # Lazy initialization |
| 126 | + self._partition_id_to_indices: dict[int, list[int]] = {} |
| 127 | + self._partition_id_to_indices_determined = False |
| 128 | + |
| 129 | + def load_partition(self, partition_id: int) -> Dataset: |
| 130 | + """Load a single partition based on the partition index. |
| 131 | +
|
| 132 | + Parameters |
| 133 | + ---------- |
| 134 | + partition_id : int |
| 135 | + The index that corresponds to the requested partition. |
| 136 | +
|
| 137 | + Returns |
| 138 | + ------- |
| 139 | + dataset_partition : Dataset |
| 140 | + A single dataset partition. |
| 141 | + """ |
| 142 | + self._check_and_generate_partitions_if_needed() |
| 143 | + return self.dataset.select(self._partition_id_to_indices[partition_id]) |
| 144 | + |
| 145 | + @property |
| 146 | + def num_partitions(self) -> int: |
| 147 | + """Total number of partitions.""" |
| 148 | + self._check_and_generate_partitions_if_needed() |
| 149 | + return self._num_partitions |
| 150 | + |
| 151 | + @property |
| 152 | + def partition_id_to_indices(self) -> dict[int, list[int]]: |
| 153 | + """Mapping from partition ID to dataset indices.""" |
| 154 | + self._check_and_generate_partitions_if_needed() |
| 155 | + return self._partition_id_to_indices |
| 156 | + |
| 157 | + def _check_and_generate_partitions_if_needed(self) -> None: |
| 158 | + """Lazy evaluation of the partitioning logic.""" |
| 159 | + if self._partition_id_to_indices_determined: |
| 160 | + return |
| 161 | + |
| 162 | + if self._num_partitions > self.dataset.num_rows: |
| 163 | + raise ValueError( |
| 164 | + "Number of partitions must be less than or equal to number of dataset samples." |
| 165 | + ) |
| 166 | + |
| 167 | + # Extract property values |
| 168 | + property_values = np.array(self.dataset[self._partition_by], dtype=np.float32) |
| 169 | + |
| 170 | + # Check for missing values (None or NaN) |
| 171 | + if np.any(property_values is None) or np.isnan(property_values).any(): |
| 172 | + raise ValueError( |
| 173 | + f"The column '{self._partition_by}' contains None or NaN values, " |
| 174 | + f"which are not supported by {self.__class__.__qualname__}. " |
| 175 | + "Please clean or filter your dataset before partitioning." |
| 176 | + ) |
| 177 | + |
| 178 | + # Standardize |
| 179 | + std = np.std(property_values) |
| 180 | + if std < 1e-6 and self._strictness > 0: |
| 181 | + raise ValueError( |
| 182 | + f"Cannot standardize column '{self._partition_by}' " |
| 183 | + f"because it has near-zero std (std={std}). " |
| 184 | + "All values are nearly identical, which prevents meaningful non-IID partitioning. " |
| 185 | + "To resolve this, choose a different partition property " |
| 186 | + "or set strictness to 0 to enable IID partitioning." |
| 187 | + ) |
| 188 | + |
| 189 | + standardized_values = (property_values - np.mean(property_values)) / std |
| 190 | + |
| 191 | + # Blend noise |
| 192 | + noise = self._rng.normal(loc=0, scale=1, size=len(standardized_values)) |
| 193 | + blended_values = ( |
| 194 | + self._strictness * standardized_values + (1 - self._strictness) * noise |
| 195 | + ) |
| 196 | + |
| 197 | + # Sort and partition |
| 198 | + sorted_indices = np.argsort(blended_values) |
| 199 | + partition_indices = np.array_split(sorted_indices, self._num_partitions) |
| 200 | + |
| 201 | + for pid, indices in enumerate(partition_indices): |
| 202 | + indices_list = indices.tolist() |
| 203 | + if self._shuffle: |
| 204 | + self._rng.shuffle(indices_list) |
| 205 | + self._partition_id_to_indices[pid] = indices_list |
| 206 | + |
| 207 | + self._partition_id_to_indices_determined = True |
0 commit comments