|
15 | 15 | from pytorch_tools.modules import bn_from_name |
16 | 16 | from pytorch_tools.modules import ABN |
17 | 17 | from pytorch_tools.utils.misc import add_docs_for |
| 18 | +from pytorch_tools.utils.misc import repeat_channels |
18 | 19 |
|
19 | 20 | # avoid overwriting doc string |
20 | 21 | wraps = partial(wraps, assigned=("__module__", "__name__", "__qualname__", "__annotations__")) |
@@ -49,6 +50,8 @@ class TResNet(ResNet): |
49 | 50 | Activation for normalizion layer. It's reccomended to use `leacky_relu` with `inplaceabn`. |
50 | 51 | encoder (bool): |
51 | 52 | Flag to overwrite forward pass to return 5 tensors with different resolutions. Defaults to False. |
| 53 | + NOTE: TResNet first features have resolution 4x times smaller than input, not 2x as all other models. |
| 54 | + So it CAN'T be used as encoder in Unet and Linknet models |
52 | 55 | drop_rate (float): |
53 | 56 | Dropout probability before classifier, for training. Defaults to 0.0. to 'avg'. |
54 | 57 | drop_connect_rate (float): |
@@ -119,6 +122,9 @@ def __init__( |
119 | 122 | self._initialize_weights(init_bn0=True) |
120 | 123 |
|
121 | 124 | def load_state_dict(self, state_dict, **kwargs): |
| 125 | + if self.encoder: |
| 126 | + state_dict.pop("last_linear.weight") |
| 127 | + state_dict.pop("last_linear.bias") |
122 | 128 | nn.Module.load_state_dict(self, state_dict, **kwargs) |
123 | 129 |
|
124 | 130 | # fmt: off |
@@ -209,6 +215,8 @@ def _resnet(arch, pretrained=None, **kwargs): |
209 | 215 | # if there is last_linear in state_dict, it's going to be overwritten |
210 | 216 | state_dict["last_linear.weight"] = model.state_dict()["last_linear.weight"] |
211 | 217 | state_dict["last_linear.bias"] = model.state_dict()["last_linear.bias"] |
| 218 | + if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels |
| 219 | + state_dict["conv1.1.weight"] = repeat_channels(state_dict["conv1.1.weight"], kwargs["in_channels"] * 16, 3 * 16) |
212 | 220 | model.load_state_dict(state_dict) |
213 | 221 | # need to adjust some parameters to be align with original model |
214 | 222 | patch_blur_pool(model) |
|
0 commit comments