|
| 1 | +"""AdaptiveDriftConstrainedMixin""" |
| 2 | + |
| 3 | +import warnings |
| 4 | +from logging import INFO, WARN |
| 5 | +from typing import Any, Protocol, runtime_checkable |
| 6 | + |
| 7 | +import torch |
| 8 | +from flwr.common.logger import log |
| 9 | +from flwr.common.typing import Config, NDArrays |
| 10 | + |
| 11 | +from fl4health.clients.basic_client import BasicClient |
| 12 | +from fl4health.losses.weight_drift_loss import WeightDriftLoss |
| 13 | +from fl4health.mixins.core_protocols import BasicClientProtocol, BasicClientProtocolPreSetup |
| 14 | +from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger |
| 15 | +from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking |
| 16 | +from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger |
| 17 | +from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint |
| 18 | +from fl4health.utils.losses import TrainingLosses |
| 19 | +from fl4health.utils.typing import TorchFeatureType, TorchPredType, TorchTargetType |
| 20 | + |
| 21 | + |
| 22 | +@runtime_checkable |
| 23 | +class AdaptiveDriftConstrainedProtocol(BasicClientProtocol, Protocol): |
| 24 | + loss_for_adaptation: float |
| 25 | + drift_penalty_tensors: list[torch.Tensor] | None |
| 26 | + drift_penalty_weight: float | None |
| 27 | + penalty_loss_function: WeightDriftLoss |
| 28 | + parameter_exchanger: FullParameterExchangerWithPacking[float] |
| 29 | + |
| 30 | + def compute_penalty_loss(self) -> torch.Tensor: ... # noqa: E704 |
| 31 | + |
| 32 | + |
| 33 | +class AdaptiveDriftConstrainedMixin: |
| 34 | + def __init__(self, *args: Any, **kwargs: Any): |
| 35 | + """Adaptive Drift Constrained Mixin |
| 36 | +
|
| 37 | + To be used with `~fl4health.BaseClient` in order to add the ability to compute |
| 38 | + losses via a constrained adaptive drift. |
| 39 | +
|
| 40 | + Raises: |
| 41 | + RuntimeError: when the inheriting class does not satisfy `BasicClientProtocolPreSetup`. |
| 42 | + """ |
| 43 | + # Initialize mixin-specific attributes with default values |
| 44 | + self.loss_for_adaptation = 0.1 |
| 45 | + self.drift_penalty_tensors = None |
| 46 | + self.drift_penalty_weight = None |
| 47 | + |
| 48 | + # Call parent's init |
| 49 | + try: |
| 50 | + super().__init__(*args, **kwargs) |
| 51 | + except TypeError: |
| 52 | + # if a parent class doesn't take args/kwargs |
| 53 | + super().__init__() |
| 54 | + |
| 55 | + # set penalty_loss_function |
| 56 | + if not isinstance(self, BasicClientProtocolPreSetup): |
| 57 | + raise RuntimeError("This object needs to satisfy `BasicClientProtocolPreSetup`.") |
| 58 | + self.penalty_loss_function = WeightDriftLoss(self.device) |
| 59 | + |
| 60 | + def __init_subclass__(cls, **kwargs: Any): |
| 61 | + """This method is called when a class inherits from AdaptiveDriftConstrainedMixin.""" |
| 62 | + super().__init_subclass__(**kwargs) |
| 63 | + |
| 64 | + # Skip check for other mixins |
| 65 | + if cls.__name__.endswith("Mixin"): |
| 66 | + return |
| 67 | + |
| 68 | + # Skip validation for dynamically created classes |
| 69 | + if hasattr(cls, "_dynamically_created"): |
| 70 | + return |
| 71 | + |
| 72 | + # Check at class definition time if the parent class satisfies BasicClientProtocol |
| 73 | + for base in cls.__bases__: |
| 74 | + if base is not AdaptiveDriftConstrainedMixin and issubclass(base, BasicClient): |
| 75 | + return |
| 76 | + |
| 77 | + # If we get here, no compatible base was found |
| 78 | + msg = ( |
| 79 | + f"Class {cls.__name__} inherits from AdaptiveDriftConstrainedMixin but none of its other " |
| 80 | + f"base classes is a BasicClient. This may cause runtime errors." |
| 81 | + ) |
| 82 | + log(WARN, msg) |
| 83 | + warnings.warn( |
| 84 | + msg, |
| 85 | + RuntimeWarning, |
| 86 | + ) |
| 87 | + |
| 88 | + def get_parameters(self: AdaptiveDriftConstrainedProtocol, config: Config) -> NDArrays: |
| 89 | + """ |
| 90 | + Packs the parameters and training loss into a single ``NDArrays`` to be sent to the server for aggregation. If |
| 91 | + the client has not been initialized, this means the server is requesting parameters for initialization and |
| 92 | + just the model parameters are sent. When using the ``FedAvgWithAdaptiveConstraint`` strategy, this should not |
| 93 | + happen, as that strategy requires server-side initialization parameters. However, other strategies may handle |
| 94 | + this case. |
| 95 | +
|
| 96 | + Args: |
| 97 | + config (Config): Configurations to allow for customization of this functions behavior |
| 98 | +
|
| 99 | + Returns: |
| 100 | + NDArrays: Parameters and training loss packed together into a list of numpy arrays to be sent to the server |
| 101 | + """ |
| 102 | + if not self.initialized: |
| 103 | + log(INFO, "Setting up client and providing full model parameters to the server for initialization") |
| 104 | + |
| 105 | + # If initialized is False, the server is requesting model parameters from which to initialize all other |
| 106 | + # clients. As such get_parameters is being called before fit or evaluate, so we must call |
| 107 | + # setup_client first. |
| 108 | + self.setup_client(config) |
| 109 | + |
| 110 | + # Need all parameters even if normally exchanging partial |
| 111 | + return FullParameterExchanger().push_parameters(self.model, config=config) |
| 112 | + else: |
| 113 | + |
| 114 | + # Make sure the proper components are there |
| 115 | + assert ( |
| 116 | + self.model is not None |
| 117 | + and self.parameter_exchanger is not None |
| 118 | + and self.loss_for_adaptation is not None |
| 119 | + ) |
| 120 | + model_weights = self.parameter_exchanger.push_parameters(self.model, config=config) |
| 121 | + |
| 122 | + # Weights and training loss sent to server for aggregation. Training loss is sent because server will |
| 123 | + # decide to increase or decrease the penalty weight, if adaptivity is turned on. |
| 124 | + packed_params = self.parameter_exchanger.pack_parameters(model_weights, self.loss_for_adaptation) |
| 125 | + return packed_params |
| 126 | + |
| 127 | + def set_parameters( |
| 128 | + self: AdaptiveDriftConstrainedProtocol, parameters: NDArrays, config: Config, fitting_round: bool |
| 129 | + ) -> None: |
| 130 | + """ |
| 131 | + Assumes that the parameters being passed contain model parameters concatenated with a penalty weight. They are |
| 132 | + unpacked for the clients to use in training. In the first fitting round, we assume the full model is being |
| 133 | + initialized and use the ``FullParameterExchanger()`` to set all model weights. |
| 134 | +
|
| 135 | + Args: |
| 136 | + parameters (NDArrays): Parameters have information about model state to be added to the relevant client |
| 137 | + model and also the penalty weight to be applied during training. |
| 138 | + config (Config): The config is sent by the FL server to allow for customization in the function if desired. |
| 139 | + fitting_round (bool): Boolean that indicates whether the current federated learning round is a fitting |
| 140 | + round or an evaluation round. This is used to help determine which parameter exchange should be used |
| 141 | + for pulling parameters. A full parameter exchanger is always used if the current federated learning |
| 142 | + round is the very first fitting round. |
| 143 | + """ |
| 144 | + assert self.model is not None and self.parameter_exchanger is not None |
| 145 | + |
| 146 | + server_model_state, self.drift_penalty_weight = self.parameter_exchanger.unpack_parameters(parameters) |
| 147 | + log(INFO, f"Penalty weight received from the server: {self.drift_penalty_weight}") |
| 148 | + |
| 149 | + super().set_parameters(server_model_state, config, fitting_round) # type: ignore[safe-super] |
| 150 | + |
| 151 | + def compute_training_loss( |
| 152 | + self: AdaptiveDriftConstrainedProtocol, |
| 153 | + preds: TorchPredType, |
| 154 | + features: TorchFeatureType, |
| 155 | + target: TorchTargetType, |
| 156 | + ) -> TrainingLosses: |
| 157 | + """ |
| 158 | + Computes training loss given predictions of the model and ground truth data. Adds to objective by including |
| 159 | + penalty loss. |
| 160 | +
|
| 161 | + Args: |
| 162 | + preds (TorchPredType): Prediction(s) of the model(s) indexed by name. All predictions included in |
| 163 | + dictionary will be used to compute metrics. |
| 164 | + features: (TorchFeatureType): Feature(s) of the model(s) indexed by name. |
| 165 | + target: (TorchTargetType): Ground truth data to evaluate predictions against. |
| 166 | +
|
| 167 | + Returns: |
| 168 | + TrainingLosses: An instance of ``TrainingLosses`` containing backward loss and additional losses indexed |
| 169 | + by name. Additional losses includes penalty loss. |
| 170 | + """ |
| 171 | + loss, additional_losses = self.compute_loss_and_additional_losses(preds, features, target) |
| 172 | + if additional_losses is None: |
| 173 | + additional_losses = {} |
| 174 | + |
| 175 | + additional_losses["loss"] = loss.clone() |
| 176 | + # adding the vanilla loss to the additional losses to be used by update_after_train for potential adaptation |
| 177 | + additional_losses["loss_for_adaptation"] = loss.clone() |
| 178 | + |
| 179 | + # Compute the drift penalty loss and store it in the additional losses dictionary. |
| 180 | + penalty_loss = self.compute_penalty_loss() |
| 181 | + additional_losses["penalty_loss"] = penalty_loss.clone() |
| 182 | + |
| 183 | + return TrainingLosses(backward=loss + penalty_loss, additional_losses=additional_losses) |
| 184 | + |
| 185 | + def get_parameter_exchanger(self: AdaptiveDriftConstrainedProtocol, config: Config) -> ParameterExchanger: |
| 186 | + """ |
| 187 | + Setting up the parameter exchanger to include the appropriate packing functionality. |
| 188 | + By default we assume that we're exchanging all parameters. Can be overridden for other behavior |
| 189 | +
|
| 190 | + Args: |
| 191 | + config (Config): The config is sent by the FL server to allow for customization in the function if desired. |
| 192 | +
|
| 193 | + Returns: |
| 194 | + ParameterExchanger: Exchanger that can handle packing/unpacking auxiliary server information. |
| 195 | + """ |
| 196 | + |
| 197 | + return FullParameterExchangerWithPacking(ParameterPackerAdaptiveConstraint()) |
| 198 | + |
| 199 | + def update_after_train( |
| 200 | + self: AdaptiveDriftConstrainedProtocol, local_steps: int, loss_dict: dict[str, float], config: Config |
| 201 | + ) -> None: |
| 202 | + """ |
| 203 | + Called after training with the number of ``local_steps`` performed over the FL round and the corresponding loss |
| 204 | + dictionary. We use this to store the training loss that we want to use to adapt the penalty weight parameter |
| 205 | + on the server side. |
| 206 | +
|
| 207 | + Args: |
| 208 | + local_steps (int): The number of steps so far in the round in the local training. |
| 209 | + loss_dict (dict[str, float]): A dictionary of losses from local training. |
| 210 | + config (Config): The config from the server |
| 211 | + """ |
| 212 | + assert "loss_for_adaptation" in loss_dict |
| 213 | + # Store current loss which is the vanilla loss without the penalty term added in |
| 214 | + self.loss_for_adaptation = loss_dict["loss_for_adaptation"] |
| 215 | + super().update_after_train(local_steps, loss_dict, config) # type: ignore[safe-super] |
| 216 | + |
| 217 | + def compute_penalty_loss(self: AdaptiveDriftConstrainedProtocol) -> torch.Tensor: |
| 218 | + """ |
| 219 | + Computes the drift loss for the client model and drift tensors |
| 220 | +
|
| 221 | + Returns: |
| 222 | + torch.Tensor: Computed penalty loss tensor |
| 223 | + """ |
| 224 | + # Penalty tensors must have been set for these clients. |
| 225 | + assert self.drift_penalty_tensors is not None |
| 226 | + |
| 227 | + return self.penalty_loss_function(self.model, self.drift_penalty_tensors, self.drift_penalty_weight) |
| 228 | + |
| 229 | + |
| 230 | +def apply_adaptive_drift_to_client(client_base_type: type[BasicClient]) -> type[BasicClient]: |
| 231 | + """Dynamically create an adapted client class. |
| 232 | +
|
| 233 | + Args: |
| 234 | + client_base_type (type[BasicClient]): The class to be mixed. |
| 235 | +
|
| 236 | + Returns: |
| 237 | + type[BasicClient]: A basic client that has been mixed with `AdaptiveDriftConstrainedMixin`. |
| 238 | + """ |
| 239 | + |
| 240 | + return type( |
| 241 | + f"AdaptiveDrift{client_base_type.__name__}", |
| 242 | + ( |
| 243 | + AdaptiveDriftConstrainedMixin, |
| 244 | + client_base_type, |
| 245 | + ), |
| 246 | + { |
| 247 | + # Special flag to bypass validation |
| 248 | + "_dynamically_created": True |
| 249 | + }, |
| 250 | + ) |
0 commit comments