Skip to content

Commit 5d4b28e

Browse files
committed
test: add test cases for int8 vector
Signed-off-by: binbin lv <[email protected]>
1 parent f20e085 commit 5d4b28e

File tree

8 files changed

+248
-36
lines changed

8 files changed

+248
-36
lines changed

tests/python_client/base/client_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,9 @@ def init_collection_general(self, prefix="test", insert_data=False, nb=ct.defaul
352352
# 4 create default index if specified
353353
if is_index:
354354
# This condition will be removed after auto index feature
355-
if is_binary:
355+
if vector_data_type == DataType.INT8_VECTOR:
356+
collection_w.create_index(ct.default_float_vec_field_name, ct.int8_vector_index)
357+
elif is_binary:
356358
collection_w.create_index(ct.default_binary_vec_field_name, ct.default_bin_flat_index)
357359
elif vector_data_type == DataType.SPARSE_FLOAT_VECTOR:
358360
for vector_name in vector_name_list:

tests/python_client/base/client_v2_base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,4 +1100,26 @@ def transfer_replica(self, client, source_group, target_group, collection_name,
11001100
source_group=source_group, target_group=target_group,
11011101
collection_name=collection_name, num_replicas=num_replicas,
11021102
**kwargs).run()
1103+
return res, check_result
1104+
1105+
@trace()
1106+
def create_field_schema(self, client, name, data_type, desc='', timeout=None, check_task=None, check_items=None, **kwargs):
1107+
timeout = TIMEOUT if timeout is None else timeout
1108+
kwargs.update({"timeout": timeout})
1109+
1110+
func_name = sys._getframe().f_code.co_name
1111+
res, check = api_request([client.create_field_schema, name, data_type, desc], **kwargs)
1112+
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
1113+
**kwargs).run()
1114+
return res, check_result
1115+
1116+
@trace()
1117+
def add_collection_field(self, client, collection_name, field_schema, timeout=None, check_task=None, check_items=None, **kwargs):
1118+
timeout = TIMEOUT if timeout is None else timeout
1119+
kwargs.update({"timeout": timeout})
1120+
1121+
func_name = sys._getframe().f_code.co_name
1122+
res, check = api_request([client.add_collection_field, collection_name, field_schema], **kwargs)
1123+
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
1124+
**kwargs).run()
11031125
return res, check_result

tests/python_client/check/func_check.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from common import common_func as cf
88
from common.common_type import CheckTasks, Connect_Object_Name
99
# from common.code_mapping import ErrorCode, ErrorMessage
10-
from pymilvus import Collection, Partition, ResourceGroupInfo
10+
from pymilvus import Collection, Partition, ResourceGroupInfo, DataType
1111
import check.param_check as pc
12-
12+
import numpy as np
13+
from ml_dtypes import bfloat16
1314

1415
class Error:
1516
def __init__(self, error):
@@ -259,8 +260,27 @@ def check_describe_collection_property(res, func_name, check_items):
259260
if check_items.get("id_name", "id"):
260261
assert res["fields"][0]["name"] == check_items.get("id_name", "id")
261262
if check_items.get("vector_name", "vector"):
262-
assert res["fields"][1]["name"] == check_items.get("vector_name", "vector")
263+
vector_name_list = []
264+
vector_name_list_expected = check_items.get("vector_name", "vector")
265+
for field in res["fields"]:
266+
if field["type"] in [101, 102, 103, 105]:
267+
vector_name_list.append(field["name"])
268+
if isinstance(vector_name_list_expected, str):
269+
assert vector_name_list[0] == check_items.get("vector_name", "vector")
270+
else:
271+
assert vector_name_list == vector_name_list_expected
263272
if check_items.get("dim", None) is not None:
273+
dim_list = []
274+
# here dim support int for only one vector field and list for multiple vector fields, and the order
275+
# should be the same of the order adding schema
276+
dim_list_expected = check_items.get("dim")
277+
for field in res["fields"]:
278+
if field["type"] in [101, 102, 103, 105]:
279+
dim_list.append(field["params"]["dim"])
280+
if isinstance(dim_list_expected, int):
281+
assert dim_list[0] == dim_list_expected
282+
else:
283+
assert dim_list == dim_list_expected
264284
assert res["fields"][1]["params"]["dim"] == check_items.get("dim")
265285
if check_items.get("nullable_fields", None) is not None:
266286
nullable_fields = check_items.get("nullable_fields")
@@ -272,7 +292,7 @@ def check_describe_collection_property(res, func_name, check_items):
272292
assert field["nullable"] is True
273293
assert res["fields"][0]["is_primary"] is True
274294
assert res["fields"][0]["field_id"] == 100 and (res["fields"][0]["type"] == 5 or 21)
275-
assert res["fields"][1]["field_id"] == 101 and res["fields"][1]["type"] == 101
295+
assert res["fields"][1]["field_id"] == 101 and (res["fields"][1]["type"] == 101 or 105)
276296

277297
return True
278298

@@ -540,6 +560,22 @@ def check_query_results(query_res, func_name, check_items):
540560
exp_res = check_items.get("exp_res", None)
541561
with_vec = check_items.get("with_vec", False)
542562
pk_name = check_items.get("pk_name", ct.default_primary_field_name)
563+
vector_type = check_items.get("vector_type", "FLOAT_VECTOR")
564+
if vector_type == DataType.FLOAT16_VECTOR:
565+
for single_exp_res in exp_res:
566+
single_exp_res['vector'] = single_exp_res['vector'] .tolist()
567+
for single_query_result in query_res:
568+
single_query_result['vector'] = np.frombuffer(single_query_result['vector'][0], dtype=np.float16).tolist()
569+
if vector_type == DataType.BFLOAT16_VECTOR:
570+
for single_exp_res in exp_res:
571+
single_exp_res['vector'] = single_exp_res['vector'] .tolist()
572+
for single_query_result in query_res:
573+
single_query_result['vector'] = np.frombuffer(single_query_result['vector'][0], dtype=bfloat16).tolist()
574+
if vector_type == DataType.INT8_VECTOR:
575+
for single_exp_res in exp_res:
576+
single_exp_res['vector'] = single_exp_res['vector'] .tolist()
577+
for single_query_result in query_res:
578+
single_query_result['vector'] = np.frombuffer(single_query_result['vector'][0], dtype=np.int8).tolist()
543579
if exp_res is not None:
544580
if isinstance(query_res, list):
545581
assert pc.equal_entities_list(exp=exp_res, actual=query_res, primary_field=pk_name,

tests/python_client/common/common_func.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -696,17 +696,18 @@ def gen_float_vec_field(name=ct.default_float_vec_field_name, is_primary=False,
696696

697697
if vector_data_type != DataType.SPARSE_FLOAT_VECTOR:
698698
float_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=vector_data_type,
699-
description=description, dim=dim,
700-
is_primary=is_primary, **kwargs)
699+
description=description, dim=dim,
700+
is_primary=is_primary, **kwargs)
701701
else:
702702
# no dim for sparse vector
703703
float_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.SPARSE_FLOAT_VECTOR,
704-
description=description,
705-
is_primary=is_primary, **kwargs)
704+
description=description,
705+
is_primary=is_primary, **kwargs)
706706

707707
return float_vec_field
708708

709709

710+
710711
def gen_binary_vec_field(name=ct.default_binary_vec_field_name, is_primary=False, dim=ct.default_dim,
711712
description=ct.default_desc, **kwargs):
712713
binary_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.BINARY_VECTOR,
@@ -1120,6 +1121,31 @@ def gen_schema_multi_string_fields(string_fields):
11201121
return schema
11211122

11221123

1124+
def gen_vectors(nb, dim, vector_data_type=DataType.FLOAT_VECTOR):
1125+
vectors = []
1126+
if vector_data_type == DataType.FLOAT_VECTOR:
1127+
vectors = [[random.random() for _ in range(dim)] for _ in range(nb)]
1128+
elif vector_data_type == DataType.FLOAT16_VECTOR:
1129+
vectors = gen_fp16_vectors(nb, dim)[1]
1130+
elif vector_data_type == DataType.BFLOAT16_VECTOR:
1131+
vectors = gen_bf16_vectors(nb, dim)[1]
1132+
elif vector_data_type == DataType.SPARSE_FLOAT_VECTOR:
1133+
vectors = gen_sparse_vectors(nb, dim)
1134+
elif vector_data_type == ct.text_sparse_vector:
1135+
vectors = gen_text_vectors(nb) # for Full Text Search
1136+
elif vector_data_type == DataType.INT8_VECTOR:
1137+
vectors = gen_int8_vectors(nb, dim)[1]
1138+
elif vector_data_type == DataType.BINARY_VECTOR:
1139+
vectors = gen_binary_vectors(nb, dim)[1]
1140+
else:
1141+
log.error(f"Invalid vector data type: {vector_data_type}")
1142+
raise Exception(f"Invalid vector data type: {vector_data_type}")
1143+
if dim > 1:
1144+
if vector_data_type == DataType.FLOAT_VECTOR:
1145+
vectors = preprocessing.normalize(vectors, axis=1, norm='l2')
1146+
vectors = vectors.tolist()
1147+
return vectors
1148+
11231149
def gen_string(nb):
11241150
string_values = [str(random.random()) for _ in range(nb)]
11251151
return string_values
@@ -3295,15 +3321,6 @@ def gen_sparse_vectors(nb, dim=1000, sparse_format="dok", empty_percentage=0):
32953321
]
32963322
return vectors
32973323

3298-
def gen_int8_vectors(num, dim):
3299-
raw_vectors = []
3300-
int8_vectors = []
3301-
for _ in range(num):
3302-
raw_vector = [random.randint(-128, 127) for _ in range(dim)]
3303-
raw_vectors.append(raw_vector)
3304-
int8_vector = np.array(raw_vector, dtype=np.int8)
3305-
int8_vectors.append(int8_vector)
3306-
return raw_vectors, int8_vectors
33073324

33083325
def gen_vectors(nb, dim, vector_data_type=DataType.FLOAT_VECTOR):
33093326
vectors = []
@@ -3331,6 +3348,17 @@ def gen_vectors(nb, dim, vector_data_type=DataType.FLOAT_VECTOR):
33313348
return vectors
33323349

33333350

3351+
def gen_int8_vectors(num, dim):
3352+
raw_vectors = []
3353+
int8_vectors = []
3354+
for _ in range(num):
3355+
raw_vector = [random.randint(-128, 127) for _ in range(dim)]
3356+
raw_vectors.append(raw_vector)
3357+
int8_vector = np.array(raw_vector, dtype=np.int8)
3358+
int8_vectors.append(int8_vector)
3359+
return raw_vectors, int8_vectors
3360+
3361+
33343362
def gen_text_vectors(nb, language="en"):
33353363

33363364
fake = Faker("en_US")
@@ -3339,6 +3367,7 @@ def gen_text_vectors(nb, language="en"):
33393367
vectors = [" milvus " + fake.text() for _ in range(nb)]
33403368
return vectors
33413369

3370+
33423371
def field_types() -> dict:
33433372
return dict(sorted(dict(DataType.__members__).items(), key=lambda item: item[0], reverse=True))
33443373

tests/python_client/common/common_type.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,11 @@
6767
DataType.BINARY_VECTOR: "HAMMING",
6868
}
6969

70-
all_dense_vector_types = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR]
71-
all_float_vector_dtypes = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR]
7270

73-
append_vector_type = [DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR]
71+
append_vector_type = [DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR, DataType.INT8_VECTOR]
72+
all_dense_vector_types = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.INT8_VECTOR]
73+
all_float_vector_dtypes = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR, DataType.INT8_VECTOR]
74+
all_vector_data_types = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR, DataType.INT8_VECTOR]
7475
default_sparse_vec_field_name = "sparse_vector"
7576
default_partition_name = "_default"
7677
default_resource_group_name = '__default_resource_group'
@@ -253,6 +254,8 @@
253254

254255
inverted_index_algo = ['TAAT_NAIVE', 'DAAT_WAND', 'DAAT_MAXSCORE']
255256

257+
int8_vector_index = ["HNSW"]
258+
256259
default_all_indexes_params = [{}, {"nlist": 128}, {"nlist": 128}, {"nlist": 128, "m": 16, "nbits": 8},
257260
{"nlist": 128, "refine": 'true', "refine_type": "SQ8"},
258261
{"M": 32, "efConstruction": 360}, {"nlist": 128}, {},

tests/python_client/milvus_client/test_milvus_client_collection.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,10 @@ def test_milvus_client_collection_fast_creation_all_params(self, dim, metric_typ
278278

279279
@pytest.mark.tags(CaseLabel.L0)
280280
@pytest.mark.parametrize("nullable", [True, False])
281-
def test_milvus_client_collection_self_creation_default(self, nullable):
281+
@pytest.mark.parametrize("vector_type", [DataType.FLOAT_VECTOR, DataType.INT8_VECTOR])
282+
def test_milvus_client_collection_self_creation_default(self, nullable, vector_type):
282283
"""
283-
target: test fast create collection normal case
284+
target: test self create collection normal case
284285
method: create collection
285286
expected: create collection with default schema, index, and load successfully
286287
"""
@@ -290,7 +291,7 @@ def test_milvus_client_collection_self_creation_default(self, nullable):
290291
# 1. create collection
291292
schema = self.create_schema(client, enable_dynamic_field=False)[0]
292293
schema.add_field("id_string", DataType.VARCHAR, max_length=64, is_primary=True, auto_id=False)
293-
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
294+
schema.add_field("embeddings", vector_type, dim=dim)
294295
schema.add_field("title", DataType.VARCHAR, max_length=64, is_partition_key=True)
295296
schema.add_field("nullable_field", DataType.INT64, nullable=nullable, default_value=10)
296297
schema.add_field("array_field", DataType.ARRAY, element_type=DataType.INT64, max_capacity=12,
@@ -318,6 +319,46 @@ def test_milvus_client_collection_self_creation_default(self, nullable):
318319
if self.has_collection(client, collection_name)[0]:
319320
self.drop_collection(client, collection_name)
320321

322+
@pytest.mark.tags(CaseLabel.L2)
323+
def test_milvus_client_collection_self_creation_multiple_vectors(self):
324+
"""
325+
target: test self create collection with multiple vectors
326+
method: create collection with multiple vectors
327+
expected: create collection with default schema, index, and load successfully
328+
"""
329+
client = self._client()
330+
collection_name = cf.gen_unique_str(prefix)
331+
dim = 128
332+
# 1. create collection
333+
schema = self.create_schema(client, enable_dynamic_field=False)[0]
334+
schema.add_field("id_int64", DataType.INT64, is_primary=True, auto_id=False)
335+
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
336+
schema.add_field("int8embeddings_1", DataType.INT8_VECTOR, dim=dim * 2)
337+
schema.add_field("int8embeddings_2", DataType.FLOAT16_VECTOR, dim=int(dim / 2))
338+
schema.add_field("int8embeddings_3", DataType.BFLOAT16_VECTOR, dim=int(dim / 2))
339+
index_params = self.prepare_index_params(client)[0]
340+
index_params.add_index("embeddings", metric_type="COSINE")
341+
index_params.add_index("embeddings_1", metric_type="IP")
342+
index_params.add_index("embeddings_2", metric_type="L2")
343+
index_params.add_index("embeddings_3", metric_type="COSINE")
344+
# index_params.add_index("title")
345+
self.create_collection(client, collection_name, dimension=dim, schema=schema, index_params=index_params)
346+
collections = self.list_collections(client)[0]
347+
assert collection_name in collections
348+
check_items = {"collection_name": collection_name,
349+
"dim": [dim, dim * 2, dim / 2, dim / 2],
350+
"consistency_level": 0,
351+
"enable_dynamic_field": False,
352+
"id_name": "id_int64",
353+
"vector_name": ["embeddings", "embeddings_1", "embeddings_2", "embeddings_3"]}
354+
self.describe_collection(client, collection_name,
355+
check_task=CheckTasks.check_describe_collection_property,
356+
check_items=check_items)
357+
index = self.list_indexes(client, collection_name)[0]
358+
assert sorted(index) == sorted(['embeddings', 'embeddings_1', 'embeddings_2', 'embeddings_3'])
359+
if self.has_collection(client, collection_name)[0]:
360+
self.drop_collection(client, collection_name)
361+
321362
@pytest.mark.tags(CaseLabel.L1)
322363
def test_milvus_client_array_insert_search(self):
323364
"""

0 commit comments

Comments
 (0)