Skip to content

Commit 2450258

Browse files
committed
add BasicClientProtocol
1 parent dd346b6 commit 2450258

File tree

1 file changed

+171
-1
lines changed

1 file changed

+171
-1
lines changed

fl4health/clients/basic_client.py

Lines changed: 171 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections.abc import Iterator, Sequence
33
from logging import INFO
44
from pathlib import Path
5-
from typing import Any
5+
from typing import Any, Protocol, runtime_checkable
66

77
import torch
88
import torch.nn as nn
@@ -1337,3 +1337,173 @@ def _load_client_state(self) -> bool:
13371337
self.lr_schedulers[key].load_state_dict(client_state["lr_schedulers_state"][key])
13381338

13391339
return True
1340+
1341+
1342+
@runtime_checkable
1343+
class BasicClientProtocol(Protocol):
1344+
data_path: str
1345+
device: torch.device
1346+
metrics: Sequence[Metric]
1347+
progress_bar: bool
1348+
client_name: str
1349+
state_checkpoint_name: str
1350+
checkpoint_and_state_module: ClientCheckpointAndStateModule | None
1351+
reports_manager: ReportsManager
1352+
initialized: bool
1353+
1354+
# Loss and Metric management
1355+
train_loss_meter: LossMeter
1356+
val_loss_meter: LossMeter
1357+
train_metric_manager: MetricManager
1358+
val_metric_manager: MetricManager
1359+
test_loss_meter: LossMeter
1360+
test_metric_manager: MetricManager
1361+
1362+
# Optional variable to store the weights that the client was initialized with during each round of training
1363+
initial_weights: NDArrays | None
1364+
1365+
total_steps: int
1366+
total_epochs: int
1367+
1368+
# Attributes to be initialized in setup_client
1369+
parameter_exchanger: ParameterExchanger
1370+
model: nn.Module
1371+
optimizers: dict[str, torch.optim.Optimizer]
1372+
train_loader: DataLoader
1373+
val_loader: DataLoader
1374+
test_loader: DataLoader | None
1375+
num_train_samples: int
1376+
num_val_samples: int
1377+
num_test_samples: int | None
1378+
learning_rate: float | None
1379+
1380+
# User can set the early stopper for the client by instantiating the EarlyStopper class
1381+
# and setting the patience and interval_steps attributes. The early stopper will be used to
1382+
# stop training if the validation loss does not improve for a certain number of steps.
1383+
early_stopper: EarlyStopper | None
1384+
# Config can contain num_validation_steps key, which determines an upper bound
1385+
# for the validation steps taken. If not specified, no upper bound will be enforced.
1386+
# By specifying this in the config we cannot guarantee the validation set is the same
1387+
# across rounds for clients.
1388+
num_validation_steps: int | None
1389+
# NOTE: These iterators are of type _BaseDataLoaderIter, which is not importable...so we're forced to use
1390+
# Iterator
1391+
train_iterator: Iterator | None
1392+
val_iterator: Iterator | None
1393+
1394+
def get_parameters(self, config: Config) -> NDArrays: ...
1395+
1396+
def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bool) -> None: ...
1397+
1398+
def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> None: ...
1399+
1400+
def shutdown(self) -> None: ...
1401+
1402+
def process_config(self, config: Config) -> tuple[int | None, int | None, int, bool, bool]: ...
1403+
1404+
def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: ...
1405+
1406+
def evaluate(self, parameters: NDArrays, config: Config) -> tuple[float, int, dict[str, Scalar]]: ...
1407+
1408+
def get_client_specific_logs(
1409+
self,
1410+
current_round: int | None,
1411+
current_epoch: int | None,
1412+
logging_mode: LoggingMode,
1413+
) -> tuple[str, list[tuple[LogLevel, str]]]: ...
1414+
1415+
def get_client_specific_reports(self) -> dict[str, Any]: ...
1416+
1417+
def update_metric_manager(
1418+
self,
1419+
preds: TorchPredType,
1420+
target: TorchTargetType,
1421+
metric_manager: MetricManager,
1422+
) -> None: ...
1423+
1424+
def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[TrainingLosses, TorchPredType]: ...
1425+
1426+
def val_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[EvaluationLosses, TorchPredType]: ...
1427+
1428+
def train_by_epochs(
1429+
self,
1430+
epochs: int,
1431+
current_round: int | None = None,
1432+
) -> tuple[dict[str, float], dict[str, Scalar]]: ...
1433+
1434+
def train_by_steps(
1435+
self,
1436+
steps: int,
1437+
current_round: int | None = None,
1438+
) -> tuple[dict[str, float], dict[str, Scalar]]: ...
1439+
1440+
def _validate_by_steps(
1441+
self, loss_meter: LossMeter, metric_manager: MetricManager, include_losses_in_metrics: bool = False
1442+
) -> tuple[float, dict[str, Scalar]]: ...
1443+
1444+
def _fully_validate_or_test(
1445+
self,
1446+
loader: DataLoader,
1447+
loss_meter: LossMeter,
1448+
metric_manager: MetricManager,
1449+
logging_mode: LoggingMode = LoggingMode.VALIDATION,
1450+
include_losses_in_metrics: bool = False,
1451+
) -> tuple[float, dict[str, Scalar]]: ...
1452+
1453+
def validate(self, include_losses_in_metrics: bool = False) -> tuple[float, dict[str, Scalar]]: ...
1454+
1455+
def get_properties(self, config: Config) -> dict[str, Scalar]: ...
1456+
1457+
def setup_client(self, config: Config) -> None: ...
1458+
1459+
def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: ...
1460+
1461+
def predict(self, input: TorchInputType) -> tuple[TorchPredType, TorchFeatureType]: ...
1462+
1463+
def compute_loss_and_additional_losses(
1464+
self, preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType
1465+
) -> tuple[torch.Tensor, dict[str, torch.Tensor] | None]: ...
1466+
1467+
def compute_training_loss(
1468+
self,
1469+
preds: TorchPredType,
1470+
features: TorchFeatureType,
1471+
target: TorchTargetType,
1472+
) -> TrainingLosses: ...
1473+
1474+
def compute_evaluation_loss(
1475+
self,
1476+
preds: TorchPredType,
1477+
features: TorchFeatureType,
1478+
target: TorchTargetType,
1479+
) -> EvaluationLosses: ...
1480+
1481+
def set_optimizer(self, config: Config) -> None: ...
1482+
1483+
def get_data_loaders(self, config: Config) -> tuple[DataLoader, ...]: ...
1484+
1485+
def get_test_data_loader(self, config: Config) -> DataLoader | None: ...
1486+
1487+
def transform_target(self, target: TorchTargetType) -> TorchTargetType: ...
1488+
1489+
def get_criterion(self, config: Config) -> _Loss: ...
1490+
1491+
def get_optimizer(self, config: Config) -> Optimizer | dict[str, Optimizer]: ...
1492+
1493+
def get_model(self, config: Config) -> nn.Module: ...
1494+
1495+
def get_lr_scheduler(self, optimizer_key: str, config: Config) -> LRScheduler | None: ...
1496+
1497+
def update_lr_schedulers(self, step: int | None = None, epoch: int | None = None) -> None: ...
1498+
1499+
def update_before_train(self, current_server_round: int) -> None: ...
1500+
1501+
def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None: ...
1502+
1503+
def update_before_step(self, step: int, current_round: int | None = None) -> None: ...
1504+
1505+
def update_after_step(self, step: int, current_round: int | None = None) -> None: ...
1506+
1507+
def update_before_epoch(self, epoch: int) -> None: ...
1508+
1509+
def transform_gradients(self, losses: TrainingLosses) -> None: ...

0 commit comments

Comments
 (0)