Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3D fcnn #68

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 53 additions & 33 deletions atomai/nets/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ConvBlock(nn.Module):

Args:
ndim:
Data dimensionality (1D or 2D)
Data dimensionality (1D, 2D, or 3D)
nb_layers:
Number of layers in the block
input_channels:
Expand Down Expand Up @@ -53,9 +53,9 @@ def __init__(self,
Initializes module parameters
"""
super(ConvBlock, self).__init__()
if not 0 < ndim < 3:
raise AssertionError("ndim must be equal to 1 or 2")
conv = nn.Conv2d if ndim == 2 else nn.Conv1d
if not 0 < ndim < 4:
raise AssertionError("ndim must be equal to 1, 2, or 3")
conv = get_conv(ndim)
block = []
for idx in range(nb_layers):
input_channels = output_channels if idx > 0 else input_channels
Expand All @@ -68,10 +68,7 @@ def __init__(self,
block.append(nn.Dropout(dropout_))
block.append(nn.LeakyReLU(negative_slope=lrelu_a))
if batch_norm:
if ndim == 2:
block.append(nn.BatchNorm2d(output_channels))
else:
block.append(nn.BatchNorm1d(output_channels))
block.append(get_BatchNorm(ndim, output_channels))
self.block = nn.Sequential(*block)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -80,6 +77,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
output = self.block(x)
return output



class UpsampleBlock(nn.Module):
Expand All @@ -90,15 +88,15 @@ class UpsampleBlock(nn.Module):

Args:
ndim:
Data dimensionality (1D or 2D)
Data dimensionality (1D, 2D, or 3D)
input_channels:
Number of input channels for the block
output_channels:
Number of the output channels for the block
scale_factor:
Scale factor for upsampling
mode:
Upsampling mode. Select between "bilinear" and "nearest"
Upsampling mode. Select between "bilinear", "nearest", and "trilinear" for 3D
"""
def __init__(self,
ndim: int,
Expand All @@ -110,14 +108,14 @@ def __init__(self,
Initializes module parameters
"""
super(UpsampleBlock, self).__init__()
if not any([mode == 'bilinear', mode == 'nearest']):
if not any([mode == 'bilinear', mode == 'nearest', mode == 'trilinear']):
raise NotImplementedError(
"use 'bilinear' or 'nearest' for upsampling mode")
if not 0 < ndim < 3:
raise AssertionError("ndim must be equal to 1 or 2")
conv = nn.Conv2d if ndim == 2 else nn.Conv1d
"use 'trilinear', 'bilinear', or 'nearest' for upsampling mode")
if not 0 < ndim < 4:
raise AssertionError("ndim must be equal to 1, 2, or 3")
conv = get_conv(ndim)
self.scale_factor = scale_factor
self.mode = mode if ndim == 2 else "nearest"
self.mode = get_interpolate_mode(ndim)
self.conv = conv(
input_channels, output_channels,
kernel_size=1, stride=1, padding=0)
Expand All @@ -137,7 +135,7 @@ class ResBlock(nn.Module):

Args:
ndim:
Data dimensionality (1D or 2D)
Data dimensionality (1D, 2D, or 3D)
nb_layers:
Number of layers in the block
input_channels:
Expand Down Expand Up @@ -170,9 +168,9 @@ def __init__(self,
Initializes block's parameters
"""
super(ResBlock, self).__init__()
if not 0 < ndim < 3:
raise AssertionError("ndim must be equal to 1 or 2")
conv = nn.Conv2d if ndim == 2 else nn.Conv1d
if not 0 < ndim < 4:
raise AssertionError("ndim must be equal to 1, 2, or 3")
conv = get_conv(ndim)
self.lrelu_a = lrelu_a
self.batch_norm = batch_norm
self.c0 = conv(input_channels,
Expand All @@ -191,9 +189,8 @@ def __init__(self,
stride=1,
padding=1)
if batch_norm:
bn = nn.BatchNorm2d if ndim == 2 else nn.BatchNorm1d
self.bn1 = bn(output_channels)
self.bn2 = bn(output_channels)
self.bn1 = get_BatchNorm(ndim, output_channels)
self.bn2 = get_BatchNorm(ndim, output_channels)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -218,7 +215,7 @@ class ResModule(nn.Module):
Stitches multiple convolutional blocks with residual connections together

Args:
ndim: Data dimensionality (1D or 2D)
ndim: Data dimensionality (1D, 2D, or 3D)
res_depth: Number of residual blocks in a residual module
input_channels: Number of filters in the input layer
output_channels: Number of channels in the output layer
Expand Down Expand Up @@ -260,15 +257,15 @@ class DilatedBlock(nn.Module):

Args:
ndim:
Data dimensionality (1D or 2D)
Data dimensionality (1D, 2D, or 3D)
input_channels:
Number of input channels for the block
output_channels:
Number of the output channels for the block
dilation_values:
List of dilation rates for each convolution layer in the block
(for example, dilation_values = [2, 4, 6] means that the dilated
block will 3 layers with dilation values of 2, 4, and 6).
block will have 3 layers with dilation values of 2, 4, and 6).
padding_values:
Edge padding for each dilated layer. The number of elements in this
list should be equal to that in the dilated_values list and
Expand All @@ -294,9 +291,9 @@ def __init__(self, ndim: int, input_channels: int, output_channels: int,
Initializes module parameters
"""
super(DilatedBlock, self).__init__()
if not 0 < ndim < 3:
raise AssertionError("ndim must be equal to 1 or 2")
conv = nn.Conv2d if ndim == 2 else nn.Conv1d
if not 0 < ndim < 4:
raise AssertionError("ndim must be equal to 1, 2, or 3")
conv = get_conv(ndim)
atrous_module = []
for idx, (dil, pad) in enumerate(zip(dilation_values, padding_values)):
input_channels = output_channels if idx > 0 else input_channels
Expand All @@ -311,10 +308,7 @@ def __init__(self, ndim: int, input_channels: int, output_channels: int,
atrous_module.append(nn.Dropout(dropout_))
atrous_module.append(nn.LeakyReLU(negative_slope=lrelu_a))
if batch_norm:
if ndim == 2:
atrous_module.append(nn.BatchNorm2d(output_channels))
else:
atrous_module.append(nn.BatchNorm1d(output_channels))
atrous_module.append(get_BatchNorm(ndim, output_channels))
self.atrous_module = nn.Sequential(*atrous_module)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -326,3 +320,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = conv_layer(x)
atrous_layers.append(x.unsqueeze(-1))
return torch.sum(torch.cat(atrous_layers, dim=-1), dim=-1)


def get_conv(ndim: int):
"""
Selects conv block based on dimensions
"""
conv_dict = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
return conv_dict[ndim]


def get_BatchNorm(ndim: int, output_channels: int):
"""
Selects BatchNorm block based on dimensions
"""
BatchNorm_dict = {1: nn.BatchNorm3d(output_channels),
2: nn.BatchNorm3d(output_channels),
3: nn.BatchNorm3d(output_channels)}
return BatchNorm_dict[ndim]


def get_interpolate_mode(ndim: int):
"""
Selects interpolation mode based on dimensions
"""
interpolate_mode_dict = {1: 'nearest', 2: 'bilinear', 3: 'trilinear'}
return interpolate_mode_dict[ndim]
38 changes: 21 additions & 17 deletions atomai/nets/fcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ class Unet(nn.Module):
Use batch normalization after each convolutional layer
(Default: True)
upsampling_mode:
Select between "bilinear" or "nearest" upsampling method.
Bilinear is usually more accurate,but adds additional (small)
randomness. For full reproducibility, consider using 'nearest'
(this assumes that all other sources of randomness are fixed)
Select between "bilinear", "nearest", or "trilinear" upsampling
method. Bilinear is usually more accurate, but adds additional
(small) randomness. Trilinear is used for 3D data.For full
reproducibility, consider using 'nearest' (this assumes that all
other sources of randomness are fixed)
with_dilation:
Use dilated convolutions instead of regular ones in the
bottleneck layers (Default: False)
Expand Down Expand Up @@ -158,10 +159,11 @@ class dilnet(nn.Module):
batch_norm:
Add batch normalization for each convolutional layer (Default: True)
upsampling_mode:
Select between "bilinear" or "nearest" upsampling method.
Bilinear is usually more accurate,but adds additional (small)
randomness. For full reproducibility, consider using 'nearest'
(this assumes that all other sources of randomness are fixed)
Select between "bilinear", "nearest", or "trilinear" upsampling
method. Bilinear is usually more accurate, but adds additional
(small) randomness. Trilinear is used for 3D data.For full
reproducibility, consider using 'nearest' (this assumes that all
other sources of randomness are fixed)
**layers (list):
List with a number of layers for each block (Default: [3, 3, 3, 3])
"""
Expand Down Expand Up @@ -237,10 +239,11 @@ class ResHedNet(nn.Module):
Number of filters in 1st residual block
(gets multiplied by 2 in each next block)
upsampling_mode:
Select between "bilinear" or "nearest" upsampling method.
Bilinear is usually more accurate,but adds additional (small)
randomness. For full reproducibility, consider using 'nearest'
(this assumes that all other sources of randomness are fixed)
Select between "bilinear", "nearest", or "trilinear" upsampling
method. Bilinear is usually more accurate, but adds additional
(small) randomness. Trilinear is used for 3D data.For full
reproducibility, consider using 'nearest' (this assumes that all
other sources of randomness are fixed)
**layers (list):
3-element list with a number of residual blocks
in each segment (Default: [3, 4, 5])
Expand Down Expand Up @@ -311,10 +314,11 @@ class SegResNet(nn.Module):
Use batch normalization after each convolutional layer
(Default: True)
upsampling_mode:
Select between "bilinear" or "nearest" upsampling method.
Bilinear is usually more accurate,but adds additional (small)
randomness. For full reproducibility, consider using 'nearest'
(this assumes that all other sources of randomness are fixed)
Select between "bilinear", "nearest", or "trilinear" upsampling
method. Bilinear is usually more accurate, but adds additional
(small) randomness. Trilinear is used for 3D data.For full
reproducibility, consider using 'nearest' (this assumes that all
other sources of randomness are fixed)
**layers (list):
3-element list with a number of residual blocks
in each residual segment (Default: [2, 2])
Expand Down Expand Up @@ -377,7 +381,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


def init_fcnn_model(model: Union[Type[nn.Module], str],
nb_classes: int, **kwargs: [bool, int, List]
nb_classes: int, **kwargs: List[bool, int, List]
) -> Type[nn.Module]:
"""
Initializes a fully convolutional neural network
Expand Down