Skip to content

Commit d228d3c

Browse files
authored
Merge pull request #80 from VectorInstitute/multi-label
bug fix in basic client
2 parents fd9b041 + 18ddf3e commit d228d3c

File tree

6 files changed

+69
-20
lines changed

6 files changed

+69
-20
lines changed

examples/feature_alignment_example/client.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22
from logging import INFO
33
from pathlib import Path
4-
from typing import Sequence, Tuple
4+
from typing import List, Sequence, Tuple, Union
55

66
import flwr as fl
77
import numpy as np
@@ -22,9 +22,14 @@
2222

2323
class Mimic3TabularDataClient(TabularDataClient):
2424
def __init__(
25-
self, data_path: Path, metrics: Sequence[Metric], device: torch.device, id_column: str, target_column: str
25+
self,
26+
data_path: Path,
27+
metrics: Sequence[Metric],
28+
device: torch.device,
29+
id_column: str,
30+
targets: Union[str, List[str]],
2631
) -> None:
27-
super().__init__(data_path, metrics, device, id_column, target_column)
32+
super().__init__(data_path, metrics, device, id_column, targets)
2833

2934
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
3035
batch_size = self.narrow_config_type(config, "batch_size", int)
@@ -92,7 +97,7 @@ def get_data_frame(self, config: Config) -> pd.DataFrame:
9297
log(INFO, f"Server Address: {args.server_address}")
9398

9499
# ham_id is the id column and LOSgroupNum is the target column.
95-
client = Mimic3TabularDataClient(data_path, [Accuracy("accuracy")], DEVICE, "hadm_id", "LOSgroupNum")
100+
client = Mimic3TabularDataClient(data_path, [Accuracy("accuracy")], DEVICE, "hadm_id", ["LOSgroupNum"])
96101
# This call demonstrates how the user may specify a particular sklearn pipeline for a specific feature.
97102
client.preset_specific_pipeline("NumNotes", MaxAbsScaler())
98103
fl.client.start_numpy_client(server_address=args.server_address, client=client)

fl4health/clients/basic_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def setup_client(self, config: Config) -> None:
557557

558558
self.wandb_reporter = ClientWandBReporter.from_config(self.client_name, config)
559559

560-
self.intialized = True
560+
self.initialized = True
561561

562562
def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
563563
"""

fl4health/clients/tabular_data_client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.pipeline import Pipeline
1010

1111
from fl4health.clients.basic_client import BasicClient
12-
from fl4health.feature_alignment.constants import FEATURE_INFO, FORMAT_SPECIFIED, INPUT_DIMENSION, OUTPUT_DIMENSION
12+
from fl4health.feature_alignment.constants import FEATURE_INFO, INPUT_DIMENSION, OUTPUT_DIMENSION, SOURCE_SPECIFIED
1313
from fl4health.feature_alignment.tab_features_info_encoder import TabularFeaturesInfoEncoder
1414
from fl4health.feature_alignment.tab_features_preprocessor import TabularFeaturesPreprocessor
1515
from fl4health.utils.metrics import Metric
@@ -43,17 +43,17 @@ def setup_client(self, config: Config) -> None:
4343
Initialize the client by encoding the information of its tabular data
4444
and initializing the corresponding TabularFeaturesPreprocessor.
4545
46-
config[FORMAT_SPECIFIED] indicates whether the server has obtained
46+
config[SOURCE_SPECIFIED] indicates whether the server has obtained
4747
the source of information to perform feature alignment.
4848
If it is True, it means the server has obtained such information
4949
(either a priori or by polling a client).
5050
So the client will encode that information and use it instead
5151
to perform feature preprocessing.
5252
"""
53-
format_specified = self.narrow_config_type(config, FORMAT_SPECIFIED, bool)
53+
source_specified = self.narrow_config_type(config, SOURCE_SPECIFIED, bool)
5454
self.df = self.get_data_frame(config)
5555

56-
if format_specified:
56+
if source_specified:
5757
# Since the server has obtained its source of information,
5858
# the client will encode that instead.
5959
self.tabular_features_info_encoder = TabularFeaturesInfoEncoder.from_json(
@@ -74,7 +74,7 @@ def setup_client(self, config: Config) -> None:
7474
# that the first dimension is the number of rows.
7575
self.input_dimension = self.aligned_features.shape[1]
7676
self.output_dimension = self.tabular_features_info_encoder.get_target_dimension()
77-
log(INFO, f"input dimension: {self.input_dimension}, output_dimension: {self.output_dimension}")
77+
log(INFO, f"input dimension: {self.input_dimension}, target dimension: {self.output_dimension}")
7878

7979
super().setup_client(config)
8080

@@ -111,8 +111,8 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]:
111111
"""
112112
if not self.initialized:
113113
self.setup_client(config)
114-
format_specified = self.narrow_config_type(config, FORMAT_SPECIFIED, bool)
115-
if not format_specified:
114+
source_specified = self.narrow_config_type(config, SOURCE_SPECIFIED, bool)
115+
if not source_specified:
116116
return {
117117
FEATURE_INFO: self.tabular_features_info_encoder.to_json(),
118118
}

fl4health/feature_alignment/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
# FORMAT_SPECIFIED indicates whether the server has the "source of truth"
1111
# to be used for feature alignment.
12-
FORMAT_SPECIFIED = "format_specified"
12+
SOURCE_SPECIFIED = "source_specified"
1313

1414
# FEATURE_INFO refers to the encoded feature information (source of truth).
1515
FEATURE_INFO = "feature_info"

fl4health/server/tabular_feature_alignment_server.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from fl4health.feature_alignment.constants import (
1414
CURRENT_SERVER_ROUND,
1515
FEATURE_INFO,
16-
FORMAT_SPECIFIED,
1716
INPUT_DIMENSION,
1817
OUTPUT_DIMENSION,
18+
SOURCE_SPECIFIED,
1919
)
2020
from fl4health.feature_alignment.tab_features_info_encoder import TabularFeaturesInfoEncoder
2121
from fl4health.reporting.fl_wanb import ServerWandBReporter
@@ -70,11 +70,11 @@ def __init__(
7070
self.tab_features_info = tabular_features_source_of_truth
7171
self.config = config
7272
self.initialize_parameters = initialize_parameters
73-
self.format_info_gathered = False
73+
self.source_info_gathered = False
7474
self.dimension_info: Dict[str, int] = {}
7575
# ensure that self.strategy has type BasicFedAvg so its on_fit_config_fn can be specified.
7676
assert isinstance(self.strategy, BasicFedAvg)
77-
self.strategy.on_fit_config_fn = partial(fit_config, self.config, self.format_info_gathered)
77+
self.strategy.on_fit_config_fn = partial(fit_config, self.config, self.source_info_gathered)
7878

7979
def _set_dimension_info(self, input_dimension: int, output_dimension: int) -> None:
8080
self.dimension_info[INPUT_DIMENSION] = input_dimension
@@ -110,9 +110,9 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> History:
110110

111111
# the feature information is sent to clients through the config parameter.
112112
self.config[FEATURE_INFO] = feature_info_source
113-
self.format_info_gathered = True
113+
self.source_info_gathered = True
114114

115-
self.strategy.on_fit_config_fn = partial(fit_config, self.config, self.format_info_gathered)
115+
self.strategy.on_fit_config_fn = partial(fit_config, self.config, self.source_info_gathered)
116116

117117
# Now the server waits until feature alignment is performed on the clients' side
118118
# and subsequently requests the input and output dimensions, which are needed for initializing
@@ -158,7 +158,7 @@ def poll_clients_for_dimension_info(self, timeout: Optional[float]) -> Tuple[int
158158
return input_dimension, output_dimension
159159

160160

161-
def fit_config(config: Config, format_specified: bool, current_server_round: int) -> Config:
162-
config[FORMAT_SPECIFIED] = format_specified
161+
def fit_config(config: Config, source_specified: bool, current_server_round: int) -> Config:
162+
config[SOURCE_SPECIFIED] = source_specified
163163
config[CURRENT_SERVER_ROUND] = current_server_round
164164
return config
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from pathlib import Path
2+
from typing import Tuple
3+
4+
import torch
5+
import torch.nn as nn
6+
from flwr.common.typing import Config
7+
from torch.nn.modules.loss import _Loss
8+
from torch.optim import Optimizer
9+
from torch.utils.data import DataLoader
10+
from torch.utils.data.dataset import TensorDataset
11+
12+
from fl4health.clients.basic_client import BasicClient
13+
from fl4health.utils.metrics import Accuracy
14+
from tests.test_utils.models_for_test import LinearModel
15+
16+
17+
class TestingClient(BasicClient):
18+
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
19+
train_loader = DataLoader(TensorDataset(torch.ones((4, 4)), torch.ones((4))))
20+
val_loader = DataLoader(TensorDataset(torch.ones((4, 4)), torch.ones((4))))
21+
return train_loader, val_loader
22+
23+
def get_criterion(self, config: Config) -> _Loss:
24+
return torch.nn.CrossEntropyLoss()
25+
26+
def get_optimizer(self, config: Config) -> Optimizer:
27+
return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
28+
29+
def get_model(self, config: Config) -> nn.Module:
30+
return LinearModel().to(self.device)
31+
32+
33+
def test_setup_client() -> None:
34+
client = TestingClient(data_path=Path(""), metrics=[Accuracy()], device=torch.device("cpu"))
35+
client.setup_client({})
36+
assert client.parameter_exchanger is not None
37+
assert client.model is not None
38+
assert client.optimizer is not None
39+
assert client.train_loader is not None
40+
assert client.val_loader is not None
41+
assert client.num_train_samples is not None
42+
assert client.num_val_samples is not None
43+
assert client.learning_rate is not None
44+
assert client.initialized

0 commit comments

Comments
 (0)