-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy pathsklearn_test.py
More file actions
94 lines (78 loc) · 3.38 KB
/
sklearn_test.py
File metadata and controls
94 lines (78 loc) · 3.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import unittest
import numpy as np
from parameterized import parameterized_class
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from flwr_datasets import FederatedDataset
# Using parameterized testing, two different sets of preprocessing:
# 1. Without scaling.
# 2. With standard scaling.
@parameterized_class(
[
{"dataset_name": "mnist", "preprocessing": None},
{"dataset_name": "mnist", "preprocessing": StandardScaler()},
]
)
class FdsWithSKLearn(unittest.TestCase):
"""Test Flower Datasets with Scikit-learn's Logistic Regression."""
dataset_name = ""
preprocessing = None
def _get_partition_data(self):
"""Retrieve partition data."""
partition_id = 0
fds = FederatedDataset(dataset=self.dataset_name, partitioners={"train": 10})
partition = fds.load_partition(partition_id, "train")
partition.set_format("numpy")
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
X_train, y_train = partition_train_test["train"][:]["image"], partition_train_test[
"train"][:]["label"]
X_test, y_test = partition_train_test["test"][:]["image"], partition_train_test[
"test"][:]["label"]
X_train = X_train.reshape(-1, 28 * 28)
X_test = X_test.reshape(-1, 28 * 28)
if self.preprocessing:
self.preprocessing.fit(X_train)
X_train = self.preprocessing.transform(X_train)
X_test = self.preprocessing.transform(X_test)
return X_train, X_test, y_train, y_test
def test_data_shape(self):
"""Test if the data shape is maintained after preprocessing."""
X_train, _, _, _ = self._get_partition_data()
self.assertEqual(X_train.shape, (4_800, 28 * 28))
def test_X_train_type(self):
"""Test if the data type is correct."""
X_train, _, _, _ = self._get_partition_data()
self.assertIsInstance(X_train, np.ndarray)
def test_y_train_type(self):
"""Test if the data type is correct."""
_, _, y_train, _ = self._get_partition_data()
self.assertIsInstance(y_train, np.ndarray)
def test_X_test_type(self):
"""Test if the data type is correct."""
_, X_test, _, _ = self._get_partition_data()
self.assertIsInstance(X_test, np.ndarray)
def test_y_test_type(self):
"""Test if the data type is correct."""
_, _, _, y_test = self._get_partition_data()
self.assertIsInstance(y_test, np.ndarray)
def test_train_classifier(self):
"""Test if the classifier trains without errors."""
X_train, X_test, y_train, y_test = self._get_partition_data()
try:
clf = LogisticRegression()
clf.fit(X_train, y_train)
except Exception as e:
self.fail(f"Fitting Logistic Regression raised {type(e)} unexpectedly!")
def test_predict_from_classifier(self):
"""Test if the classifier predicts without errors."""
X_train, X_test, y_train, y_test = self._get_partition_data()
clf = LogisticRegression()
clf.fit(X_train, y_train)
try:
_ = clf.predict(X_test)
except Exception as e:
self.fail(
f"Predicting using Logistic Regression model raised {type(e)} "
f"unexpectedly!")
if __name__ == '__main__':
unittest.main()