-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquery_methods.py
More file actions
executable file
·152 lines (124 loc) · 4.32 KB
/
query_methods.py
File metadata and controls
executable file
·152 lines (124 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/bin/python3
from utils import *
import numpy.linalg as la
import heapq as hp
import collections as cl
from scipy.spatial import KDTree
from lsh import *
import numpy.random as rd
def second_closest_ratio(h, max_ratio):
if len(h) == 0:
return False
d1, first_im = h[0]
for d2, snd_im in h:
if first_im.group_id != snd_im.group_id:
return d1 <= max_ratio * d2
return False
# Fonction générique pour
def query(
data: Database,
query_im: Image,
search_func,
im_k: int = 5,
descr_k: int = 20,
verbose=False,
weight=lambda x: 1,
snd_closest_ratio=True,
max_ratio: float = 0.75,
ignore_self=True,
vote_for_class=False,
):
"""
@param data : La database dans laquelle chercher
@param im : l'image query
@param search_func : la fonction de recherche des plus proches voisins à appliquer sur chacun des descripeurs
@param im_k : le nombre d'images voisines à renvoyer
@param descr_k : le nombre de voisins maximum à considérer pour chaque descripteurs
@param weight : la fonction (de la distance) de pondération à utiliser pour le vote
"""
histogram = dict()
if verbose:
it = tqdm(query_im.descr, desc="Calcul des plus proches voisins")
else:
it = query_im.descr
for query_descr in it:
h = search_func(data, query_descr, descr_k)
if ignore_self:
h = list(filter(lambda x: x[1].id != query_im.id, h))
# skip this descriptor if not relevant enought
if snd_closest_ratio and not second_closest_ratio(h, max_ratio):
continue
for dist, im in h:
# incrémente si existe déjà, sinon met à 1 * weight
if vote_for_class:
histogram[im.group_id] = histogram.get(im, 0) + weight(dist)
else:
histogram[im] = histogram.get(im, 0) + weight(dist)
return sorted(histogram.items(), key=lambda x: x[1], reverse=True)[:im_k]
def basic_search(data: Database, query_descr, descr_k: int):
return basic_search_base(data.iter_descr(), query_point=query_descr, k=descr_k)
def query_on_tree(data: Database, tree, descr_k, query_descr):
d, inds = tree.query(query_descr, k=descr_k, p=2)
inds = map(lambda x: data.image_of_descr_index(x), inds)
return zip(d, inds)
def kd_tree_search_func_gen(data: Database, verbose=False):
print("Building tree...")
d = data.to_array()
tree = KDTree(d, leafsize=10, balanced_tree=True)
print("Tree built succesfully !")
return lambda _, query_descr, descr_k: query_on_tree(
data, tree, descr_k, query_descr
)
def init_lsh(
data: Database, verbose=False, nb_tables=10, nb_fun_per_table=2, r=10
): # augmenter le nbr de fonctions par table fait très vite augmenter le nombre de buckets
s = Lsh(nb_fun_per_table=nb_fun_per_table, nb_tables=nb_tables, r=r)
if verbose:
print("Preprosessing ...")
s.preprocess(data, verbose=verbose)
if verbose:
print("Preprossesing finished !")
return s, lambda _, query_descr, descr_k: s.query_knn(descr_k, query_descr)
if __name__ == "__main__":
args = sys.argv
assert len(args) == 2
datapath = args[1]
# impath = args[2]
d = Database(datapath, auto_init=True, verbose=True)
print("Nombre de points du nuage : ", d.taille_nuage())
query_im = rd.choice(d.images)
print("Classe de l'image recherchée : ", query_im.group_id)
# search_f = kd_tree_search_func_gen(d, verbose=True)
# search_f = basic_search
l, search_f = init_lsh(
d,
verbose=True,
)
print(l.nb_buck_per_table())
result = query(
d,
query_im,
search_f,
im_k=1,
descr_k=50, # bug : si le nombre de tables est >= descr_k , aucun voisins ne sont trouvés
verbose=True,
weight=lambda x: 1 / (x + 0.001),
snd_closest_ratio=False,
ignore_self=True,
)
print(result[0][0].group_id)
search_f = kd_tree_search_func_gen(d, verbose=True)
# search_f = basic_search
# search_f = init_lsh(d, verbose=True)
result = query(
d,
query_im,
search_f,
im_k=5,
descr_k=20,
verbose=True,
weight=lambda x: 1 / (x + 0.001),
snd_closest_ratio=True,
)
for r in result:
print(r[0].name, r[1])