Skip to content

Commit 0824ead

Browse files
committed
feat: add new init function for Connector
1 parent dbc23e6 commit 0824ead

File tree

3 files changed

+86
-14
lines changed

3 files changed

+86
-14
lines changed

.github/workflows/auto-ci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ jobs:
8888
- name: Try Pytest
8989
working-directory: ./
9090
run: |
91+
pip install duckdb_engine
9192
pip install pytest
9293
pytest tests
9394
- name: Uploading notebooks

pygwalker/data_parsers/database_parser.py

+49-14
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import io
77

88
from sqlalchemy import create_engine, text
9-
from sqlalchemy.engine import Engine
9+
from sqlalchemy.engine import Engine, Connection
1010
import pandas as pd
1111
import sqlglot.expressions as exp
1212
import sqlglot
@@ -59,30 +59,62 @@ class Connector:
5959
}
6060

6161
def __init__(self, url: str, view_sql: str, engine_params: Optional[Dict[str, Any]] = None) -> "Connector":
62-
_check_view_sql(view_sql)
6362
if engine_params is None:
6463
engine_params = {}
6564

66-
self.url = url
67-
self.engine = self._get_engine(engine_params)
65+
self._init_instance(self._get_or_create_engine(url, engine_params), view_sql)
66+
67+
@classmethod
68+
def from_sqlalchemy_engine(cls, engine: Engine, view_sql: str) -> "Connector":
69+
"""Create connector from engine"""
70+
instance = cls.__new__(cls)
71+
instance._init_instance(engine, view_sql)
72+
return instance
73+
74+
@classmethod
75+
def from_sqlalchemy_connection(cls, connection: Connection, view_sql: str) -> "Connector":
76+
"""
77+
Create a Connector instance from an existing SQLAlchemy connection.
78+
This adapts the DuckDB connector.
79+
80+
Note:
81+
- All subsequent queries will use the same connection.
82+
- The caller is responsible for managing and closing the connection when no longer needed.
83+
"""
84+
instance = cls.__new__(cls)
85+
instance._init_instance(connection.engine, view_sql)
86+
instance._existing_conn = connection
87+
return instance
88+
89+
def _init_instance(self, engine: Engine, view_sql: str):
90+
_check_view_sql(view_sql)
91+
self.engine = engine
92+
self.url = str(engine.url)
6893
self.view_sql = view_sql
6994
self._json_type_code_set = self.JSON_TYPE_CODE_SET_MAP.get(self.dialect_name, set())
95+
self._existing_conn = None
96+
self._run_pre_init_sql(engine)
7097

71-
def _get_engine(self, engine_params: Dict[str, Any]) -> Engine:
72-
if self.url not in self.engine_map:
73-
engine = create_engine(self.url, **engine_params)
98+
def _get_or_create_engine(self, url: str, engine_params: Dict[str, Any]) -> Engine:
99+
if url not in self.engine_map:
100+
engine = create_engine(url, **engine_params)
74101
engine.dialect.requires_name_normalize = False
75-
self.engine_map[self.url] = engine
76-
if engine.dialect.name in self.PRE_INIT_SQL_MAP:
77-
pre_init_sql = self.PRE_INIT_SQL_MAP[engine.dialect.name]
78-
with engine.connect(True) as connection:
79-
connection.execute(text(pre_init_sql))
102+
self.engine_map[url] = engine
103+
104+
return self.engine_map[url]
80105

81-
return self.engine_map[self.url]
106+
def _run_pre_init_sql(self, engine: Engine) -> None:
107+
if engine.dialect.name in self.PRE_INIT_SQL_MAP:
108+
pre_init_sql = self.PRE_INIT_SQL_MAP[engine.dialect.name]
109+
with engine.connect(True) as connection:
110+
connection.execute(text(pre_init_sql))
82111

83112
def query_datas(self, sql: str) -> List[Dict[str, Any]]:
84113
field_type_map = {}
85-
with self.engine.connect() as connection:
114+
should_close_connection = self._existing_conn is None
115+
connection = self._existing_conn or self.engine.connect()
116+
117+
try:
86118
result = connection.execute(text(sql))
87119
if self.dialect_name in self.JSON_TYPE_CODE_SET_MAP:
88120
field_type_map = {
@@ -96,6 +128,9 @@ def query_datas(self, sql: str) -> List[Dict[str, Any]]:
96128
}
97129
for item in result.mappings()
98130
]
131+
finally:
132+
if should_close_connection:
133+
connection.close()
99134

100135
@property
101136
def dialect_name(self) -> str:

tests/test_data_parsers.py

+36
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
import os.path
2+
3+
from sqlalchemy import create_engine
14
import pandas as pd
25
import polars as pl
36
import pytest
47

58
from pygwalker.services.data_parsers import get_parser
9+
from pygwalker.data_parsers.database_parser import Connector, text
610
from pygwalker.data_parsers.database_parser import _check_view_sql
711
from pygwalker.errors import ViewSqlSameColumnError
812

@@ -72,3 +76,35 @@ def test_check_view_sql():
7276
_check_view_sql("SELECT * FROM a left join b on a.id = b.id")
7377
with pytest.raises(ViewSqlSameColumnError):
7478
_check_view_sql("SELECT a.* FROM a left join b on a.id = b.id")
79+
80+
81+
def test_connector():
82+
csv_file = os.path.join(os.path.dirname(__file__), "bike_sharing_dc.csv")
83+
database_url = "duckdb:///:memory:"
84+
view_sql = f"SELECT 1"
85+
data_count = 17379
86+
87+
connector = Connector(database_url, view_sql)
88+
result = connector.query_datas(f"SELECT COUNT(1) count FROM read_csv_auto('{csv_file}')")
89+
assert result[0]["count"] == data_count
90+
assert connector.dialect_name == "duckdb"
91+
assert connector.view_sql == view_sql
92+
assert connector.url == database_url
93+
94+
engine = create_engine(database_url)
95+
connector = Connector.from_sqlalchemy_engine(engine, view_sql)
96+
result = connector.query_datas(f"SELECT COUNT(1) count FROM read_csv_auto('{csv_file}')")
97+
assert result[0]["count"] == data_count
98+
assert connector.dialect_name == "duckdb"
99+
assert connector.view_sql == view_sql
100+
assert connector.url == database_url
101+
102+
engine = create_engine(database_url)
103+
with engine.connect() as conn:
104+
conn.execute(text(f"CREATE TABLE test_datas AS SELECT * FROM read_csv_auto('{csv_file}')"))
105+
connector = Connector.from_sqlalchemy_connection(conn, view_sql)
106+
result = connector.query_datas(f"SELECT COUNT(1) count FROM test_datas")
107+
assert result[0]["count"] == data_count
108+
assert connector.dialect_name == "duckdb"
109+
assert connector.view_sql == view_sql
110+
assert connector.url == database_url

0 commit comments

Comments
 (0)