Skip to content

Commit 716c78c

Browse files
IcemapMini256
andauthored
fix: make check_vector_column to use the proper method for identifying vector columns (#126)
* fix: check_vector_column has an error way to check the vector column * drop table should check_first by default * build: bump pytidb version to 0.0.8.post2 --------- Co-authored-by: Mini256 <minianter@foxmail.com>
1 parent 3f42bcf commit 716c78c

File tree

4 files changed

+75
-5
lines changed

4 files changed

+75
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "pytidb"
3-
version = "0.0.8.post1"
3+
version = "0.0.8.post2"
44
description = "A Python library for TiDB."
55
readme = "README.md"
66
requires-python = ">=3.10"

pytidb/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def drop_table(self, table_name: str):
148148
table = sqlalchemy.Table(
149149
table_name, Base.metadata, autoload_with=self._db_engine
150150
)
151-
return table.drop(self._db_engine)
151+
return table.drop(self._db_engine, checkfirst=True)
152152

153153
# Raw SQL API
154154

pytidb/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,15 @@ def filter_vector_columns(columns: Dict) -> List[Column]:
8484

8585

8686
def check_vector_column(columns: Dict, column_name: str) -> Optional[str]:
87+
if not isinstance(column_name, str):
88+
raise ValueError(f"Invalid vector column name: {column_name}")
89+
8790
if column_name not in columns:
8891
raise ValueError(f"Non-exists vector column: {column_name}")
8992

9093
vector_column = columns[column_name]
91-
if vector_column.type != VectorType:
92-
raise ValueError(f"Invalid vector column: {vector_column}")
94+
if not isinstance(vector_column.type, VectorType):
95+
raise ValueError(f"Invalid vector column: {column_name}")
9396

9497
return vector_column
9598

@@ -140,4 +143,8 @@ def get_row_id_from_row(row: Row, table: Table) -> Optional[RowKeyType]:
140143

141144

142145
def get_index_type(index: Index) -> str:
143-
return index.dialect_kwargs.get("mysql_prefix", "").lower()
146+
dialect_kwargs = getattr(index, "dialect_kwargs", None)
147+
if dialect_kwargs is None:
148+
return ""
149+
mysql_prefix = dialect_kwargs.get("mysql_prefix", "")
150+
return mysql_prefix.lower() if mysql_prefix else ""

tests/test_search_vector.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,66 @@ def test_rerank(vector_table: Table, reranker: BaseReranker):
212212
assert len(reranked_results) > 0
213213
assert reranked_results[0]["text"] == "bar"
214214
assert reranked_results[0]["_score"] > 0
215+
216+
217+
def test_with_multi_vector_fields(client: TiDBClient):
218+
class ChunkWithMultiVec(TableModel):
219+
__tablename__ = "test_vector_search_multi_vec"
220+
id: int = Field(None, primary_key=True)
221+
title: str = Field(None)
222+
title_vec: list[float] = Field(sa_column=Column(Vector(3)))
223+
body: str = Field(None)
224+
body_vec: list[float] = Field(sa_column=Column(Vector(3)))
225+
226+
tbl = client.create_table(schema=ChunkWithMultiVec, mode="overwrite")
227+
tbl.bulk_insert(
228+
[
229+
ChunkWithMultiVec(
230+
id=1, title="tree", title_vec=[4, 5, 6], body="cat", body_vec=[1, 2, 3]
231+
),
232+
ChunkWithMultiVec(
233+
id=2, title="grass", title_vec=[1, 2, 3], body="dog", body_vec=[7, 8, 9]
234+
),
235+
ChunkWithMultiVec(
236+
id=3, title="leaf", title_vec=[7, 8, 9], body="bird", body_vec=[4, 5, 6]
237+
),
238+
]
239+
)
240+
241+
with pytest.raises(ValueError, match="more than two vector columns"):
242+
tbl.search([1, 2, 3], search_type="vector").limit(3).to_list()
243+
244+
with pytest.raises(ValueError, match="Invalid vector column"):
245+
tbl.search([1, 2, 3], search_type="vector").vector_column("title").limit(3)
246+
247+
with pytest.raises(ValueError, match="Non-exists vector column"):
248+
tbl.search([1, 2, 3], search_type="vector").vector_column(
249+
"non_exist_column"
250+
).limit(3)
251+
252+
with pytest.raises(ValueError, match="Invalid vector column name"):
253+
tbl.search([1, 2, 3], search_type="vector").vector_column(None).limit(3)
254+
255+
results = (
256+
tbl.search([1, 2, 3], search_type="vector")
257+
.vector_column("title_vec")
258+
.limit(3)
259+
.to_list()
260+
)
261+
assert len(results) == 3
262+
assert results[0]["id"] == 2
263+
assert results[0]["title"] == "grass"
264+
assert results[0]["_distance"] == 0
265+
assert results[0]["_score"] == 1
266+
267+
results = (
268+
tbl.search([1, 2, 3], search_type="vector")
269+
.vector_column("body_vec")
270+
.limit(3)
271+
.to_list()
272+
)
273+
assert len(results) == 3
274+
assert results[0]["id"] == 1
275+
assert results[0]["body"] == "cat"
276+
assert results[0]["_distance"] == 0
277+
assert results[0]["_score"] == 1

0 commit comments

Comments
 (0)