generated from VectorInstitute/aieng-template-poetry
-
Notifications
You must be signed in to change notification settings - Fork 12
[Feature] Add BasicClientProtocol
and AdaptiveDriftConstrainedMixin
#384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
ccaba39
basic client protocols
nerdai e345284
lint
nerdai febf990
unit test for raise runtime error when protocol not satisfied
nerdai 9c45150
add test for raising type error if protocol compliance check fails
nerdai 9bbdb97
more tests
nerdai 81c1a41
don't cover protocol interfaces
nerdai 3d02767
unit test for test params initialized
nerdai a8dc0de
cr
nerdai e66fde4
davids cr
nerdai 9558d1d
renames and other typos
nerdai 6cf0340
util for dynamic application of mixin
nerdai 0defd8e
typo
nerdai 2ce3083
fix broken test
nerdai 01f00bf
add comment in tests
nerdai 06c08b6
add test for dynamically created class
nerdai cf5d8c4
rm unnecessary type: ignore annotations
nerdai 6e0c353
add unit tests for get params when client uninitalized
nerdai 629c5af
add unit test for set params
nerdai 552a781
nit
nerdai af052a5
rest of david's cr
nerdai 560bf5a
log warning
nerdai a175745
cr
nerdai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .adaptive_drift_constrained import AdaptiveDriftConstrainedMixin | ||
|
||
__all__ = ["AdaptiveDriftConstrainedMixin"] | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
"""AdaptiveDriftConstrainedMixin""" | ||
|
||
import warnings | ||
from logging import INFO | ||
from typing import Any, Protocol, runtime_checkable | ||
|
||
import torch | ||
from flwr.common.logger import log | ||
from flwr.common.typing import Config, NDArrays | ||
|
||
from fl4health.clients.basic_client import BasicClient | ||
from fl4health.losses.weight_drift_loss import WeightDriftLoss | ||
from fl4health.mixins.core import BasicClientProtocol, BasicClientProtocolPreSetup | ||
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger | ||
from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking | ||
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger | ||
from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint | ||
from fl4health.utils.losses import TrainingLosses | ||
from fl4health.utils.typing import TorchFeatureType, TorchPredType, TorchTargetType | ||
|
||
|
||
@runtime_checkable | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class AdaptiveDriftConstrainedProtocol(BasicClientProtocol, Protocol): | ||
loss_for_adaptation: float | ||
drift_penalty_tensors: list[torch.Tensor] | None | ||
drift_penalty_weight: float | None | ||
penalty_loss_function: WeightDriftLoss | ||
parameter_exchanger: FullParameterExchangerWithPacking[float] | ||
|
||
def compute_penalty_loss(self) -> torch.Tensor: ... # noqa: E704 | ||
|
||
def ensure_protocol_compliance(self) -> None: ... # noqa: E704 | ||
|
||
|
||
class AdaptiveDriftConstrainedMixin: | ||
"""Adaptive Drift Constrained Mixin | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
To be used with `~fl4health.BaseClient` in order to add the ability to compute | ||
losses via a constrained adaptive drift. | ||
""" | ||
|
||
def __init__(self, *args: Any, **kwargs: Any): | ||
# Initialize mixin-specific attributes with default values | ||
self.loss_for_adaptation = 0.1 | ||
self.drift_penalty_tensors = None | ||
self.drift_penalty_weight = None | ||
|
||
# Call parent's init | ||
try: | ||
super().__init__(*args, **kwargs) | ||
except TypeError: | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
super().__init__() | ||
|
||
# set penalty_loss_function | ||
if not isinstance(self, BasicClientProtocolPreSetup): | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise RuntimeError("This object needs to satisfy `BasicClientProtocolPreSetup`.") | ||
self.penalty_loss_function = WeightDriftLoss(self.device) | ||
|
||
def __init_subclass__(cls, **kwargs: Any): | ||
"""This method is called when a class inherits from AdaptiveMixin""" | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
super().__init_subclass__(**kwargs) | ||
|
||
# Skip check for other mixins | ||
if cls.__name__.endswith("Mixin"): | ||
return | ||
|
||
# Skip validation for dynamically created classes | ||
if hasattr(cls, "_dynamically_created"): | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return | ||
|
||
# Check at class definition time if the parent class satisfies BasicClientProtocol | ||
for base in cls.__bases__: | ||
if base is not AdaptiveDriftConstrainedMixin and isinstance(base, BasicClient): | ||
return | ||
|
||
# If we get here, no compatible base was found | ||
warnings.warn( | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"Class {cls.__name__} inherits from AdaptiveMixin but none of its other " | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"base classes is a BasicClient. This may cause runtime errors.", | ||
RuntimeWarning, | ||
) | ||
|
||
def ensure_protocol_compliance(self) -> None: | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Call this after the object is fully initialized""" | ||
if not isinstance(self, BasicClient): | ||
raise TypeError("Protocol requirements not met.") | ||
|
||
def get_parameters(self: AdaptiveDriftConstrainedProtocol, config: Config) -> NDArrays: | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Packs the parameters and training loss into a single ``NDArrays`` to be sent to the server for aggregation. If | ||
the client has not been initialized, this means the server is requesting parameters for initialization and | ||
just the model parameters are sent. When using the ``FedAvgWithAdaptiveConstraint`` strategy, this should not | ||
happen, as that strategy requires server-side initialization parameters. However, other strategies may handle | ||
this case. | ||
|
||
Args: | ||
config (Config): Configurations to allow for customization of this functions behavior | ||
|
||
Returns: | ||
NDArrays: Parameters and training loss packed together into a list of numpy arrays to be sent to the server | ||
""" | ||
if not self.initialized: | ||
log(INFO, "Setting up client and providing full model parameters to the server for initialization") | ||
|
||
# If initialized is False, the server is requesting model parameters from which to initialize all other | ||
# clients. As such get_parameters is being called before fit or evaluate, so we must call | ||
# setup_client first. | ||
self.setup_client(config) | ||
|
||
# Need all parameters even if normally exchanging partial | ||
return FullParameterExchanger().push_parameters(self.model, config=config) | ||
else: | ||
|
||
# Make sure the proper components are there | ||
assert ( | ||
self.model is not None | ||
and self.parameter_exchanger is not None | ||
and self.loss_for_adaptation is not None | ||
) | ||
model_weights = self.parameter_exchanger.push_parameters(self.model, config=config) | ||
|
||
# Weights and training loss sent to server for aggregation. Training loss is sent because server will | ||
# decide to increase or decrease the penalty weight, if adaptivity is turned on. | ||
packed_params = self.parameter_exchanger.pack_parameters(model_weights, self.loss_for_adaptation) | ||
return packed_params | ||
|
||
def set_parameters( | ||
self: AdaptiveDriftConstrainedProtocol, parameters: NDArrays, config: Config, fitting_round: bool | ||
) -> None: | ||
""" | ||
Assumes that the parameters being passed contain model parameters concatenated with a penalty weight. They are | ||
unpacked for the clients to use in training. In the first fitting round, we assume the full model is being | ||
initialized and use the ``FullParameterExchanger()`` to set all model weights. | ||
|
||
Args: | ||
parameters (NDArrays): Parameters have information about model state to be added to the relevant client | ||
model and also the penalty weight to be applied during training. | ||
config (Config): The config is sent by the FL server to allow for customization in the function if desired. | ||
fitting_round (bool): Boolean that indicates whether the current federated learning round is a fitting | ||
round or an evaluation round. This is used to help determine which parameter exchange should be used | ||
for pulling parameters. A full parameter exchanger is always used if the current federated learning | ||
round is the very first fitting round. | ||
""" | ||
assert self.model is not None and self.parameter_exchanger is not None | ||
|
||
server_model_state, self.drift_penalty_weight = self.parameter_exchanger.unpack_parameters(parameters) | ||
log(INFO, f"Penalty weight received from the server: {self.drift_penalty_weight}") | ||
|
||
super().set_parameters(server_model_state, config, fitting_round) # type: ignore[safe-super] | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def compute_training_loss( | ||
self: AdaptiveDriftConstrainedProtocol, | ||
preds: TorchPredType, | ||
features: TorchFeatureType, | ||
target: TorchTargetType, | ||
) -> TrainingLosses: | ||
""" | ||
Computes training loss given predictions of the model and ground truth data. Adds to objective by including | ||
penalty loss. | ||
|
||
Args: | ||
preds (TorchPredType): Prediction(s) of the model(s) indexed by name. All predictions included in | ||
dictionary will be used to compute metrics. | ||
features: (TorchFeatureType): Feature(s) of the model(s) indexed by name. | ||
target: (TorchTargetType): Ground truth data to evaluate predictions against. | ||
|
||
Returns: | ||
TrainingLosses: An instance of ``TrainingLosses`` containing backward loss and additional losses indexed | ||
by name. Additional losses includes penalty loss. | ||
""" | ||
loss, additional_losses = self.compute_loss_and_additional_losses(preds, features, target) | ||
if additional_losses is None: | ||
additional_losses = {} | ||
|
||
additional_losses["loss"] = loss.clone() | ||
# adding the vanilla loss to the additional losses to be used by update_after_train for potential adaptation | ||
additional_losses["loss_for_adaptation"] = loss.clone() | ||
|
||
# Compute the drift penalty loss and store it in the additional losses dictionary. | ||
penalty_loss = self.compute_penalty_loss() | ||
additional_losses["penalty_loss"] = penalty_loss.clone() | ||
|
||
return TrainingLosses(backward=loss + penalty_loss, additional_losses=additional_losses) | ||
|
||
def get_parameter_exchanger(self: AdaptiveDriftConstrainedProtocol, config: Config) -> ParameterExchanger: | ||
""" | ||
Setting up the parameter exchanger to include the appropriate packing functionality. | ||
By default we assume that we're exchanging all parameters. Can be overridden for other behavior | ||
|
||
Args: | ||
config (Config): The config is sent by the FL server to allow for customization in the function if desired. | ||
|
||
Returns: | ||
ParameterExchanger: Exchanger that can handle packing/unpacking auxiliary server information. | ||
""" | ||
|
||
return FullParameterExchangerWithPacking(ParameterPackerAdaptiveConstraint()) | ||
|
||
def update_after_train( | ||
self: AdaptiveDriftConstrainedProtocol, local_steps: int, loss_dict: dict[str, float], config: Config | ||
) -> None: | ||
""" | ||
Called after training with the number of ``local_steps`` performed over the FL round and the corresponding loss | ||
dictionary. We use this to store the training loss that we want to use to adapt the penalty weight parameter | ||
on the server side. | ||
|
||
Args: | ||
local_steps (int): The number of steps so far in the round in the local training. | ||
loss_dict (dict[str, float]): A dictionary of losses from local training. | ||
config (Config): The config from the server | ||
""" | ||
assert "loss_for_adaptation" in loss_dict | ||
# Store current loss which is the vanilla loss without the penalty term added in | ||
self.loss_for_adaptation = loss_dict["loss_for_adaptation"] | ||
super().update_after_train(local_steps, loss_dict, config) # type: ignore[safe-super] | ||
|
||
def compute_penalty_loss(self: AdaptiveDriftConstrainedProtocol) -> torch.Tensor: | ||
""" | ||
Computes the drift loss for the client model and drift tensors | ||
|
||
Returns: | ||
torch.Tensor: Computed penalty loss tensor | ||
""" | ||
# Penalty tensors must have been set for these clients. | ||
assert self.drift_penalty_tensors is not None | ||
|
||
return self.penalty_loss_function(self.model, self.drift_penalty_tensors, self.drift_penalty_weight) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from typing import Protocol, runtime_checkable | ||
|
||
import torch | ||
import torch.nn as nn | ||
from flwr.common.typing import Config, NDArrays, Scalar | ||
from torch.nn.modules.loss import _Loss | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
|
||
from fl4health.utils.typing import TorchFeatureType, TorchPredType, TorchTargetType | ||
|
||
|
||
@runtime_checkable | ||
class NumPyClientMinimalProtocol(Protocol): | ||
"""A minimal protocol for NumPyClient with just essential methods.""" | ||
|
||
def get_parameters(self, config: dict[str, Scalar]) -> NDArrays: | ||
pass | ||
|
||
def fit(self, parameters: NDArrays, config: dict[str, Scalar]) -> tuple[NDArrays, int, dict[str, Scalar]]: | ||
pass | ||
|
||
def evaluate(self, parameters: NDArrays, config: dict[str, Scalar]) -> tuple[float, int, dict[str, Scalar]]: | ||
pass | ||
|
||
def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bool) -> None: | ||
pass | ||
|
||
def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None: | ||
pass | ||
|
||
|
||
@runtime_checkable | ||
class BasicClientProtocolPreSetup(NumPyClientMinimalProtocol, Protocol): | ||
"""A minimal protocol for BasicClient focused on methods.""" | ||
|
||
device: torch.device | ||
initialized: bool | ||
|
||
# Include only methods, not attributes that get initialized later | ||
def setup_client(self, config: Config) -> None: | ||
pass | ||
|
||
def get_model(self, config: Config) -> nn.Module: | ||
pass | ||
|
||
def get_data_loaders(self, config: Config) -> tuple[DataLoader, ...]: | ||
pass | ||
|
||
def get_optimizer(self, config: Config) -> Optimizer | dict[str, Optimizer]: | ||
pass | ||
|
||
def get_criterion(self, config: Config) -> _Loss: | ||
pass | ||
|
||
def compute_loss_and_additional_losses( | ||
self, preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType | ||
) -> tuple[torch.Tensor, dict[str, torch.Tensor] | None]: | ||
pass | ||
|
||
|
||
@runtime_checkable | ||
class BasicClientProtocol(BasicClientProtocolPreSetup, Protocol): | ||
"""A minimal protocol for BasicClient focused on methods.""" | ||
|
||
model: nn.Module | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this also include the optimizer and dataloaders as minimal necessary components? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! Added:
|
Empty file.
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from pathlib import Path | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from flwr.common.typing import Config | ||
from torch.nn.modules.loss import _Loss | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader, TensorDataset | ||
|
||
from fl4health.clients.basic_client import BasicClient | ||
from fl4health.metrics import Accuracy | ||
from fl4health.mixins.adaptive_drift_constrained import AdaptiveDriftConstrainedMixin, AdaptiveDriftConstrainedProtocol | ||
from fl4health.mixins.core import BasicClientProtocol | ||
from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking | ||
from fl4health.parameter_exchange.parameter_packer import ( | ||
ParameterPackerAdaptiveConstraint, | ||
) | ||
|
||
|
||
class _TestBasicClient(BasicClient): | ||
def get_model(self, config: Config) -> nn.Module: | ||
return self.model | ||
|
||
def get_data_loaders(self, config: Config) -> tuple[DataLoader, ...]: | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self.train_loader, self.val_loader | ||
|
||
def get_optimizer(self, config: Config) -> Optimizer | dict[str, Optimizer]: | ||
return self.optimizers["global"] | ||
|
||
def get_criterion(self, config: Config) -> _Loss: | ||
return torch.nn.CrossEntropyLoss() | ||
|
||
|
||
class _TestAdaptedClient(AdaptiveDriftConstrainedMixin, _TestBasicClient): | ||
pass | ||
|
||
|
||
class _InvalidTestAdaptedClient(AdaptiveDriftConstrainedMixin): | ||
pass | ||
|
||
|
||
def test_init() -> None: | ||
# setup client | ||
client = _TestAdaptedClient(data_path=Path(""), metrics=[Accuracy()], device=torch.device("cpu")) | ||
client.model = torch.nn.Linear(5, 5) | ||
client.optimizers = {"global": torch.optim.SGD(client.model.parameters(), lr=0.0001)} # type: ignore | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
client.train_loader = DataLoader(TensorDataset(torch.ones((1000, 28, 28, 1)), torch.ones((1000)))) # type: ignore | ||
emersodb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
client.val_loader = DataLoader(TensorDataset(torch.ones((1000, 28, 28, 1)), torch.ones((1000)))) # type: ignore | ||
client.parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerAdaptiveConstraint()) | ||
client.initialized = True | ||
client.setup_client({}) | ||
|
||
assert isinstance(client, BasicClientProtocol) | ||
assert isinstance(client, AdaptiveDriftConstrainedProtocol) | ||
|
||
|
||
def test_init_raises_value_error_when_basic_client_protocol_not_satisfied() -> None: | ||
with pytest.raises(RuntimeError, match="This object needs to satisfy `BasicClientProtocolPreSetup`."): | ||
|
||
_InvalidTestAdaptedClient(data_path=Path(""), metrics=[Accuracy()]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.