Skip to content

Commit 4e70710

Browse files
committed
address #21
1 parent abb50a0 commit 4e70710

2 files changed

Lines changed: 7 additions & 1 deletion

File tree

enformer_pytorch/modeling_enformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,14 @@ def __init__(self, dim, pool_size = 2):
136136
super().__init__()
137137
self.pool_size = pool_size
138138
self.pool_fn = Rearrange('b d (n p) -> b d n p', p = pool_size)
139+
139140
self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias = False)
140141

142+
nn.init.dirac_(self.to_attn_logits.weight)
143+
144+
with torch.no_grad():
145+
self.to_attn_logits.weight.mul_(2)
146+
141147
def forward(self, x):
142148
b, _, n = x.shape
143149
remainder = n % self.pool_size

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.6.2',
7+
version = '0.6.4',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)