|
31 | 31 | # from ml_dtypes import bfloat16
|
32 | 32 |
|
33 | 33 |
|
34 |
| -@pytest.mark.skip("please fix me") |
35 | 34 | class TestTypes:
|
36 |
| - @pytest.mark.parametrize("input_expect", [ |
37 |
| - ([1], DataType.FLOAT_VECTOR), |
38 |
| - ([True], DataType.UNKNOWN), |
39 |
| - ([1.0, 2.0], DataType.FLOAT_VECTOR), |
40 |
| - (["abc"], DataType.UNKNOWN), |
41 |
| - (bytes("abc", encoding='ascii'), DataType.BINARY_VECTOR), |
42 |
| - (1, DataType.INT64), |
43 |
| - (True, DataType.BOOL), |
44 |
| - ("abc", DataType.VARCHAR), |
45 |
| - (np.int8(1), DataType.INT8), |
46 |
| - (np.int16(1), DataType.INT16), |
47 |
| - ([np.int8(1)], DataType.FLOAT_VECTOR), |
48 |
| - ([np.float16(1.0)], DataType.FLOAT16_VECTOR), |
49 |
| - # ([np.array([1, 1], dtype=bfloat16)], DataType.BFLOAT16_VECTOR), |
50 |
| - ]) |
| 35 | + @pytest.mark.skip("please fix me") |
| 36 | + @pytest.mark.parametrize( |
| 37 | + "input_expect", |
| 38 | + [ |
| 39 | + ([1], DataType.FLOAT_VECTOR), |
| 40 | + ([True], DataType.UNKNOWN), |
| 41 | + ([1.0, 2.0], DataType.FLOAT_VECTOR), |
| 42 | + (["abc"], DataType.UNKNOWN), |
| 43 | + (bytes("abc", encoding="ascii"), DataType.BINARY_VECTOR), |
| 44 | + (1, DataType.INT64), |
| 45 | + (True, DataType.BOOL), |
| 46 | + ("abc", DataType.VARCHAR), |
| 47 | + (np.int8(1), DataType.INT8), |
| 48 | + (np.int16(1), DataType.INT16), |
| 49 | + ([np.int8(1)], DataType.FLOAT_VECTOR), |
| 50 | + ([np.float16(1.0)], DataType.FLOAT16_VECTOR), |
| 51 | + # ([np.array([1, 1], dtype=bfloat16)], DataType.BFLOAT16_VECTOR), |
| 52 | + ], |
| 53 | + ) |
51 | 54 | def test_infer_dtype_bydata(self, input_expect):
|
52 | 55 | data, expect = input_expect
|
53 | 56 | got = infer_dtype_bydata(data)
|
54 | 57 | assert got == expect
|
55 | 58 |
|
| 59 | + def test_str_of_data_type(self): |
| 60 | + for v in DataType: |
| 61 | + assert isinstance(v, DataType) |
| 62 | + assert str(v) == str(v.value) |
| 63 | + assert str(v) != v.name |
| 64 | + |
56 | 65 |
|
57 | 66 | class TestConsistencyLevel:
|
58 | 67 | def test_consistency_level_int(self):
|
@@ -91,6 +100,8 @@ def test_shard(self):
|
91 | 100 | def test_shard_dup_nodeIDs(self):
|
92 | 101 | s = Shard("channel-1", (1, 1, 1), 1)
|
93 | 102 | assert s.channel_name == "channel-1"
|
94 |
| - assert s.shard_nodes == {1,} |
| 103 | + assert s.shard_nodes == { |
| 104 | + 1, |
| 105 | + } |
95 | 106 | assert s.shard_leader == 1
|
96 | 107 | print(s)
|
0 commit comments