Skip to content

Commit 5edb054

Browse files
authored
fix: Unify output of str(DataType) for different pythons (#2635) (#2638)
See also: #2633 pr: #2635 Signed-off-by: yangxuan <[email protected]>
1 parent 2139c65 commit 5edb054

File tree

2 files changed

+56
-38
lines changed

2 files changed

+56
-38
lines changed

pymilvus/client/types.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
ExceptionsMessage,
88
InvalidConsistencyLevel,
99
)
10-
from pymilvus.grpc_gen import common_pb2, rg_pb2
10+
from pymilvus.grpc_gen import common_pb2, rg_pb2, schema_pb2
1111
from pymilvus.grpc_gen import milvus_pb2 as milvus_types
1212

1313
Status = TypeVar("Status")
@@ -84,29 +84,36 @@ def OK(self):
8484

8585

8686
class DataType(IntEnum):
87-
NONE = 0
88-
BOOL = 1
89-
INT8 = 2
90-
INT16 = 3
91-
INT32 = 4
92-
INT64 = 5
93-
94-
FLOAT = 10
95-
DOUBLE = 11
96-
97-
STRING = 20
98-
VARCHAR = 21
99-
ARRAY = 22
100-
JSON = 23
101-
102-
BINARY_VECTOR = 100
103-
FLOAT_VECTOR = 101
104-
FLOAT16_VECTOR = 102
105-
BFLOAT16_VECTOR = 103
106-
SPARSE_FLOAT_VECTOR = 104
87+
"""
88+
String of DataType is str of its value, e.g.: str(DataType.BOOL) == "1"
89+
"""
90+
91+
NONE = 0 # schema_pb2.None, this is an invalid representation in python
92+
BOOL = schema_pb2.Bool
93+
INT8 = schema_pb2.Int8
94+
INT16 = schema_pb2.Int16
95+
INT32 = schema_pb2.Int32
96+
INT64 = schema_pb2.Int64
97+
98+
FLOAT = schema_pb2.Float
99+
DOUBLE = schema_pb2.Double
100+
101+
STRING = schema_pb2.String
102+
VARCHAR = schema_pb2.VarChar
103+
ARRAY = schema_pb2.Array
104+
JSON = schema_pb2.JSON
105+
106+
BINARY_VECTOR = schema_pb2.BinaryVector
107+
FLOAT_VECTOR = schema_pb2.FloatVector
108+
FLOAT16_VECTOR = schema_pb2.Float16Vector
109+
BFLOAT16_VECTOR = schema_pb2.BFloat16Vector
110+
SPARSE_FLOAT_VECTOR = schema_pb2.SparseFloatVector
107111

108112
UNKNOWN = 999
109113

114+
def __str__(self) -> str:
115+
return str(self.value)
116+
110117

111118
class RangeType(IntEnum):
112119
LT = 0 # less than

tests/test_types.py

+28-17
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,37 @@
3131
# from ml_dtypes import bfloat16
3232

3333

34-
@pytest.mark.skip("please fix me")
3534
class TestTypes:
36-
@pytest.mark.parametrize("input_expect", [
37-
([1], DataType.FLOAT_VECTOR),
38-
([True], DataType.UNKNOWN),
39-
([1.0, 2.0], DataType.FLOAT_VECTOR),
40-
(["abc"], DataType.UNKNOWN),
41-
(bytes("abc", encoding='ascii'), DataType.BINARY_VECTOR),
42-
(1, DataType.INT64),
43-
(True, DataType.BOOL),
44-
("abc", DataType.VARCHAR),
45-
(np.int8(1), DataType.INT8),
46-
(np.int16(1), DataType.INT16),
47-
([np.int8(1)], DataType.FLOAT_VECTOR),
48-
([np.float16(1.0)], DataType.FLOAT16_VECTOR),
49-
# ([np.array([1, 1], dtype=bfloat16)], DataType.BFLOAT16_VECTOR),
50-
])
35+
@pytest.mark.skip("please fix me")
36+
@pytest.mark.parametrize(
37+
"input_expect",
38+
[
39+
([1], DataType.FLOAT_VECTOR),
40+
([True], DataType.UNKNOWN),
41+
([1.0, 2.0], DataType.FLOAT_VECTOR),
42+
(["abc"], DataType.UNKNOWN),
43+
(bytes("abc", encoding="ascii"), DataType.BINARY_VECTOR),
44+
(1, DataType.INT64),
45+
(True, DataType.BOOL),
46+
("abc", DataType.VARCHAR),
47+
(np.int8(1), DataType.INT8),
48+
(np.int16(1), DataType.INT16),
49+
([np.int8(1)], DataType.FLOAT_VECTOR),
50+
([np.float16(1.0)], DataType.FLOAT16_VECTOR),
51+
# ([np.array([1, 1], dtype=bfloat16)], DataType.BFLOAT16_VECTOR),
52+
],
53+
)
5154
def test_infer_dtype_bydata(self, input_expect):
5255
data, expect = input_expect
5356
got = infer_dtype_bydata(data)
5457
assert got == expect
5558

59+
def test_str_of_data_type(self):
60+
for v in DataType:
61+
assert isinstance(v, DataType)
62+
assert str(v) == str(v.value)
63+
assert str(v) != v.name
64+
5665

5766
class TestConsistencyLevel:
5867
def test_consistency_level_int(self):
@@ -91,6 +100,8 @@ def test_shard(self):
91100
def test_shard_dup_nodeIDs(self):
92101
s = Shard("channel-1", (1, 1, 1), 1)
93102
assert s.channel_name == "channel-1"
94-
assert s.shard_nodes == {1,}
103+
assert s.shard_nodes == {
104+
1,
105+
}
95106
assert s.shard_leader == 1
96107
print(s)

0 commit comments

Comments
 (0)