-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathparams.py
More file actions
42 lines (41 loc) · 2.06 KB
/
params.py
File metadata and controls
42 lines (41 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import argparse
def get_parser():
# model parameters
parser = argparse.ArgumentParser()
parser.add_argument("--embed_dim", type=int, default=512,
help="Embedding layer size")
parser.add_argument("--n_layers", type=int, default=4,
help="Number of Transformer layers")
parser.add_argument("--n_heads", type=int, default=8,
help="Number of Transformer heads")
parser.add_argument("--n_val", type=int, default=11,
help="Number of values")
parser.add_argument("--lr", type=float, default=1e-5,
help="Learning rate")
parser.add_argument("--data_path", type=str, default="",
help="Data path")
parser.add_argument("--batch_size", type=int, default=4,
help="Number of sentences per batch")
parser.add_argument("--dataset", type=str, default='hcl',
help="Reference dataset")
parser.add_argument("--experiment", type=str, default='2',
help="Experiment name")
parser.add_argument("--n_epochs", type=int, default=1,
help="Maximum epoch size")
parser.add_argument("--log_step", type=int, default=1,
help="evaluation step")
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--n_workers", type=int, default=8,
help="DataLoader workers")
parser.add_argument("--test", type=bool, default=False,
help="test")
parser.add_argument("--reload", type=bool, default=False,
help="reload")
parser.add_argument("--finetune", type=bool, default=False,
help="finetune")
parser.add_argument("--ad", type=bool, default=True,
help="adversarial embedding")
parser.add_argument("--pooling", type=str, default='sum',
help="embedding pooling")
return parser