|
2 | 2 | from collections.abc import Iterator, Sequence
|
3 | 3 | from logging import INFO
|
4 | 4 | from pathlib import Path
|
5 |
| -from typing import Any |
| 5 | +from typing import Any, Protocol, runtime_checkable |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 | import torch.nn as nn
|
@@ -1337,3 +1337,173 @@ def _load_client_state(self) -> bool:
|
1337 | 1337 | self.lr_schedulers[key].load_state_dict(client_state["lr_schedulers_state"][key])
|
1338 | 1338 |
|
1339 | 1339 | return True
|
| 1340 | + |
| 1341 | + |
| 1342 | +@runtime_checkable |
| 1343 | +class BasicClientProtocol(Protocol): |
| 1344 | + data_path: str |
| 1345 | + device: torch.device |
| 1346 | + metrics: Sequence[Metric] |
| 1347 | + progress_bar: bool |
| 1348 | + client_name: str |
| 1349 | + state_checkpoint_name: str |
| 1350 | + checkpoint_and_state_module: ClientCheckpointAndStateModule | None |
| 1351 | + reports_manager: ReportsManager |
| 1352 | + initialized: bool |
| 1353 | + |
| 1354 | + # Loss and Metric management |
| 1355 | + train_loss_meter: LossMeter |
| 1356 | + val_loss_meter: LossMeter |
| 1357 | + train_metric_manager: MetricManager |
| 1358 | + val_metric_manager: MetricManager |
| 1359 | + test_loss_meter: LossMeter |
| 1360 | + test_metric_manager: MetricManager |
| 1361 | + |
| 1362 | + # Optional variable to store the weights that the client was initialized with during each round of training |
| 1363 | + initial_weights: NDArrays | None |
| 1364 | + |
| 1365 | + total_steps: int |
| 1366 | + total_epochs: int |
| 1367 | + |
| 1368 | + # Attributes to be initialized in setup_client |
| 1369 | + parameter_exchanger: ParameterExchanger |
| 1370 | + model: nn.Module |
| 1371 | + optimizers: dict[str, torch.optim.Optimizer] |
| 1372 | + train_loader: DataLoader |
| 1373 | + val_loader: DataLoader |
| 1374 | + test_loader: DataLoader | None |
| 1375 | + num_train_samples: int |
| 1376 | + num_val_samples: int |
| 1377 | + num_test_samples: int | None |
| 1378 | + learning_rate: float | None |
| 1379 | + |
| 1380 | + # User can set the early stopper for the client by instantiating the EarlyStopper class |
| 1381 | + # and setting the patience and interval_steps attributes. The early stopper will be used to |
| 1382 | + # stop training if the validation loss does not improve for a certain number of steps. |
| 1383 | + early_stopper: EarlyStopper | None |
| 1384 | + # Config can contain num_validation_steps key, which determines an upper bound |
| 1385 | + # for the validation steps taken. If not specified, no upper bound will be enforced. |
| 1386 | + # By specifying this in the config we cannot guarantee the validation set is the same |
| 1387 | + # across rounds for clients. |
| 1388 | + num_validation_steps: int | None |
| 1389 | + # NOTE: These iterators are of type _BaseDataLoaderIter, which is not importable...so we're forced to use |
| 1390 | + # Iterator |
| 1391 | + train_iterator: Iterator | None |
| 1392 | + val_iterator: Iterator | None |
| 1393 | + |
| 1394 | + def get_parameters(self, config: Config) -> NDArrays: ... |
| 1395 | + |
| 1396 | + def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bool) -> None: ... |
| 1397 | + |
| 1398 | + def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> None: ... |
| 1399 | + |
| 1400 | + def shutdown(self) -> None: ... |
| 1401 | + |
| 1402 | + def process_config(self, config: Config) -> tuple[int | None, int | None, int, bool, bool]: ... |
| 1403 | + |
| 1404 | + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: ... |
| 1405 | + |
| 1406 | + def evaluate(self, parameters: NDArrays, config: Config) -> tuple[float, int, dict[str, Scalar]]: ... |
| 1407 | + |
| 1408 | + def get_client_specific_logs( |
| 1409 | + self, |
| 1410 | + current_round: int | None, |
| 1411 | + current_epoch: int | None, |
| 1412 | + logging_mode: LoggingMode, |
| 1413 | + ) -> tuple[str, list[tuple[LogLevel, str]]]: ... |
| 1414 | + |
| 1415 | + def get_client_specific_reports(self) -> dict[str, Any]: ... |
| 1416 | + |
| 1417 | + def update_metric_manager( |
| 1418 | + self, |
| 1419 | + preds: TorchPredType, |
| 1420 | + target: TorchTargetType, |
| 1421 | + metric_manager: MetricManager, |
| 1422 | + ) -> None: ... |
| 1423 | + |
| 1424 | + def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[TrainingLosses, TorchPredType]: ... |
| 1425 | + |
| 1426 | + def val_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[EvaluationLosses, TorchPredType]: ... |
| 1427 | + |
| 1428 | + def train_by_epochs( |
| 1429 | + self, |
| 1430 | + epochs: int, |
| 1431 | + current_round: int | None = None, |
| 1432 | + ) -> tuple[dict[str, float], dict[str, Scalar]]: ... |
| 1433 | + |
| 1434 | + def train_by_steps( |
| 1435 | + self, |
| 1436 | + steps: int, |
| 1437 | + current_round: int | None = None, |
| 1438 | + ) -> tuple[dict[str, float], dict[str, Scalar]]: ... |
| 1439 | + |
| 1440 | + def _validate_by_steps( |
| 1441 | + self, loss_meter: LossMeter, metric_manager: MetricManager, include_losses_in_metrics: bool = False |
| 1442 | + ) -> tuple[float, dict[str, Scalar]]: ... |
| 1443 | + |
| 1444 | + def _fully_validate_or_test( |
| 1445 | + self, |
| 1446 | + loader: DataLoader, |
| 1447 | + loss_meter: LossMeter, |
| 1448 | + metric_manager: MetricManager, |
| 1449 | + logging_mode: LoggingMode = LoggingMode.VALIDATION, |
| 1450 | + include_losses_in_metrics: bool = False, |
| 1451 | + ) -> tuple[float, dict[str, Scalar]]: ... |
| 1452 | + |
| 1453 | + def validate(self, include_losses_in_metrics: bool = False) -> tuple[float, dict[str, Scalar]]: ... |
| 1454 | + |
| 1455 | + def get_properties(self, config: Config) -> dict[str, Scalar]: ... |
| 1456 | + |
| 1457 | + def setup_client(self, config: Config) -> None: ... |
| 1458 | + |
| 1459 | + def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: ... |
| 1460 | + |
| 1461 | + def predict(self, input: TorchInputType) -> tuple[TorchPredType, TorchFeatureType]: ... |
| 1462 | + |
| 1463 | + def compute_loss_and_additional_losses( |
| 1464 | + self, preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType |
| 1465 | + ) -> tuple[torch.Tensor, dict[str, torch.Tensor] | None]: ... |
| 1466 | + |
| 1467 | + def compute_training_loss( |
| 1468 | + self, |
| 1469 | + preds: TorchPredType, |
| 1470 | + features: TorchFeatureType, |
| 1471 | + target: TorchTargetType, |
| 1472 | + ) -> TrainingLosses: ... |
| 1473 | + |
| 1474 | + def compute_evaluation_loss( |
| 1475 | + self, |
| 1476 | + preds: TorchPredType, |
| 1477 | + features: TorchFeatureType, |
| 1478 | + target: TorchTargetType, |
| 1479 | + ) -> EvaluationLosses: ... |
| 1480 | + |
| 1481 | + def set_optimizer(self, config: Config) -> None: ... |
| 1482 | + |
| 1483 | + def get_data_loaders(self, config: Config) -> tuple[DataLoader, ...]: ... |
| 1484 | + |
| 1485 | + def get_test_data_loader(self, config: Config) -> DataLoader | None: ... |
| 1486 | + |
| 1487 | + def transform_target(self, target: TorchTargetType) -> TorchTargetType: ... |
| 1488 | + |
| 1489 | + def get_criterion(self, config: Config) -> _Loss: ... |
| 1490 | + |
| 1491 | + def get_optimizer(self, config: Config) -> Optimizer | dict[str, Optimizer]: ... |
| 1492 | + |
| 1493 | + def get_model(self, config: Config) -> nn.Module: ... |
| 1494 | + |
| 1495 | + def get_lr_scheduler(self, optimizer_key: str, config: Config) -> LRScheduler | None: ... |
| 1496 | + |
| 1497 | + def update_lr_schedulers(self, step: int | None = None, epoch: int | None = None) -> None: ... |
| 1498 | + |
| 1499 | + def update_before_train(self, current_server_round: int) -> None: ... |
| 1500 | + |
| 1501 | + def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None: ... |
| 1502 | + |
| 1503 | + def update_before_step(self, step: int, current_round: int | None = None) -> None: ... |
| 1504 | + |
| 1505 | + def update_after_step(self, step: int, current_round: int | None = None) -> None: ... |
| 1506 | + |
| 1507 | + def update_before_epoch(self, epoch: int) -> None: ... |
| 1508 | + |
| 1509 | + def transform_gradients(self, losses: TrainingLosses) -> None: ... |
0 commit comments