Skip to content

Commit 1e86b17

Browse files
committed
add support for pretrained + custom in channels. add tresnet encoders
1 parent d8d7b4b commit 1e86b17

File tree

6 files changed

+41
-4
lines changed

6 files changed

+41
-4
lines changed

pytorch_tools/models/efficientnet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pytorch_tools.utils.misc import add_docs_for
2929
from pytorch_tools.utils.misc import make_divisible
3030
from pytorch_tools.utils.misc import DEFAULT_IMAGENET_SETTINGS
31+
from pytorch_tools.utils.misc import repeat_channels
3132

3233
# avoid overwriting doc string
3334
wraps = partial(wraps, assigned=("__module__", "__name__", "__qualname__", "__annotations__"))
@@ -420,6 +421,8 @@ def _efficientnet(arch, pretrained=None, **kwargs):
420421
)
421422
state_dict["classifier.weight"] = model.state_dict()["classifier.weight"]
422423
state_dict["classifier.bias"] = model.state_dict()["classifier.bias"]
424+
if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels
425+
state_dict["conv_stem.weight"] = repeat_channels(state_dict["conv_stem.weight"], kwargs["in_channels"])
423426
model.load_state_dict(state_dict)
424427
patch_bn(model) # adjust epsilon
425428
setattr(model, "pretrained_settings", cfg_settings)

pytorch_tools/models/resnet.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pytorch_tools.modules import bn_from_name
2121
from pytorch_tools.utils.misc import add_docs_for
2222
from pytorch_tools.utils.misc import DEFAULT_IMAGENET_SETTINGS
23+
from pytorch_tools.utils.misc import repeat_channels
2324

2425
# avoid overwriting doc string
2526
wraps = partial(wraps, assigned=("__module__", "__name__", "__qualname__", "__annotations__"))
@@ -471,6 +472,12 @@ def _resnet(arch, pretrained=None, **kwargs):
471472
# if there is last_linear in state_dict, it's going to be overwritten
472473
state_dict["fc.weight"] = model.state_dict()["last_linear.weight"]
473474
state_dict["fc.bias"] = model.state_dict()["last_linear.bias"]
475+
# support pretrained for custom input channels
476+
# layer0. is needed to support se_resne(x)t weights
477+
if kwargs.get("in_channels", 3) != 3:
478+
old_weights = state_dict.get("conv1.weight")
479+
old_weights = state_dict.get("layer0.conv1.weight") if old_weights is None else old_weights
480+
state_dict["layer0.conv1.weight"] = repeat_channels(old_weights, kwargs["in_channels"])
474481
model.load_state_dict(state_dict)
475482
setattr(model, "pretrained_settings", cfg_settings)
476483
return model

pytorch_tools/models/tresnet.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytorch_tools.modules import bn_from_name
1616
from pytorch_tools.modules import ABN
1717
from pytorch_tools.utils.misc import add_docs_for
18+
from pytorch_tools.utils.misc import repeat_channels
1819

1920
# avoid overwriting doc string
2021
wraps = partial(wraps, assigned=("__module__", "__name__", "__qualname__", "__annotations__"))
@@ -49,6 +50,8 @@ class TResNet(ResNet):
4950
Activation for normalizion layer. It's reccomended to use `leacky_relu` with `inplaceabn`.
5051
encoder (bool):
5152
Flag to overwrite forward pass to return 5 tensors with different resolutions. Defaults to False.
53+
NOTE: TResNet first features have resolution 4x times smaller than input, not 2x as all other models.
54+
So it CAN'T be used as encoder in Unet and Linknet models
5255
drop_rate (float):
5356
Dropout probability before classifier, for training. Defaults to 0.0. to 'avg'.
5457
drop_connect_rate (float):
@@ -119,6 +122,9 @@ def __init__(
119122
self._initialize_weights(init_bn0=True)
120123

121124
def load_state_dict(self, state_dict, **kwargs):
125+
if self.encoder:
126+
state_dict.pop("last_linear.weight")
127+
state_dict.pop("last_linear.bias")
122128
nn.Module.load_state_dict(self, state_dict, **kwargs)
123129

124130
# fmt: off
@@ -209,6 +215,8 @@ def _resnet(arch, pretrained=None, **kwargs):
209215
# if there is last_linear in state_dict, it's going to be overwritten
210216
state_dict["last_linear.weight"] = model.state_dict()["last_linear.weight"]
211217
state_dict["last_linear.bias"] = model.state_dict()["last_linear.bias"]
218+
if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels
219+
state_dict["conv1.1.weight"] = repeat_channels(state_dict["conv1.1.weight"], kwargs["in_channels"] * 16, 3 * 16)
212220
model.load_state_dict(state_dict)
213221
# need to adjust some parameters to be align with original model
214222
patch_blur_pool(model)

pytorch_tools/segmentation_models/encoders.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,16 @@
2828
"efficientnet_b5": (2048, 128, 64, 40, 24),
2929
"efficientnet_b6": (2304, 144, 72, 40, 32),
3030
"efficientnet_b7": (2560, 160, 80, 48, 32),
31+
"tresnetm": (2048, 1024, 128, 64, 64),
32+
"tresnetl": (2432, 1216, 152, 76, 76),
33+
"tresnetxl": (2656, 1328, 166, 83, 83),
3134
}
3235

3336

3437
def get_encoder(name, **kwargs):
3538
if name not in models.__dict__:
3639
raise ValueError(f"No such encoder: {name}")
3740
kwargs["encoder"] = True
38-
# if 'resne' in name:
39-
# kwargs['dilated'] = True # dilate resnets for better performance
4041
kwargs["pretrained"] = kwargs.pop("encoder_weights")
4142
m = models.__dict__[name](**kwargs)
4243
m.out_shapes = ENCODER_SHAPES[name]

pytorch_tools/utils/misc.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import math
23
import time
34
import torch
45
import random
@@ -203,3 +204,15 @@ def make_divisible(v, divisor=8):
203204
if new_v < 0.9 * v: # ensure round down does not go down by more than 10%.
204205
new_v += divisor
205206
return new_v
207+
208+
def repeat_channels(conv_weights, new_channels, old_channels=3):
209+
"""Repeat channels to match new number of input channels
210+
Args:
211+
conv_weights (torch.Tensor): shape [*, old_channels, *, *]
212+
new_channels (int): desired number of channels
213+
old_channels (int): original number of channels
214+
"""
215+
rep_times = math.ceil(new_channels / old_channels)
216+
new_weights = conv_weights.repeat(1, rep_times, 1, 1)[:, :new_channels, :, :]
217+
new_weights *= old_channels / new_channels # to keep the same output amplitude
218+
return new_weights

tests/models/test_models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def test_custom_in_channels(arch):
4949
with torch.no_grad():
5050
m(torch.ones(2, 5, 128, 128))
5151

52+
@pytest.mark.parametrize("arch", EFFNET_NAMES[:2] + RESNET_NAMES[:2])
53+
def test_pretrained_custom_in_channels(arch):
54+
m = models.__dict__[arch](in_channels=5, pretrained="imagenet")
55+
with torch.no_grad():
56+
m(torch.ones(2, 5, 128, 128))
57+
5258

5359
@pytest.mark.parametrize("arch", TEST_MODEL_NAMES)
5460
def test_inplace_abn(arch):
@@ -73,7 +79,7 @@ def test_dilation(arch, output_stride):
7379
W, H = INP.shape[-2:]
7480
assert res.shape[-2:] == (W // output_stride, H // output_stride)
7581

76-
@pytest.mark.parametrize("arch", TEST_MODEL_NAMES)
82+
@pytest.mark.parametrize("arch", EFFNET_NAMES[:2] + RESNET_NAMES[:2])
7783
def test_drop_connect(arch):
7884
m = models.__dict__[arch](drop_connect_rate=0.2)
7985
_test_forward(m)
@@ -87,7 +93,6 @@ def test_drop_connect(arch):
8793
"efficientnet_b2": 9109994,
8894
"efficientnet_b3": 12233232,
8995
}
90-
# @pytest.mark.parametrize('name, num_params', NUM_PARAMS.values(), ids=list(NUM_PARAMS.keys()))
9196
@pytest.mark.parametrize('name_num_params', zip(NUM_PARAMS.items()))
9297
def test_num_parameters(name_num_params):
9398
name, num_params = name_num_params[0]

0 commit comments

Comments
 (0)