-
Notifications
You must be signed in to change notification settings - Fork 360
/
Copy pathlightning_protonet_test_notravis.py
68 lines (57 loc) · 1.84 KB
/
lightning_protonet_test_notravis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python3
import unittest
import learn2learn as l2l
import pytorch_lightning as pl
from learn2learn.utils.lightning import EpisodicBatcher
from learn2learn.algorithms import LightningPrototypicalNetworks
class TestLightningProtoNets(unittest.TestCase):
def test_protonets(self):
meta_batch_size = 4
max_epochs = 20
seed = 42
ways = 5
shots = 5
pl.seed_everything(seed)
# Create tasksets using the benchmark interface
tasksets = l2l.vision.benchmarks.get_tasksets(
"cifarfs",
train_samples=2 * shots,
train_ways=ways,
test_samples=2 * shots,
test_ways=ways,
root="~/data",
)
# init model
model = l2l.vision.models.CNN4(ways, embedding_size=32*4)
features = model.features
protonet = LightningPrototypicalNetworks(features, lr=3e-4)
episodic_data = EpisodicBatcher(
tasksets.train,
tasksets.validation,
tasksets.test,
epoch_length=meta_batch_size,
)
trainer = pl.Trainer(
accumulate_grad_batches=meta_batch_size,
min_epochs=max_epochs,
max_epochs=max_epochs,
progress_bar_refresh_rate=0,
deterministic=True,
weights_summary=None,
)
trainer.fit(protonet, episodic_data)
acc = trainer.test(
test_dataloaders=tasksets.test,
verbose=False,
)
self.assertTrue(acc[0]["valid_accuracy"] >= 0.20)
trainer.validate(
val_dataloaders=tasksets.validation,
verbose=False,
)
predictions = trainer.predict(
test_dataloaders=tasksets.validation,
verbose=False,
)
if __name__ == "__main__":
unittest.main()