1- from typing import Any , Dict , Union
1+ from typing import Any , Dict
22
33import torch
44from torch import nn
@@ -17,50 +17,6 @@ def __init__(self, visual_feature_size: int):
1717 self .visual_feature_size = visual_feature_size
1818
1919
20- class BlindVisualBackbone (VisualBackbone ):
21- r"""
22- A visual backbone which cannot see the image. It always outputs a tensor
23- filled with constant value.
24-
25- Parameters
26- ----------
27- visual_feature_size: int, optional (default = 2048)
28- Size of the last dimension (channels) of output from forward pass.
29- bias_value: float, optional (default = 1.0)
30- Constant value to fill in the output tensor.
31- """
32-
33- def __init__ (self , visual_feature_size : int = 2048 , bias_value : float = 1.0 ):
34- super ().__init__ (visual_feature_size )
35-
36- # We never update the bias because a blind model cannot learn anything
37- # about the image. Add an axis for proper broadcasting.
38- self ._bias = nn .Parameter (
39- torch .full ((1 , self .visual_feature_size ), fill_value = bias_value ),
40- requires_grad = False ,
41- )
42-
43- def forward (self , image : torch .Tensor ) -> torch .Tensor :
44- r"""
45- Compute visual features for a batch of input images. Since this model
46- is *blind*, output will always be constant.
47-
48- Parameters
49- ----------
50- image: torch.Tensor
51- Batch of input images. A tensor of shape
52- ``(batch_size, 3, height, width)``.
53-
54- Returns
55- -------
56- torch.Tensor
57- Output visual features, filled with :attr:`bias_value`. A tensor of
58- shape ``(batch_size, visual_feature_size)``.
59- """
60- batch_size = image .size (0 )
61- return self ._bias .repeat (batch_size , 1 )
62-
63-
6420class TorchvisionVisualBackbone (VisualBackbone ):
6521 r"""
6622 A visual backbone from `Torchvision model zoo
@@ -91,7 +47,8 @@ def __init__(
9147 self .cnn = getattr (torchvision .models , name )(
9248 pretrained , zero_init_residual = True
9349 )
94- # Do nothing after the final residual stage.
50+ # Reove global average pooling and fc layer.
51+ self .cnn .avgpool = nn .Identity ()
9552 self .cnn .fc = nn .Identity ()
9653
9754 # Freeze all weights if specified.
@@ -100,12 +57,7 @@ def __init__(
10057 param .requires_grad = False
10158 self .cnn .eval ()
10259
103- # Keep a list of intermediate layer names.
104- self ._stage_names = [f"layer{ i } " for i in range (1 , 5 )]
105-
106- def forward (
107- self , image : torch .Tensor , return_intermediate_outputs : bool = False
108- ) -> Union [torch .Tensor , Dict [str , torch .Tensor ]]:
60+ def forward (self , image : torch .Tensor ) -> torch .Tensor :
10961 r"""
11062 Compute visual features for a batch of input images.
11163
@@ -114,41 +66,17 @@ def forward(
11466 image: torch.Tensor
11567 Batch of input images. A tensor of shape
11668 ``(batch_size, 3, height, width)``.
117- return_intermediate_outputs: bool, optional (default = False)
118- Whether to return feaures extracted from all intermediate stages or
119- just the last one. This can only be set ``True`` when using a
120- ResNet-like model.
12169
12270 Returns
12371 -------
124- Union[torch.Tensor, Dict[str, torch.Tensor]]
125- - If ``return_intermediate_outputs = False``, this will be a tensor
126- of shape ``(batch_size, channels, height, width)``, for example
127- it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50 (``layer4``).
128-
129- - If ``return_intermediate_outputs = True``, this will be a dict
130- with keys ``{"layer1", "layer2", "layer3", "layer4", "avgpool"}``
131- containing features from all intermediate layers and global
132- average pooling layer.
72+ torch.Tensor
73+ A tensor of shape ``(batch_size, channels, height, width)``, for
74+ example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50.
13375 """
13476
135- # Iterate through the modules in sequence and collect feature
136- # vectors for last layers in each stage.
137- intermediate_outputs : Dict [str , torch .Tensor ] = {}
138- for idx , (name , layer ) in enumerate (self .cnn .named_children ()):
139- out = layer (image ) if idx == 0 else layer (out )
140- if name in self ._stage_names :
141- intermediate_outputs [name ] = out
142-
143- # Add pooled spatial features.
144- intermediate_outputs ["avgpool" ] = torch .mean (
145- intermediate_outputs ["layer4" ], dim = [2 , 3 ]
146- )
147- if return_intermediate_outputs :
148- return intermediate_outputs
149- else :
150- # shape: (batch_size, channels, height, width)
151- return intermediate_outputs ["layer4" ]
77+ # shape: (batch_size, channels, height, width)
78+ # [ResNet-50: (b, 2048, 7, 7)]
79+ return self .cnn (image )
15280
15381 def detectron2_backbone_state_dict (self ) -> Dict [str , Any ]:
15482 r"""
0 commit comments