Skip to content

Feature extraction in torchvision.models.vit_b_16 #5718

Open
@DavidTorpey

Description

@DavidTorpey

🐛 Describe the bug

Hi

It’s easy enough to obtain output features from the CNNs in torchvision.models by doing this:

import torch
import torch.nn as nn
import torchvision.models as models

model = models.resnet18()
feature_extractor = nn.Sequential(*list(model.children())[:-1])
output_features = feature_extractor(torch.randn(1, 3, 224, 224))

However, when I attempt to do this with torchvision.models.vit_b_16:

import torch
import torch.nn as nn
import torchvision.models as models

model = models.vit_b_16()
feature_extractor = nn.Sequential(*list(model.children())[:-1])
output_features = feature_extractor(torch.randn(1, 3, 224, 224))

I get the following error:

AssertionError: Expected (batch_size, seq_length, hidden_dim) got torch.Size([1, 768, 14, 14])

Any help would be greatly appreciated.

Versions

Torch version: 1.11.0+cu102
Torchvision version: 0.12.0+cu102

cc @datumbox

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions