Skip to content

Commit 109ebf6

Browse files
committed
parse connect_args from DATABASE_URL env var
1 parent ecd91c1 commit 109ebf6

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

src/sql/connection.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from sqlalchemy.exc import NoSuchModuleError
77
from IPython.core.error import UsageError
88
import difflib
9+
import urllib.parse
10+
import ast
911

1012
PLOOMBER_SUPPORT_LINK_STR = (
1113
"For technical support: https://ploomber.io/community"
@@ -278,12 +280,23 @@ def set(cls, descriptor, displaycon, connect_args=None, creator=None, alias=None
278280
# display list of connections
279281
print(cls.connection_list())
280282
elif os.getenv("DATABASE_URL"):
281-
cls.current = Connection.from_connect_str(
282-
connect_str=os.getenv("DATABASE_URL"),
283-
connect_args=connect_args,
284-
creator=creator,
285-
alias=alias,
286-
)
283+
# try and extract connect_args from DATABASE_URL
284+
if not connect_args and '&' in os.getenv("DATABASE_URL"):
285+
cls.current = Connection.from_connect_str(
286+
connect_str=os.getenv("DATABASE_URL").split('&')[0],
287+
connect_args=ast.literal_eval(
288+
'{' + urllib.parse.unquote_plus(os.getenv('DATABASE_URL').split('&')[1]).replace(
289+
'requests_kwargs=','"requests_kwargs":') + '}'),
290+
creator=creator,
291+
alias=alias,
292+
)
293+
else:
294+
cls.current = Connection.from_connect_str(
295+
connect_str=os.getenv("DATABASE_URL"),
296+
connect_args=connect_args,
297+
creator=creator,
298+
alias=alias,
299+
)
287300
else:
288301
raise cls._error_no_connection()
289302

src/sql/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ def _commit(conn, config):
385385

386386
def run(conn, sql, config, user_namespace):
387387
if sql.strip():
388+
conn.dialect.supports_statement_cache = False # disable SQL compilation caching warning
388389
for statement in sqlparse.split(sql):
389390
first_word = sql.strip().split()[0].lower()
390391
if first_word == "begin":

0 commit comments

Comments
 (0)