33from enum import Enum
44from logging import INFO
55from pathlib import Path
6- from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
6+ from typing import Any , Dict , Iterable , List , Optional , Sequence , Tuple , Union
77
88import torch
99import torch .nn as nn
1010from flwr .client import NumPyClient
11- from flwr .common .logger import log
11+ from flwr .common .logger import LOG_COLORS , log
1212from flwr .common .typing import Config , NDArrays , Scalar
1313from torch .nn .modules .loss import _Loss
1414from torch .optim import Optimizer
1515from torch .utils .data import DataLoader
16+ from tqdm import tqdm
1617
1718from fl4health .checkpointing .client_module import CheckpointMode , ClientCheckpointModule
1819from fl4health .parameter_exchange .full_exchanger import FullParameterExchanger
@@ -41,6 +42,7 @@ def __init__(
4142 loss_meter_type : LossMeterType = LossMeterType .AVERAGE ,
4243 checkpointer : Optional [ClientCheckpointModule ] = None ,
4344 metrics_reporter : Optional [MetricsReporter ] = None ,
45+ progress_bar : bool = False ,
4446 ) -> None :
4547 """
4648 Base FL Client with functionality to train, evaluate, log, report and checkpoint.
@@ -59,12 +61,16 @@ def __init__(
5961 None.
6062 metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics
6163 during the execution. Defaults to an instance of MetricsReporter with default init parameters.
64+ progress_bar (bool): Whether or not to display a progress bar
65+ during client training and validation. Uses tqdm. Defaults to
66+ False
6267 """
6368
6469 self .data_path = data_path
6570 self .device = device
6671 self .metrics = metrics
6772 self .checkpointer = checkpointer
73+ self .progress_bar = progress_bar
6874
6975 self .client_name = generate_hash ()
7076
@@ -363,7 +369,35 @@ def _should_evaluate_after_fit(self, evaluate_after_fit: bool) -> bool:
363369 )
364370 return evaluate_after_fit or pre_aggregation_checkpointing_enabled
365371
366- def _handle_logging (
372+ def _log_header_str (
373+ self ,
374+ current_round : Optional [int ] = None ,
375+ current_epoch : Optional [int ] = None ,
376+ logging_mode : LoggingMode = LoggingMode .TRAIN ,
377+ ) -> None :
378+ """
379+ Logs a header string. By default this is logged at the beginning of each local
380+ epoch or at the beginning of the round if training by steps
381+
382+ Args:
383+ current_round (Optional[int], optional): The current FL round. (Ie current
384+ server round). Defaults to None.
385+ current_epoch (Optional[int], optional): The current epoch of local
386+ training. Defaults to None.
387+ """
388+
389+ log_str = f"Current FL Round: { str (current_round )} \t " if current_round is not None else ""
390+ log_str += f"Current Epoch: { str (current_epoch )} " if current_epoch is not None else ""
391+
392+ # Maybe add client specific info to initial log string
393+ client_str , _ = self .get_client_specific_logs (current_round , current_epoch , logging_mode )
394+
395+ log_str += client_str
396+
397+ log (INFO , "" ) # For aesthetics
398+ log (INFO , log_str )
399+
400+ def _log_results (
367401 self ,
368402 loss_dict : Dict [str , float ],
369403 metrics_dict : Dict [str , Scalar ],
@@ -382,23 +416,15 @@ def _handle_logging(
382416 current_epoch (Optional[int]): The current epoch of local training.
383417 logging_mode (LoggingMode): The logging mode (Training, Validation, or Testing).
384418 """
385- log ( INFO , "" ) # An empty log line for aesthetics
419+ _ , client_logs = self . get_client_specific_logs ( current_round , current_epoch , logging_mode )
386420
387- initial_log_str = f"Current FL Round: { str (current_round )} \t " if current_round is not None else ""
388- initial_log_str += f"Current Epoch: { str (current_epoch )} " if current_epoch is not None else ""
389-
390- # Maybe add client specific info to initial log string
391- client_str , client_logs = self .get_client_specific_logs ()
392- initial_log_str += client_str
393-
394- if initial_log_str != "" :
395- log (INFO , initial_log_str )
396- self .add_to_initial_log_str = "" # Reset variable
397-
398- # Log loss/losses
421+ # Get Metric Prefix
399422 metric_prefix = logging_mode .value
400- log (INFO , f"Client { metric_prefix } Losses:" )
401- [log (INFO , f"\t { key } : { str (val )} " ) for key , val in loss_dict .items ()]
423+
424+ # Log losses if any were provided
425+ if len (loss_dict .keys ()) > 0 :
426+ log (INFO , f"Client { metric_prefix } Losses:" )
427+ [log (INFO , f"\t { key } : { str (val )} " ) for key , val in loss_dict .items ()]
402428
403429 # Log metrics if any
404430 if len (metrics_dict .keys ()) > 0 :
@@ -409,21 +435,32 @@ def _handle_logging(
409435 if len (client_logs ) > 0 :
410436 [log (level .value , msg ) for level , msg in client_logs ]
411437
412- def get_client_specific_logs (self ) -> Tuple [str , List [Tuple [LogLevel , str ]]]:
438+ def get_client_specific_logs (
439+ self , current_round : Optional [int ], current_epoch : Optional [int ], logging_mode : LoggingMode
440+ ) -> Tuple [str , List [Tuple [LogLevel , str ]]]:
413441 """
414- This function can be overriden to provide any client specific
442+ This function can be overridden to provide any client specific
415443 information to the basic client logging. For example, perhaps a client
416- uses an LR scheduler and wants the LR to be logged each epoch. The
417- logging is called at the end of either every epoch for
418- train_by_epochs, or the end of the server round for train_by_steps
444+ uses an LR scheduler and wants the LR to be logged each epoch. Called at the
445+ beginning and end of each server round or local epoch. Also called at the end
446+ of validation/testing.
447+
448+ Args:
449+ current_round (Optional[int]): The current FL round (i.e., current
450+ server round).
451+ current_epoch (Optional[int]): The current epoch of local training.
452+ logging_mode (LoggingMode): The logging mode (Training,
453+ Validation, or Testing).
419454
420455 Returns:
421- Optional[str]: A string to append to the initial log string that
422- typically announces the current server round and current epoch
456+ Optional[str]: A string to append to the header log string that
457+ typically announces the current server round and current epoch at the
458+ beginning of each round or local epoch.
423459 Optional[List[Tuple[LogLevel, str]]]]: A list of tuples where the
424460 first element is a LogLevel as defined in fl4health.utils.
425461 typing and the second element is a string message. Each item
426- in the list will be logged when self._handle_logging is called
462+ in the list will be logged at the end of each server round or epoch.
463+ Elements will also be logged at the end of validation/testing.
427464 """
428465 return "" , []
429466
@@ -612,13 +649,16 @@ def train_by_epochs(
612649 Loss is a dictionary of one or more losses that represent the different components of the loss.
613650 """
614651 self .model .train ()
615- local_step = 0
652+ steps_this_round = 0 # Reset number of steps this round
616653 for local_epoch in range (epochs ):
617654 self .train_metric_manager .clear ()
618655 self .train_loss_meter .clear ()
656+ # Print initial log string on epoch start
657+ self ._log_header_str (current_round , local_epoch )
619658 # update before epoch hook
620659 self .update_before_epoch (epoch = local_epoch )
621- for input , target in self .train_loader :
660+ for input , target in self .maybe_progress_bar (self .train_loader ):
661+ self .update_before_step (steps_this_round )
622662 # Assume first dimension is batch size. Sampling iterators (such as Poisson batch sampling), can
623663 # construct empty batches. We skip the iteration if this occurs.
624664 if self .is_empty_batch (input ):
@@ -630,14 +670,14 @@ def train_by_epochs(
630670 losses , preds = self .train_step (input , target )
631671 self .train_loss_meter .update (losses )
632672 self .update_metric_manager (preds , target , self .train_metric_manager )
633- self .update_after_step (local_step )
673+ self .update_after_step (steps_this_round )
634674 self .total_steps += 1
635- local_step += 1
675+ steps_this_round += 1
636676 metrics = self .train_metric_manager .compute ()
637677 loss_dict = self .train_loss_meter .compute ().as_dict ()
638678
639679 # Log results and maybe report via WANDB
640- self ._handle_logging (loss_dict , metrics , current_round = current_round , current_epoch = local_epoch )
680+ self ._log_results (loss_dict , metrics , current_round , local_epoch )
641681 self ._handle_reporting (loss_dict , metrics , current_round = current_round )
642682
643683 # Return final training metrics
@@ -663,7 +703,11 @@ def train_by_steps(
663703
664704 self .train_loss_meter .clear ()
665705 self .train_metric_manager .clear ()
666- for step in range (steps ):
706+ self ._log_header_str (current_round )
707+ for step in self .maybe_progress_bar (range (steps )):
708+
709+ self .update_before_step (step )
710+
667711 try :
668712 input , target = next (train_iterator )
669713 except StopIteration :
@@ -690,7 +734,7 @@ def train_by_steps(
690734 metrics = self .train_metric_manager .compute ()
691735
692736 # Log results and maybe report via WANDB
693- self ._handle_logging (loss_dict , metrics , current_round = current_round )
737+ self ._log_results (loss_dict , metrics , current_round )
694738 self ._handle_reporting (loss_dict , metrics , current_round = current_round )
695739
696740 return loss_dict , metrics
@@ -720,7 +764,7 @@ def _validate_or_test(
720764 metric_manager .clear ()
721765 loss_meter .clear ()
722766 with torch .no_grad ():
723- for input , target in loader :
767+ for input , target in self . maybe_progress_bar ( loader ) :
724768 input = self ._move_data_to_device (input )
725769 target = self ._move_data_to_device (target )
726770 losses , preds = self .val_step (input , target )
@@ -730,7 +774,7 @@ def _validate_or_test(
730774 # Compute losses and metrics over validation set
731775 loss_dict = loss_meter .compute ().as_dict ()
732776 metrics = metric_manager .compute ()
733- self ._handle_logging (loss_dict , metrics , logging_mode = logging_mode )
777+ self ._log_results (loss_dict , metrics , logging_mode = logging_mode )
734778
735779 return loss_dict ["checkpoint" ], metrics
736780
@@ -1074,19 +1118,31 @@ def update_after_train(self, local_steps: int, loss_dict: Dict[str, float]) -> N
10741118 aggregation.
10751119
10761120 Args:
1077- local_steps (int): The number of steps in the local training.
1121+ local_steps (int): The number of steps so far in the round in the local
1122+ training.
10781123 loss_dict (Dict[str, float]): A dictionary of losses from local training.
10791124 """
10801125 pass
10811126
1127+ def update_before_step (self , step : int ) -> None :
1128+ """
1129+ Hook method called before local train step.
1130+
1131+ Args:
1132+ step (int): The local training step that was most recently
1133+ completed. Resets only at the end of the round.
1134+ """
1135+ pass
1136+
10821137 def update_after_step (self , step : int ) -> None :
10831138 """
10841139 Hook method called after local train step on client. step is an integer that represents
10851140 the local training step that was most recently completed. For example, used by the APFL
10861141 method to update the alpha value after a training a step.
10871142
10881143 Args:
1089- step (int): The step number in local training that was most recently completed.
1144+ step (int): The step number in local training that was most recently
1145+ completed. Resets only at the end of the round.
10901146 """
10911147 pass
10921148
@@ -1100,3 +1156,32 @@ def update_before_epoch(self, epoch: int) -> None:
11001156 epoch (int): Integer representing the epoch about to begin
11011157 """
11021158 pass
1159+
1160+ def maybe_progress_bar (self , iterable : Iterable ) -> Iterable :
1161+ """
1162+ Used to print progress bars during client training and validation. If
1163+ self.progress_bar is false, just returns the original input iterable
1164+ wihout modifying it.
1165+
1166+ Args:
1167+ iterable (Iterable): The iterable to wrap
1168+
1169+ Returns:
1170+ Iterable: an iterator which acts exactly like the original
1171+ iterable, but prints a dynamically updating progress bar every
1172+ time a value is requested. Or the original iterable if
1173+ self.progress_bar is False
1174+ """
1175+ if not self .progress_bar :
1176+ return iterable
1177+ else :
1178+ # Create a clean looking tqdm instance that matches the flwr logging
1179+ kwargs = {
1180+ "leave" : True ,
1181+ "ascii" : " >=" ,
1182+ # "desc": f"{LOG_COLORS['INFO']}INFO{LOG_COLORS['RESET']} ",
1183+ "unit" : "steps" ,
1184+ "dynamic_ncols" : True ,
1185+ "bar_format" : f"{ LOG_COLORS ['INFO' ]} INFO{ LOG_COLORS ['RESET' ]} " + " : {l_bar}{bar}{r_bar}" ,
1186+ }
1187+ return tqdm (iterable , ** kwargs )
0 commit comments