1- from unittest .mock import patch
2-
31import pytest
42import torch
53from torchvision .models import resnet18
@@ -30,21 +28,24 @@ def test_evaluate_on_one_task_gives_correct_output(
3028 query_labels ,
3129 expected_correct ,
3230 expected_total ,
31+ mocker ,
3332 ):
34- with patch ("torch.Tensor.cuda" , new = torch .Tensor .cpu ):
35- with patch ("easyfsl.methods.AbstractMetaLearner.forward" ) as mock_forward :
36- with patch ("easyfsl.methods.AbstractMetaLearner.process_support_set" ):
37- mock_forward .return_value = torch .tensor (5 * [[0.25 , 0.75 ]]).cuda ()
38- model = AbstractMetaLearner (resnet18 ())
39- assert (
40- model .evaluate_on_one_task (
41- support_images ,
42- support_labels ,
43- query_images ,
44- query_labels ,
45- )
46- == (expected_correct , expected_total )
47- )
33+ mocker .patch ("torch.Tensor.cuda" , new = torch .Tensor .cpu )
34+ mocker .patch (
35+ "easyfsl.methods.AbstractMetaLearner.forward" ,
36+ return_value = torch .tensor (5 * [[0.25 , 0.75 ]]).cuda (),
37+ )
38+ mocker .patch ("easyfsl.methods.AbstractMetaLearner.process_support_set" )
39+ model = AbstractMetaLearner (resnet18 ())
40+ assert (
41+ model .evaluate_on_one_task (
42+ support_images ,
43+ support_labels ,
44+ query_images ,
45+ query_labels ,
46+ )
47+ == (expected_correct , expected_total )
48+ )
4849
4950
5051# pylint: enable=not-callable
@@ -66,20 +67,20 @@ def test_process_support_set_raises_error_when_not_implemented():
6667
6768class TestAMLValidate :
6869 @staticmethod
69- def test_validate_returns_accuracy ():
70- with patch ("easyfsl.methods.AbstractMetaLearner.evaluate" ) as mock_evaluate :
71- mock_evaluate .return_value = 0.0
72- meta_learner = AbstractMetaLearner (resnet18 ())
73- assert meta_learner .validate (None ) == 0.0
70+ def test_validate_returns_accuracy (mocker ):
71+ mocker .patch ("easyfsl.methods.AbstractMetaLearner.evaluate" , return_value = 0.0 )
72+ meta_learner = AbstractMetaLearner (resnet18 ())
73+ assert meta_learner .validate (None ) == 0.0
7474
7575 @staticmethod
76- def test_validate_updates_best_model_state_if_it_has_best_validation_accuracy ():
77- with patch ("easyfsl.methods.AbstractMetaLearner.evaluate" ) as mock_evaluate :
78- mock_evaluate .return_value = 0.5
79- meta_learner = AbstractMetaLearner (resnet18 ())
80- meta_learner .best_validation_accuracy = 0.1
81- meta_learner .validate (None )
82- assert meta_learner .best_model_state is not None
76+ def test_validate_updates_best_model_state_if_it_has_best_validation_accuracy (
77+ mocker ,
78+ ):
79+ mocker .patch ("easyfsl.methods.AbstractMetaLearner.evaluate" , return_value = 0.5 )
80+ meta_learner = AbstractMetaLearner (resnet18 ())
81+ meta_learner .best_validation_accuracy = 0.1
82+ meta_learner .validate (None )
83+ assert meta_learner .best_model_state is not None
8384
8485 @staticmethod
8586 @pytest .mark .parametrize (
@@ -91,10 +92,44 @@ def test_validate_updates_best_model_state_if_it_has_best_validation_accuracy():
9192 )
9293 def test_validate_leaves_best_model_state_if_it_has_worse_validation_accuracy (
9394 accuracy ,
95+ mocker ,
96+ ):
97+ mocker .patch (
98+ "easyfsl.methods.AbstractMetaLearner.evaluate" , return_value = accuracy
99+ )
100+ meta_learner = AbstractMetaLearner (resnet18 ())
101+ meta_learner .best_validation_accuracy = 0.1
102+ meta_learner .validate (None )
103+ assert meta_learner .best_model_state is None
104+
105+ @staticmethod
106+ @pytest .mark .parametrize (
107+ "n_train_episodes,validation_frequency,expected_number_of_validations" ,
108+ [
109+ (5 , 1 , 5 ),
110+ (5 , 5 , 1 ),
111+ (5 , 6 , 0 ),
112+ (5 , 3 , 1 ),
113+ (6 , 3 , 2 ),
114+ ],
115+ )
116+ def test_validation_occurs_when_expected (
117+ n_train_episodes , validation_frequency , expected_number_of_validations , mocker
94118 ):
95- with patch ("easyfsl.methods.AbstractMetaLearner.evaluate" ) as mock_evaluate :
96- mock_evaluate .return_value = accuracy
97- meta_learner = AbstractMetaLearner (resnet18 ())
98- meta_learner .best_validation_accuracy = 0.1
99- meta_learner .validate (None )
100- assert meta_learner .best_model_state is None
119+ mocker .patch (
120+ "easyfsl.methods.AbstractMetaLearner.fit_on_task" , return_value = 0.0
121+ )
122+ mocker .patch ("easyfsl.methods.AbstractMetaLearner.validate" )
123+ spy_validate = mocker .spy (AbstractMetaLearner , "validate" )
124+
125+ meta_learner = AbstractMetaLearner (resnet18 ())
126+ train_loader = n_train_episodes * [(None , None , None , None , None )]
127+
128+ meta_learner .fit (
129+ train_loader = train_loader ,
130+ optimizer = None ,
131+ val_loader = True ,
132+ validation_frequency = validation_frequency ,
133+ )
134+
135+ assert spy_validate .call_count == expected_number_of_validations
0 commit comments