-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathretrieval.py
More file actions
71 lines (66 loc) · 2.95 KB
/
Copy pathretrieval.py
File metadata and controls
71 lines (66 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from huggingface_hub import snapshot_download
from clip_retrieval.clip_back import load_clip_indices, KnnService, ClipOptions
from collections import defaultdict
import os
import glob
import shutil
import random
class FeatureRetriever:
def __init__(self,
num_images=50,
imgs_per_dir=15,
force_download=False):
if force_download or not os.path.exists("./clip"):
print("Downloading clip resources")
rand_num = random.randint(0, 100000)
tmp_dir = f"./tmp_{rand_num}"
snapshot_download(repo_type="dataset", repo_id="wendlerc/sdxl-unbox-clip-indices", cache_dir=tmp_dir)
clip_dirs = glob.glob(f"{tmp_dir}/**/down_10_5120", recursive=True)
if len(clip_dirs) > 0:
shutil.copytree(clip_dirs[0].replace("down_10_5120", ""), "./clip", dirs_exist_ok=True)
shutil.rmtree(tmp_dir)
else:
ValueError("Could not find clip indices in the downloaded repo.")
# Initialize CLIP service
clip_options = ClipOptions(
indice_folder="currently unused by knn.query()",
clip_model="ViT-B/32", #"open_clip:ViT-H-14",
enable_hdf5=False,
enable_faiss_memory_mapping=True,
columns_to_return=["image_path", "similarity"],
reorder_metadata_by_ivf_index=False,
enable_mclip_option=False,
use_jit=False,
use_arrow=False,
provide_safety_model=False,
provide_violence_detector=False,
provide_aesthetic_embeddings=False,
)
self.names = ["down.2.1", "mid.0", "up.0.0", "up.0.1"]
self.paths = ["./clip/down_10_5120/indices_paths.json",
"./clip/mid_10_5120/indices_paths.json",
"./clip/up0_10_5120/indices_paths.json",
"./clip/up_10_5120/indices_paths.json",]
self.knn_service = {}
for name, path in zip(self.names, self.paths):
resources = load_clip_indices(path, clip_options)
self.knn_service[name] = KnnService(clip_resources=resources)
self.num_images = num_images
self.imgs_per_dir = imgs_per_dir
def query_text(self, query, block):
if block not in self.names:
raise ValueError(f"Block must be one of {self.names}")
results = self.knn_service[block].query(
text_input=query,
num_images=self.num_images,
num_result_ids=self.num_images,
deduplicate=True,
)
feat_sims = defaultdict(list)
feat_scores = {}
for result in results:
feature_id = result["image_path"].split("/")[-2]
feat_sims[feature_id] += [result["similarity"]]
for fid, sims in feat_sims.items():
feat_scores[fid] = (sum(sims) / len(sims)) * (len(sims)/self.imgs_per_dir)
return dict(sorted(feat_scores.items(), key=lambda item: -item[1]))