Skip to content

Commit cbf80bb

Browse files
committed
allow columns with quotes
1 parent 05dd0da commit cbf80bb

File tree

7 files changed

+61
-69
lines changed

7 files changed

+61
-69
lines changed

libs/libapi/src/libapi/duckdb.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from typing import Any, Optional, TypeVar
1414

1515
import anyio
16-
import duckdb
1716
import filelock
1817
from datasets import Features
1918
from filelock import AsyncFileLock
@@ -27,8 +26,11 @@
2726
DUCKDB_DEFAULT_PARTIAL_INDEX_FILENAME,
2827
compute_transformed_data,
2928
create_index,
29+
duckdb_connect,
3030
get_indexable_columns,
3131
get_monolingual_stemmer,
32+
key_sql,
33+
varchar_sql,
3234
)
3335
from libcommon.parquet_utils import (
3436
ParquetFileMetadataItem,
@@ -110,7 +112,7 @@ def build_index_file(
110112
split_parquet_files = split_parquet_files[:num_parquet_files_to_index]
111113
parquet_file_names = [parquet_file["filename"] for parquet_file in split_parquet_files]
112114

113-
column_names_sql = ",".join(f'"{column}"' for column in features)
115+
column_names_sql = ",".join(key_sql(column) for column in features)
114116

115117
# look for indexable columns (= possibly nested columns containing string data)
116118
indexable_columns = get_indexable_columns(Features.from_dict(features))
@@ -142,18 +144,18 @@ def build_index_file(
142144
Path(index_file_location).parent.mkdir(exist_ok=True, parents=True)
143145

144146
try:
145-
with duckdb.connect(index_file_location) as con:
147+
with duckdb_connect(index_file_location=index_file_location, extensions_directory=extensions_directory) as con:
146148
if transformed_df is not None:
147149
logging.debug(transformed_df.head())
148150
# update original data with results of transformations (string lengths, audio durations, etc.):
149151
logging.info(f"Updating data with {transformed_df.columns}")
150152
create_command_sql = CREATE_TABLE_JOIN_WITH_TRANSFORMED_DATA_COMMAND_FROM_LIST_OF_PARQUET_FILES.format(
151-
columns=column_names_sql, source=[str(p) for p in all_split_parquets]
153+
columns=column_names_sql, source="[" + ",".join(varchar_sql(str(p)) for p in all_split_parquets) + "]"
152154
)
153155

154156
else:
155157
create_command_sql = CREATE_TABLE_COMMAND_FROM_LIST_OF_PARQUET_FILES.format(
156-
columns=column_names_sql, source=[str(p) for p in all_split_parquets]
158+
columns=column_names_sql, source="[" + ",".join(varchar_sql(str(p)) for p in all_split_parquets) + "]"
157159
)
158160

159161
logging.info(create_command_sql)

libs/libcommon/src/libcommon/duckdb_utils.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434
DEFAULT_STEMMER = "none" # Exact word matches
3535
DUCKDB_DEFAULT_INDEX_FILENAME = "index.duckdb"
3636
DUCKDB_DEFAULT_PARTIAL_INDEX_FILENAME = "partial-index.duckdb"
37+
38+
ATTACH_DATABASE = "ATTACH '{database}' as db; USE db;"
39+
ATTACH_READ_ONLY_DATABASE = "ATTACH '{database}' as db (READ_ONLY); USE db;"
40+
LOAD_FTS_COMMAND = "INSTALL 'fts'; LOAD 'fts';"
41+
DISABLE_EXTERNAL_ACCESS_COMMAND = "SET enable_external_access=false;"
42+
LOCK_CONFIG_COMMAND = "SET lock_configuration=true;"
43+
SET_EXTENSIONS_DIRECTORY_COMMAND = "SET extension_directory='{directory}';"
44+
3745
CREATE_INDEX_COMMAND = (
3846
f"PRAGMA create_fts_index('data', '{ROW_IDX_COLUMN}', {{columns}}, stemmer='{{stemmer}}', overwrite=1);"
3947
)
@@ -216,16 +224,7 @@ def _sql(con: duckdb.DuckDBPyConnection, query: str) -> duckdb.DuckDBPyRelation:
216224
return out
217225

218226
with tempfile.TemporaryDirectory(suffix=".duckdb") as tmp_dir:
219-
with duckdb.connect(":memory:") as con:
220-
# configure duckdb extensions
221-
if extensions_directory is not None:
222-
con.execute(SET_EXTENSIONS_DIRECTORY_COMMAND.format(directory=extensions_directory))
223-
con.execute(INSTALL_AND_LOAD_EXTENSION_COMMAND)
224-
225-
# init
226-
_sql(con, "ATTACH '%database%' as db;")
227-
_sql(con, "USE db;")
228-
227+
with duckdb_connect(database=database, extensions_directory=extensions_directory) as con:
229228
# check input_table and get number of rows
230229
_count = _sql(con, "SELECT count(*) FROM %input_table%;").fetchone()
231230
if _count and isinstance(_count[0], int):
@@ -255,7 +254,8 @@ def _sql(con: duckdb.DuckDBPyConnection, query: str) -> duckdb.DuckDBPyRelation:
255254
)
256255

257256
# create fields table
258-
field_values = ", ".join(f"({i}, '{field}')" for i, field in enumerate(columns))
257+
258+
field_values = ", ".join(f"({i}, {varchar_sql(field)})" for i, field in enumerate(columns))
259259
_sql(
260260
con,
261261
"""
@@ -278,17 +278,8 @@ def _sql(con: duckdb.DuckDBPyConnection, query: str) -> duckdb.DuckDBPyRelation:
278278
batch_size = 1 + count // num_jobs
279279
commands = [
280280
(
281-
(
282-
SET_EXTENSIONS_DIRECTORY_COMMAND.format(directory=extensions_directory)
283-
if extensions_directory is not None
284-
else ""
285-
)
286-
+ INSTALL_AND_LOAD_EXTENSION_COMMAND
287-
+ (
288-
"ATTACH IF NOT EXISTS '%database%' as db (READ_ONLY);" # nosec - tmp_dir, batch_size, rank and i are safe
289-
"USE db;"
290-
f"ATTACH '{tmp_dir}/tmp_{rank}_{i}.duckdb' as tmp_{rank}_{i};"
291-
f"""
281+
f"ATTACH '{tmp_dir}/tmp_{rank}_{i}.duckdb' as tmp_{rank}_{i};" # nosec - tmp_dir, batch_size, rank and i are safe
282+
f"""
292283
CREATE TABLE tmp_{rank}_{i}.tokenized AS (
293284
SELECT unnest(%fts_schema%.tokenize(fts_ii."{column}")) AS w,
294285
{rank * batch_size} + row_number() OVER () - 1 AS docid,
@@ -299,14 +290,13 @@ def _sql(con: duckdb.DuckDBPyConnection, query: str) -> duckdb.DuckDBPyRelation:
299290
);
300291
CHECKPOINT;
301292
"""
302-
)
303293
)
304294
for rank in range(num_jobs)
305295
for i, column in enumerate(columns)
306296
]
307297

308298
def _parallel_sql(command: str) -> None:
309-
with duckdb.connect(":memory:") as rank_con:
299+
with duckdb_connect(database=database, extensions_directory=extensions_directory) as rank_con:
310300
_sql(rank_con, command)
311301

312302
thread_map(_parallel_sql, commands, desc="Tokenize")
@@ -324,16 +314,9 @@ def _parallel_sql(command: str) -> None:
324314
# """)
325315
# union_fields_query = " UNION ALL ".join(f"SELECT * FROM tmp.tokenized_{i}" for i in range(len(columns)))
326316

327-
with duckdb.connect(":memory:") as con:
328-
# configure duckdb extensions
329-
if extensions_directory is not None:
330-
con.execute(SET_EXTENSIONS_DIRECTORY_COMMAND.format(directory=extensions_directory))
331-
con.execute(INSTALL_AND_LOAD_EXTENSION_COMMAND)
332-
317+
with duckdb_connect(database=database, extensions_directory=extensions_directory) as con:
333318
# init
334319
_sql(con, f"ATTACH '{tmp_dir}/tmp.duckdb' as tmp;") # nosec - tmp_dir is safe
335-
_sql(con, "ATTACH '%database%' as db;")
336-
_sql(con, "USE db;")
337320
_sql(
338321
con,
339322
";".join(
@@ -526,3 +509,32 @@ def _parallel_sql(command: str) -> None:
526509
""",
527510
)
528511
_sql(con, "CHECKPOINT;")
512+
513+
514+
def varchar_sql(value: str) -> str:
515+
"""escape the value and return the varchar `'{value}'`"""
516+
return "'" + value.replace("'", "''") + "'"
517+
518+
519+
def key_sql(value: str) -> str:
520+
"""escape the value and return the key `"{value}"`"""
521+
return '"' + value.replace('"', '""') + '"'
522+
523+
524+
def duckdb_connect(
525+
index_file_location: Optional[str] = None,
526+
database: Optional[str] = None,
527+
extensions_directory: Optional[str] = None,
528+
read_only: bool = False,
529+
**kwargs: Any,
530+
) -> duckdb.DuckDBPyConnection:
531+
"""In-memory session with the current database attached with read-only and fts extension"""
532+
con = duckdb.connect(":memory:" if index_file_location is None else index_file_location, **kwargs)
533+
if database is not None:
534+
con.execute((ATTACH_READ_ONLY_DATABASE if read_only else ATTACH_DATABASE).format(database=database))
535+
if extensions_directory is not None:
536+
con.execute(SET_EXTENSIONS_DIRECTORY_COMMAND.format(directory=extensions_directory))
537+
con.sql(LOAD_FTS_COMMAND)
538+
con.sql(DISABLE_EXTERNAL_ACCESS_COMMAND)
539+
con.sql(LOCK_CONFIG_COMMAND)
540+
return con

services/search/src/search/duckdb_connection.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

services/search/src/search/routes/filter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,13 @@
3030
get_json_ok_response,
3131
)
3232
from libcommon.constants import ROW_IDX_COLUMN
33+
from libcommon.duckdb_utils import duckdb_connect, key_sql
3334
from libcommon.prometheus import StepProfiler
3435
from libcommon.storage import StrPath, clean_dir
3536
from libcommon.storage_client import StorageClient
3637
from starlette.requests import Request
3738
from starlette.responses import Response
3839

39-
from search.duckdb_connection import duckdb_connect_readonly
40-
4140
FILTER_QUERY = """\
4241
SELECT {columns}
4342
FROM data
@@ -224,9 +223,11 @@ def execute_filter_query(
224223
offset: int,
225224
extensions_directory: Optional[str] = None,
226225
) -> tuple[int, pa.Table]:
227-
with duckdb_connect_readonly(extensions_directory=extensions_directory, database=index_file_location) as con:
226+
with duckdb_connect(
227+
index_file_location=index_file_location, extensions_directory=extensions_directory, read_only=True
228+
) as con:
228229
filter_query = FILTER_QUERY.format(
229-
columns=",".join([f'"{column}"' for column in columns]),
230+
columns=",".join([key_sql(column) for column in columns]),
230231
where=f"WHERE {where}" if where else "",
231232
orderby=f"ORDER BY {orderby}" if orderby else "",
232233
limit=limit,

services/search/src/search/routes/search.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,14 @@
3232
)
3333
from libcommon.constants import HF_FTS_SCORE, MAX_NUM_ROWS_PER_PAGE, ROW_IDX_COLUMN
3434
from libcommon.dtos import PaginatedResponse
35+
from libcommon.duckdb_utils import duckdb_connect, key_sql
3536
from libcommon.prometheus import StepProfiler
3637
from libcommon.storage import StrPath, clean_dir
3738
from libcommon.storage_client import StorageClient
3839
from libcommon.viewer_utils.features import to_features_list
3940
from starlette.requests import Request
4041
from starlette.responses import Response
4142

42-
from search.duckdb_connection import duckdb_connect_readonly
43-
4443
logger = logging.getLogger(__name__)
4544

4645
FTS_STAGE_TABLE_COMMAND = f"SELECT * FROM (SELECT {ROW_IDX_COLUMN}, fts_main_data.match_bm25({ROW_IDX_COLUMN}, ?) AS {HF_FTS_SCORE} FROM data) A WHERE {HF_FTS_SCORE} IS NOT NULL;" # nosec
@@ -57,13 +56,15 @@ def full_text_search(
5756
length: int,
5857
extensions_directory: Optional[str] = None,
5958
) -> tuple[int, pa.Table]:
60-
with duckdb_connect_readonly(extensions_directory=extensions_directory, database=index_file_location) as con:
59+
with duckdb_connect(
60+
extensions_directory=extensions_directory, database=index_file_location, read_only=True
61+
) as con:
6162
fts_stage_table = con.execute(query=FTS_STAGE_TABLE_COMMAND, parameters=[query]).arrow().read_all()
6263
num_rows_total = fts_stage_table.num_rows
6364
logging.info(f"got {num_rows_total=} results for {query=} using {offset=} {length=}")
6465
fts_stage_table = fts_stage_table.sort_by([(HF_FTS_SCORE, "descending")]).slice(offset, length)
6566
join_stage_and_data_query = JOIN_STAGE_AND_DATA_COMMAND.format(
66-
columns=",".join([f'"{column}"' for column in columns]),
67+
columns=",".join([key_sql(column) for column in columns]),
6768
row_idx_column=ROW_IDX_COLUMN,
6869
hf_fts_score=HF_FTS_SCORE,
6970
)

services/search/tests/routes/test_filter.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ def ds() -> Dataset:
4848
def index_file_location(ds: Dataset) -> Generator[str, None, None]:
4949
index_file_location = "index.duckdb"
5050
con = duckdb.connect(index_file_location)
51-
con.execute("INSTALL 'httpfs';")
52-
con.execute("LOAD 'httpfs';")
5351
con.execute("INSTALL 'fts';")
5452
con.execute("LOAD 'fts';")
5553
con.sql("CREATE OR REPLACE SEQUENCE serial START 0 MINVALUE 0;")

services/search/tests/routes/test_search.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ def test_full_text_search(
6363
# simulate index file
6464
index_file_location = "index.duckdb"
6565
con = duckdb.connect(index_file_location)
66-
con.execute("INSTALL 'httpfs';")
67-
con.execute("LOAD 'httpfs';")
6866
con.execute("INSTALL 'fts';")
6967
con.execute("LOAD 'fts';")
7068
con.sql("CREATE OR REPLACE SEQUENCE serial START 0 MINVALUE 0;")

0 commit comments

Comments
 (0)