Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions datasets/flwr_datasets/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower Ddatasets command line interface."""
57 changes: 57 additions & 0 deletions datasets/flwr_datasets/cli/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower Datasets command line interface."""


import typer
from typer.main import get_command

from flwr_datasets.common.version import package_version

from .create import create

app = typer.Typer(
help=typer.style(
"flwr-datasets is the Flower Datasets command line interface.",
fg=typer.colors.BRIGHT_YELLOW,
bold=True,
),
no_args_is_help=True,
context_settings={"help_option_names": ["-h", "--help"]},
)

app.command()(create)

typer_click_object = get_command(app)


@app.callback(invoke_without_command=True)
def main(
version: bool = typer.Option(
None,
"-V",
"--version",
is_eager=True,
help="Show the version and exit.",
),
) -> None:
"""Flower Datasets CLI."""
if version:
typer.secho(f"Flower Datasets version: {package_version}", fg="blue")
raise typer.Exit()


if __name__ == "__main__":
app()
92 changes: 92 additions & 0 deletions datasets/flwr_datasets/cli/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower Datasets command line interface `create` command."""


from pathlib import Path
from typing import Annotated

import click
import typer

from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner


def create(
dataset_name: Annotated[
str,
typer.Argument(
help="Hugging Face dataset identifier (e.g., 'ylecun/mnist').",
),
],
num_partitions: Annotated[
int,
typer.Option(
"--num-partitions",
min=1,
help="Number of partitions to create.",
),
] = 10,
out_dir: Annotated[
Path,
typer.Option(
"--out-dir",
help="Output directory for the federated dataset.",
),
] = Path("./federated_dataset"),
) -> None:
"""Create a federated dataset and save each partition in a sub-directory.

This command is used to generate federated datasets
for demo purposes and currently supports only IID
partitioning `IidPartitioner`.
"""
# Validate number of partitions
if num_partitions <= 0:
raise click.ClickException("--num-partitions must be a positive integer.")

# Handle output directory
if out_dir.exists():
overwrite = typer.confirm(
typer.style(
f"\n💬 {out_dir} already exists, do you want to override it?",
fg=typer.colors.MAGENTA,
bold=True,
),
default=False,
)
if not overwrite:
return

out_dir.mkdir(parents=True, exist_ok=True)

# Create data partitioner
partitioner = IidPartitioner(num_partitions=num_partitions)

# Create the federated dataset
fds = FederatedDataset(dataset=dataset_name, partitioners={"train": partitioner})

# Load partitions and save them to disk
for partition_id in range(num_partitions):
partition = fds.load_partition(partition_id=partition_id)
partition_out_dir = out_dir / f"partition_{partition_id}"
partition.save_to_disk(partition_out_dir)

typer.secho(
f"🎊 Created {num_partitions} partitions for {dataset_name} in {out_dir}",
fg=typer.colors.GREEN,
bold=True,
)
177 changes: 177 additions & 0 deletions datasets/flwr_datasets/cli/create_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for Flower Datasets command line interface `create` command."""


from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
from typing import Any

import click
import pytest
import typer

from . import create as create_module
from .create import create


class _FakePartition:
"""Fake dataset partition used to capture save-to-disk calls."""

def __init__(self, saved_dirs: list[Path]) -> None:
"""Initialize the fake partition."""
self._saved_dirs = saved_dirs

def save_to_disk(self, out_dir: Path) -> None:
"""Record the output directory instead of writing to disk."""
self._saved_dirs.append(out_dir)


class _FakeFederatedDataset:
"""Fake FederatedDataset that records partition loading behavior."""

def __init__(self, calls: dict[str, Any]) -> None:
"""Initialize the fake federated dataset."""
self._calls = calls

def load_partition(self, *, partition_id: int) -> _FakePartition:
"""Simulate loading a partition and record calls."""
self._calls.setdefault("loaded_ids", []).append(partition_id)
return _FakePartition(self._calls.setdefault("saved_dirs", []))


def test_create_raises_on_non_positive_num_partitions(tmp_path: Path) -> None:
"""Ensure `create` fails when `num_partitions` is not a positive integer."""
with pytest.raises(click.ClickException, match="positive integer"):
create(dataset_name="user/ds", num_partitions=0, out_dir=tmp_path)


@dataclass(frozen=True)
class _CreateCase:
"""Single parametrized case for `create` output-directory behavior tests."""

out_dir_exists: bool
user_overwrite: bool | None
expect_runs: bool
expect_confirm_calls: int
num_partitions: int = 3


@pytest.mark.parametrize(
"case",
[
_CreateCase(
out_dir_exists=False,
user_overwrite=None,
expect_runs=True,
expect_confirm_calls=0,
),
_CreateCase(
out_dir_exists=True,
user_overwrite=False,
expect_runs=False,
expect_confirm_calls=1,
),
_CreateCase(
out_dir_exists=True,
user_overwrite=True,
expect_runs=True,
expect_confirm_calls=1,
),
],
)
def test_create_partitions_save_behavior(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
case: _CreateCase,
) -> None:
"""Test `create` behavior depending on whether the output directory exists."""
out_dir = tmp_path / "out"
calls: dict[str, Any] = {}
confirm_calls: list[str] = []
mkdir_calls: list[Path] = []

def _exists(self: Path) -> bool:
"""Simulate existence of the output directory."""
return case.out_dir_exists and self == out_dir

def _confirm(message: str, default: bool = False) -> bool:
"""Simulate user response to overwrite confirmation."""
del default # unused
confirm_calls.append(message)
assert (
case.user_overwrite is not None
), "confirm should not be called in this scenario"
return case.user_overwrite

def _mkdir(self: Path, parents: bool = False, exist_ok: bool = False) -> None:
"""Record directory creation attempts."""
del parents, exist_ok # unused
mkdir_calls.append(self)

monkeypatch.setattr(Path, "exists", _exists)
monkeypatch.setattr(typer, "confirm", _confirm)
monkeypatch.setattr(Path, "mkdir", _mkdir)

if case.expect_runs:

def _fake_partitioner(*, num_partitions: int) -> SimpleNamespace:
"""Record partitioner initialization."""
calls["partitioner_num_partitions"] = num_partitions
return SimpleNamespace(num_partitions=num_partitions)

def _fake_fds(
*, dataset: str, partitioners: dict[str, object]
) -> _FakeFederatedDataset:
"""Record dataset creation and return a fake federated dataset."""
calls["dataset"] = dataset
calls["partitioners"] = partitioners
return _FakeFederatedDataset(calls)

monkeypatch.setattr(create_module, "IidPartitioner", _fake_partitioner)
monkeypatch.setattr(create_module, "FederatedDataset", _fake_fds)
else:
monkeypatch.setattr(
create_module,
"IidPartitioner",
lambda **_: (_ for _ in ()).throw(
AssertionError("IidPartitioner should not be called")
),
)
monkeypatch.setattr(
create_module,
"FederatedDataset",
lambda **_: (_ for _ in ()).throw(
AssertionError("FederatedDataset should not be called")
),
)

create(dataset_name="user/ds", num_partitions=case.num_partitions, out_dir=out_dir)

assert len(confirm_calls) == case.expect_confirm_calls

if not case.expect_runs:
assert not mkdir_calls
return

assert mkdir_calls == [out_dir]
assert calls["partitioner_num_partitions"] == case.num_partitions
assert calls["dataset"] == "user/ds"
assert "train" in calls["partitioners"]
assert calls["loaded_ids"] == list(range(case.num_partitions))
assert calls["saved_dirs"] == [
out_dir / f"partition_{i}" for i in range(case.num_partitions)
]
5 changes: 5 additions & 0 deletions datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ classifiers = [
packages = [{ include = "flwr_datasets", from = "./" }]
exclude = ["./**/*_test.py"]

[tool.poetry.scripts]
# `flwr-datasets` CLI
flwr-datasets = "flwr_datasets.cli.app:app"

[tool.poetry.dependencies]
python = "^3.10"
numpy = ">=1.26.0,<3.0.0"
Expand All @@ -57,6 +61,7 @@ pillow = { version = ">=6.2.1", optional = true }
soundfile = { version = ">=0.12.1", optional = true }
librosa = { version = ">=0.10.0.post2", optional = true }
tqdm = "^4.66.1"
rich = "^13.5.0"
matplotlib = "^3.7.5"
seaborn = "^0.13.0"
torch = { version = ">=2.8.0", optional = true, python = ">=3.10,<3.14" }
Expand Down