-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy pathcreate.py
More file actions
92 lines (79 loc) · 2.87 KB
/
create.py
File metadata and controls
92 lines (79 loc) · 2.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower Datasets command line interface `create` command."""
from pathlib import Path
from typing import Annotated
import click
import typer
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
def create(
dataset_name: Annotated[
str,
typer.Argument(
help="Hugging Face dataset identifier (e.g., 'ylecun/mnist').",
),
],
num_partitions: Annotated[
int,
typer.Option(
"--num-partitions",
min=1,
help="Number of partitions to create.",
),
] = 10,
out_dir: Annotated[
Path,
typer.Option(
"--out-dir",
help="Output directory for the federated dataset.",
),
] = Path("./federated_dataset"),
) -> None:
"""Create a federated dataset and save each partition in a sub-directory.
This command is used to generate federated datasets
for demo purposes and currently supports only IID
partitioning `IidPartitioner`.
"""
# Validate number of partitions
if num_partitions <= 0:
raise click.ClickException("--num-partitions must be a positive integer.")
# Handle output directory
if out_dir.exists():
overwrite = typer.confirm(
typer.style(
f"\n💬 {out_dir} already exists, do you want to override it?",
fg=typer.colors.MAGENTA,
bold=True,
),
default=False,
)
if not overwrite:
return
out_dir.mkdir(parents=True, exist_ok=True)
# Create data partitioner
partitioner = IidPartitioner(num_partitions=num_partitions)
# Create the federated dataset
fds = FederatedDataset(dataset=dataset_name, partitioners={"train": partitioner})
# Load partitions and save them to disk
for partition_id in range(num_partitions):
partition = fds.load_partition(partition_id=partition_id)
partition_out_dir = out_dir / f"partition_{partition_id}"
partition.save_to_disk(partition_out_dir)
typer.secho(
f"🎊 Created {num_partitions} partitions for {dataset_name} in {out_dir}",
fg=typer.colors.GREEN,
bold=True,
)