-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheic_for_generative.py
More file actions
29 lines (26 loc) · 1.08 KB
/
eic_for_generative.py
File metadata and controls
29 lines (26 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import sys
from argparse import ArgumentParser
def main(args):
if args.method == 'e2esr':
from src.generative.e2esr.train import train as e2esr_main
del sys.argv[sys.argv.index('--method'):sys.argv.index('--method') + 2]
e2esr_main()
elif args.method == 'snip':
from src.generative.e2esr.train import train as snip_main
try:
assert sys.argv[sys.argv.index('--snip_loss')+1] > 0
except:
print("Set --snip_loss to train snip!")
del sys.argv[sys.argv.index('--method'):sys.argv.index('--method') + 2]
snip_main()
elif args.method == 'sr4mdl':
from src.generative.sr4mdl.train import train as sr4mdl_main
del sys.argv[sys.argv.index('--method'):sys.argv.index('--method') + 2]
sr4mdl_main()
else:
raise ValueError(f"Unknown method: {args.method}")
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--method', type=str, choices=['e2esr', 'snip', 'sr4mdl'], default='e2esr')
args, unknown = parser.parse_known_args()
main(args)