-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathfinetune_track.py
More file actions
79 lines (60 loc) · 2.08 KB
/
Copy pathfinetune_track.py
File metadata and controls
79 lines (60 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from dotenv import load_dotenv
# set path to cache in .env and unset the next comment
# load_dotenv()
from enformer_pytorch import Enformer
from tf_bind_transformer import AdapterModel, BigWigTrainer
# training constants
BATCH_SIZE = 1
GRAD_ACCUM_STEPS = 8
LEARNING_RATE = 1e-4 # Deepmind used 1e-4 for fine-tuning of Enformer
# effective batch size of BATCH_SIZE * GRAD_ACCUM_STEPS = 16
VALIDATE_EVERY = 250
GRAD_CLIP_MAX_NORM = 1.5
TFACTOR_FOLDER = './tfactor.fastas'
HUMAN_FASTA_FILE_PATH = './hg38.ml.fa'
MOUSE_FASTA_FILE_PATH = './mm10.ml.fa'
HUMAN_LOCI_PATH = './chip_atlas/human_sequences.bed'
MOUSE_LOCI_PATH = './chip_atlas/mouse_sequences.bed'
BIGWIG_PATH = './chip_atlas/bigwig'
BIGWIG_TRACKS_ONLY_PATH = './chip_atlas/bigwig_tracks_only'
ANNOT_FILE_PATH = './chip_atlas/annot.tab'
TARGET_LENGTH = 896
HELD_OUT_TARGET = ['GATA2']
# instantiate enformer or load pretrained
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_length = TARGET_LENGTH)
# instantiate model wrapper that takes in enformer
model = AdapterModel(
enformer = enformer,
use_aa_embeds = True,
use_free_text_context = True,
free_text_embed_method = 'mean_pool',
aa_embed_encoder = 'esm',
finetune_output_heads = dict(
human = 12,
mouse = 24
)
).cuda()
# trainer class for fine-tuning
trainer = BigWigTrainer(
model,
human_loci_path = HUMAN_LOCI_PATH,
mouse_loci_path = MOUSE_LOCI_PATH,
human_fasta_file = HUMAN_FASTA_FILE_PATH,
mouse_fasta_file = MOUSE_FASTA_FILE_PATH,
bigwig_folder_path = BIGWIG_PATH,
bigwig_tracks_only_folder_path = BIGWIG_TRACKS_ONLY_PATH,
annot_file_path = ANNOT_FILE_PATH,
target_length = TARGET_LENGTH,
lr = LEARNING_RATE,
batch_size = BATCH_SIZE,
shuffle = True,
validate_every = VALIDATE_EVERY,
grad_clip_norm = GRAD_CLIP_MAX_NORM,
grad_accum_every = GRAD_ACCUM_STEPS,
human_factor_fasta_folder = TFACTOR_FOLDER,
mouse_factor_fasta_folder = TFACTOR_FOLDER,
held_out_targets = HELD_OUT_TARGET
)
# do gradient steps in a while loop
while True:
_ = trainer()