Skip to content

Commit eea6b50

Browse files
authored
test: add test cases for int8 vector (#41957)
Signed-off-by: binbin lv <[email protected]>
1 parent 1735f55 commit eea6b50

File tree

9 files changed

+257
-39
lines changed

9 files changed

+257
-39
lines changed

tests/python_client/base/client_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ def init_collection_general(self, prefix="test", insert_data=False, nb=ct.defaul
368368
# Unlike dense vectors, sparse vectors cannot create flat index.
369369
if DataType.SPARSE_FLOAT_VECTOR.name in vector_name:
370370
collection_w.create_index(vector_name, ct.default_sparse_inverted_index)
371+
elif vector_data_type == DataType.INT8_VECTOR:
372+
collection_w.create_index(vector_name, ct.int8_vector_index)
371373
else:
372374
collection_w.create_index(vector_name, ct.default_flat_index)
373375

tests/python_client/base/client_v2_base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,4 +1100,26 @@ def transfer_replica(self, client, source_group, target_group, collection_name,
11001100
source_group=source_group, target_group=target_group,
11011101
collection_name=collection_name, num_replicas=num_replicas,
11021102
**kwargs).run()
1103+
return res, check_result
1104+
1105+
@trace()
1106+
def create_field_schema(self, client, name, data_type, desc='', timeout=None, check_task=None, check_items=None, **kwargs):
1107+
timeout = TIMEOUT if timeout is None else timeout
1108+
kwargs.update({"timeout": timeout})
1109+
1110+
func_name = sys._getframe().f_code.co_name
1111+
res, check = api_request([client.create_field_schema, name, data_type, desc], **kwargs)
1112+
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
1113+
**kwargs).run()
1114+
return res, check_result
1115+
1116+
@trace()
1117+
def add_collection_field(self, client, collection_name, field_schema, timeout=None, check_task=None, check_items=None, **kwargs):
1118+
timeout = TIMEOUT if timeout is None else timeout
1119+
kwargs.update({"timeout": timeout})
1120+
1121+
func_name = sys._getframe().f_code.co_name
1122+
res, check = api_request([client.add_collection_field, collection_name, field_schema], **kwargs)
1123+
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
1124+
**kwargs).run()
11031125
return res, check_result

tests/python_client/check/func_check.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from common import common_func as cf
88
from common.common_type import CheckTasks, Connect_Object_Name
99
# from common.code_mapping import ErrorCode, ErrorMessage
10-
from pymilvus import Collection, Partition, ResourceGroupInfo
10+
from pymilvus import Collection, Partition, ResourceGroupInfo, DataType
1111
import check.param_check as pc
12-
12+
import numpy as np
13+
from ml_dtypes import bfloat16
1314

1415
class Error:
1516
def __init__(self, error):
@@ -259,8 +260,27 @@ def check_describe_collection_property(res, func_name, check_items):
259260
if check_items.get("id_name", "id"):
260261
assert res["fields"][0]["name"] == check_items.get("id_name", "id")
261262
if check_items.get("vector_name", "vector"):
262-
assert res["fields"][1]["name"] == check_items.get("vector_name", "vector")
263+
vector_name_list = []
264+
vector_name_list_expected = check_items.get("vector_name", "vector")
265+
for field in res["fields"]:
266+
if field["type"] in [101, 102, 103, 105]:
267+
vector_name_list.append(field["name"])
268+
if isinstance(vector_name_list_expected, str):
269+
assert vector_name_list[0] == check_items.get("vector_name", "vector")
270+
else:
271+
assert vector_name_list == vector_name_list_expected
263272
if check_items.get("dim", None) is not None:
273+
dim_list = []
274+
# here dim support int for only one vector field and list for multiple vector fields, and the order
275+
# should be the same of the order adding schema
276+
dim_list_expected = check_items.get("dim")
277+
for field in res["fields"]:
278+
if field["type"] in [101, 102, 103, 105]:
279+
dim_list.append(field["params"]["dim"])
280+
if isinstance(dim_list_expected, int):
281+
assert dim_list[0] == dim_list_expected
282+
else:
283+
assert dim_list == dim_list_expected
264284
assert res["fields"][1]["params"]["dim"] == check_items.get("dim")
265285
if check_items.get("nullable_fields", None) is not None:
266286
nullable_fields = check_items.get("nullable_fields")
@@ -272,7 +292,7 @@ def check_describe_collection_property(res, func_name, check_items):
272292
assert field["nullable"] is True
273293
assert res["fields"][0]["is_primary"] is True
274294
assert res["fields"][0]["field_id"] == 100 and (res["fields"][0]["type"] == 5 or 21)
275-
assert res["fields"][1]["field_id"] == 101 and res["fields"][1]["type"] == 101
295+
assert res["fields"][1]["field_id"] == 101 and (res["fields"][1]["type"] == 101 or 105)
276296

277297
return True
278298

@@ -540,6 +560,22 @@ def check_query_results(query_res, func_name, check_items):
540560
exp_res = check_items.get("exp_res", None)
541561
with_vec = check_items.get("with_vec", False)
542562
pk_name = check_items.get("pk_name", ct.default_primary_field_name)
563+
vector_type = check_items.get("vector_type", "FLOAT_VECTOR")
564+
if vector_type == DataType.FLOAT16_VECTOR:
565+
for single_exp_res in exp_res:
566+
single_exp_res['vector'] = single_exp_res['vector'] .tolist()
567+
for single_query_result in query_res:
568+
single_query_result['vector'] = np.frombuffer(single_query_result['vector'][0], dtype=np.float16).tolist()
569+
if vector_type == DataType.BFLOAT16_VECTOR:
570+
for single_exp_res in exp_res:
571+
single_exp_res['vector'] = single_exp_res['vector'] .tolist()
572+
for single_query_result in query_res:
573+
single_query_result['vector'] = np.frombuffer(single_query_result['vector'][0], dtype=bfloat16).tolist()
574+
if vector_type == DataType.INT8_VECTOR:
575+
for single_exp_res in exp_res:
576+
single_exp_res['vector'] = single_exp_res['vector'] .tolist()
577+
for single_query_result in query_res:
578+
single_query_result['vector'] = np.frombuffer(single_query_result['vector'][0], dtype=np.int8).tolist()
543579
if exp_res is not None:
544580
if isinstance(query_res, list):
545581
assert pc.equal_entities_list(exp=exp_res, actual=query_res, primary_field=pk_name,

tests/python_client/common/common_func.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -696,17 +696,18 @@ def gen_float_vec_field(name=ct.default_float_vec_field_name, is_primary=False,
696696

697697
if vector_data_type != DataType.SPARSE_FLOAT_VECTOR:
698698
float_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=vector_data_type,
699-
description=description, dim=dim,
700-
is_primary=is_primary, **kwargs)
699+
description=description, dim=dim,
700+
is_primary=is_primary, **kwargs)
701701
else:
702702
# no dim for sparse vector
703703
float_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.SPARSE_FLOAT_VECTOR,
704-
description=description,
705-
is_primary=is_primary, **kwargs)
704+
description=description,
705+
is_primary=is_primary, **kwargs)
706706

707707
return float_vec_field
708708

709709

710+
710711
def gen_binary_vec_field(name=ct.default_binary_vec_field_name, is_primary=False, dim=ct.default_dim,
711712
description=ct.default_desc, **kwargs):
712713
binary_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.BINARY_VECTOR,
@@ -792,7 +793,8 @@ def gen_default_collection_schema(description=ct.default_desc, primary_field=ct.
792793

793794
if len(multiple_dim_array) != 0:
794795
for other_dim in multiple_dim_array:
795-
fields.append(gen_float_vec_field(gen_unique_str("multiple_vector"), dim=other_dim,
796+
name_prefix = "multiple_vector"
797+
fields.append(gen_float_vec_field(gen_unique_str(name_prefix), dim=other_dim,
796798
vector_data_type=vector_data_type))
797799

798800
schema, _ = ApiCollectionSchemaWrapper().init_collection_schema(fields=fields, description=description,
@@ -1120,6 +1122,32 @@ def gen_schema_multi_string_fields(string_fields):
11201122
return schema
11211123

11221124

1125+
def gen_vectors(nb, dim, vector_data_type=DataType.FLOAT_VECTOR):
1126+
vectors = []
1127+
if vector_data_type == DataType.FLOAT_VECTOR:
1128+
vectors = [[random.random() for _ in range(dim)] for _ in range(nb)]
1129+
elif vector_data_type == DataType.FLOAT16_VECTOR:
1130+
vectors = gen_fp16_vectors(nb, dim)[1]
1131+
elif vector_data_type == DataType.BFLOAT16_VECTOR:
1132+
vectors = gen_bf16_vectors(nb, dim)[1]
1133+
elif vector_data_type == DataType.SPARSE_FLOAT_VECTOR:
1134+
vectors = gen_sparse_vectors(nb, dim)
1135+
elif vector_data_type == ct.text_sparse_vector:
1136+
vectors = gen_text_vectors(nb) # for Full Text Search
1137+
elif vector_data_type == DataType.INT8_VECTOR:
1138+
vectors = gen_int8_vectors(nb, dim)[1]
1139+
elif vector_data_type == DataType.BINARY_VECTOR:
1140+
vectors = gen_binary_vectors(nb, dim)[1]
1141+
else:
1142+
log.error(f"Invalid vector data type: {vector_data_type}")
1143+
raise Exception(f"Invalid vector data type: {vector_data_type}")
1144+
if dim > 1:
1145+
if vector_data_type == DataType.FLOAT_VECTOR:
1146+
vectors = preprocessing.normalize(vectors, axis=1, norm='l2')
1147+
vectors = vectors.tolist()
1148+
return vectors
1149+
1150+
11231151
def gen_string(nb):
11241152
string_values = [str(random.random()) for _ in range(nb)]
11251153
return string_values
@@ -3141,7 +3169,8 @@ def extract_vector_field_name_list(collection_w):
31413169
if field['type'] == DataType.FLOAT_VECTOR \
31423170
or field['type'] == DataType.FLOAT16_VECTOR \
31433171
or field['type'] == DataType.BFLOAT16_VECTOR \
3144-
or field['type'] == DataType.SPARSE_FLOAT_VECTOR:
3172+
or field['type'] == DataType.SPARSE_FLOAT_VECTOR\
3173+
or field['type'] == DataType.INT8_VECTOR:
31453174
if field['name'] != ct.default_float_vec_field_name:
31463175
vector_name_list.append(field['name'])
31473176

@@ -3295,15 +3324,6 @@ def gen_sparse_vectors(nb, dim=1000, sparse_format="dok", empty_percentage=0):
32953324
]
32963325
return vectors
32973326

3298-
def gen_int8_vectors(num, dim):
3299-
raw_vectors = []
3300-
int8_vectors = []
3301-
for _ in range(num):
3302-
raw_vector = [random.randint(-128, 127) for _ in range(dim)]
3303-
raw_vectors.append(raw_vector)
3304-
int8_vector = np.array(raw_vector, dtype=np.int8)
3305-
int8_vectors.append(int8_vector)
3306-
return raw_vectors, int8_vectors
33073327

33083328
def gen_vectors(nb, dim, vector_data_type=DataType.FLOAT_VECTOR):
33093329
vectors = []
@@ -3331,6 +3351,17 @@ def gen_vectors(nb, dim, vector_data_type=DataType.FLOAT_VECTOR):
33313351
return vectors
33323352

33333353

3354+
def gen_int8_vectors(num, dim):
3355+
raw_vectors = []
3356+
int8_vectors = []
3357+
for _ in range(num):
3358+
raw_vector = [random.randint(-128, 127) for _ in range(dim)]
3359+
raw_vectors.append(raw_vector)
3360+
int8_vector = np.array(raw_vector, dtype=np.int8)
3361+
int8_vectors.append(int8_vector)
3362+
return raw_vectors, int8_vectors
3363+
3364+
33343365
def gen_text_vectors(nb, language="en"):
33353366

33363367
fake = Faker("en_US")
@@ -3339,6 +3370,7 @@ def gen_text_vectors(nb, language="en"):
33393370
vectors = [" milvus " + fake.text() for _ in range(nb)]
33403371
return vectors
33413372

3373+
33423374
def field_types() -> dict:
33433375
return dict(sorted(dict(DataType.__members__).items(), key=lambda item: item[0], reverse=True))
33443376

tests/python_client/common/common_type.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,11 @@
6868
DataType.BINARY_VECTOR: "HAMMING",
6969
}
7070

71-
all_dense_vector_types = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR]
72-
all_float_vector_dtypes = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR]
7371

74-
append_vector_type = [DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR]
72+
append_vector_type = [DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR, DataType.INT8_VECTOR]
73+
all_dense_vector_types = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.INT8_VECTOR]
74+
all_float_vector_dtypes = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR, DataType.INT8_VECTOR]
75+
all_vector_data_types = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR, DataType.INT8_VECTOR]
7576
default_sparse_vec_field_name = "sparse_vector"
7677
default_partition_name = "_default"
7778
default_resource_group_name = '__default_resource_group'
@@ -254,6 +255,8 @@
254255

255256
inverted_index_algo = ['TAAT_NAIVE', 'DAAT_WAND', 'DAAT_MAXSCORE']
256257

258+
int8_vector_index = ["HNSW"]
259+
257260
default_all_indexes_params = [{}, {"nlist": 128}, {"nlist": 128}, {"nlist": 128, "m": 16, "nbits": 8},
258261
{"nlist": 128, "refine": 'true', "refine_type": "SQ8"},
259262
{"M": 32, "efConstruction": 360}, {"nlist": 128}, {},

tests/python_client/milvus_client/test_milvus_client_collection.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,10 @@ def test_milvus_client_collection_fast_creation_all_params(self, dim, metric_typ
278278

279279
@pytest.mark.tags(CaseLabel.L0)
280280
@pytest.mark.parametrize("nullable", [True, False])
281-
def test_milvus_client_collection_self_creation_default(self, nullable):
281+
@pytest.mark.parametrize("vector_type", [DataType.FLOAT_VECTOR, DataType.INT8_VECTOR])
282+
def test_milvus_client_collection_self_creation_default(self, nullable, vector_type):
282283
"""
283-
target: test fast create collection normal case
284+
target: test self create collection normal case
284285
method: create collection
285286
expected: create collection with default schema, index, and load successfully
286287
"""
@@ -290,7 +291,7 @@ def test_milvus_client_collection_self_creation_default(self, nullable):
290291
# 1. create collection
291292
schema = self.create_schema(client, enable_dynamic_field=False)[0]
292293
schema.add_field("id_string", DataType.VARCHAR, max_length=64, is_primary=True, auto_id=False)
293-
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
294+
schema.add_field("embeddings", vector_type, dim=dim)
294295
schema.add_field("title", DataType.VARCHAR, max_length=64, is_partition_key=True)
295296
schema.add_field("nullable_field", DataType.INT64, nullable=nullable, default_value=10)
296297
schema.add_field("array_field", DataType.ARRAY, element_type=DataType.INT64, max_capacity=12,
@@ -318,6 +319,46 @@ def test_milvus_client_collection_self_creation_default(self, nullable):
318319
if self.has_collection(client, collection_name)[0]:
319320
self.drop_collection(client, collection_name)
320321

322+
@pytest.mark.tags(CaseLabel.L2)
323+
def test_milvus_client_collection_self_creation_multiple_vectors(self):
324+
"""
325+
target: test self create collection with multiple vectors
326+
method: create collection with multiple vectors
327+
expected: create collection with default schema, index, and load successfully
328+
"""
329+
client = self._client()
330+
collection_name = cf.gen_unique_str(prefix)
331+
dim = 128
332+
# 1. create collection
333+
schema = self.create_schema(client, enable_dynamic_field=False)[0]
334+
schema.add_field("id_int64", DataType.INT64, is_primary=True, auto_id=False)
335+
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
336+
schema.add_field("int8embeddings_1", DataType.INT8_VECTOR, dim=dim * 2)
337+
schema.add_field("int8embeddings_2", DataType.FLOAT16_VECTOR, dim=int(dim / 2))
338+
schema.add_field("int8embeddings_3", DataType.BFLOAT16_VECTOR, dim=int(dim / 2))
339+
index_params = self.prepare_index_params(client)[0]
340+
index_params.add_index("embeddings", metric_type="COSINE")
341+
index_params.add_index("embeddings_1", metric_type="IP")
342+
index_params.add_index("embeddings_2", metric_type="L2")
343+
index_params.add_index("embeddings_3", metric_type="COSINE")
344+
# index_params.add_index("title")
345+
self.create_collection(client, collection_name, dimension=dim, schema=schema, index_params=index_params)
346+
collections = self.list_collections(client)[0]
347+
assert collection_name in collections
348+
check_items = {"collection_name": collection_name,
349+
"dim": [dim, dim * 2, dim / 2, dim / 2],
350+
"consistency_level": 0,
351+
"enable_dynamic_field": False,
352+
"id_name": "id_int64",
353+
"vector_name": ["embeddings", "embeddings_1", "embeddings_2", "embeddings_3"]}
354+
self.describe_collection(client, collection_name,
355+
check_task=CheckTasks.check_describe_collection_property,
356+
check_items=check_items)
357+
index = self.list_indexes(client, collection_name)[0]
358+
assert sorted(index) == sorted(['embeddings', 'embeddings_1', 'embeddings_2', 'embeddings_3'])
359+
if self.has_collection(client, collection_name)[0]:
360+
self.drop_collection(client, collection_name)
361+
321362
@pytest.mark.tags(CaseLabel.L1)
322363
def test_milvus_client_array_insert_search(self):
323364
"""

0 commit comments

Comments
 (0)