Skip to content

Commit d00e787

Browse files
committed
Move support of model types to all_clip module.
I created the all_clip module in order to have a single place to support all kind of clip models. It is already used in clip retrieval and I propose to use it here too.
1 parent 5f23a76 commit d00e787

File tree

6 files changed

+32
-83
lines changed

6 files changed

+32
-83
lines changed

README.md

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,8 @@ Here is an example of use
8080

8181
### How to add other CLIP models
8282

83-
Please follow these steps:
84-
1. Add a identity file to load model in `clip_benchmark/models`
85-
2. Define a loading function, that returns a tuple (model, transform, tokenizer). Please see `clip_benchmark/models/open_clip.py` as an example.
86-
3. Add the function into `TYPE2FUNC` in `clip_benchmark/models/__init__.py`
87-
88-
Remarks:
89-
- The new tokenizer/model must enable to do the following things as https://github.com/openai/CLIP#usage
90-
- `tokenizer(texts).to(device)` ... `texts` is a list of string
91-
- `model.encode_text(tokenized_texts)` ... `tokenized_texts` is a output from `tokenizer(texts).to(device)`
92-
- `model.encode_image(images)` ... `images` is a image tensor by the `transform`
83+
Please add your model into [all-clip](https://github.com/rom1504/all-clip) and it will be supported into CLIP-benchmark (and in clip-retrieval).
84+
See [How to add a model type](https://github.com/rom1504/all-clip?tab=readme-ov-file#how-to-add-a-model-type)
9385

9486

9587
### CIFAR-10 example

clip_benchmark/models/__init__.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
from typing import Union
22
import torch
3-
from .open_clip import load_open_clip
4-
from .japanese_clip import load_japanese_clip
3+
import all_clip
54

6-
# loading function must return (model, transform, tokenizer)
7-
TYPE2FUNC = {
8-
"open_clip": load_open_clip,
9-
"ja_clip": load_japanese_clip
10-
}
11-
MODEL_TYPES = list(TYPE2FUNC.keys())
5+
# see https://github.com/rom1504/all-clip?tab=readme-ov-file#supported-models
6+
MODEL_TYPES = ["openai_clip", "open_clip", "ja_clip", "hf_clip", "nm"]
127

138

149
def load_clip(
@@ -19,5 +14,10 @@ def load_clip(
1914
device: Union[str, torch.device] = "cuda"
2015
):
2116
assert model_type in MODEL_TYPES, f"model_type={model_type} is invalid!"
22-
load_func = TYPE2FUNC[model_type]
23-
return load_func(model_name=model_name, pretrained=pretrained, cache_dir=cache_dir, device=device)
17+
return all_clip.load_clip(
18+
clip_model=model_type+":"+model_name+"/"+pretrained,
19+
use_jit=True,
20+
warmup_batch_size=1,
21+
clip_cache_path=cache_dir,
22+
device=device,
23+
)

clip_benchmark/models/japanese_clip.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

clip_benchmark/models/open_clip.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ open_clip_torch>=0.2.1
66
pycocoevalcap
77
webdataset>=0.2.31
88
transformers
9+
all_clip>=1.0.0,<2

tests/test_clip_benchmark.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from clip_benchmark.cli import run
77
import logging
88
import torch
9+
import pytest
910

1011
class base_args:
1112
dataset="dummy"
@@ -109,10 +110,27 @@ class linear_probe_args:
109110
custom_classname_file=None
110111
distributed=False
111112

112-
def test_base():
113+
114+
def test_linear_probe():
113115
if torch.cuda.is_available():
114116
run(linear_probe_args)
115117
else:
116118
logging.warning("GPU acceleration is required for linear evaluation to ensure optimal performance and efficiency.")
119+
120+
121+
@pytest.mark.parametrize(
122+
"full_model_name",
123+
[
124+
"openai_clip:ViT-B/32",
125+
"open_clip:ViT-B-32/laion2b_s34b_b79k",
126+
"hf_clip:patrickjohncyh/fashion-clip",
127+
],
128+
)
129+
def test_base(full_model_name):
130+
model_type, model_name = full_model_name.split(":")
131+
model, pretrained = model_name.split("/")
132+
base_args.model_type = model_type
133+
base_args.model = model
134+
base_args.pretrained = pretrained
117135
os.environ["CUDA_VISIBLE_DEVICES"] = ""
118136
run(base_args)

0 commit comments

Comments
 (0)