6
6
import io
7
7
8
8
from sqlalchemy import create_engine , text
9
- from sqlalchemy .engine import Engine
9
+ from sqlalchemy .engine import Engine , Connection
10
10
import pandas as pd
11
11
import sqlglot .expressions as exp
12
12
import sqlglot
@@ -59,30 +59,62 @@ class Connector:
59
59
}
60
60
61
61
def __init__ (self , url : str , view_sql : str , engine_params : Optional [Dict [str , Any ]] = None ) -> "Connector" :
62
- _check_view_sql (view_sql )
63
62
if engine_params is None :
64
63
engine_params = {}
65
64
66
- self .url = url
67
- self .engine = self ._get_engine (engine_params )
65
+ self ._init_instance (self ._get_or_create_engine (url , engine_params ), view_sql )
66
+
67
+ @classmethod
68
+ def from_sqlalchemy_engine (cls , engine : Engine , view_sql : str ) -> "Connector" :
69
+ """Create connector from engine"""
70
+ instance = cls .__new__ (cls )
71
+ instance ._init_instance (engine , view_sql )
72
+ return instance
73
+
74
+ @classmethod
75
+ def from_sqlalchemy_connection (cls , connection : Connection , view_sql : str ) -> "Connector" :
76
+ """
77
+ Create a Connector instance from an existing SQLAlchemy connection.
78
+ This adapts the DuckDB connector.
79
+
80
+ Note:
81
+ - All subsequent queries will use the same connection.
82
+ - The caller is responsible for managing and closing the connection when no longer needed.
83
+ """
84
+ instance = cls .__new__ (cls )
85
+ instance ._init_instance (connection .engine , view_sql )
86
+ instance ._existing_conn = connection
87
+ return instance
88
+
89
+ def _init_instance (self , engine : Engine , view_sql : str ):
90
+ _check_view_sql (view_sql )
91
+ self .engine = engine
92
+ self .url = str (engine .url )
68
93
self .view_sql = view_sql
69
94
self ._json_type_code_set = self .JSON_TYPE_CODE_SET_MAP .get (self .dialect_name , set ())
95
+ self ._existing_conn = None
96
+ self ._run_pre_init_sql (engine )
70
97
71
- def _get_engine (self , engine_params : Dict [str , Any ]) -> Engine :
72
- if self . url not in self .engine_map :
73
- engine = create_engine (self . url , ** engine_params )
98
+ def _get_or_create_engine (self , url : str , engine_params : Dict [str , Any ]) -> Engine :
99
+ if url not in self .engine_map :
100
+ engine = create_engine (url , ** engine_params )
74
101
engine .dialect .requires_name_normalize = False
75
- self .engine_map [self .url ] = engine
76
- if engine .dialect .name in self .PRE_INIT_SQL_MAP :
77
- pre_init_sql = self .PRE_INIT_SQL_MAP [engine .dialect .name ]
78
- with engine .connect (True ) as connection :
79
- connection .execute (text (pre_init_sql ))
102
+ self .engine_map [url ] = engine
103
+
104
+ return self .engine_map [url ]
80
105
81
- return self .engine_map [self .url ]
106
+ def _run_pre_init_sql (self , engine : Engine ) -> None :
107
+ if engine .dialect .name in self .PRE_INIT_SQL_MAP :
108
+ pre_init_sql = self .PRE_INIT_SQL_MAP [engine .dialect .name ]
109
+ with engine .connect (True ) as connection :
110
+ connection .execute (text (pre_init_sql ))
82
111
83
112
def query_datas (self , sql : str ) -> List [Dict [str , Any ]]:
84
113
field_type_map = {}
85
- with self .engine .connect () as connection :
114
+ should_close_connection = self ._existing_conn is None
115
+ connection = self ._existing_conn or self .engine .connect ()
116
+
117
+ try :
86
118
result = connection .execute (text (sql ))
87
119
if self .dialect_name in self .JSON_TYPE_CODE_SET_MAP :
88
120
field_type_map = {
@@ -96,6 +128,9 @@ def query_datas(self, sql: str) -> List[Dict[str, Any]]:
96
128
}
97
129
for item in result .mappings ()
98
130
]
131
+ finally :
132
+ if should_close_connection :
133
+ connection .close ()
99
134
100
135
@property
101
136
def dialect_name (self ) -> str :
0 commit comments