-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconstants.py
52 lines (44 loc) · 1.17 KB
/
constants.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# File: constants
# ---------------
# Houses the constants for the experiments. Could
# be better to use config yaml files for this since
# we'll have multiple methods.
# MAML constants
# Optimization arguments
NUM_TRAIN_ITERATIONS = 15000
NUM_INNER_UPDATES = 1
OUTER_LR = 0.001
INNER_LR = 0.003
# Baseline arguments
BASELINE_NUM_EPOCHS = 10
BASELINE_LR = 0.003
BASELINE_BATCH_SIZE = 32
BASELINE_VAL_EVERY = 1000
BASELINE_REPORT_TRAIN_EVERY = 500
# Model architecture arguments
EMBED_DIM = 128
HIDDEN_DIM = 128
NUM_BLOCKS = 3
NUM_HEADS = 8
VOCAB_SIZE = 128
# Data loading arguments
NUM_SUPPORT = 5
NUM_QUERY = 5
META_BATCH_SIZE = 32
NUM_WORKERS = 4
CONTEXT_LEN = 120
TEST_PREFIX_LEN = 0
# Data sampling arguments
GENERATION_LENGTH = 120
CONDITION_LENGTH = CONTEXT_LEN
TEMPERATURE = 1.0
# Splits
TRAIN_SPLIT = ["Vocal", "Folk", "Pop_Rock", "International", "Electronic", "New Age"]
VAL_SPLIT = ["RnB", "Blues", "Latin"]
TEST_SPLIT = ["Country", "Reggae", "Jazz"]
# Miscellaneous evaluation and checkpointing arguments
MODEL_TYPES = ["SimpleLSTM", "SimpleTransformer", "ReferenceTransformer"]
EVALUATE_EVERY = 100
REPORT_TRAIN_EVERY = 50
SAVE_CHECKPOINT_EVERY = 1000
TESTING_ITERATIONS = 1000