|
1 |
| -from models.coil import COIL |
2 |
| -from models.der import DER |
3 |
| -from models.ewc import EWC |
4 |
| -from models.finetune import Finetune |
5 |
| -from models.foster import FOSTER |
6 |
| -from models.gem import GEM |
7 |
| -from models.icarl import iCaRL |
8 |
| -from models.lwf import LwF |
9 |
| -from models.replay import Replay |
10 |
| -from models.bic import BiC |
11 |
| -from models.podnet import PODNet |
12 |
| -from models.rmm import RMM_FOSTER, RMM_iCaRL |
13 |
| -from models.ssre import SSRE |
14 |
| -from models.wa import WA |
15 |
| -from models.fetril import FeTrIL |
16 |
| -from models.pa2s import PASS |
17 |
| -from models.il2a import IL2A |
18 |
| -from models.memo import MEMO |
19 |
| -from models.beef_iso import BEEFISO |
20 |
| -from models.simplecil import SimpleCIL |
21 |
| - |
22 | 1 | def get_model(model_name, args):
|
23 | 2 | name = model_name.lower()
|
24 | 3 | if name == "icarl":
|
| 4 | + from models.icarl import iCaRL |
25 | 5 | return iCaRL(args)
|
26 | 6 | elif name == "bic":
|
| 7 | + from models.bic import BiC |
27 | 8 | return BiC(args)
|
28 | 9 | elif name == "podnet":
|
| 10 | + from models.podnet import PODNet |
29 | 11 | return PODNet(args)
|
30 | 12 | elif name == "lwf":
|
| 13 | + from models.lwf import LwF |
31 | 14 | return LwF(args)
|
32 | 15 | elif name == "ewc":
|
| 16 | + from models.ewc import EWC |
33 | 17 | return EWC(args)
|
34 | 18 | elif name == "wa":
|
| 19 | + from models.wa import WA |
35 | 20 | return WA(args)
|
36 | 21 | elif name == "der":
|
| 22 | + from models.der import DER |
37 | 23 | return DER(args)
|
38 | 24 | elif name == "finetune":
|
| 25 | + from models.finetune import Finetune |
39 | 26 | return Finetune(args)
|
40 | 27 | elif name == "replay":
|
| 28 | + from models.replay import Replay |
41 | 29 | return Replay(args)
|
42 | 30 | elif name == "gem":
|
| 31 | + from models.gem import GEM |
43 | 32 | return GEM(args)
|
44 | 33 | elif name == "coil":
|
| 34 | + from models.coil import COIL |
45 | 35 | return COIL(args)
|
46 | 36 | elif name == "foster":
|
| 37 | + from models.foster import FOSTER |
47 | 38 | return FOSTER(args)
|
48 | 39 | elif name == "rmm-icarl":
|
| 40 | + from models.rmm import RMM_FOSTER, RMM_iCaRL |
49 | 41 | return RMM_iCaRL(args)
|
50 | 42 | elif name == "rmm-foster":
|
| 43 | + from models.rmm import RMM_FOSTER, RMM_iCaRL |
51 | 44 | return RMM_FOSTER(args)
|
52 | 45 | elif name == "fetril":
|
| 46 | + from models.fetril import FeTrIL |
53 | 47 | return FeTrIL(args)
|
54 | 48 | elif name == "pass":
|
| 49 | + from models.pa2s import PASS |
55 | 50 | return PASS(args)
|
56 | 51 | elif name == "il2a":
|
| 52 | + from models.il2a import IL2A |
57 | 53 | return IL2A(args)
|
58 | 54 | elif name == "ssre":
|
| 55 | + from models.ssre import SSRE |
59 | 56 | return SSRE(args)
|
60 |
| - elif name == "memo": |
| 57 | + elif name == "memo": |
| 58 | + from models.memo import MEMO |
61 | 59 | return MEMO(args)
|
62 | 60 | elif name == "beefiso":
|
| 61 | + from models.beef_iso import BEEFISO |
63 | 62 | return BEEFISO(args)
|
64 | 63 | elif name == "simplecil":
|
| 64 | + from models.simplecil import SimpleCIL |
65 | 65 | return SimpleCIL(args)
|
66 | 66 | else:
|
67 | 67 | assert 0
|
0 commit comments