Skip to content

Commit 83063af

Browse files
authored
Merge pull request #32 from VectorInstitute/other_metrics
Other metrics
2 parents 9fbe30f + 1981983 commit 83063af

File tree

14 files changed

+89
-14
lines changed

14 files changed

+89
-14
lines changed

examples/apfl_example/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def main(config: Dict[str, Any]) -> None:
7979
action="store",
8080
type=str,
8181
help="Path to configuration file.",
82-
default="config.yaml",
82+
default="examples/apfl_example/config.yaml",
8383
)
8484
args = parser.parse_args()
8585

examples/basic_example/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def main(config: Dict[str, Any]) -> None:
7777
action="store",
7878
type=str,
7979
help="Path to configuration file.",
80-
default="config.yaml",
80+
default="examples/basic_example/config.yaml",
8181
)
8282
args = parser.parse_args()
8383

examples/fedopt_example/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def main(config: Dict[str, Any]) -> None:
174174
action="store",
175175
type=str,
176176
help="Path to configuration file.",
177-
default="config.yaml",
177+
default="examples/fedopt_example/config.yaml",
178178
)
179179
args = parser.parse_args()
180180

examples/fedprox_example/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
108108
action="store",
109109
type=str,
110110
help="Path to configuration file.",
111-
default="config.yaml",
111+
default="examples/fedprox_example/config.yaml",
112112
)
113113
parser.add_argument(
114114
"--server_address",

examples/fenda_example/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def main(config: Dict[str, Any]) -> None:
8383
action="store",
8484
type=str,
8585
help="Path to configuration file.",
86-
default="config.yaml",
86+
default="examples/fenda_example/config.yaml",
8787
)
8888
args = parser.parse_args()
8989

examples/fl_plus_local_ft_example/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def main(config: Dict[str, Any]) -> None:
7777
action="store",
7878
type=str,
7979
help="Path to configuration file.",
80-
default="config.yaml",
80+
default="examples/fl_plus_local_ft_example/config.yaml",
8181
)
8282
args = parser.parse_args()
8383

examples/scaffold_example/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# APFL Federated Learning Example
1+
# SCAFFOLD Federated Learning Example
22
This is an example of [Stochastic Controlled Averaging for Federated Learning](https://arxiv.org/pdf/1910.06378.pdf)(SCAFFOLD). SCAFFOLD is a popular method for federated learning in situations where data across clients is heterogenous (non-iid). In these cases, FedAvg suffers from client drift resulting in unstable and slow
33
convergence. To surmount this, SCAFFOLD uses control variates to correct for client drift during local updates. This is shown to decrease the number of communication rounds required when compared to other approaches to Federated Learning such as FedAvg.
44

examples/scaffold_example/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def setup_client(self, config: Config) -> None:
2929
learning_rate_local = self.narrow_config_type(config, "learning_rate_local", float)
3030

3131
self.learning_rate_local = learning_rate_local
32-
self.model: nn.Module = MnistNet()
32+
self.model: nn.Module = MnistNet().to(self.device)
3333
self.criterion = torch.nn.CrossEntropyLoss()
3434
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate_local)
3535
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75)

examples/scaffold_example/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def main(config: Dict[str, Any]) -> None:
9090
action="store",
9191
type=str,
9292
help="Path to configuration file.",
93-
default="config.yaml",
93+
default="examples/scaffold_example/config.yaml",
9494
)
9595
args = parser.parse_args()
9696

fl4health/model_bases/apfl_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def update_alpha(self) -> None:
4747
assert local_grad is not None and global_grad is not None
4848
dif = local_p - global_p
4949
grad = torch.tensor(self.alpha) * local_grad + torch.tensor(1.0 - self.alpha) * global_grad
50-
grad_alpha += torch.mul(dif, grad).sum().detach().numpy()
50+
grad_alpha += torch.mul(dif, grad).sum().detach().cpu().numpy()
5151

5252
# This update constant of 0.02 is not referenced in the paper
5353
# but is present in the official implementation and other ones I have seen

0 commit comments

Comments
 (0)