Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions apps/pe/clip_benchmark/metrics/zeroshot_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,13 @@ def zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=Tr
texts = [template.format(c=classname) for template in templates]
else:
raise ValueError("templates must be a list or a dict")
texts = tokenizer(texts).to(device) # tokenize
class_embeddings = model.encode_text(texts)
if hasattr(model, 'encode_text'):
texts = tokenizer(texts).to(device) # tokenize
class_embeddings = model.encode_text(texts)
else:
# siglip
texts = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device) # tokenize
class_embeddings = model.text_model(**texts).pooler_output
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
Expand Down Expand Up @@ -139,7 +144,11 @@ def run_classification(
if video_dataset:
image_features = model.encode_video(images)
else:
image_features = model.encode_image(images)
if hasattr(model, 'encode_image'):
image_features = model.encode_image(images)
else:
# siglip models
image_features = model.vision_model(pixel_values=images).pooler_output

image_features = F.normalize(image_features, dim=-1)
logits = 100.0 * image_features @ classifier
Expand Down
17 changes: 2 additions & 15 deletions apps/pe/clip_benchmark/tasks/wds_benchmarks.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
# image classification
wds/wds_imagenet1k
wds/wds_imagenetv2
wds/wds_imagenet-a
wds/wds_imagenet-r
wds/wds_imagenet_sketch

# image retrieval
wds/wds_mscoco_captions
wds/wds_flickr30k

# video classification
k400_val

# video retrieval
msrvtt
wds/wds_cub-200-sam3_test
wds/wds_fgvc-aircraft-sam3_test
Binary file added apps/pe/docs/assets/test_image.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
172 changes: 168 additions & 4 deletions apps/pe/docs/pe_demo.ipynb

Large diffs are not rendered by default.

201 changes: 201 additions & 0 deletions apps/pe/docs/pe_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import os, sys
import json
import argparse

import torch
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
from PIL import Image
import open_clip

sys.path.append('../../../')
sys.path.append('../')
# import decord

if torch.cuda.is_available():
print('GPU is available. Use GPU for this demo')
else:
print('Use CPU for this demo')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import core.vision_encoder.pe as pe
import core.vision_encoder.transforms as transforms
from clip_benchmark.datasets.builder import build_wds_dataset
from clip_benchmark.metrics import zeroshot_classification


AVAILABLE_PE_MODELS = ['PE-Core-G14-448', 'PE-Core-L14-336', 'PE-Core-B16-224']
AVAILABLE_SIGLIP_MODELS = [
# SigLIP2
"ViT-B-32-SigLIP2-256",
"ViT-B-16-SigLIP2",
"ViT-B-16-SigLIP2-256",
"ViT-B-16-SigLIP2-384",
"ViT-B-16-SigLIP2-512",
"ViT-L-16-SigLIP2-256",
"ViT-L-16-SigLIP2-384",
"ViT-L-16-SigLIP2-512",
"ViT-SO400M-14-SigLIP2",
"ViT-SO400M-14-SigLIP2-378",
"ViT-SO400M-16-SigLIP2-256",
"ViT-SO400M-16-SigLIP2-384",
"ViT-SO400M-16-SigLIP2-512",
"ViT-gopt-16-SigLIP2-256",
"ViT-gopt-16-SigLIP2-384",
# SigLIPv1
'ViT-SO400M-14-SigLIP-384',
'ViT-B-16-SigLIP',
'ViT-B-16-SigLIP-256',
'ViT-B-16-SigLIP-i18n-256',
'ViT-B-16-SigLIP-384',
'ViT-B-16-SigLIP-512',
'ViT-L-16-SigLIP-256',
'ViT-L-16-SigLIP-384',
'ViT-SO400M-14-SigLIP',
'ViT-SO400M-14-SigLIP-384',
# other open_clip model
'ViT-B-32',
]

# WDS_DATASETS = ['wds_cars', 'wds_cifar10', 'wds_country211', 'wds_dollar_street', 'wds_fairface', 'wds_fgvc_aircraft', 'wds_food101', 'wds_geode', 'wds_gtsrb', 'wds_imagenet-a', 'wds_imagenet-o', 'wds_imagenet-r', 'wds_imagenet_sketch', 'wds_imagenetv2', 'wds_inaturalist', 'wds_mnist', 'wds_objectnet', 'wds_renderedsst2', 'wds_stl10', 'wds_sun397', 'wds_utkface', 'wds_voc2007', 'wds_vtab-caltech101', 'wds_vtab-cifar100', 'wds_vtab-clevr_closest_object_distance', 'wds_vtab-clevr_count_all', 'wds_vtab-dtd', 'wds_vtab-eurosat', 'wds_vtab-flowers', 'wds_vtab-kitti_closest_vehicle_distance', 'wds_vtab-pcam', 'wds_vtab-pets', 'wds_vtab-resisc45', 'wds_vtab-svhn', 'wds_wilds-camelyon17', 'wds_wilds-fmow', 'wds_wilds-iwildcam']
WDS_DATASETS = ['wds_inaturalist', 'wds_mnist', 'wds_objectnet', 'wds_renderedsst2', 'wds_stl10', 'wds_sun397', 'wds_utkface', 'wds_voc2007', 'wds_vtab-caltech101', 'wds_vtab-cifar100', 'wds_vtab-clevr_closest_object_distance', 'wds_vtab-clevr_count_all', 'wds_vtab-dtd', 'wds_vtab-eurosat', 'wds_vtab-flowers', 'wds_vtab-kitti_closest_vehicle_distance', 'wds_vtab-pcam', 'wds_vtab-pets', 'wds_vtab-resisc45', 'wds_vtab-svhn', 'wds_wilds-camelyon17', 'wds_wilds-fmow', 'wds_wilds-iwildcam']


def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default=None, help="model name.")
# parser.add_argument("--dataset_name", type=str, default="imagenet1k", help="Dataset specifier. See data.py.")
parser.add_argument("--bs", type=int, default=256, help="Eval batch size.")
parser.add_argument("--workers", type=int, default=8, help="Dataloder workers.")
args = parser.parse_args(args)
return args


args = parse_args(sys.argv[1:])

# Load <LANG>_classnames.json (packaged with CLIP benchmark that are used by default)
default_classname_file = os.path.join(
'/home/pengchuanzhang/GitHub/perception_models/apps/pe/clip_benchmark/datasets', "en_classnames.json"
)
if os.path.exists(default_classname_file):
with open(default_classname_file, "r") as f:
default_classnames = json.load(f)
else:
default_classnames = None

# Load <LANG>_zeroshot_classification_templates.json (packaged with CLIP benchmark that are used by default)
default_template_file = os.path.join(
'/home/pengchuanzhang/GitHub/perception_models/apps/pe/clip_benchmark/datasets', "en_zeroshot_classification_templates.json"
)
if os.path.exists(default_template_file):
with open(default_template_file, "r") as f:
default_templates = json.load(f)
else:
default_templates = None


model_name = args.model_name
if model_name in AVAILABLE_PE_MODELS:
model = pe.CLIP.from_config(model_name, pretrained=True) # Downloads from HF
model = model.to(device)

preprocess = transforms.get_image_transform(model.image_size)
tokenizer = transforms.get_text_tokenizer(model.context_length)
elif model_name in AVAILABLE_SIGLIP_MODELS:
model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s34b_b79k')
model = model.to(device)
model.eval() # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
tokenizer = open_clip.get_tokenizer(model_name)
else:
raise ValueError(f"Not supported model: {model_name}!")


batch_size = args.bs
num_workers = args.workers
to_evaluate_datasets = ['wds_fgvc-aircraft', 'wds_cub-200', 'imagenet1k']
# to_evaluate_datasets = WDS_DATASETS
for dataset_name in to_evaluate_datasets:
print(f"Run inference on {dataset_name}...")
if dataset_name.startswith("wds"):
if dataset_name in WDS_DATASETS:
data_root = f'/fsx-onevision/pengchuanzhang/datasets/pe_datasets/wds/{dataset_name}_test'
elif dataset_name == 'wds_fgvc-aircraft':
data_root = '/fsx-onevision/pengchuanzhang/datasets/pe_datasets/wds/wds_fgvc-aircraft-sam3_test'
elif dataset_name == 'wds_cub-200':
data_root = '/fsx-onevision/pengchuanzhang/datasets/pe_datasets/wds/wds_cub-200-sam3_test'
else:
raise ValueError(f"Dataset {dataset_name} not supported yet!")
dataset = build_wds_dataset(
dataset_name, preprocess, split="test", data_dir=data_root
)
dataloader = torch.utils.data.DataLoader(
dataset.batched(batch_size),
batch_size=None,
shuffle=False,
num_workers=num_workers,
)
zeroshot_templates = (
dataset.templates if hasattr(dataset, "templates") else None
)
classnames = dataset.classes if hasattr(dataset, "classes") else None
assert (
zeroshot_templates is not None and classnames is not None
), "Dataset does not support classification"
elif dataset_name == 'imagenet1k':
data_root = '/fsx-onevision/shared/data/imagenet_full_size'
dataset = ImageFolder(
root=os.path.join(data_root, "val"), transform=preprocess,
)
dataset.classes = default_classnames["imagenet1k"]
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
)
zeroshot_templates = default_templates["imagenet1k"]
classnames = default_classnames["imagenet1k"]
else:
raise ValueError(f"Dataset {dataset_name} not supported yet!")


classifier = zeroshot_classification.zero_shot_classifier(
model, tokenizer, classnames, zeroshot_templates, device, amp=True
)

logits, target = zeroshot_classification.run_classification(
model,
classifier,
dataloader,
device,
amp=True,
)

pred = logits.argmax(axis=1)
(acc1,) = zeroshot_classification.accuracy(logits, target, topk=(1,))
print("Top1 accuracy: ", acc1)

output_root = f'/fsx-onevision/pengchuanzhang/output/pe_evals/{model_name}'
if not os.path.exists(output_root):
os.makedirs(output_root)
save_file = os.path.join(output_root, f"{dataset_name}.pt")
torch.save(
{
"classifier": classifier,
"logits": logits,
"target": target,
"acc1": acc1
},
save_file
)


# metrics = zeroshot_classification.evaluate(
# model,
# dataloader,
# tokenizer,
# classnames,
# zeroshot_templates,
# device=device,
# )
# print(metrics)
5 changes: 3 additions & 2 deletions core/vision_encoder/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
def get_image_transform(
image_size: int,
center_crop: bool = False,
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR # We used bilinear during training
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, # We used bilinear during training
to_RGB: bool = True
):
if center_crop:
crop = [
Expand All @@ -20,7 +21,7 @@ def get_image_transform(
]

return T.Compose(crop + [
T.Lambda(lambda x: x.convert("RGB")),
T.Lambda(lambda x: x.convert("RGB") if to_RGB else x),
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
])
Expand Down
19 changes: 19 additions & 0 deletions perception_models.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Metadata-Version: 2.4
Name: perception_models
Version: 1.0.0
Summary: Occhi package.
Home-page: https://github.com/facebookresearch/perception_models
Author: Meta AI Research, FAIR
License: FAIR Noncommercial Research License
Classifier: Programming Language :: Python :: 3
Classifier: License :: Other/Proprietary License
Requires-Python: >=3.10
License-File: LICENSE.PE
License-File: LICENSE.PLM
Dynamic: author
Dynamic: classifier
Dynamic: home-page
Dynamic: license
Dynamic: license-file
Dynamic: requires-python
Dynamic: summary
8 changes: 8 additions & 0 deletions perception_models.egg-info/SOURCES.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
LICENSE.PE
LICENSE.PLM
README.md
setup.py
perception_models.egg-info/PKG-INFO
perception_models.egg-info/SOURCES.txt
perception_models.egg-info/dependency_links.txt
perception_models.egg-info/top_level.txt
1 change: 1 addition & 0 deletions perception_models.egg-info/dependency_links.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

1 change: 1 addition & 0 deletions perception_models.egg-info/top_level.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

32 changes: 0 additions & 32 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,32 +0,0 @@
numpy==2.1.2
omegaconf==2.3.0
msgspec==0.19.0
rouge-score==0.1.2
sacrebleu==2.5.1
sentencepiece==0.2.0
tiktoken==0.9.0
blobfile==3.0.0
wandb==0.19.8
viztracer==1.0.3
lm-eval==0.4.8
scipy==1.15.2
pynvml==12.0.0
orjson==3.10.15
einops==0.8.1
pillow==11.0.0
pyahocorasick==2.1.0
iopath==0.1.10
torchdata==0.11.0
torchcodec==0.1.0
timm==1.0.15
decord==0.6.0
opencv-python==4.11.0.86
pycocoevalcap==1.2
scikit-learn==1.6.1
scipy==1.15.2
sentencepiece==0.2.0
tokenizers==0.21.1
webdataset==0.2.111
fsspec
datatrove
ftfy
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
"License :: Other/Proprietary License",
],
license="FAIR Noncommercial Research License",
python_requires=">=3.11",
python_requires=">=3.10",
)