Skip to content

Commit d2dbc21

Browse files
committed
directly set use_tf_gamma on Attention instances
1 parent ece4306 commit d2dbc21

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

enformer_pytorch/modeling_enformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,10 @@ def from_pretrained(name, use_tf_gamma = None, **kwargs):
480480
enformer = Enformer.from_pretrained(name, **kwargs)
481481

482482
if name == 'EleutherAI/enformer-official-rough':
483-
enformer.use_tf_gamma = default(use_tf_gamma, True)
483+
use_tf_gamma = default(use_tf_gamma, True)
484+
485+
for module in enformer.modules():
486+
if isinstance(module, Attention):
487+
module.use_tf_gamma = use_tf_gamma
484488

485489
return enformer

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.8.2',
7+
version = '0.8.3',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)