Skip to content

Commit f9f597b

Browse files
authored
[v0.10.2] (mishards) Fix mishards search bug (#3169)
* [skip ci]Reverse query result if metric is IP Signed-off-by: yinghao.zou <yinghao.zou@zilliz.com> * [skip ci] Update version check Signed-off-by: yinghao.zou <yinghao.zou@zilliz.com>
1 parent 058cdf0 commit f9f597b

File tree

4 files changed

+20
-8
lines changed

4 files changed

+20
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Please mark all change in change log and use the issue from GitHub
99
- \#2952 Fix the result merging of IVF_PQ IP
1010
- \#2975 Fix config UT failed
1111
- \#3012 If the cache is too small, queries using multiple GPUs will cause to crash
12+
- \#3133 Reverse query result in mishards if metric type is IP
1213

1314
## Feature
1415

shards/mishards/connections.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,16 @@
224224
# connection = Connection(name=self.name, uri=self.uri, max_retry=self.max_retry, **self.kwargs)
225225
# return connection
226226

227+
def version_supported(version):
228+
version_pattern = lambda v : ".".join(v.split(".")[:2])
229+
230+
sv_patterns = set()
231+
for supported_version in settings.SERVER_VERSIONS:
232+
sv_patterns.add(version_pattern(supported_version))
233+
234+
v_pattern = version_pattern(version)
235+
return v_pattern in sv_patterns
236+
227237

228238
class ConnectionGroup(topology.TopoGroup):
229239
def __init__(self, name):
@@ -243,7 +253,7 @@ def on_pre_add(self, topo_object):
243253
if not status.OK():
244254
logger.error('Cannot connect to newly added address: {}. Remove it now'.format(topo_object.name))
245255
return False
246-
if version not in settings.SERVER_VERSIONS:
256+
if not version_supported(version):
247257
logger.error('Cannot connect to server of version: {}. Only {} supported'.format(version,
248258
settings.SERVER_VERSIONS))
249259
return False

shards/mishards/service_handler.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ def __init__(self, tracer, router, max_workers=multiprocessing.cpu_count(), **kw
2727
self.max_workers = max_workers
2828

2929
def _reduce(self, source_ids, ids, source_diss, diss, k, reverse):
30-
if source_diss[k - 1] <= diss[0]:
30+
sort_f = lambda x, y: x >= y if reverse else lambda x, y: x <= y
31+
if sort_f(source_diss[k - 1], diss[0]):
3132
return source_ids, source_diss
32-
if diss[k - 1] <= source_diss[0]:
33+
if sort_f(diss[k - 1], source_diss[0]):
3334
return ids, diss
3435

3536
source_diss.extend(diss)
3637
diss_t = enumerate(source_diss)
37-
diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k]
38+
diss_m_rst = sorted(diss_t, key=lambda x: x[1], reverse=reverse)[:k]
3839
diss_m_out = [id_ for _, id_ in diss_m_rst]
3940

4041
source_ids.extend(ids)
@@ -149,9 +150,9 @@ def _do_query(self,
149150
params=search_params, _async=True)
150151
futures.append(future)
151152

152-
for f in futures:
153-
ret = f.result(raw=True)
154-
all_topk_results.append(ret)
153+
for f in futures:
154+
ret = f.result(raw=True)
155+
all_topk_results.append(ret)
155156

156157
reverse = collection_meta.metric_type == Types.MetricType.IP
157158
with self.tracer.start_span('do_merge', child_of=p_span):

shards/mishards/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
env.read_env()
1313

1414

15-
SERVER_VERSIONS = ['0.9.0', '0.9.1', '0.10.0', '0.10.1']
15+
SERVER_VERSIONS = ['0.9.x', '0.10.x']
1616
DEBUG = env.bool('DEBUG', False)
1717
MAX_RETRY = env.int('MAX_RETRY', 3)
1818

0 commit comments

Comments
 (0)