Skip to content

Commit 0e0a6ed

Browse files
authored
Merge pull request #204 from VectorInstitute/progress_bar
Added progress bar and fixed bugs
2 parents dc574f7 + 70e0f7d commit 0e0a6ed

File tree

10 files changed

+256
-113
lines changed

10 files changed

+256
-113
lines changed

examples/nnunet_example/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def main(
7676
device=DEVICE,
7777
metrics=[dice],
7878
data_path=dataset_path, # Argument not actually used by nnUNetClient
79+
progress_bar=True,
7980
)
8081

8182
start_client(server_address=server_address, client=client.to_client())

fl4health/clients/basic_client.py

Lines changed: 122 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
from enum import Enum
44
from logging import INFO
55
from pathlib import Path
6-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
77

88
import torch
99
import torch.nn as nn
1010
from flwr.client import NumPyClient
11-
from flwr.common.logger import log
11+
from flwr.common.logger import LOG_COLORS, log
1212
from flwr.common.typing import Config, NDArrays, Scalar
1313
from torch.nn.modules.loss import _Loss
1414
from torch.optim import Optimizer
1515
from torch.utils.data import DataLoader
16+
from tqdm import tqdm
1617

1718
from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointModule
1819
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
@@ -41,6 +42,7 @@ def __init__(
4142
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
4243
checkpointer: Optional[ClientCheckpointModule] = None,
4344
metrics_reporter: Optional[MetricsReporter] = None,
45+
progress_bar: bool = False,
4446
) -> None:
4547
"""
4648
Base FL Client with functionality to train, evaluate, log, report and checkpoint.
@@ -59,12 +61,16 @@ def __init__(
5961
None.
6062
metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics
6163
during the execution. Defaults to an instance of MetricsReporter with default init parameters.
64+
progress_bar (bool): Whether or not to display a progress bar
65+
during client training and validation. Uses tqdm. Defaults to
66+
False
6267
"""
6368

6469
self.data_path = data_path
6570
self.device = device
6671
self.metrics = metrics
6772
self.checkpointer = checkpointer
73+
self.progress_bar = progress_bar
6874

6975
self.client_name = generate_hash()
7076

@@ -363,7 +369,35 @@ def _should_evaluate_after_fit(self, evaluate_after_fit: bool) -> bool:
363369
)
364370
return evaluate_after_fit or pre_aggregation_checkpointing_enabled
365371

366-
def _handle_logging(
372+
def _log_header_str(
373+
self,
374+
current_round: Optional[int] = None,
375+
current_epoch: Optional[int] = None,
376+
logging_mode: LoggingMode = LoggingMode.TRAIN,
377+
) -> None:
378+
"""
379+
Logs a header string. By default this is logged at the beginning of each local
380+
epoch or at the beginning of the round if training by steps
381+
382+
Args:
383+
current_round (Optional[int], optional): The current FL round. (Ie current
384+
server round). Defaults to None.
385+
current_epoch (Optional[int], optional): The current epoch of local
386+
training. Defaults to None.
387+
"""
388+
389+
log_str = f"Current FL Round: {str(current_round)}\t" if current_round is not None else ""
390+
log_str += f"Current Epoch: {str(current_epoch)}" if current_epoch is not None else ""
391+
392+
# Maybe add client specific info to initial log string
393+
client_str, _ = self.get_client_specific_logs(current_round, current_epoch, logging_mode)
394+
395+
log_str += client_str
396+
397+
log(INFO, "") # For aesthetics
398+
log(INFO, log_str)
399+
400+
def _log_results(
367401
self,
368402
loss_dict: Dict[str, float],
369403
metrics_dict: Dict[str, Scalar],
@@ -382,23 +416,15 @@ def _handle_logging(
382416
current_epoch (Optional[int]): The current epoch of local training.
383417
logging_mode (LoggingMode): The logging mode (Training, Validation, or Testing).
384418
"""
385-
log(INFO, "") # An empty log line for aesthetics
419+
_, client_logs = self.get_client_specific_logs(current_round, current_epoch, logging_mode)
386420

387-
initial_log_str = f"Current FL Round: {str(current_round)}\t" if current_round is not None else ""
388-
initial_log_str += f"Current Epoch: {str(current_epoch)}" if current_epoch is not None else ""
389-
390-
# Maybe add client specific info to initial log string
391-
client_str, client_logs = self.get_client_specific_logs()
392-
initial_log_str += client_str
393-
394-
if initial_log_str != "":
395-
log(INFO, initial_log_str)
396-
self.add_to_initial_log_str = "" # Reset variable
397-
398-
# Log loss/losses
421+
# Get Metric Prefix
399422
metric_prefix = logging_mode.value
400-
log(INFO, f"Client {metric_prefix} Losses:")
401-
[log(INFO, f"\t {key}: {str(val)}") for key, val in loss_dict.items()]
423+
424+
# Log losses if any were provided
425+
if len(loss_dict.keys()) > 0:
426+
log(INFO, f"Client {metric_prefix} Losses:")
427+
[log(INFO, f"\t {key}: {str(val)}") for key, val in loss_dict.items()]
402428

403429
# Log metrics if any
404430
if len(metrics_dict.keys()) > 0:
@@ -409,21 +435,32 @@ def _handle_logging(
409435
if len(client_logs) > 0:
410436
[log(level.value, msg) for level, msg in client_logs]
411437

412-
def get_client_specific_logs(self) -> Tuple[str, List[Tuple[LogLevel, str]]]:
438+
def get_client_specific_logs(
439+
self, current_round: Optional[int], current_epoch: Optional[int], logging_mode: LoggingMode
440+
) -> Tuple[str, List[Tuple[LogLevel, str]]]:
413441
"""
414-
This function can be overriden to provide any client specific
442+
This function can be overridden to provide any client specific
415443
information to the basic client logging. For example, perhaps a client
416-
uses an LR scheduler and wants the LR to be logged each epoch. The
417-
logging is called at the end of either every epoch for
418-
train_by_epochs, or the end of the server round for train_by_steps
444+
uses an LR scheduler and wants the LR to be logged each epoch. Called at the
445+
beginning and end of each server round or local epoch. Also called at the end
446+
of validation/testing.
447+
448+
Args:
449+
current_round (Optional[int]): The current FL round (i.e., current
450+
server round).
451+
current_epoch (Optional[int]): The current epoch of local training.
452+
logging_mode (LoggingMode): The logging mode (Training,
453+
Validation, or Testing).
419454
420455
Returns:
421-
Optional[str]: A string to append to the initial log string that
422-
typically announces the current server round and current epoch
456+
Optional[str]: A string to append to the header log string that
457+
typically announces the current server round and current epoch at the
458+
beginning of each round or local epoch.
423459
Optional[List[Tuple[LogLevel, str]]]]: A list of tuples where the
424460
first element is a LogLevel as defined in fl4health.utils.
425461
typing and the second element is a string message. Each item
426-
in the list will be logged when self._handle_logging is called
462+
in the list will be logged at the end of each server round or epoch.
463+
Elements will also be logged at the end of validation/testing.
427464
"""
428465
return "", []
429466

@@ -612,13 +649,16 @@ def train_by_epochs(
612649
Loss is a dictionary of one or more losses that represent the different components of the loss.
613650
"""
614651
self.model.train()
615-
local_step = 0
652+
steps_this_round = 0 # Reset number of steps this round
616653
for local_epoch in range(epochs):
617654
self.train_metric_manager.clear()
618655
self.train_loss_meter.clear()
656+
# Print initial log string on epoch start
657+
self._log_header_str(current_round, local_epoch)
619658
# update before epoch hook
620659
self.update_before_epoch(epoch=local_epoch)
621-
for input, target in self.train_loader:
660+
for input, target in self.maybe_progress_bar(self.train_loader):
661+
self.update_before_step(steps_this_round)
622662
# Assume first dimension is batch size. Sampling iterators (such as Poisson batch sampling), can
623663
# construct empty batches. We skip the iteration if this occurs.
624664
if self.is_empty_batch(input):
@@ -630,14 +670,14 @@ def train_by_epochs(
630670
losses, preds = self.train_step(input, target)
631671
self.train_loss_meter.update(losses)
632672
self.update_metric_manager(preds, target, self.train_metric_manager)
633-
self.update_after_step(local_step)
673+
self.update_after_step(steps_this_round)
634674
self.total_steps += 1
635-
local_step += 1
675+
steps_this_round += 1
636676
metrics = self.train_metric_manager.compute()
637677
loss_dict = self.train_loss_meter.compute().as_dict()
638678

639679
# Log results and maybe report via WANDB
640-
self._handle_logging(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch)
680+
self._log_results(loss_dict, metrics, current_round, local_epoch)
641681
self._handle_reporting(loss_dict, metrics, current_round=current_round)
642682

643683
# Return final training metrics
@@ -663,7 +703,11 @@ def train_by_steps(
663703

664704
self.train_loss_meter.clear()
665705
self.train_metric_manager.clear()
666-
for step in range(steps):
706+
self._log_header_str(current_round)
707+
for step in self.maybe_progress_bar(range(steps)):
708+
709+
self.update_before_step(step)
710+
667711
try:
668712
input, target = next(train_iterator)
669713
except StopIteration:
@@ -690,7 +734,7 @@ def train_by_steps(
690734
metrics = self.train_metric_manager.compute()
691735

692736
# Log results and maybe report via WANDB
693-
self._handle_logging(loss_dict, metrics, current_round=current_round)
737+
self._log_results(loss_dict, metrics, current_round)
694738
self._handle_reporting(loss_dict, metrics, current_round=current_round)
695739

696740
return loss_dict, metrics
@@ -720,7 +764,7 @@ def _validate_or_test(
720764
metric_manager.clear()
721765
loss_meter.clear()
722766
with torch.no_grad():
723-
for input, target in loader:
767+
for input, target in self.maybe_progress_bar(loader):
724768
input = self._move_data_to_device(input)
725769
target = self._move_data_to_device(target)
726770
losses, preds = self.val_step(input, target)
@@ -730,7 +774,7 @@ def _validate_or_test(
730774
# Compute losses and metrics over validation set
731775
loss_dict = loss_meter.compute().as_dict()
732776
metrics = metric_manager.compute()
733-
self._handle_logging(loss_dict, metrics, logging_mode=logging_mode)
777+
self._log_results(loss_dict, metrics, logging_mode=logging_mode)
734778

735779
return loss_dict["checkpoint"], metrics
736780

@@ -1074,19 +1118,31 @@ def update_after_train(self, local_steps: int, loss_dict: Dict[str, float]) -> N
10741118
aggregation.
10751119
10761120
Args:
1077-
local_steps (int): The number of steps in the local training.
1121+
local_steps (int): The number of steps so far in the round in the local
1122+
training.
10781123
loss_dict (Dict[str, float]): A dictionary of losses from local training.
10791124
"""
10801125
pass
10811126

1127+
def update_before_step(self, step: int) -> None:
1128+
"""
1129+
Hook method called before local train step.
1130+
1131+
Args:
1132+
step (int): The local training step that was most recently
1133+
completed. Resets only at the end of the round.
1134+
"""
1135+
pass
1136+
10821137
def update_after_step(self, step: int) -> None:
10831138
"""
10841139
Hook method called after local train step on client. step is an integer that represents
10851140
the local training step that was most recently completed. For example, used by the APFL
10861141
method to update the alpha value after a training a step.
10871142
10881143
Args:
1089-
step (int): The step number in local training that was most recently completed.
1144+
step (int): The step number in local training that was most recently
1145+
completed. Resets only at the end of the round.
10901146
"""
10911147
pass
10921148

@@ -1100,3 +1156,32 @@ def update_before_epoch(self, epoch: int) -> None:
11001156
epoch (int): Integer representing the epoch about to begin
11011157
"""
11021158
pass
1159+
1160+
def maybe_progress_bar(self, iterable: Iterable) -> Iterable:
1161+
"""
1162+
Used to print progress bars during client training and validation. If
1163+
self.progress_bar is false, just returns the original input iterable
1164+
wihout modifying it.
1165+
1166+
Args:
1167+
iterable (Iterable): The iterable to wrap
1168+
1169+
Returns:
1170+
Iterable: an iterator which acts exactly like the original
1171+
iterable, but prints a dynamically updating progress bar every
1172+
time a value is requested. Or the original iterable if
1173+
self.progress_bar is False
1174+
"""
1175+
if not self.progress_bar:
1176+
return iterable
1177+
else:
1178+
# Create a clean looking tqdm instance that matches the flwr logging
1179+
kwargs = {
1180+
"leave": True,
1181+
"ascii": " >=",
1182+
# "desc": f"{LOG_COLORS['INFO']}INFO{LOG_COLORS['RESET']} ",
1183+
"unit": "steps",
1184+
"dynamic_ncols": True,
1185+
"bar_format": f"{LOG_COLORS['INFO']}INFO{LOG_COLORS['RESET']}" + " : {l_bar}{bar}{r_bar}",
1186+
}
1187+
return tqdm(iterable, **kwargs)

fl4health/clients/flash_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def train_by_epochs(
6565
for local_epoch in range(epochs):
6666
self.train_metric_manager.clear()
6767
self.train_loss_meter.clear()
68+
self._log_header_str(current_round, local_epoch)
6869
for input, target in self.train_loader:
6970
if self.is_empty_batch(input):
7071
log(INFO, "Empty batch generated by data loader. Skipping step.")
@@ -83,7 +84,7 @@ def train_by_epochs(
8384
loss_dict = self.train_loss_meter.compute().as_dict()
8485
current_loss, _ = self.validate()
8586

86-
self._handle_logging(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch)
87+
self._log_results(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch)
8788
self._handle_reporting(loss_dict, metrics, current_round=current_round)
8889

8990
if self.gamma is not None and previous_loss - current_loss < self.gamma / (local_epoch + 1):

0 commit comments

Comments
 (0)