Skip to content

Commit afb3155

Browse files
authored
Merge pull request #12 from sicara/fix-validation
Fix call to validation during training (0.2.0 => 0.2.1)
2 parents 8387628 + 966bf34 commit afb3155

File tree

7 files changed

+81
-44
lines changed

7 files changed

+81
-44
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.2.0
2+
current_version = 0.2.1
33
commit = True
44
tag = False
55

dev_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ matplotlib>=3.3.4
44
pandas>=1.1.0
55
pylint>=2.7.0
66
pytest>=6.2.2
7+
pytest-mock>=3.6.1
78
torch>=1.7.1
89
torchvision>=0.8.2
910
tqdm>=4.56.0

easyfsl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
for few-shot learning experiences.
66
"""
77

8-
__version__ = "0.2.0"
8+
__version__ = "0.2.1"

easyfsl/methods/abstract_meta_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def fit(
219219

220220
# Validation
221221
if val_loader:
222-
if episode_index + 1 % validation_frequency == 0:
222+
if (episode_index + 1) % validation_frequency == 0:
223223
self.validate(val_loader)
224224

225225
def validate(self, val_loader: DataLoader) -> float:

easyfsl/tests/data_tools/easy_set_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@ class TestEasySetListDataInstances:
9292
)
9393
],
9494
)
95-
def test_list_data_instances_returns_expected_values(class_roots, images, labels):
96-
with patch("pathlib.Path.glob") as mock_glob:
97-
mock_glob.return_value = [Path("a.png"), Path("b.png")]
98-
with patch("pathlib.Path.is_file") as mock_is_file:
99-
mock_is_file.return_value = True
100-
assert (images, labels) == EasySet.list_data_instances(class_roots)
95+
def test_list_data_instances_returns_expected_values(
96+
class_roots, images, labels, mocker
97+
):
98+
mocker.patch("pathlib.Path.glob", return_value=[Path("a.png"), Path("b.png")])
99+
mocker.patch("pathlib.Path.is_file", return_value=True)
100+
101+
assert (images, labels) == EasySet.list_data_instances(class_roots)

easyfsl/tests/methods/abstract_meta_learner_test.py

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from unittest.mock import patch
2-
31
import pytest
42
import torch
53
from torchvision.models import resnet18
@@ -30,21 +28,24 @@ def test_evaluate_on_one_task_gives_correct_output(
3028
query_labels,
3129
expected_correct,
3230
expected_total,
31+
mocker,
3332
):
34-
with patch("torch.Tensor.cuda", new=torch.Tensor.cpu):
35-
with patch("easyfsl.methods.AbstractMetaLearner.forward") as mock_forward:
36-
with patch("easyfsl.methods.AbstractMetaLearner.process_support_set"):
37-
mock_forward.return_value = torch.tensor(5 * [[0.25, 0.75]]).cuda()
38-
model = AbstractMetaLearner(resnet18())
39-
assert (
40-
model.evaluate_on_one_task(
41-
support_images,
42-
support_labels,
43-
query_images,
44-
query_labels,
45-
)
46-
== (expected_correct, expected_total)
47-
)
33+
mocker.patch("torch.Tensor.cuda", new=torch.Tensor.cpu)
34+
mocker.patch(
35+
"easyfsl.methods.AbstractMetaLearner.forward",
36+
return_value=torch.tensor(5 * [[0.25, 0.75]]).cuda(),
37+
)
38+
mocker.patch("easyfsl.methods.AbstractMetaLearner.process_support_set")
39+
model = AbstractMetaLearner(resnet18())
40+
assert (
41+
model.evaluate_on_one_task(
42+
support_images,
43+
support_labels,
44+
query_images,
45+
query_labels,
46+
)
47+
== (expected_correct, expected_total)
48+
)
4849

4950

5051
# pylint: enable=not-callable
@@ -66,20 +67,20 @@ def test_process_support_set_raises_error_when_not_implemented():
6667

6768
class TestAMLValidate:
6869
@staticmethod
69-
def test_validate_returns_accuracy():
70-
with patch("easyfsl.methods.AbstractMetaLearner.evaluate") as mock_evaluate:
71-
mock_evaluate.return_value = 0.0
72-
meta_learner = AbstractMetaLearner(resnet18())
73-
assert meta_learner.validate(None) == 0.0
70+
def test_validate_returns_accuracy(mocker):
71+
mocker.patch("easyfsl.methods.AbstractMetaLearner.evaluate", return_value=0.0)
72+
meta_learner = AbstractMetaLearner(resnet18())
73+
assert meta_learner.validate(None) == 0.0
7474

7575
@staticmethod
76-
def test_validate_updates_best_model_state_if_it_has_best_validation_accuracy():
77-
with patch("easyfsl.methods.AbstractMetaLearner.evaluate") as mock_evaluate:
78-
mock_evaluate.return_value = 0.5
79-
meta_learner = AbstractMetaLearner(resnet18())
80-
meta_learner.best_validation_accuracy = 0.1
81-
meta_learner.validate(None)
82-
assert meta_learner.best_model_state is not None
76+
def test_validate_updates_best_model_state_if_it_has_best_validation_accuracy(
77+
mocker,
78+
):
79+
mocker.patch("easyfsl.methods.AbstractMetaLearner.evaluate", return_value=0.5)
80+
meta_learner = AbstractMetaLearner(resnet18())
81+
meta_learner.best_validation_accuracy = 0.1
82+
meta_learner.validate(None)
83+
assert meta_learner.best_model_state is not None
8384

8485
@staticmethod
8586
@pytest.mark.parametrize(
@@ -91,10 +92,44 @@ def test_validate_updates_best_model_state_if_it_has_best_validation_accuracy():
9192
)
9293
def test_validate_leaves_best_model_state_if_it_has_worse_validation_accuracy(
9394
accuracy,
95+
mocker,
96+
):
97+
mocker.patch(
98+
"easyfsl.methods.AbstractMetaLearner.evaluate", return_value=accuracy
99+
)
100+
meta_learner = AbstractMetaLearner(resnet18())
101+
meta_learner.best_validation_accuracy = 0.1
102+
meta_learner.validate(None)
103+
assert meta_learner.best_model_state is None
104+
105+
@staticmethod
106+
@pytest.mark.parametrize(
107+
"n_train_episodes,validation_frequency,expected_number_of_validations",
108+
[
109+
(5, 1, 5),
110+
(5, 5, 1),
111+
(5, 6, 0),
112+
(5, 3, 1),
113+
(6, 3, 2),
114+
],
115+
)
116+
def test_validation_occurs_when_expected(
117+
n_train_episodes, validation_frequency, expected_number_of_validations, mocker
94118
):
95-
with patch("easyfsl.methods.AbstractMetaLearner.evaluate") as mock_evaluate:
96-
mock_evaluate.return_value = accuracy
97-
meta_learner = AbstractMetaLearner(resnet18())
98-
meta_learner.best_validation_accuracy = 0.1
99-
meta_learner.validate(None)
100-
assert meta_learner.best_model_state is None
119+
mocker.patch(
120+
"easyfsl.methods.AbstractMetaLearner.fit_on_task", return_value=0.0
121+
)
122+
mocker.patch("easyfsl.methods.AbstractMetaLearner.validate")
123+
spy_validate = mocker.spy(AbstractMetaLearner, "validate")
124+
125+
meta_learner = AbstractMetaLearner(resnet18())
126+
train_loader = n_train_episodes * [(None, None, None, None, None)]
127+
128+
meta_learner.fit(
129+
train_loader=train_loader,
130+
optimizer=None,
131+
val_loader=True,
132+
validation_frequency=validation_frequency,
133+
)
134+
135+
assert spy_validate.call_count == expected_number_of_validations

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name="easyfsl",
10-
version="0.2.0",
10+
version="0.2.1",
1111
description="Ready-to-use PyTorch code to boost your way into few-shot image classification",
1212
long_description=long_description,
1313
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)