Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions src/duckberg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,26 @@ def __init__(
db_thread_limit: Optional[int] = DEFAULT_DB_THREAD_LIMIT,
db_mem_limit: Optional[str] = DEFAULT_MEM_LIMIT,
batch_size_rows: Optional[int] = BATCH_SIZE_ROWS,
single_db_per_request: bool = True
):
self.db_thread_limit = db_thread_limit
self.db_mem_limit = db_mem_limit
self.batch_size_rows = batch_size_rows
self.duckdb_connection = duckdb_connection
self.single_db_per_request = single_db_per_request

if self.duckdb_connection == None:
if self.duckdb_connection == None and not single_db_per_request:
self.duckdb_connection = duckdb.connect()
self.init_duckdb()
self.init_duckdb(self.duckdb_connection)

self.sql_parser = DuckBergSQLParser()
self.tables: dict[str, DuckBergTable] = {}

self.init_duckdb()
self.__get_tables(catalog_config, catalog_name)

def init_duckdb(self):
self.duckdb_connection.execute(f"SET memory_limit='{self.db_mem_limit}'")
self.duckdb_connection.execute(f"SET threads TO {self.db_thread_limit}")
def init_duckdb(self, db = None):
db.execute(f"SET memory_limit='{self.db_mem_limit}'")
db.execute(f"SET threads TO {self.db_thread_limit}")

def __get_tables(self, catalog_config, catalog_name):
tables = {}
Expand Down Expand Up @@ -92,18 +93,28 @@ def select(
row_filter = table.comparisons

table_data_scan_as_arrow = self.tables[table_name].scan(row_filter=row_filter).to_arrow()
self.duckdb_connection.register(table_name, table_data_scan_as_arrow)

db = self.duckdb_connection
if self.single_db_per_request:
db = duckdb.connect()
self.init_duckdb(db)
db.register(table_name, table_data_scan_as_arrow)

if sql_params is None:
return self.duckdb_connection.execute(sql).fetch_record_batch(self.batch_size_rows)
return db.execute(sql).fetch_record_batch(self.batch_size_rows)
else:
return self.duckdb_connection.execute(sql, parameters=sql_params).fetch_record_batch(self.batch_size_rows)
return db.execute(sql, parameters=sql_params).fetch_record_batch(self.batch_size_rows)

def _select_old(self, sql: str, table: str, partition_filter: str, sql_params: [str] = None):
table_data_scan_as_arrow = self.tables[table].scan(row_filter=partition_filter).to_arrow()
self.duckdb_connection.register(table, table_data_scan_as_arrow)

db = self.duckdb_connection
if self.single_db_per_request:
db = duckdb.connect()
self.init_duckdb(db)
db.register(table, table_data_scan_as_arrow)

if sql_params is None:
return self.duckdb_connection.execute(sql).fetch_record_batch(self.batch_size_rows)
return db.execute(sql).fetch_record_batch(self.batch_size_rows)
else:
return self.duckdb_connection.execute(sql, parameters=sql_params).fetch_record_batch(self.batch_size_rows)
return db.execute(sql, parameters=sql_params).fetch_record_batch(self.batch_size_rows)
11 changes: 11 additions & 0 deletions tests/test_duckberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,14 @@ def test_select_3(get_duckberg):
query: str = "SELECT count(*) FROM (SELECT * FROM 'nyc.taxis' WHERE payment_type = 1 AND trip_distance > 40 ORDER BY tolls_amount DESC)"
df = get_duckberg.select(sql=query, table="nyc.taxis", partition_filter="payment_type = 1").read_pandas()
assert df["count_star()"][0] == 1673


def test_select_multiple_one_request(get_duckberg):
# Chceck running multiple queries at once on single duckberg
query: str = "SELECT count(*) FROM (SELECT * FROM 'nyc.taxis' WHERE trip_distance > 40 ORDER BY tolls_amount DESC)"
dbconn1 = get_duckberg.select(sql=query)
query: str = "SELECT count(*) FROM (SELECT * FROM 'nyc.taxis' WHERE payment_type = 1 AND trip_distance > 40 ORDER BY tolls_amount DESC)"
dbconn2 = get_duckberg.select(sql=query)

assert dbconn1.read_pandas()["count_star()"][0] == 2614
assert dbconn2.read_pandas()["count_star()"][0] == 1673