|
| 1 | +import logging |
| 2 | +import torch |
| 3 | +from gardening_tools.modules.networks.BaseNet import BaseNet |
| 4 | +from gardening_tools.modules.networks.components.blocks import ResidualBlock |
| 5 | +from gardening_tools.modules.networks.components.encoders import ResidualUNetEncoder |
| 6 | +from gardening_tools.modules.networks.components.heads import ClsRegHead |
1 | 7 | from gardening_tools.modules.networks.resunet import ResidualEncoderUNet |
| 8 | +from torch import nn |
| 9 | +from typing import List, Tuple, Type, Union |
| 10 | + |
| 11 | + |
| 12 | +class ResidualEncoderUNetCLSREG(BaseNet): |
| 13 | + def __init__( |
| 14 | + self, |
| 15 | + input_channels: int, |
| 16 | + output_channels: int, |
| 17 | + dimensions: str, |
| 18 | + kernel_size: int, |
| 19 | + stride: int, |
| 20 | + features_per_stage: list, |
| 21 | + n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], |
| 22 | + conv_bias: bool = True, |
| 23 | + encoder_basic_block: Type[ResidualBlock] = ResidualBlock, |
| 24 | + decoder: nn.Module = ClsRegHead, |
| 25 | + norm_op_kwargs={"eps": 1e-05, "affine": True}, |
| 26 | + dropout_op=None, |
| 27 | + dropout_op_kwargs=None, |
| 28 | + nonlin=torch.nn.LeakyReLU, |
| 29 | + nonlin_kwargs={"inplace": True}, |
| 30 | + ): |
| 31 | + super().__init__() |
| 32 | + |
| 33 | + # Extract dropout rates from kwargs |
| 34 | + if dropout_op_kwargs is None: |
| 35 | + dropout_op_kwargs = {} |
| 36 | + |
| 37 | + encoder_dropout_rate = dropout_op_kwargs.get("encoder_dropout_rate", 0.0) |
| 38 | + decoder_dropout_rate = dropout_op_kwargs.get("decoder_dropout_rate", 0.0) |
| 39 | + inplace = dropout_op_kwargs.get("inplace", True) |
| 40 | + |
| 41 | + if dimensions == "2D": |
| 42 | + conv_op = nn.Conv2d |
| 43 | + dropout_op = nn.Dropout2d |
| 44 | + norm_op = nn.InstanceNorm2d |
| 45 | + pool_op = nn.MaxPool2d |
| 46 | + clsreg_pool_op = nn.AdaptiveAvgPool2d |
| 47 | + elif dimensions == "3D": |
| 48 | + conv_op = nn.Conv3d |
| 49 | + dropout_op = nn.Dropout3d |
| 50 | + norm_op = nn.InstanceNorm3d |
| 51 | + pool_op = nn.MaxPool3d |
| 52 | + clsreg_pool_op = nn.AdaptiveAvgPool3d |
| 53 | + else: |
| 54 | + logging.warning("Uuh, dimensions not in ['2D', '3D']") |
| 55 | + |
| 56 | + self.num_classes = output_channels |
| 57 | + |
| 58 | + self.stem_weight_name = "encoder.stem.conv1.conv.weight" |
| 59 | + |
| 60 | + self.encoder = ResidualUNetEncoder( |
| 61 | + input_channels=input_channels, |
| 62 | + features_per_stage=features_per_stage, |
| 63 | + conv_op=conv_op, |
| 64 | + kernel_size=kernel_size, |
| 65 | + stride=stride, |
| 66 | + n_blocks_per_stage=n_blocks_per_stage, |
| 67 | + conv_bias=conv_bias, |
| 68 | + norm_op=norm_op, |
| 69 | + norm_op_kwargs=norm_op_kwargs, |
| 70 | + dropout_op=dropout_op, |
| 71 | + dropout_op_kwargs={"p": encoder_dropout_rate, "inplace": inplace}, |
| 72 | + nonlin=nonlin, |
| 73 | + nonlin_kwargs=nonlin_kwargs, |
| 74 | + block=encoder_basic_block, |
| 75 | + pool_op=pool_op, |
| 76 | + ) |
| 77 | + |
| 78 | + self.decoder = decoder( |
| 79 | + pool_op=clsreg_pool_op, |
| 80 | + input_channels=features_per_stage[-1], |
| 81 | + output_channels=output_channels, |
| 82 | + dropout_rate=decoder_dropout_rate, |
| 83 | + ) |
| 84 | + |
| 85 | + def forward(self, x): |
| 86 | + skips = self.encoder(x) |
| 87 | + return self.decoder(skips) |
| 88 | + |
| 89 | + def forward_with_features(self, x): |
| 90 | + skips = self.encoder(x) |
| 91 | + output = self.decoder(skips) |
| 92 | + return output, skips[-1] |
2 | 93 |
|
3 | 94 |
|
4 | 95 | # Encoder 29M parameters |
@@ -46,6 +137,26 @@ def resenc_unet_b( |
46 | 137 | ) |
47 | 138 |
|
48 | 139 |
|
| 140 | +# Encoder 90M parameters |
| 141 | +# Full model - 90.3 M Total params |
| 142 | +def resenc_unet_b_clsreg( |
| 143 | + input_channels: int = 1, |
| 144 | + output_channels: int = 1, |
| 145 | + dimensions: str = "3D", |
| 146 | + dropout_op_kwargs: dict = None, |
| 147 | +): |
| 148 | + return ResidualEncoderUNetCLSREG( |
| 149 | + dimensions=dimensions, |
| 150 | + input_channels=input_channels, |
| 151 | + output_channels=output_channels, |
| 152 | + features_per_stage=(32, 64, 128, 256, 320, 320), |
| 153 | + stride=2, |
| 154 | + kernel_size=3, |
| 155 | + n_blocks_per_stage=(1, 3, 4, 6, 6, 6), |
| 156 | + dropout_op_kwargs=dropout_op_kwargs, |
| 157 | + ) |
| 158 | + |
| 159 | + |
49 | 160 | # Encoder 345M parameters |
50 | 161 | # Full model 391M parameters |
51 | 162 | def resenc_unet_l( |
|
0 commit comments