|
| 1 | +import argparse |
| 2 | +from functools import partial |
| 3 | +from typing import Any, Dict, List, Tuple |
| 4 | + |
| 5 | +import flwr as fl |
| 6 | +from flwr.common.parameter import ndarrays_to_parameters |
| 7 | +from flwr.common.typing import Config, Metrics, Parameters |
| 8 | +from flwr.server.strategy import FedAvg |
| 9 | + |
| 10 | +from examples.models.cnn_model import MnistNet |
| 11 | +from examples.simple_metric_aggregation import metric_aggregation, normalize_metrics |
| 12 | +from fl4health.model_bases.apfl_base import APFLModule |
| 13 | +from fl4health.utils.config import load_config |
| 14 | + |
| 15 | + |
| 16 | +def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: |
| 17 | + # This function is run by the server to aggregate metrics returned by each clients fit function |
| 18 | + # NOTE: The first value of the tuple is number of examples for FedAvg |
| 19 | + total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) |
| 20 | + return normalize_metrics(total_examples, aggregated_metrics) |
| 21 | + |
| 22 | + |
| 23 | +def evaluate_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: |
| 24 | + # This function is run by the server to aggregate metrics returned by each clients evaluate function |
| 25 | + # NOTE: The first value of the tuple is number of examples for FedAvg |
| 26 | + total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) |
| 27 | + return normalize_metrics(total_examples, aggregated_metrics) |
| 28 | + |
| 29 | + |
| 30 | +def get_initial_model_parameters() -> Parameters: |
| 31 | + # Initializing the model parameters on the server side. |
| 32 | + # Currently uses the Pytorch default initialization for the model parameters. |
| 33 | + initial_model = APFLModule(MnistNet()) |
| 34 | + return ndarrays_to_parameters([val.cpu().numpy() for _, val in initial_model.state_dict().items()]) |
| 35 | + |
| 36 | + |
| 37 | +def fit_config(local_epochs: int, batch_size: int, n_server_rounds: int, current_round: int) -> Config: |
| 38 | + return { |
| 39 | + "local_epochs": local_epochs, |
| 40 | + "batch_size": batch_size, |
| 41 | + "n_server_rounds": n_server_rounds, |
| 42 | + } |
| 43 | + |
| 44 | + |
| 45 | +def main(config: Dict[str, Any]) -> None: |
| 46 | + # This function will be used to produce a config that is sent to each client to initialize their own environment |
| 47 | + fit_config_fn = partial( |
| 48 | + fit_config, |
| 49 | + config["local_epochs"], |
| 50 | + config["batch_size"], |
| 51 | + config["n_server_rounds"], |
| 52 | + ) |
| 53 | + |
| 54 | + # Server performs simple FedAveraging as its server-side optimization strategy |
| 55 | + strategy = FedAvg( |
| 56 | + min_fit_clients=config["n_clients"], |
| 57 | + min_evaluate_clients=config["n_clients"], |
| 58 | + # Server waits for min_available_clients before starting FL rounds |
| 59 | + min_available_clients=config["n_clients"], |
| 60 | + on_fit_config_fn=fit_config_fn, |
| 61 | + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, |
| 62 | + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, |
| 63 | + initial_parameters=get_initial_model_parameters(), |
| 64 | + ) |
| 65 | + |
| 66 | + fl.server.start_server( |
| 67 | + server_address="0.0.0.0:8080", |
| 68 | + config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), |
| 69 | + strategy=strategy, |
| 70 | + ) |
| 71 | + |
| 72 | + |
| 73 | +if __name__ == "__main__": |
| 74 | + parser = argparse.ArgumentParser(description="FL Server Main") |
| 75 | + parser.add_argument( |
| 76 | + "--config_path", |
| 77 | + action="store", |
| 78 | + type=str, |
| 79 | + help="Path to configuration file.", |
| 80 | + default="config.yaml", |
| 81 | + ) |
| 82 | + args = parser.parse_args() |
| 83 | + |
| 84 | + config = load_config(args.config_path) |
| 85 | + |
| 86 | + main(config) |
0 commit comments