diff --git a/baselines/verifiableagg/.gitignore b/baselines/verifiableagg/.gitignore new file mode 100644 index 000000000000..ea11652f7eda --- /dev/null +++ b/baselines/verifiableagg/.gitignore @@ -0,0 +1,3 @@ +artifacts/ +__pycache__/ +*.pyc diff --git a/baselines/verifiableagg/README.md b/baselines/verifiableagg/README.md new file mode 100644 index 000000000000..a06ac3b9437d --- /dev/null +++ b/baselines/verifiableagg/README.md @@ -0,0 +1,135 @@ +--- +title: "Verifiable Aggregation Workflow" +url: https://github.com/rwilliamspbg-ops/Sovereign-Mohawk-Proto +labels: [verification, aggregation, reproducibility, message-api, synthetic-data] +dataset: [synthetic] +--- + +## Verifiable Aggregation Workflow + +> Note: If you use this baseline in your work, please cite Flower and any upstream work that inspired your implementation. + +**Paper/Reference:** [Sovereign-Mohawk-Proto](https://github.com/rwilliamspbg-ops/Sovereign-Mohawk-Proto) + +**Authors:** Community contribution by rwilliamspbg-ops + +**Abstract:** This baseline demonstrates a reproducible federated learning workflow in Flower where standard FedAvg aggregation is augmented with optional server-side verification hooks. At each round, the server recomputes the weighted aggregate from raw client updates, compares it to the strategy output under a configurable tolerance, and records deterministic hashes and verification outcomes in a JSON report. + +## About this baseline + +**What is implemented:** A minimal Message API Flower baseline with deterministic synthetic data, optional verification checks around aggregation outputs, and benchmark-friendly reporting scripts. + +**Datasets:** Fully deterministic synthetic binary classification data generated per client partition. + +**Hardware Setup:** CPU-only runs are supported. Default configuration (8 clients, 5 rounds) typically finishes in under a minute on a laptop-class CPU. + +**Contributors:** rwilliamspbg-ops, Flower community maintainers + +## Experimental Setup + +**Task:** Binary classification. + +**Model:** Small MLP with two linear layers and one ReLU. + +**Dataset:** + +| Property | Value | +| --- | --- | +| Source | Generated on the fly (no downloads) | +| Features | 10 float features | +| Labels | Binary (0/1) | +| Clients | 8 by default | +| Local train examples/client | 128 | +| Local val examples/client | 64 | +| Partitioning | Deterministic client-specific distribution shift | + +**Training Hyperparameters (default):** + +| Hyperparameter | Value | +| --- | --- | +| num-server-rounds | 5 | +| fraction-train | 1.0 | +| fraction-evaluate | 1.0 | +| local-epochs | 1 | +| learning-rate | 0.05 | +| batch-size | 32 | +| random-seed | 2026 | +| verify-aggregation | true | +| verification-tolerance | 1e-6 | + +The number of simulated clients is controlled by the federation setting +`options.num-supernodes` (default: 8), not by `--run-config`. + +## Environment Setup + +```bash +# Create the virtual environment +pyenv virtualenv 3.12.12 verifiableagg + +# Activate it +pyenv activate verifiableagg + +# Install baseline +pip install -e . + +# If you are contributing changes and want to run lint/type checks +pip install -e ".[dev]" +``` + +For contributor checks used in Flower baselines CI: + +```bash +cd .. +./dev/test-baseline-structure.sh verifiableagg +./dev/test-baseline.sh verifiableagg +``` + +## Running the Experiments + +```bash +# Run with defaults from pyproject.toml +flwr run . + +# Override selected values from the CLI +flwr run . --run-config "num-server-rounds=10 verify-aggregation=true random-seed=2026" + +# Run benchmark helper script (train + report check) +bash run_benchmark.sh +``` + +## Verification Outputs and Reproducibility + +After each run, artifacts are written to the directory set by artifacts-dir (default: artifacts): + +- artifacts/final_model.pt +- artifacts/report.json + +The report includes: + +- Effective run configuration +- Per-round aggregated train/eval metrics +- Per-round verification status (pass/fail) +- Maximum absolute replay difference +- Deterministic SHA256 hash of aggregated parameters per round + +To summarize and validate verification outcomes: + +```bash +python benchmark_report.py --report-path artifacts/report.json +``` + +This command exits non-zero if any round fails verification, which makes it suitable for CI or benchmark automation. + +## Expected Results + +With default settings, all rounds should pass verification with very small numerical replay error (typically near machine precision). Example benchmark output: + +```text +round num_replies max_abs_diff passed +1 8 0.00000000e+00 1 +2 8 0.00000000e+00 1 +3 8 0.00000000e+00 1 +4 8 0.00000000e+00 1 +5 8 0.00000000e+00 1 +All rounds verified. Max observed absolute difference: 0.00000000e+00 +``` diff --git a/baselines/verifiableagg/benchmark_report.py b/baselines/verifiableagg/benchmark_report.py new file mode 100644 index 000000000000..6d4639ebe6b4 --- /dev/null +++ b/baselines/verifiableagg/benchmark_report.py @@ -0,0 +1,51 @@ +"""Summarize verification report and exit non-zero on verification failure.""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--report-path", + type=Path, + default=Path("artifacts/report.json"), + help="Path to report generated by the ServerApp.", + ) + return parser.parse_args() + + +def main() -> int: + """Print benchmark-ready verification summary.""" + args = parse_args() + report = json.loads(args.report_path.read_text(encoding="utf-8")) + + verification_rounds = report.get("verification_rounds", []) + if not verification_rounds: + print("No verification rounds found in report.") + return 1 + + failed = [item for item in verification_rounds if not item.get("passed", False)] + + print("round\tnum_replies\tmax_abs_diff\tpassed") + for item in verification_rounds: + print( + f"{item['round']}\t{item['num_replies']}\t" + f"{item['max_abs_diff']:.8e}\t{int(item['passed'])}" + ) + + if failed: + print(f"Verification failed in {len(failed)} rounds.") + return 2 + + max_diff = max(float(item["max_abs_diff"]) for item in verification_rounds) + print(f"All rounds verified. Max observed absolute difference: {max_diff:.8e}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/baselines/verifiableagg/pyproject.toml b/baselines/verifiableagg/pyproject.toml new file mode 100644 index 000000000000..d332815aee17 --- /dev/null +++ b/baselines/verifiableagg/pyproject.toml @@ -0,0 +1,144 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "verifiableagg" +version = "1.0.0" +description = "Verifiable Aggregation Workflow baseline using Flower Message API" +license = "Apache-2.0" +dependencies = [ + "flwr[simulation]>=1.24.0", + "numpy>=1.26.0", + "torch==2.8.0", +] + +[tool.hatch.metadata] +allow-direct-references = true + +[project.optional-dependencies] +dev = [ + "isort==5.13.2", + "black==24.2.0", + "docformatter==1.7.5", + "mypy==1.8.0", + "pylint==3.3.1", + "pytest==7.4.4", + "pytest-watch==4.2.0", + "ruff==0.4.5", + "types-requests==2.31.0.20240125", +] + +[tool.isort] +profile = "black" + +[tool.black] +line-length = 88 +target-version = ["py310", "py311", "py312"] + +[tool.pytest.ini_options] +minversion = "6.2" +addopts = "-qq" + +[tool.mypy] +ignore_missing_imports = true +strict = false +plugins = "numpy.typing.mypy_plugin" + +[tool.pylint."MESSAGES CONTROL"] +disable = "duplicate-code,too-few-public-methods,useless-import-alias" +good-names = "i,j,k,_,x,y,X,Y,K,N" +max-args = 10 +max-attributes = 15 +max-locals = 36 +max-branches = 20 +max-statements = 55 + +[tool.pylint.typecheck] +generated-members = "numpy.*, torch.*, tensorflow.*" + +[[tool.mypy.overrides]] +module = [ + "importlib.metadata.*", + "importlib_metadata.*", +] +follow_imports = "skip" +follow_imports_for_stubs = true +disallow_untyped_calls = false + +[[tool.mypy.overrides]] +module = "torch.*" +follow_imports = "skip" +follow_imports_for_stubs = true + +[tool.docformatter] +wrap-summaries = 88 +wrap-descriptions = 88 + +[tool.ruff] +target-version = "py310" +line-length = 88 +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "proto", +] + +[tool.ruff.lint] +select = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] +fixable = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] +ignore = ["B024", "B027", "D205", "D209"] + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.flwr.app] +publisher = "rwilliamspbg-ops" + +[tool.flwr.app.components] +serverapp = "verifiableagg.server_app:app" +clientapp = "verifiableagg.client_app:app" + +[tool.flwr.app.config] +num-server-rounds = 5 +fraction-train = 1.0 +fraction-evaluate = 1.0 +local-epochs = 1 +learning-rate = 0.05 +batch-size = 32 +num-features = 10 +num-train-examples = 128 +num-val-examples = 64 +random-seed = 2026 +verify-aggregation = true +verification-tolerance = 1e-6 +artifacts-dir = "artifacts" + +[tool.flwr.federations] +default = "local-simulation" + +[tool.flwr.federations.local-simulation] +options.num-supernodes = 8 +options.backend.client-resources.num-cpus = 1 +options.backend.client-resources.num-gpus = 0.0 diff --git a/baselines/verifiableagg/run_benchmark.sh b/baselines/verifiableagg/run_benchmark.sh new file mode 100755 index 000000000000..dc6a26f7cfa5 --- /dev/null +++ b/baselines/verifiableagg/run_benchmark.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -euo pipefail + +flwr run . --run-config "num-server-rounds=5 random-seed=2026 verify-aggregation=true" +python benchmark_report.py --report-path artifacts/report.json diff --git a/baselines/verifiableagg/verifiableagg/__init__.py b/baselines/verifiableagg/verifiableagg/__init__.py new file mode 100644 index 000000000000..7a2a6325b352 --- /dev/null +++ b/baselines/verifiableagg/verifiableagg/__init__.py @@ -0,0 +1 @@ +"""Verifiableagg baseline package.""" diff --git a/baselines/verifiableagg/verifiableagg/client_app.py b/baselines/verifiableagg/verifiableagg/client_app.py new file mode 100644 index 000000000000..6b9f88012ba4 --- /dev/null +++ b/baselines/verifiableagg/verifiableagg/client_app.py @@ -0,0 +1,100 @@ +"""ClientApp for verifiable aggregation baseline.""" + +from __future__ import annotations + +import torch +from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict +from flwr.clientapp import ClientApp + +from verifiableagg.dataset import load_data +from verifiableagg.model import Net, evaluate, train + +app = ClientApp() + + +@app.train() +def train_fn(msg: Message, context: Context) -> Message: + """Train local model and reply with model arrays and metrics.""" + num_features = int(context.run_config["num-features"]) + local_epochs = int(context.run_config["local-epochs"]) + learning_rate = float(context.run_config["learning-rate"]) + batch_size = int(context.run_config["batch-size"]) + num_train_examples = int(context.run_config["num-train-examples"]) + num_val_examples = int(context.run_config["num-val-examples"]) + base_seed = int(context.run_config["random-seed"]) + + partition_id = int(context.node_config["partition-id"]) + num_partitions = int(context.node_config["num-partitions"]) + + trainloader, _ = load_data( + partition_id=partition_id, + num_partitions=num_partitions, + num_train_examples=num_train_examples, + num_val_examples=num_val_examples, + num_features=num_features, + batch_size=batch_size, + base_seed=base_seed, + ) + + model = Net(num_features=num_features) + arrays = msg.content.array_records["arrays"] + model.load_state_dict(arrays.to_torch_state_dict()) + + device = torch.device("cpu") + train_loss = train( + model=model, + trainloader=trainloader, + epochs=local_epochs, + learning_rate=learning_rate, + device=device, + ) + + model_record = ArrayRecord(model.state_dict()) + metrics = MetricRecord( + { + "train_loss": train_loss, + "num-examples": len(trainloader.dataset), + } + ) + content = RecordDict({"arrays": model_record, "metrics": metrics}) + return Message(content=content, reply_to=msg) + + +@app.evaluate() +def evaluate_fn(msg: Message, context: Context) -> Message: + """Evaluate global model on local validation split.""" + num_features = int(context.run_config["num-features"]) + batch_size = int(context.run_config["batch-size"]) + num_train_examples = int(context.run_config["num-train-examples"]) + num_val_examples = int(context.run_config["num-val-examples"]) + base_seed = int(context.run_config["random-seed"]) + + partition_id = int(context.node_config["partition-id"]) + num_partitions = int(context.node_config["num-partitions"]) + + _, valloader = load_data( + partition_id=partition_id, + num_partitions=num_partitions, + num_train_examples=num_train_examples, + num_val_examples=num_val_examples, + num_features=num_features, + batch_size=batch_size, + base_seed=base_seed, + ) + + model = Net(num_features=num_features) + arrays = msg.content.array_records["arrays"] + model.load_state_dict(arrays.to_torch_state_dict()) + + device = torch.device("cpu") + eval_loss, eval_acc = evaluate(model=model, valloader=valloader, device=device) + + metrics = MetricRecord( + { + "eval_loss": eval_loss, + "eval_acc": eval_acc, + "num-examples": len(valloader.dataset), + } + ) + content = RecordDict({"metrics": metrics}) + return Message(content=content, reply_to=msg) diff --git a/baselines/verifiableagg/verifiableagg/dataset.py b/baselines/verifiableagg/verifiableagg/dataset.py new file mode 100644 index 000000000000..4a26acd97658 --- /dev/null +++ b/baselines/verifiableagg/verifiableagg/dataset.py @@ -0,0 +1,77 @@ +"""Deterministic synthetic dataset utilities for verifiable aggregation baseline.""" + +from __future__ import annotations + +import numpy as np +import torch +from torch.utils.data import DataLoader, TensorDataset + + +def _make_partition( + partition_id: int, + num_partitions: int, + *, + num_train_examples: int, + num_val_examples: int, + num_features: int, + base_seed: int, +) -> tuple[TensorDataset, TensorDataset]: + """Create deterministic train/val tensors for one partition.""" + total_examples = num_train_examples + num_val_examples + + partition_seed = base_seed + (partition_id * 9973) + rng = np.random.default_rng(partition_seed) + + part_position = partition_id / max(num_partitions - 1, 1) + shift = np.linspace(-0.3, 0.3, num_features, dtype=np.float32) * ( + 2.0 * part_position - 1.0 + ) + + x = rng.normal(loc=0.0, scale=1.0, size=(total_examples, num_features)).astype( + np.float32 + ) + x = x + shift + + w = np.linspace(-1.0, 1.0, num_features, dtype=np.float32) + logits = x @ w + 0.15 * (partition_id % 3) + y = (logits > 0.0).astype(np.int64) + + x_train = torch.from_numpy(x[:num_train_examples]) + y_train = torch.from_numpy(y[:num_train_examples]) + x_val = torch.from_numpy(x[num_train_examples:]) + y_val = torch.from_numpy(y[num_train_examples:]) + + trainset = TensorDataset(x_train, y_train) + valset = TensorDataset(x_val, y_val) + return trainset, valset + + +def load_data( + partition_id: int, + num_partitions: int, + *, + num_train_examples: int, + num_val_examples: int, + num_features: int, + batch_size: int, + base_seed: int, +) -> tuple[DataLoader, DataLoader]: + """Load deterministic local train and val data loaders for one client.""" + trainset, valset = _make_partition( + partition_id=partition_id, + num_partitions=num_partitions, + num_train_examples=num_train_examples, + num_val_examples=num_val_examples, + num_features=num_features, + base_seed=base_seed, + ) + + generator = torch.Generator().manual_seed(base_seed + partition_id) + trainloader = DataLoader( + trainset, + batch_size=batch_size, + shuffle=True, + generator=generator, + ) + valloader = DataLoader(valset, batch_size=batch_size, shuffle=False) + return trainloader, valloader diff --git a/baselines/verifiableagg/verifiableagg/model.py b/baselines/verifiableagg/verifiableagg/model.py new file mode 100644 index 000000000000..1c8d9434bdb6 --- /dev/null +++ b/baselines/verifiableagg/verifiableagg/model.py @@ -0,0 +1,80 @@ +"""Model and train/eval utilities for verifiable aggregation baseline.""" + +from __future__ import annotations + +import torch +from torch import nn +from torch.utils.data import DataLoader + + +class Net(nn.Module): + """Small MLP for deterministic synthetic binary classification.""" + + def __init__(self, num_features: int) -> None: + super().__init__() + self.net = nn.Sequential( + nn.Linear(num_features, 16), + nn.ReLU(), + nn.Linear(16, 2), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Compute model outputs.""" + return self.net(x) + + +def train( + model: nn.Module, + trainloader: DataLoader, + epochs: int, + learning_rate: float, + device: torch.device, +) -> float: + """Train model for a number of local epochs and return average loss.""" + model.to(device) + model.train() + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + + loss_sum = 0.0 + num_batches = 0 + + for _ in range(epochs): + for x, y in trainloader: + x = x.to(device) + y = y.to(device) + optimizer.zero_grad() + loss = criterion(model(x), y) + loss.backward() + optimizer.step() + loss_sum += float(loss.item()) + num_batches += 1 + + return loss_sum / max(num_batches, 1) + + +def evaluate( + model: nn.Module, valloader: DataLoader, device: torch.device +) -> tuple[float, float]: + """Evaluate model and return loss and accuracy.""" + model.to(device) + model.eval() + criterion = nn.CrossEntropyLoss() + + loss_sum = 0.0 + correct = 0 + total = 0 + + with torch.no_grad(): + for x, y in valloader: + x = x.to(device) + y = y.to(device) + logits = model(x) + loss_sum += float(criterion(logits, y).item()) + preds = torch.argmax(logits, dim=1) + correct += int((preds == y).sum().item()) + total += int(y.shape[0]) + + avg_loss = loss_sum / max(len(valloader), 1) + accuracy = correct / max(total, 1) + return avg_loss, accuracy diff --git a/baselines/verifiableagg/verifiableagg/reporting.py b/baselines/verifiableagg/verifiableagg/reporting.py new file mode 100644 index 000000000000..8b77b1703a2d --- /dev/null +++ b/baselines/verifiableagg/verifiableagg/reporting.py @@ -0,0 +1,16 @@ +"""Utilities for writing reproducibility and verification reports.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + + +def write_json_report(report: dict[str, Any], output_path: Path) -> None: + """Write JSON report, creating parent directories as needed.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text( + json.dumps(report, indent=2, sort_keys=True), + encoding="utf-8", + ) diff --git a/baselines/verifiableagg/verifiableagg/server_app.py b/baselines/verifiableagg/verifiableagg/server_app.py new file mode 100644 index 000000000000..7634b995ebfc --- /dev/null +++ b/baselines/verifiableagg/verifiableagg/server_app.py @@ -0,0 +1,111 @@ +"""ServerApp for verifiable aggregation baseline.""" + +from __future__ import annotations + +import random +from pathlib import Path + +import numpy as np +import torch +from flwr.app import ArrayRecord, Context +from flwr.serverapp import Grid, ServerApp + +from verifiableagg.model import Net +from verifiableagg.reporting import write_json_report +from verifiableagg.strategy import VerifiableFedAvg +from verifiableagg.utils import as_bool + +app = ServerApp() + + +def _set_global_seeds(seed: int) -> None: + """Set global seeds with deterministic CUDA behavior when available.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + if hasattr(torch.backends, "cudnn"): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +@app.main() +def main(grid: Grid, context: Context) -> None: + """Run federated training and write reproducibility report.""" + run_config = context.run_config + + seed = int(run_config["random-seed"]) + _set_global_seeds(seed) + + num_rounds = int(run_config["num-server-rounds"]) + num_features = int(run_config["num-features"]) + fraction_train = float(run_config["fraction-train"]) + fraction_evaluate = float(run_config["fraction-evaluate"]) + verify_aggregation = as_bool(run_config["verify-aggregation"]) + verification_tolerance = float(run_config["verification-tolerance"]) + artifacts_dir = Path(str(run_config["artifacts-dir"])) + + model = Net(num_features=num_features) + initial_arrays = ArrayRecord(model.state_dict()) + + strategy = VerifiableFedAvg( + fraction_train=fraction_train, + fraction_evaluate=fraction_evaluate, + min_available_nodes=2, + verify_aggregation=verify_aggregation, + verification_tolerance=verification_tolerance, + weighted_by_key="num-examples", + arrayrecord_key="arrays", + configrecord_key="config", + ) + + result = strategy.start( + grid=grid, + initial_arrays=initial_arrays, + num_rounds=num_rounds, + ) + + artifacts_dir.mkdir(parents=True, exist_ok=True) + model_path = artifacts_dir / "final_model.pt" + final_arrays = result.arrays if result.arrays is not None else initial_arrays + torch.save(final_arrays.to_torch_state_dict(), model_path) + + verification_rounds = [ + { + "round": item.server_round, + "num_replies": item.num_replies, + "max_abs_diff": item.max_abs_diff, + "passed": item.passed, + "aggregate_hash": item.aggregate_hash, + } + for item in strategy.verification_rounds + ] + + train_metrics = { + str(round_id): dict(metrics) + for round_id, metrics in result.train_metrics_clientapp.items() + } + eval_metrics = { + str(round_id): dict(metrics) + for round_id, metrics in result.evaluate_metrics_clientapp.items() + } + + report = { + "baseline": "verifiableagg", + "run_config": {key: run_config[key] for key in sorted(run_config)}, + "train_metrics": train_metrics, + "evaluate_metrics": eval_metrics, + "verification_rounds": verification_rounds, + "artifacts": {"model_path": str(model_path)}, + } + + report_path = artifacts_dir / "report.json" + write_json_report(report=report, output_path=report_path) + + passed_rounds = sum(1 for item in verification_rounds if item["passed"]) + total_rounds = len(verification_rounds) + print( + f"Verification summary: {passed_rounds}/{total_rounds} rounds passed. " + f"Report: {report_path}" + ) diff --git a/baselines/verifiableagg/verifiableagg/strategy.py b/baselines/verifiableagg/verifiableagg/strategy.py new file mode 100644 index 000000000000..aea5c502cab8 --- /dev/null +++ b/baselines/verifiableagg/verifiableagg/strategy.py @@ -0,0 +1,170 @@ +"""Custom FedAvg strategy with optional aggregation verification hooks.""" + +from __future__ import annotations + +import hashlib +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any, cast + +import numpy as np +from flwr.app import ArrayRecord, Message, MetricRecord, RecordDict +from flwr.serverapp.strategy import FedAvg +from flwr.serverapp.strategy.strategy_utils import aggregate_arrayrecords + + +@dataclass +class VerificationRound: + """Verification metadata for one aggregation round.""" + + server_round: int + num_replies: int + max_abs_diff: float + passed: bool + aggregate_hash: str + + +def _update_with_length_prefix(digest: Any, data: bytes) -> None: + """Update a hash digest using unambiguous length-prefixed framing.""" + digest.update(len(data).to_bytes(8, byteorder="little", signed=False)) + digest.update(data) + + +def _hash_arrayrecord(arrays: ArrayRecord) -> str: + """Compute a deterministic SHA256 hash of an ArrayRecord.""" + digest = hashlib.sha256() + for key in sorted(arrays.keys()): + arr = np.ascontiguousarray(arrays[key].numpy()) + _update_with_length_prefix(digest, key.encode("utf-8")) + _update_with_length_prefix(digest, str(arr.dtype).encode("utf-8")) + _update_with_length_prefix( + digest, + np.asarray(arr.shape, dtype=np.int64).tobytes(), + ) + _update_with_length_prefix(digest, arr.tobytes()) + return digest.hexdigest() + + +def _recompute_weighted_average( + replies: list[RecordDict], + weighted_by_key: str, + arrayrecord_key: str, + metricrecord_key: str, +) -> dict[str, np.ndarray]: + """Recompute weighted average from client replies.""" + first_array_records = replies[0].array_records + if arrayrecord_key in first_array_records: + first_arrays = first_array_records[arrayrecord_key] + else: + first_arrays = next(iter(first_array_records.values())) + keys = list(first_arrays.keys()) + + sums: dict[str, np.ndarray] = {} + total_weight = 0.0 + + for reply in replies: + metric_records = reply.metric_records + if metricrecord_key in metric_records: + metrics = metric_records[metricrecord_key] + else: + metrics = next(iter(metric_records.values())) + + array_records = reply.array_records + if arrayrecord_key in array_records: + arrays = array_records[arrayrecord_key] + else: + arrays = next(iter(array_records.values())) + weight = float(cast(int | float, metrics[weighted_by_key])) + total_weight += weight + + for key in keys: + current = arrays[key].numpy().astype(np.float64) + if key not in sums: + sums[key] = np.zeros_like(current) + sums[key] += weight * current + + if total_weight <= 0.0: + return {key: np.zeros_like(value) for key, value in sums.items()} + + return {key: value / total_weight for key, value in sums.items()} + + +class VerifiableFedAvg(FedAvg): + """FedAvg with optional deterministic post-aggregation verification.""" + + def __init__( + self, + verify_aggregation: bool = True, + verification_tolerance: float = 1e-6, + metricrecord_key: str = "metrics", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.verify_aggregation = verify_aggregation + self.verification_tolerance = verification_tolerance + self.metricrecord_key = metricrecord_key + self.verification_rounds: list[VerificationRound] = [] + + def aggregate_train( + self, + server_round: int, + replies: Iterable[Message], + ) -> tuple[ArrayRecord | None, MetricRecord | None]: + """Aggregate train replies and optionally verify deterministic replay.""" + valid_replies, _ = self._check_and_log_replies(replies, is_train=True) + + if not valid_replies: + self.verification_rounds.append( + VerificationRound( + server_round=server_round, + num_replies=0, + max_abs_diff=0.0, + passed=False, + aggregate_hash="", + ) + ) + return None, None + + reply_contents = [msg.content for msg in valid_replies] + + arrays = aggregate_arrayrecords( + reply_contents, + self.weighted_by_key, + ) + metrics = self.train_metrics_aggr_fn(reply_contents, self.weighted_by_key) + + max_abs_diff = 0.0 + verification_passed = True + + if self.verify_aggregation: + replay_avg = _recompute_weighted_average( + replies=reply_contents, + weighted_by_key=self.weighted_by_key, + arrayrecord_key=self.arrayrecord_key, + metricrecord_key=self.metricrecord_key, + ) + for key in replay_avg: + agg_arr = arrays[key].numpy().astype(np.float64) + diff = np.max(np.abs(replay_avg[key] - agg_arr)) + max_abs_diff = max(max_abs_diff, float(diff)) + verification_passed = max_abs_diff <= self.verification_tolerance + + aggregate_hash = _hash_arrayrecord(arrays) + + if metrics is None: + metrics = MetricRecord() + metrics["verification_passed"] = int(verification_passed) + metrics["verification_max_abs_diff"] = float(max_abs_diff) + metrics["verification_num_replies"] = int(len(valid_replies)) + + self.verification_rounds.append( + VerificationRound( + server_round=server_round, + num_replies=len(valid_replies), + max_abs_diff=max_abs_diff, + passed=verification_passed, + aggregate_hash=aggregate_hash, + ) + ) + + return arrays, metrics diff --git a/baselines/verifiableagg/verifiableagg/utils.py b/baselines/verifiableagg/verifiableagg/utils.py new file mode 100644 index 000000000000..8b4b2f73ac6d --- /dev/null +++ b/baselines/verifiableagg/verifiableagg/utils.py @@ -0,0 +1,18 @@ +"""Shared utility helpers for verifiable aggregation baseline.""" + +from __future__ import annotations + + +def as_bool(value: object) -> bool: + """Convert bool-like run-config values to bool.""" + if isinstance(value, bool): + return value + if isinstance(value, int | float): + return bool(value) + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + raise ValueError(f"Cannot parse boolean value from {value!r}")