Skip to content

Commit 243151b

Browse files
committed
automatically use SyncBatchNorm if doing distributed training
1 parent dce5709 commit 243151b

2 files changed

Lines changed: 12 additions & 3 deletions

File tree

enformer_pytorch/modeling_enformer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch import nn, einsum
66
import torch.nn.functional as F
7+
import torch.distributed as dist
78
from torch.utils.checkpoint import checkpoint_sequential
89

910
from einops import rearrange, reduce
@@ -53,6 +54,12 @@ def _round(x):
5354
def log(t, eps = 1e-20):
5455
return torch.log(t.clamp(min = eps))
5556

57+
# maybe sync batchnorm, for distributed training
58+
59+
def MaybeSyncBatchnorm(is_distributed = None):
60+
is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
61+
return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d
62+
5663
# losses and metrics
5764

5865
def poisson_loss(pred, target):
@@ -204,9 +211,11 @@ def forward(self, x):
204211

205212
return x[:, -trim:trim]
206213

207-
def ConvBlock(dim, dim_out = None, kernel_size = 1):
214+
def ConvBlock(dim, dim_out = None, kernel_size = 1, is_distributed = None):
215+
batchnorm_klass = MaybeSyncBatchnorm(is_distributed = is_distributed)
216+
208217
return nn.Sequential(
209-
nn.BatchNorm1d(dim),
218+
batchnorm_klass(dim),
210219
GELU(),
211220
nn.Conv1d(dim, default(dim_out, dim), kernel_size, padding = kernel_size // 2)
212221
)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'enformer-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.8.5',
7+
version = '0.8.6',
88
license='MIT',
99
description = 'Enformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)