Skip to content

Commit ece4306

Browse files
committed
adopt the fix proposed by @johahi for fixing the small numerical discrepancy in the pretrained model between tensorflow and pytorch #31
1 parent 05a3654 commit ece4306

8 files changed

Lines changed: 67 additions & 32 deletions

File tree

MANIFEST.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
recursive-include enformer_pytorch *.yml
1+
include enformer_pytorch/precomputed/tf_gammas.pt

README.md

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,18 @@ Deepmind has released the weights for their tensorflow sonnet Enformer model! I
119119

120120
Update: <a href="https://github.com/jstjohn">John St. John</a> did some work and found that the `enformer-official-rough` model hits the reported marks in the paper - human pearson R of `0.625` for validation, and `0.65` for test.
121121

122+
Update: As of version 0.8.0, if one were to use the `from_pretrained` function to load the pretrained model, it should automatically use precomputed gamma positions to address a difference between tensorflow and pytorch `xlogy`. This should resolve the numerical discrepancy above. If you were to further finetune and not be using the `from_pretrained` function, please make sure to set `use_tf_gamma = True` when using `.from_hparams` to instantiate the `Enformer`
123+
122124
```bash
123125
$ pip install enformer-pytorch>=0.5
124126
````
125127

126128
Loading the model
127129

128130
```python
129-
from enformer_pytorch import Enformer
131+
from enformer_pytorch import from_pretrained
130132
131-
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
133+
enformer = from_pretrained('EleutherAI/enformer-official-rough')
132134
```
133135

134136
Quick sanity check on a single human validation point
@@ -143,19 +145,19 @@ This is all made possible thanks to HuggingFace's [custom model](https://hugging
143145
You can also load, with overriding of the `target_length` parameter, if you are working with shorter sequence lengths
144146
145147
```python
146-
from enformer_pytorch import Enformer
148+
from enformer_pytorch import from_pretrained
147149
148-
model = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_length = 128, dropout_rate = 0.1)
150+
model = from_pretrained('EleutherAI/enformer-official-rough', target_length = 128, dropout_rate = 0.1)
149151
150152
# do your fine-tuning
151153
```
152154
153155
To save on memory during fine-tuning a large Enformer model
154156
155157
```python
156-
from enformer_pytorch import Enformer
158+
from enformer_pytorch import from_pretrained
157159
158-
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True)
160+
enformer = from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True)
159161
160162
# finetune enformer on a limited budget
161163
```
@@ -168,10 +170,10 @@ Fine-tuning on new tracks
168170
169171
```python
170172
import torch
171-
from enformer_pytorch import Enformer
173+
from enformer_pytorch import from_pretrained
172174
from enformer_pytorch.finetune import HeadAdapterWrapper
173175
174-
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
176+
enformer = from_pretrained('EleutherAI/enformer-official-rough')
175177
176178
model = HeadAdapterWrapper(
177179
enformer = enformer,
@@ -190,10 +192,10 @@ Finetuning on contextual data (cell type, transcription factor, etc)
190192
191193
```python
192194
import torch
193-
from enformer_pytorch import Enformer
195+
from enformer_pytorch import from_pretrained
194196
from enformer_pytorch.finetune import ContextAdapterWrapper
195197

196-
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
198+
enformer = from_pretrained('EleutherAI/enformer-official-rough')
197199

198200
model = ContextAdapterWrapper(
199201
enformer = enformer,
@@ -218,10 +220,10 @@ Finally, there is also a way to use attention aggregation from a set of context
218220
219221
```python
220222
import torch
221-
from enformer_pytorch import Enformer
223+
from enformer_pytorch import from_pretrained
222224
from enformer_pytorch.finetune import ContextAttentionAdapterWrapper
223225

224-
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
226+
enformer = from_pretrained('EleutherAI/enformer-official-rough')
225227

226228
model = ContextAttentionAdapterWrapper(
227229
enformer = enformer,
@@ -315,6 +317,8 @@ seq, rand_shift_val, rc_bool = ds[0] # (196608,), (1,), (1,)
315317
316318
Special thanks goes out to <a href="https://www.eleuther.ai/">EleutherAI</a> for providing the resources to retrain the model, during a time when the official model from Deepmind had not been released yet.
317319
320+
Thanks also goes out to <a href="johahi">@johahi</a> for finding out that there are numerical differences between the torch and tensorflow implementations of `xlogy`. He provided a fix for this difference, which is adopted in this repository in `v0.8.0`
321+
318322
## Todo
319323
320324
- [x] script to load weights from trained tensorflow enformer model to pytorch model

enformer_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from enformer_pytorch.config_enformer import EnformerConfig
2-
from enformer_pytorch.modeling_enformer import Enformer, SEQUENCE_LENGTH, AttentionPool
2+
from enformer_pytorch.modeling_enformer import Enformer, from_pretrained, SEQUENCE_LENGTH, AttentionPool
33
from enformer_pytorch.data import seq_indices_to_one_hot, str_to_one_hot, GenomeIntervalDataset, FastaInterval

enformer_pytorch/config_enformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(
1818
use_convnext = False,
1919
num_downsamples = 7, # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution
2020
dim_divisible_by = 128,
21+
use_tf_gamma = False,
2122
**kwargs,
2223
):
2324
self.dim = dim
@@ -32,5 +33,6 @@ def __init__(
3233
self.use_checkpointing = use_checkpointing
3334
self.num_downsamples = num_downsamples
3435
self.dim_divisible_by = dim_divisible_by
35-
36+
self.use_tf_gamma = use_tf_gamma
37+
3638
super().__init__(**kwargs)

enformer_pytorch/modeling_enformer.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import math
2+
from pathlib import Path
3+
24
import torch
35
from torch import nn, einsum
46
import torch.nn.functional as F
@@ -18,6 +20,13 @@
1820
SEQUENCE_LENGTH = 196_608
1921
TARGET_LENGTH = 896
2022

23+
# gamma positions from tensorflow
24+
# addressing a difference between xlogy results from tensorflow and pytorch
25+
# solution came from @johahi
26+
27+
DIR = Path(__file__).parents[0]
28+
TF_GAMMAS = torch.load(str(DIR / "precomputed"/ "tf_gammas.pt"))
29+
2130
# helpers
2231

2332
def exists(val):
@@ -26,6 +35,11 @@ def exists(val):
2635
def default(val, d):
2736
return val if exists(val) else d
2837

38+
def always(val):
39+
def inner(*args, **kwargs):
40+
return val
41+
return inner
42+
2943
def map_values(fn, d):
3044
return {key: fn(values) for key, values in d.items()}
3145

@@ -75,30 +89,24 @@ def get_positional_features_gamma(positions, features, seq_len, stddev = None, s
7589
if not exists(start_mean):
7690
start_mean = seq_len / features
7791

78-
# turns out xlogy between tensorflow and torch differs because of the log - thanks to phd student @johahi for finding this!
79-
# do everything in float64 here for precision
80-
81-
dtype = positions.dtype
82-
positions = positions.double()
83-
mean = torch.linspace(start_mean, seq_len, features, device = positions.device, dtype = torch.float64)
92+
mean = torch.linspace(start_mean, seq_len, features, device = positions.device)
8493

8594
mean = mean[None, ...]
8695
concentration = (mean / stddev) ** 2
8796
rate = mean / stddev ** 2
8897

89-
probabilities = gamma_pdf(positions.abs()[..., None], concentration, rate)
98+
probabilities = gamma_pdf(positions.float().abs()[..., None], concentration, rate)
9099
probabilities = probabilities + eps
91100
outputs = probabilities / torch.amax(probabilities, dim = -1, keepdim = True)
101+
return outputs
92102

93-
return outputs.to(dtype)
94-
95-
def get_positional_embed(seq_len, feature_size, device):
103+
def get_positional_embed(seq_len, feature_size, device, use_tf_gamma):
96104
distances = torch.arange(-seq_len + 1, seq_len, device = device)
97105

98106
feature_functions = [
99107
get_positional_features_exponential,
100108
get_positional_features_central_mask,
101-
get_positional_features_gamma
109+
get_positional_features_gamma if not use_tf_gamma else always(TF_GAMMAS.to(device))
102110
]
103111

104112
num_components = len(feature_functions) * 2
@@ -213,7 +221,8 @@ def __init__(
213221
dim_key = 64,
214222
dim_value = 64,
215223
dropout = 0.,
216-
pos_dropout = 0.
224+
pos_dropout = 0.,
225+
use_tf_gamma = False
217226
):
218227
super().__init__()
219228
self.scale = dim_key ** -0.5
@@ -240,6 +249,10 @@ def __init__(
240249
self.pos_dropout = nn.Dropout(pos_dropout)
241250
self.attn_dropout = nn.Dropout(dropout)
242251

252+
# whether to use tf gamma
253+
254+
self.use_tf_gamma = use_tf_gamma
255+
243256
def forward(self, x):
244257
n, h, device = x.shape[-2], self.heads, x.device
245258

@@ -253,7 +266,7 @@ def forward(self, x):
253266

254267
content_logits = einsum('b h i d, b h j d -> b h i j', q + self.rel_content_bias, k)
255268

256-
positions = get_positional_embed(n, self.num_rel_pos_features, device)
269+
positions = get_positional_embed(n, self.num_rel_pos_features, device, use_tf_gamma = self.use_tf_gamma)
257270
positions = self.pos_dropout(positions)
258271
rel_k = self.to_rel_k(positions)
259272

@@ -308,6 +321,11 @@ def __init__(self, config):
308321

309322
self.conv_tower = nn.Sequential(*conv_layers)
310323

324+
# whether to use tensorflow gamma positions
325+
326+
use_tf_gamma = config.use_tf_gamma
327+
self.use_tf_gamma = use_tf_gamma
328+
311329
# transformer
312330

313331
transformer = []
@@ -322,7 +340,8 @@ def __init__(self, config):
322340
dim_value = config.dim // config.heads,
323341
dropout = config.attn_dropout,
324342
pos_dropout = config.pos_dropout,
325-
num_rel_pos_features = config.dim // config.heads
343+
num_rel_pos_features = config.dim // config.heads,
344+
use_tf_gamma = use_tf_gamma
326345
),
327346
nn.Dropout(config.dropout_rate)
328347
)),
@@ -454,3 +473,13 @@ def forward(
454473
return out, x
455474

456475
return out
476+
477+
# from pretrained function
478+
479+
def from_pretrained(name, use_tf_gamma = None, **kwargs):
480+
enformer = Enformer.from_pretrained(name, **kwargs)
481+
482+
if name == 'EleutherAI/enformer-official-rough':
483+
enformer.use_tf_gamma = default(use_tf_gamma, True)
484+
485+
return enformer
385 KB
Binary file not shown.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.7.7',
7+
version = '0.8.2',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

test_pretrained.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
2-
from enformer_pytorch import Enformer
2+
from enformer_pytorch import from_pretrained
33

4-
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough').cuda()
4+
enformer = from_pretrained('EleutherAI/enformer-official-rough').cuda()
55
enformer.eval()
66

77
data = torch.load('./data/test-sample.pt')

0 commit comments

Comments
 (0)