Skip to content

Commit 9e82518

Browse files
committed
tests for ensure_protocol_compliance and add unit tests
1 parent 2d22847 commit 9e82518

File tree

4 files changed

+69
-1
lines changed

4 files changed

+69
-1
lines changed

fl4health/mixins/personalized/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
@wrapt.decorator
10-
def ensure_protocol_compliance(func: Callable, instance: Any | None, *args: Any, **kwargs: Any) -> None:
10+
def ensure_protocol_compliance(func: Callable, instance: Any | None, args: Any, kwargs: Any) -> None:
1111
# validate self is a BasicClient
1212
self = instance
1313
if not isinstance(self, BasicClient):
File renamed without changes.

tests/mixins/personalized/conftest.py

Whitespace-only changes.
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from contextlib import nullcontext as does_not_raise
2+
from pathlib import Path
3+
4+
import pytest
5+
import torch
6+
import torch.nn as nn
7+
from flwr.common.typing import Config
8+
from torch.nn.modules.loss import _Loss
9+
from torch.optim import Optimizer
10+
from torch.utils.data import DataLoader, TensorDataset
11+
12+
from fl4health.clients.basic_client import BasicClient
13+
from fl4health.metrics import Accuracy
14+
from fl4health.mixins.personalized.utils import ensure_protocol_compliance
15+
from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking
16+
from fl4health.parameter_exchange.parameter_packer import (
17+
ParameterPackerAdaptiveConstraint,
18+
)
19+
20+
21+
def test_ensure_protocol_compliance_does_not_raise() -> None:
22+
# arrange
23+
class MyClient(BasicClient):
24+
def get_model(self, config: Config) -> nn.Module:
25+
return self.model
26+
27+
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
28+
return self.train_loader, self.val_loader
29+
30+
def get_optimizer(self, config: Config) -> Optimizer | dict[str, Optimizer]:
31+
return self.optimizers["global"]
32+
33+
def get_criterion(self, config: Config) -> _Loss:
34+
return torch.nn.CrossEntropyLoss()
35+
36+
@ensure_protocol_compliance
37+
def some_method(self, x: int) -> int:
38+
return x + 1
39+
40+
# setup client
41+
client = MyClient(data_path=Path(""), metrics=[Accuracy()], device=torch.device("cpu"))
42+
client.model = torch.nn.Linear(5, 5)
43+
client.optimizers = {"global": torch.optim.SGD(client.model.parameters(), lr=0.0001)}
44+
client.train_loader = DataLoader(TensorDataset(torch.ones((1000, 28, 28, 1)), torch.ones((1000))))
45+
client.val_loader = DataLoader(TensorDataset(torch.ones((1000, 28, 28, 1)), torch.ones((1000))))
46+
client.parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerAdaptiveConstraint())
47+
client.initialized = True
48+
client.setup_client({})
49+
50+
# act/assert
51+
with does_not_raise():
52+
client.some_method(2)
53+
54+
55+
def test_ensure_protocol_compliance_does_raise_type_error() -> None:
56+
# arrange
57+
class MyClient:
58+
"""My Client DOES not satisfy the protocol of BasicClient."""
59+
60+
@ensure_protocol_compliance
61+
def some_method(self, x: int) -> int:
62+
return x + 1
63+
64+
client = MyClient()
65+
66+
# act/assert
67+
with pytest.raises(TypeError, match="Protocol requirements not met."):
68+
client.some_method(2)

0 commit comments

Comments
 (0)