Skip to content

Commit 30140f5

Browse files
committed
fix: move encoder upsampler to multi-task-decoder for easier weight saving
1 parent fb4b833 commit 30140f5

File tree

10 files changed

+80
-186
lines changed

10 files changed

+80
-186
lines changed

cellseg_models_pytorch/decoders/multitask_decoder.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from cellseg_models_pytorch.decoders.long_skips import StemSkip
99
from cellseg_models_pytorch.decoders.unet_decoder import UnetDecoder
10+
from cellseg_models_pytorch.encoders.encoder_upsampler import EncoderUpsampler
1011
from cellseg_models_pytorch.models.base._initialization import (
1112
initialize_decoder,
1213
initialize_head,
@@ -36,8 +37,7 @@ def __init__(
3637
decoders: Tuple[str, ...],
3738
heads: Dict[str, Dict[str, int]],
3839
out_channels: Tuple[int, ...],
39-
enc_channels: Tuple[int, ...],
40-
enc_reductions: Tuple[int, ...],
40+
enc_feature_info: Tuple[Dict[str, Any], ...],
4141
n_layers: Tuple[int, ...],
4242
n_blocks: Tuple[int, ...],
4343
stage_kws: Tuple[Dict[str, Any], ...],
@@ -59,10 +59,8 @@ def __init__(
5959
out_channels (Tuple[int, ...]):
6060
Tuple of output channels for each decoder stage. The length of the tuple
6161
should be equal to the number of enc_channels.
62-
enc_channels (Tuple[int, ...]):
63-
Tuple of encoder channels.
64-
enc_reductions (Tuple[int, ...]):
65-
Tuple of encoder reduction factors.
62+
enc_feature_info (Tuple[Dict[str, Any], ...]):
63+
Tuple of encoder feature info dicts. Basically timm.model.feature_info
6664
n_layers (Tuple[int, ...]):
6765
Tuple of number of conv layers in each decoder stage.
6866
n_blocks (Tuple[int, ...]):
@@ -87,15 +85,30 @@ def __init__(
8785
self._check_head_args(heads, decoders)
8886
self._check_decoder_args(decoders)
8987
self._check_depth(
90-
len(enc_channels),
88+
len(n_blocks),
9189
{
92-
"n_blocks": n_blocks,
9390
"n_layers": n_layers,
9491
"out_channels": out_channels,
95-
"enc_reductions": enc_reductions,
92+
"enc_feature_info": enc_feature_info,
9693
},
9794
)
9895

96+
# get the reduction factors and out channels of the encoder
97+
self.enc_feature_info = enc_feature_info[::-1] # bottleneck first
98+
enc_reductions = tuple([inf["reduction"] for inf in self.enc_feature_info])
99+
enc_channels = tuple([inf["num_chs"] for inf in self.enc_feature_info])
100+
101+
# initialize feature upsampler if encoder is a vision transformer
102+
self.encoder_upsampler = None
103+
if all(elem == enc_reductions[0] for elem in enc_reductions):
104+
self.encoder_upsampler = EncoderUpsampler(
105+
feature_info=enc_feature_info,
106+
out_channels=out_channels,
107+
)
108+
self.enc_feature_info = self.encoder_upsampler.feature_info # bottlneck 1st
109+
enc_reductions = tuple([inf["reduction"] for inf in self.enc_feature_info])
110+
enc_channels = tuple([inf["num_chs"] for inf in self.enc_feature_info])
111+
99112
# style
100113
self.make_style = None
101114
if style_channels is not None:
@@ -194,14 +207,19 @@ def forward(
194207
195208
Parameters:
196209
enc_feats (Tuple[torch.Tensor, ...]):
197-
Tuple containing encoder feature tensors.
210+
Tuple containing encoder feature tensors. Assumes that the deepest i.e.
211+
the bottleneck features is the last element of the tuple.
198212
x_in (torch.Tensor, default=None):
199213
Optional (the input image) tensor for stem skip connection.
200214
201215
Returns:
202216
Tuple[Dict[str, List[torch.Tensor]], Dict[str, torch.Tensor]]:
203217
The output of the seg heads.
204218
"""
219+
enc_feats = enc_feats[::-1] # bottleneck first
220+
if self.encoder_upsampler is not None:
221+
enc_feats = self.encoder_upsampler(enc_feats)
222+
205223
style = self.forward_style(enc_feats[0])
206224
dec_feats = self.forward_features(enc_feats, style)
207225

@@ -211,7 +229,7 @@ def forward(
211229

212230
out = self.forward_heads(dec_feats)
213231

214-
return dec_feats, out
232+
return enc_feats, dec_feats, out
215233

216234
def initialize(self) -> None:
217235
"""Initialize the decoders and segmentation heads."""

cellseg_models_pytorch/encoders/encoder.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44
import torch.nn as nn
55

6-
from .encoder_upsampler import EncoderUpsampler
76
from .timm_encoder import TimmEncoder
87

98
__all__ = ["Encoder"]
@@ -14,26 +13,22 @@ def __init__(
1413
self,
1514
timm_encoder_name: str,
1615
timm_encoder_out_indices: Tuple[int, ...],
17-
pixel_decoder_out_channels: Tuple[int, ...],
1816
timm_encoder_pretrained: bool = True,
1917
timm_extra_kwargs: Dict[str, Any] = {},
2018
) -> None:
2119
"""Wrap timm encoders to one class.
2220
23-
Parameters
24-
----------
25-
timm_encoder_name : str
26-
Name of the encoder. If the name is in `TR_ENCODERS.keys()`, a transformer
27-
will be used. Otherwise, a timm encoder will be used.
28-
timm_encoder_out_indices : Tuple[int], optional
29-
Indices of the output features.
30-
pixel_decoder_out_channels : Tuple[int], optional
31-
Number of output channels at each upsampling stage.
32-
timm_encoder_pretrained : bool, optional, default=False
33-
If True, load pretrained timm weights, by default False.
34-
timm_extra_kwargs : Dict[str, Any], optional, default={}
35-
Key-word arguments for any `timm` based encoder. These arguments are
36-
used in `timm.create_model(**kwargs)` function call.
21+
Parameters:
22+
timm_encoder_name (str):
23+
Name of the encoder. If the name is in `TR_ENCODERS.keys()`, a transformer
24+
will be used. Otherwise, a timm encoder will be used.
25+
timm_encoder_out_indices (Tuple[int, ...]):
26+
Indices of the output features.
27+
timm_encoder_pretrained (bool, default=True):
28+
If True, load pretrained timm weights.
29+
timm_extra_kwargs (Dict[str, Any], default={}):
30+
Key-word arguments for any `timm` based encoder. These arguments are
31+
used in `timm.create_model(**kwargs)` function call.
3732
"""
3833
super().__init__()
3934

@@ -45,23 +40,13 @@ def __init__(
4540
extra_kwargs=timm_extra_kwargs,
4641
)
4742

48-
# initialize feature upsampler if encoder is a vision transformer
49-
feature_info = self.encoder.feature_info
50-
reductions = [finfo["reduction"] for finfo in feature_info]
51-
if all(element == reductions[0] for element in reductions):
52-
self.encoder = EncoderUpsampler(
53-
backbone=self.encoder,
54-
out_channels=pixel_decoder_out_channels,
55-
)
56-
feature_info = self.encoder.feature_info
57-
58-
self.out_channels = [f["num_chs"] for f in self.encoder.feature_info][::-1]
59-
self.feature_info = self.encoder.feature_info[::-1]
43+
self.out_channels = [f["num_chs"] for f in self.encoder.feature_info]
44+
self.feature_info = self.encoder.feature_info # bottleneck last element
6045

6146
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
6247
"""Forward pass of the encoder and return all the features."""
6348
output, feats = self.encoder(x)
64-
return output, feats[::-1]
49+
return output, feats # bottleneck feature is the last element
6550

6651
def freeze_encoder(self) -> None:
6752
"""Freeze the parameters of the encoeder."""

cellseg_models_pytorch/encoders/encoder_upsampler.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6464
class EncoderUpsampler(nn.Module):
6565
def __init__(
6666
self,
67-
backbone: nn.Module,
67+
feature_info: Tuple[dict, ...],
6868
out_channels: Tuple[int, ...],
6969
) -> None:
7070
"""Feature upsampler for transformer-like backbones.
@@ -75,28 +75,27 @@ def __init__(
7575
are two. Builds an image-pyramid like structure.
7676
7777
Parameters:
78-
backbone (nn.Module):
79-
Backbone network that extracts features.
78+
feature_info (Tuple[dict, ...]):
79+
timm feature info of the backbone. Assumes that the feature info dicts
80+
are in bottleneck first order I.e. the deepest encoder block first.
81+
For example: [
82+
{'module': 'blocks.8', 'num_chs': 1024, 'reduction': 16},
83+
{'module': 'blocks.4', 'num_chs': 1024, 'reduction': 16}
84+
}
8085
out_channels (Tuple[int, ...]):
8186
Number of channels in the output tensor of each upsampling block.
8287
Defaults to None.
8388
"""
84-
print(out_channels, backbone.feature_info)
8589
super().__init__()
86-
if len(out_channels) != len(backbone.feature_info):
90+
if len(out_channels) != len(feature_info):
8791
raise ValueError(
8892
"`out_channels` must have the same len as the `backbone.feature_info.`"
89-
f"Got {len(out_channels)} and {len(backbone.feature_info)} respectively."
93+
f"Got {len(out_channels)} and {len(feature_info)} respectively."
9094
)
9195

92-
self.backbone = backbone
9396
self.out_channels = out_channels
9497
self.feature_info = []
9598

96-
# flip the feature info so that we start building the
97-
# upsampling blocks from the bottleneck layer
98-
feature_info = backbone.feature_info[::-1]
99-
10099
# bottleneck layer
101100
self.bottleneck = nn.Conv2d(
102101
in_channels=feature_info[0]["num_chs"],
@@ -144,17 +143,17 @@ def __init__(
144143
)
145144
self.up_blocks[f"up{i + 1}"] = nn.Sequential(*up_blocks)
146145

147-
# flip the feature info back to the original order to match the top-down
148-
# order of timm feature_info. (high to low res)
149-
self.feature_info = self.feature_info[::-1]
150-
151-
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
152-
# get the features from the backbone
153-
final_feat, feats = self.backbone(x)
146+
def forward(self, feats: Tuple[torch.Tensor]) -> Tuple[torch.Tensor, ...]:
147+
"""Forward pass of the encoder upsampler.
154148
155-
# flip the features so that we start from the bottleneck (low res)
156-
feats = feats[::-1]
149+
Parameters:
150+
feats (Tuple[torch.Tensor]):
151+
Tuple of features from the backbone in bottleneck first order. I.e. the
152+
bottleneck (deepest) feature is the first element in the tuple.
157153
154+
Returns:
155+
Tuple[torch.Tensor, ...]: Tuple of upsampled features in hi-to-lo res order.
156+
"""
158157
# bottleneck feature
159158
up_feat = self.bottleneck(feats[0])
160159
intermediate_features = [up_feat]
@@ -164,4 +163,4 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ..
164163
up_feat = self.up_blocks[f"up{i + 1}"](feat)
165164
intermediate_features.append(up_feat)
166165

167-
return final_feat, tuple(intermediate_features[::-1]) # feats in top-down order
166+
return tuple(intermediate_features) # hi-to-lo res order

cellseg_models_pytorch/encoders/timm_encoder.py

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -69,86 +69,3 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
6969
offset = len(intermediates) - len(self.encoder.feature_info)
7070

7171
return final_feat, [intermediates[i + offset] for i in self.out_indices]
72-
73-
74-
# class TimmEncoder(nn.Module):
75-
# def __init__(
76-
# self,
77-
# name: str,
78-
# pretrained: bool = True,
79-
# checkpoint_path: str = None,
80-
# in_channels: int = 3,
81-
# depth: int = 4,
82-
# out_indices: List[int] = None,
83-
# **kwargs,
84-
# ) -> None:
85-
# """Import any encoder from timm package.
86-
87-
# Parameters
88-
# ----------
89-
# name : str
90-
# Name of the encoder.
91-
# pretrained : bool, optional
92-
# If True, load pretrained weights, by default True.
93-
# checkpoint_path : str, optional
94-
# Path to the checkpoint file, by default None. If not None, overrides
95-
# the `pretrained` argument.
96-
# in_channels : int, optional
97-
# Number of input channels, by default 3.
98-
# depth : int, optional
99-
# Number of output features, by default 4.
100-
# out_indices : List[int], optional
101-
# Indices of the output features, by default None. If None,
102-
# out_indices is set to range(len(depth)). Overrides the `depth` argument.
103-
# **kwargs : Dict[str, Any]
104-
# Key-word arguments for any `timm` based encoder. These arguments are
105-
# used in `timm.create_model(**kwargs)` function call.
106-
# """
107-
# super().__init__()
108-
109-
# # set out_indices
110-
# self.out_indices = out_indices
111-
# if out_indices is None:
112-
# self.out_indices = tuple(range(depth))
113-
114-
# # set checkpoint_path
115-
# if checkpoint_path is None:
116-
# checkpoint_path = ""
117-
118-
# # create the timm model
119-
# try:
120-
# self.backbone = timm.create_model(
121-
# name,
122-
# pretrained=pretrained,
123-
# checkpoint_path=checkpoint_path,
124-
# in_chans=in_channels,
125-
# features_only=True,
126-
# out_indices=self.out_indices,
127-
# **kwargs,
128-
# )
129-
# except (AttributeError, RuntimeError) as err:
130-
# print(err)
131-
# raise RuntimeError(
132-
# f"timm backbone: {name} is not supported due to missing "
133-
# "features_only argument implementation in timm-package."
134-
# )
135-
# except IndexError as err:
136-
# print(err)
137-
# raise IndexError(
138-
# f"It's possible that the given depth: {depth} is too large for "
139-
# f"the given backbone: {name}. Try passing a smaller `depth` argument "
140-
# "or a different backbone."
141-
# )
142-
143-
# # set in_channels and out_channels
144-
# self.in_channels = in_channels
145-
# self.out_channels = tuple(self.backbone.feature_info.channels()[::-1])
146-
# if out_indices is not None:
147-
# self.out_channels = tuple(self.out_channels[i] for i in self.out_indices)
148-
149-
# self.feature_info = self.backbone.feature_info.info[:depth][::-1]
150-
151-
# def forward(self, x: torch.Tensor, **kwargs) -> List[torch.Tensor]:
152-
# """Forward pass of the encoder and return all the features."""
153-
# features = self.backbone(x)
154-
# return features[::-1]

cellseg_models_pytorch/models/cellpose/cellpose.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,15 @@ def __init__(
156156
self.encoder = Encoder(
157157
timm_encoder_name=enc_name,
158158
timm_encoder_out_indices=enc_out_indices,
159-
pixel_decoder_out_channels=out_channels,
160159
timm_encoder_pretrained=enc_pretrain,
161160
timm_extra_kwargs=encoder_kws,
162161
)
163162

164-
# get the reduction factors for the encoder
165-
enc_reductions = tuple([inf["reduction"] for inf in self.encoder.feature_info])
166-
167163
self.decoder = MultiTaskDecoder(
168164
decoders=decoders,
169165
heads=heads,
170166
out_channels=out_channels,
171-
enc_channels=self.encoder.out_channels,
172-
enc_reductions=enc_reductions,
167+
enc_feature_info=self.encoder.feature_info,
173168
n_layers=n_layers,
174169
n_blocks=n_blocks,
175170
stage_kws=stage_kws,
@@ -208,7 +203,7 @@ def forward(
208203
outputs (segmentations) dict.
209204
"""
210205
enc_output, feats = self.encoder.forward(x)
211-
dec_feats, out = self.decoder.forward(feats, x)
206+
feats, dec_feats, out = self.decoder.forward(feats, x)
212207

213208
if return_feats:
214209
return enc_output, feats, dec_feats, out

0 commit comments

Comments
 (0)