Skip to content

Commit 13221f9

Browse files
authored
Small fixes in EasySet and AbstractMetaLearner
* Fix best validation accuracy update * Sort data instances for each class in EasySet * Move switch to train mode inside fit_on_task() * Add citation file * Make AbstractMetaLearner.fit() return average loss * Add EasySet.number_of_classes() * Fix python version in CI linter * Fix linter version * Bump version: 0.2.1 → 0.2.2
1 parent afb3155 commit 13221f9

File tree

11 files changed

+52
-13
lines changed

11 files changed

+52
-13
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.2.1
2+
current_version = 0.2.2
33
commit = True
44
tag = False
55

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
lint:
3434
working_directory: ~/project
3535
docker:
36-
- image: cimg/python:3.7
36+
- image: cimg/python:3.7.10
3737
steps:
3838
- run:
3939
name: Set python 3.7.10

CITATION.cff

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# This CITATION.cff file was generated with cffinit.
2+
# Visit https://bit.ly/cffinit to generate yours today!
3+
4+
cff-version: 1.2.0
5+
title: easyfsl
6+
message: >-
7+
If you use easyfsl in your research, please cite it
8+
using these metadata.
9+
type: software
10+
authors:
11+
- given-names: Etienne
12+
family-names: Bennequin
13+
14+
affiliation: Université Paris-Saclay
15+
repository-code: 'https://github.com/sicara/easy-few-shot-learning'
16+
abstract: >-
17+
Ready-to-use code and tutorial notebooks to boost
18+
your way into few-shot image classification.
19+
license: MIT

dev_requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ black>=20.8b1
22
loguru>=0.5.3
33
matplotlib>=3.3.4
44
pandas>=1.1.0
5-
pylint>=2.7.0
6-
pytest>=6.2.2
7-
pytest-mock>=3.6.1
5+
pylint==2.8.2
6+
pytest==6.2.2
7+
pytest-mock==3.6.1
88
torch>=1.7.1
99
torchvision>=0.8.2
1010
tqdm>=4.56.0

easyfsl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
for few-shot learning experiences.
66
"""
77

8-
__version__ = "0.2.1"
8+
__version__ = "0.2.2"

easyfsl/data_tools/easy_set.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def load_specs(specs_file: Path) -> dict:
6161
if specs_file.suffix != ".json":
6262
raise ValueError("EasySet requires specs in a JSON file.")
6363

64-
specs = json.load(open(specs_file, "r"))
64+
with open(specs_file, "r") as file:
65+
specs = json.load(file)
6566

6667
if "class_names" not in specs.keys() or "class_roots" not in specs.keys():
6768
raise ValueError(
@@ -124,7 +125,7 @@ def list_data_instances(class_roots: List[str]) -> (List[str], List[int]):
124125
for class_id, class_root in enumerate(class_roots):
125126
class_images = [
126127
str(image_path)
127-
for image_path in Path(class_root).glob("*")
128+
for image_path in sorted(Path(class_root).glob("*"))
128129
if image_path.is_file()
129130
]
130131
images += class_images
@@ -155,3 +156,6 @@ def __getitem__(self, item: int):
155156

156157
def __len__(self) -> int:
157158
return len(self.labels)
159+
160+
def number_of_classes(self):
161+
return len(self.class_names)

easyfsl/methods/abstract_meta_learner.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import abstractmethod
22
from pathlib import Path
3+
from statistics import mean
34
from typing import Union
45

56
import torch
@@ -162,6 +163,7 @@ def fit_on_task(
162163
Returns:
163164
the value of the classification loss (for reporting purposes)
164165
"""
166+
self.train()
165167
optimizer.zero_grad()
166168
self.process_support_set(support_images.cuda(), support_labels.cuda())
167169
classification_scores = self(query_images.cuda())
@@ -178,7 +180,7 @@ def fit(
178180
optimizer: optim.Optimizer,
179181
val_loader: DataLoader = None,
180182
validation_frequency: int = 1000,
181-
):
183+
) -> float:
182184
"""
183185
Train the model on few-shot classification tasks.
184186
Args:
@@ -187,11 +189,12 @@ def fit(
187189
val_loader: loads data from the validation set in the shape of few-shot classification
188190
tasks
189191
validation_frequency: number of training episodes between two validations
192+
Returns:
193+
average loss
190194
"""
191195
log_update_frequency = 10
192196

193197
all_loss = []
194-
self.train()
195198
with tqdm(
196199
enumerate(train_loader), total=len(train_loader), desc="Meta-Training"
197200
) as tqdm_train:
@@ -222,6 +225,8 @@ def fit(
222225
if (episode_index + 1) % validation_frequency == 0:
223226
self.validate(val_loader)
224227

228+
return mean(all_loss)
229+
225230
def validate(self, val_loader: DataLoader) -> float:
226231
"""
227232
Validate the model on the validation set.
@@ -237,6 +242,7 @@ def validate(self, val_loader: DataLoader) -> float:
237242
if validation_accuracy > self.best_validation_accuracy:
238243
print("Best validation accuracy so far!")
239244
self.best_model_state = self.state_dict()
245+
self.best_validation_accuracy = validation_accuracy
240246

241247
return validation_accuracy
242248

easyfsl/methods/matching_networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
import torch
7-
import torch.nn as nn
7+
from torch import nn
88
from easyfsl.methods import AbstractMetaLearner
99

1010

easyfsl/methods/relation_networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
import torch
7-
import torch.nn as nn
7+
from torch import nn
88
from easyfsl.methods import AbstractMetaLearner
99
from easyfsl.utils import compute_prototypes
1010

easyfsl/tests/methods/abstract_meta_learner_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,16 @@ def test_validate_updates_best_model_state_if_it_has_best_validation_accuracy(
8282
meta_learner.validate(None)
8383
assert meta_learner.best_model_state is not None
8484

85+
@staticmethod
86+
def test_validate_updates_best_accuracy_if_it_has_best_validation_accuracy(
87+
mocker,
88+
):
89+
mocker.patch("easyfsl.methods.AbstractMetaLearner.evaluate", return_value=0.5)
90+
meta_learner = AbstractMetaLearner(resnet18())
91+
meta_learner.best_validation_accuracy = 0.1
92+
meta_learner.validate(None)
93+
assert meta_learner.best_validation_accuracy == 0.5
94+
8595
@staticmethod
8696
@pytest.mark.parametrize(
8797
"accuracy",

0 commit comments

Comments
 (0)