Skip to content

Commit 11c5793

Browse files
committed
Drop intermediate conv outputs and VOC FeatureExtractor.
1 parent c9dcf41 commit 11c5793

File tree

3 files changed

+39
-166
lines changed

3 files changed

+39
-166
lines changed

scripts/clf_voc07.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from virtex.config import Config
1717
from virtex.factories import PretrainingModelFactory, DownstreamDatasetFactory
18-
from virtex.models.downstream import FeatureExtractor
1918
from virtex.utils.checkpointing import CheckpointManager
2019
from virtex.utils.common import common_parser, common_setup
2120

@@ -36,10 +35,6 @@
3635

3736
# fmt: off
3837
parser.add_argument_group("Checkpointing")
39-
group.add_argument(
40-
"--layer", choices=["layer1", "layer2", "layer3", "layer4", "avgpool"],
41-
default="avgpool", help="Evaluate features extracted from this layer."
42-
)
4338
parser.add_argument(
4439
"--weight-init", choices=["random", "imagenet", "torchvision", "virtex"],
4540
default="virtex", help="""How to initialize weights:
@@ -161,9 +156,12 @@ def main(_A: argparse.Namespace):
161156
torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"],
162157
strict=False,
163158
)
159+
# Set ``ITERATION`` to a dummy value.
160+
ITERATION = 0
164161

165-
model = FeatureExtractor(model, layer_name=_A.layer, flatten_and_normalize=True)
166-
model = model.to(device).eval()
162+
# Transfer model to GPU and set to eval mode. This is a torchvision model
163+
# and it returns features as ``(batch_size, 2048, 7, 7)``.
164+
model = model.visual.cnn.to(device).eval()
167165

168166
# -------------------------------------------------------------------------
169167
# EXTRACT FEATURES FOR TRAINING SVMs
@@ -180,13 +178,33 @@ def main(_A: argparse.Namespace):
180178
for batch in tqdm(train_dataloader, desc="Extracting train features:"):
181179
features = model(batch["image"].to(device))
182180

181+
# Global average pool features. Assume the tensor is in NCHW format.
182+
if len(features.size()) > 2:
183+
features = features.view(features.size(0), features.size(1), -1)
184+
185+
# shape: (batch_size, visual_feature_size)
186+
features = features.mean(dim=-1)
187+
188+
# shape: (batch_size, visual_feature_size)
189+
features = features.view(features.size(0), -1)
190+
191+
# L2-normalize the global average pooled features.
192+
features = features / torch.norm(features, dim=-1).unsqueeze(-1)
193+
183194
features_train.append(features.cpu())
184195
targets_train.append(batch["label"])
185196

186197
# Similarly extract test features.
187198
for batch in tqdm(test_dataloader, desc="Extracting test features:"):
188199
features = model(batch["image"].to(device))
189200

201+
if len(features.size()) > 2:
202+
features = features.view(features.size(0), features.size(1), -1)
203+
features = features.mean(dim=-1)
204+
205+
features = features.view(features.size(0), -1)
206+
features = features / torch.norm(features, dim=-1).unsqueeze(-1)
207+
190208
features_test.append(features.cpu())
191209
targets_test.append(batch["label"])
192210

@@ -226,13 +244,10 @@ def main(_A: argparse.Namespace):
226244

227245
# Test set mAP for each class, for features from every layer.
228246
test_map = torch.tensor(pool_output).mean()
229-
logger.info(f"mAP: {test_map}")
230-
231-
# Tensorboard logging only when _A.weight_init == "virtex"
232-
if _A.weight_init == "virtex":
233-
tensorboard_writer.add_scalars(
234-
"metrics/voc07_clf", {f"{_A.layer}_mAP": test_map}, ITERATION
235-
)
247+
logger.info(f"Iteration: {ITERATION}, mAP: {test_map}")
248+
tensorboard_writer.add_scalars(
249+
"metrics/voc07_clf", {f"{_A.layer}_mAP": test_map}, ITERATION
250+
)
236251

237252

238253
if __name__ == "__main__":

virtex/models/downstream.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

virtex/modules/visual_backbones.py

Lines changed: 10 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Union
1+
from typing import Any, Dict
22

33
import torch
44
from torch import nn
@@ -17,50 +17,6 @@ def __init__(self, visual_feature_size: int):
1717
self.visual_feature_size = visual_feature_size
1818

1919

20-
class BlindVisualBackbone(VisualBackbone):
21-
r"""
22-
A visual backbone which cannot see the image. It always outputs a tensor
23-
filled with constant value.
24-
25-
Parameters
26-
----------
27-
visual_feature_size: int, optional (default = 2048)
28-
Size of the last dimension (channels) of output from forward pass.
29-
bias_value: float, optional (default = 1.0)
30-
Constant value to fill in the output tensor.
31-
"""
32-
33-
def __init__(self, visual_feature_size: int = 2048, bias_value: float = 1.0):
34-
super().__init__(visual_feature_size)
35-
36-
# We never update the bias because a blind model cannot learn anything
37-
# about the image. Add an axis for proper broadcasting.
38-
self._bias = nn.Parameter(
39-
torch.full((1, self.visual_feature_size), fill_value=bias_value),
40-
requires_grad=False,
41-
)
42-
43-
def forward(self, image: torch.Tensor) -> torch.Tensor:
44-
r"""
45-
Compute visual features for a batch of input images. Since this model
46-
is *blind*, output will always be constant.
47-
48-
Parameters
49-
----------
50-
image: torch.Tensor
51-
Batch of input images. A tensor of shape
52-
``(batch_size, 3, height, width)``.
53-
54-
Returns
55-
-------
56-
torch.Tensor
57-
Output visual features, filled with :attr:`bias_value`. A tensor of
58-
shape ``(batch_size, visual_feature_size)``.
59-
"""
60-
batch_size = image.size(0)
61-
return self._bias.repeat(batch_size, 1)
62-
63-
6420
class TorchvisionVisualBackbone(VisualBackbone):
6521
r"""
6622
A visual backbone from `Torchvision model zoo
@@ -91,7 +47,8 @@ def __init__(
9147
self.cnn = getattr(torchvision.models, name)(
9248
pretrained, zero_init_residual=True
9349
)
94-
# Do nothing after the final residual stage.
50+
# Reove global average pooling and fc layer.
51+
self.cnn.avgpool = nn.Identity()
9552
self.cnn.fc = nn.Identity()
9653

9754
# Freeze all weights if specified.
@@ -100,12 +57,7 @@ def __init__(
10057
param.requires_grad = False
10158
self.cnn.eval()
10259

103-
# Keep a list of intermediate layer names.
104-
self._stage_names = [f"layer{i}" for i in range(1, 5)]
105-
106-
def forward(
107-
self, image: torch.Tensor, return_intermediate_outputs: bool = False
108-
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
60+
def forward(self, image: torch.Tensor) -> torch.Tensor:
10961
r"""
11062
Compute visual features for a batch of input images.
11163
@@ -114,41 +66,17 @@ def forward(
11466
image: torch.Tensor
11567
Batch of input images. A tensor of shape
11668
``(batch_size, 3, height, width)``.
117-
return_intermediate_outputs: bool, optional (default = False)
118-
Whether to return feaures extracted from all intermediate stages or
119-
just the last one. This can only be set ``True`` when using a
120-
ResNet-like model.
12169
12270
Returns
12371
-------
124-
Union[torch.Tensor, Dict[str, torch.Tensor]]
125-
- If ``return_intermediate_outputs = False``, this will be a tensor
126-
of shape ``(batch_size, channels, height, width)``, for example
127-
it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50 (``layer4``).
128-
129-
- If ``return_intermediate_outputs = True``, this will be a dict
130-
with keys ``{"layer1", "layer2", "layer3", "layer4", "avgpool"}``
131-
containing features from all intermediate layers and global
132-
average pooling layer.
72+
torch.Tensor
73+
A tensor of shape ``(batch_size, channels, height, width)``, for
74+
example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50.
13375
"""
13476

135-
# Iterate through the modules in sequence and collect feature
136-
# vectors for last layers in each stage.
137-
intermediate_outputs: Dict[str, torch.Tensor] = {}
138-
for idx, (name, layer) in enumerate(self.cnn.named_children()):
139-
out = layer(image) if idx == 0 else layer(out)
140-
if name in self._stage_names:
141-
intermediate_outputs[name] = out
142-
143-
# Add pooled spatial features.
144-
intermediate_outputs["avgpool"] = torch.mean(
145-
intermediate_outputs["layer4"], dim=[2, 3]
146-
)
147-
if return_intermediate_outputs:
148-
return intermediate_outputs
149-
else:
150-
# shape: (batch_size, channels, height, width)
151-
return intermediate_outputs["layer4"]
77+
# shape: (batch_size, channels, height, width)
78+
# [ResNet-50: (b, 2048, 7, 7)]
79+
return self.cnn(image)
15280

15381
def detectron2_backbone_state_dict(self) -> Dict[str, Any]:
15482
r"""

0 commit comments

Comments
 (0)