Skip to content

Commit d3de276

Browse files
author
valhassan
committed
Refactor MultiLevelNeck class in neckhead.py to accept a list of output channels, ensuring compatibility with varying output scales.
1 parent f4535b0 commit d3de276

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

geo_deep_learning/models/dofa/neckhead.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def xavier_init(module, gain=1, bias=0, distribution='normal'):
3535
if hasattr(module, 'bias') and module.bias is not None:
3636
nn.init.constant_(module.bias, bias)
3737

38+
3839
class ConvModule(nn.Module):
3940
def __init__(self,
4041
in_channels,
@@ -78,7 +79,7 @@ class MultiLevelNeck(nn.Module):
7879
7980
Args:
8081
in_channels (List[int]): Number of input channels per scale.
81-
out_channels (int): Number of output channels (used at each scale).
82+
out_channels (List[int]): Number of output channels (used at each scale).
8283
scales (List[float]): Scale factors for each input feature map.
8384
Default: [0.5, 1, 2, 4]
8485
norm_cfg (dict): Config dict for normalization layer. Default: None.
@@ -94,25 +95,26 @@ def __init__(self,
9495
act_cfg=None):
9596
super().__init__()
9697
assert isinstance(in_channels, list)
98+
assert isinstance(out_channels, list)
9799
self.in_channels = in_channels
98100
self.out_channels = out_channels
99101
self.scales = scales
100102
self.num_outs = len(scales)
101103
self.lateral_convs = nn.ModuleList()
102104
self.convs = nn.ModuleList()
103-
for in_channel in in_channels:
105+
for in_channel, out_channel in zip(in_channels, out_channels):
104106
self.lateral_convs.append(
105107
ConvModule(
106108
in_channel,
107-
out_channels,
109+
out_channel,
108110
kernel_size=1,
109111
norm_cfg=norm_cfg,
110112
act_cfg=act_cfg))
111-
for _ in range(self.num_outs):
113+
for out_channel in out_channels:
112114
self.convs.append(
113115
ConvModule(
114-
out_channels,
115-
out_channels,
116+
out_channel,
117+
out_channel,
116118
kernel_size=3,
117119
padding=1,
118120
stride=1,

0 commit comments

Comments
 (0)