diff --git a/datasets/flwr_datasets/cli/__init__.py b/datasets/flwr_datasets/cli/__init__.py new file mode 100644 index 000000000000..c50081801fde --- /dev/null +++ b/datasets/flwr_datasets/cli/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower Ddatasets command line interface.""" diff --git a/datasets/flwr_datasets/cli/app.py b/datasets/flwr_datasets/cli/app.py new file mode 100644 index 000000000000..1e59ef34a857 --- /dev/null +++ b/datasets/flwr_datasets/cli/app.py @@ -0,0 +1,57 @@ +# Copyright 2026 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower Datasets command line interface.""" + + +import typer +from typer.main import get_command + +from flwr_datasets.common.version import package_version + +from .create import create + +app = typer.Typer( + help=typer.style( + "flwr-datasets is the Flower Datasets command line interface.", + fg=typer.colors.BRIGHT_YELLOW, + bold=True, + ), + no_args_is_help=True, + context_settings={"help_option_names": ["-h", "--help"]}, +) + +app.command()(create) + +typer_click_object = get_command(app) + + +@app.callback(invoke_without_command=True) +def main( + version: bool = typer.Option( + None, + "-V", + "--version", + is_eager=True, + help="Show the version and exit.", + ), +) -> None: + """Flower Datasets CLI.""" + if version: + typer.secho(f"Flower Datasets version: {package_version}", fg="blue") + raise typer.Exit() + + +if __name__ == "__main__": + app() diff --git a/datasets/flwr_datasets/cli/create.py b/datasets/flwr_datasets/cli/create.py new file mode 100644 index 000000000000..c3533ffe1008 --- /dev/null +++ b/datasets/flwr_datasets/cli/create.py @@ -0,0 +1,109 @@ +# Copyright 2026 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower Datasets command line interface `create` command.""" + + +from pathlib import Path +from typing import Annotated + +import click +import typer + +from datasets.load import DatasetNotFoundError +from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner + + +def create( + dataset_name: Annotated[ + str, + typer.Argument( + help="Hugging Face dataset identifier (e.g., 'ylecun/mnist').", + ), + ], + num_partitions: Annotated[ + int, + typer.Option( + "--num-partitions", + help="Number of partitions to create. Must be a positive integer", + ), + ] = 10, + out_dir: Annotated[ + Path, + typer.Option( + "--out-dir", + help="Output directory for the federated dataset.", + ), + ] = Path("./federated_dataset"), +) -> None: + """Create a federated dataset and save each partition in a sub-directory. + + This command is used to generate federated datasets + for demo purposes and currently supports only IID + partitioning `IidPartitioner`. + """ + # Validate number of partitions + if num_partitions <= 0: + raise click.ClickException("--num-partitions must be a positive integer.") + + # Handle output directory + if out_dir.exists(): + overwrite = typer.confirm( + typer.style( + f"\nšŸ’¬ {out_dir} already exists, do you want to override it?", + fg=typer.colors.MAGENTA, + bold=True, + ), + default=False, + ) + if not overwrite: + return + + out_dir.mkdir(parents=True, exist_ok=True) + + # Create data partitioner + partitioner = IidPartitioner(num_partitions=num_partitions) + + try: + # Create the federated dataset + fds = FederatedDataset( + dataset=dataset_name, + partitioners={"train": partitioner}, + ) + + # Load partitions and save them to disk + for partition_id in range(num_partitions): + partition = fds.load_partition(partition_id=partition_id) + partition_out_dir = out_dir / f"partition_{partition_id}" + partition.save_to_disk(partition_out_dir) + + except DatasetNotFoundError as err: + raise click.ClickException( + f"Dataset '{dataset_name}' could not be found on the Hugging Face Hub, " + "or network access is unavailable. " + "Please verify the dataset identifier and your internet connection." + ) from err + + except Exception as ex: # pylint: disable=broad-exception-caught + raise click.ClickException( + "An unexpected error occurred while creating the federated dataset. " + f"Please try again or check the logs for more details: {str(ex)}" + ) from ex + + typer.secho( + f"šŸŽŠ Created {num_partitions} partitions for '{dataset_name}' in '{out_dir.absolute()}'", + fg=typer.colors.GREEN, + bold=True, + ) diff --git a/datasets/flwr_datasets/cli/create_test.py b/datasets/flwr_datasets/cli/create_test.py new file mode 100644 index 000000000000..709b0cb82a70 --- /dev/null +++ b/datasets/flwr_datasets/cli/create_test.py @@ -0,0 +1,217 @@ +# Copyright 2026 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Flower Datasets command line interface `create` command.""" + + +import re +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import click +import pytest +import typer + +from datasets.load import DatasetNotFoundError + +from . import create as create_module +from .create import create + + +class _FakePartition: + """Fake dataset partition used to capture save-to-disk calls.""" + + def __init__(self, saved_dirs: list[Path]) -> None: + """Initialize the fake partition.""" + self._saved_dirs = saved_dirs + + def save_to_disk(self, out_dir: Path) -> None: + """Record the output directory instead of writing to disk.""" + self._saved_dirs.append(out_dir) + + +class _FakeFederatedDataset: + """Fake FederatedDataset that records partition loading behavior.""" + + def __init__(self, calls: dict[str, Any]) -> None: + """Initialize the fake federated dataset.""" + self._calls = calls + + def load_partition(self, *, partition_id: int) -> _FakePartition: + """Simulate loading a partition and record calls.""" + self._calls.setdefault("loaded_ids", []).append(partition_id) + return _FakePartition(self._calls.setdefault("saved_dirs", [])) + + +def test_create_raises_on_non_positive_num_partitions(tmp_path: Path) -> None: + """Ensure `create` fails when `num_partitions` is not a positive integer.""" + with pytest.raises(click.ClickException, match="positive integer"): + create(dataset_name="user/ds", num_partitions=0, out_dir=tmp_path) + + +def test_create_raises_click_exception_when_dataset_load_fails( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Ensure `create` raises a user-friendly error when the dataset is. + + missing/unreachable. + """ + # Create a unique dataset name + out_dir = tmp_path / "out" + dataset_name = "does-not-exist/dataset" + + # Avoid overwrite prompt path + monkeypatch.setattr(Path, "exists", lambda _self: False) + + # Avoid touching the real filesystem in this unit test + monkeypatch.setattr(Path, "mkdir", lambda _self, **_kwargs: None) + + # Mock partitioner + monkeypatch.setattr( + create_module, + "IidPartitioner", + lambda *, num_partitions: SimpleNamespace(num_partitions=num_partitions), + ) + + # Ensure the command handles DatasetNotFoundError specifically + def _raise_fds( + *, dataset: str, partitioners: dict[str, object] + ) -> _FakeFederatedDataset: + raise DatasetNotFoundError() + + monkeypatch.setattr(create_module, "FederatedDataset", _raise_fds) + + expected_msg = ( + f"Dataset '{dataset_name}' could not be found on the Hugging Face Hub, " + "or network access is unavailable. " + "Please verify the dataset identifier and your internet connection." + ) + + with pytest.raises(click.ClickException, match=re.escape(expected_msg)): + create(dataset_name=dataset_name, num_partitions=2, out_dir=out_dir) + + +@dataclass(frozen=True) +class _CreateCase: + """Single parametrized case for `create` output-directory behavior tests.""" + + out_dir_exists: bool + user_overwrite: bool | None + expect_runs: bool + expect_confirm_calls: int + num_partitions: int = 3 + + +@pytest.mark.parametrize( + "case", + [ + _CreateCase( + out_dir_exists=False, + user_overwrite=None, + expect_runs=True, + expect_confirm_calls=0, + ), + _CreateCase( + out_dir_exists=True, + user_overwrite=False, + expect_runs=False, + expect_confirm_calls=1, + ), + _CreateCase( + out_dir_exists=True, + user_overwrite=True, + expect_runs=True, + expect_confirm_calls=1, + ), + ], +) +def test_create_partitions_save_behavior( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + case: _CreateCase, +) -> None: + """Test `create` behavior depending on whether the output directory exists.""" + out_dir = tmp_path / "out" + calls: dict[str, Any] = {} + confirm_calls: list[str] = [] + mkdir_calls: list[Path] = [] + + def _exists(self: Path) -> bool: + """Simulate existence of the output directory.""" + return case.out_dir_exists and self == out_dir + + def _confirm(message: str, default: bool = False) -> bool: + """Simulate user response to overwrite confirmation.""" + del default # unused + confirm_calls.append(message) + assert ( + case.user_overwrite is not None + ), "confirm should not be called in this scenario" + return case.user_overwrite + + def _mkdir(self: Path, parents: bool = False, exist_ok: bool = False) -> None: + """Record directory creation attempts.""" + del parents, exist_ok # unused + mkdir_calls.append(self) + + monkeypatch.setattr(Path, "exists", _exists) + monkeypatch.setattr(typer, "confirm", _confirm) + monkeypatch.setattr(Path, "mkdir", _mkdir) + + if case.expect_runs: + + def _fake_partitioner(*, num_partitions: int) -> SimpleNamespace: + """Record partitioner initialization.""" + calls["partitioner_num_partitions"] = num_partitions + return SimpleNamespace(num_partitions=num_partitions) + + def _fake_fds( + *, dataset: str, partitioners: dict[str, object] + ) -> _FakeFederatedDataset: + """Record dataset creation and return a fake federated dataset.""" + calls["dataset"] = dataset + calls["partitioners"] = partitioners + return _FakeFederatedDataset(calls) + + monkeypatch.setattr(create_module, "IidPartitioner", _fake_partitioner) + monkeypatch.setattr(create_module, "FederatedDataset", _fake_fds) + else: + + def _fail_partitioner(**_: object) -> None: + raise AssertionError("IidPartitioner should not be called") + + def _fail_fds(**_: object) -> None: + raise AssertionError("FederatedDataset should not be called") + + monkeypatch.setattr(create_module, "IidPartitioner", _fail_partitioner) + monkeypatch.setattr(create_module, "FederatedDataset", _fail_fds) + + create(dataset_name="user/ds", num_partitions=case.num_partitions, out_dir=out_dir) + + assert len(confirm_calls) == case.expect_confirm_calls + + if not case.expect_runs: + assert not mkdir_calls + return + + assert mkdir_calls == [out_dir] + assert calls["partitioner_num_partitions"] == case.num_partitions + assert calls["dataset"] == "user/ds" + assert "train" in calls["partitioners"] + assert calls["loaded_ids"] == list(range(case.num_partitions)) + assert calls["saved_dirs"] == [ + out_dir / f"partition_{i}" for i in range(case.num_partitions) + ] diff --git a/datasets/pyproject.toml b/datasets/pyproject.toml index 25e54e1223b4..79b9d4b41ca3 100644 --- a/datasets/pyproject.toml +++ b/datasets/pyproject.toml @@ -49,6 +49,10 @@ classifiers = [ packages = [{ include = "flwr_datasets", from = "./" }] exclude = ["./**/*_test.py"] +[tool.poetry.scripts] +# `flwr-datasets` CLI +flwr-datasets = "flwr_datasets.cli.app:app" + [tool.poetry.dependencies] python = "^3.10" numpy = ">=1.26.0,<3.0.0" @@ -57,6 +61,7 @@ pillow = { version = ">=6.2.1", optional = true } soundfile = { version = ">=0.12.1", optional = true } librosa = { version = ">=0.10.0.post2", optional = true } tqdm = "^4.66.1" +rich = "^13.5.0" matplotlib = "^3.7.5" seaborn = "^0.13.0" torch = { version = ">=2.8.0", optional = true, python = ">=3.10,<3.14" }