Skip to content

Commit ae41296

Browse files
committed
Updating synapse connection manager to allow sql auth mode in v.1.8.1 and above
1 parent eef6fda commit ae41296

File tree

7 files changed

+294
-8
lines changed

7 files changed

+294
-8
lines changed

dbt/adapters/synapse/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
version = "1.8.0"
1+
version = "1.8.1"
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,291 @@
1-
from dbt.adapters.fabric import FabricConnectionManager
1+
import struct
2+
import time
3+
from itertools import chain, repeat
4+
from typing import Callable, Dict, Mapping, Optional
5+
6+
import pyodbc
7+
from azure.core.credentials import AccessToken
8+
from azure.identity import AzureCliCredential, DefaultAzureCredential, EnvironmentCredential
9+
from dbt.adapters.contracts.connection import Connection, ConnectionState
10+
from dbt.adapters.events.logging import AdapterLogger
11+
from dbt.adapters.fabric import FabricConnectionManager, __version__
12+
from dbt.adapters.fabric.fabric_credentials import FabricCredentials
13+
14+
AZURE_CREDENTIAL_SCOPE = "https://database.windows.net//.default"
15+
_TOKEN: Optional[AccessToken] = None
16+
AZURE_AUTH_FUNCTION_TYPE = Callable[[FabricCredentials], AccessToken]
17+
18+
logger = AdapterLogger("fabric")
19+
20+
21+
def convert_bytes_to_mswindows_byte_string(value: bytes) -> bytes:
22+
"""
23+
Convert bytes to a Microsoft windows byte string.
24+
25+
Parameters
26+
----------
27+
value : bytes
28+
The bytes.
29+
30+
Returns
31+
-------
32+
out : bytes
33+
The Microsoft byte string.
34+
"""
35+
encoded_bytes = bytes(chain.from_iterable(zip(value, repeat(0))))
36+
return struct.pack("<i", len(encoded_bytes)) + encoded_bytes
37+
38+
39+
def convert_access_token_to_mswindows_byte_string(token: AccessToken) -> bytes:
40+
"""
41+
Convert an access token to a Microsoft windows byte string.
42+
43+
Parameters
44+
----------
45+
token : AccessToken
46+
The token.
47+
48+
Returns
49+
-------
50+
out : bytes
51+
The Microsoft byte string.
52+
"""
53+
value = bytes(token.token, "UTF-8")
54+
return convert_bytes_to_mswindows_byte_string(value)
55+
56+
57+
def get_cli_access_token(credentials: FabricCredentials) -> AccessToken:
58+
"""
59+
Get an Azure access token using the CLI credentials
60+
61+
First login with:
62+
63+
```bash
64+
az login
65+
```
66+
67+
Parameters
68+
----------
69+
credentials: FabricConnectionManager
70+
The credentials.
71+
72+
Returns
73+
-------
74+
out : AccessToken
75+
Access token.
76+
"""
77+
_ = credentials
78+
token = AzureCliCredential().get_token(AZURE_CREDENTIAL_SCOPE)
79+
return token
80+
81+
82+
def get_auto_access_token(credentials: FabricCredentials) -> AccessToken:
83+
"""
84+
Get an Azure access token automatically through azure-identity
85+
86+
Parameters
87+
-----------
88+
credentials: FabricCredentials
89+
Credentials.
90+
91+
Returns
92+
-------
93+
out : AccessToken
94+
The access token.
95+
"""
96+
token = DefaultAzureCredential().get_token(AZURE_CREDENTIAL_SCOPE)
97+
return token
98+
99+
100+
def get_environment_access_token(credentials: FabricCredentials) -> AccessToken:
101+
"""
102+
Get an Azure access token by reading environment variables
103+
104+
Parameters
105+
-----------
106+
credentials: FabricCredentials
107+
Credentials.
108+
109+
Returns
110+
-------
111+
out : AccessToken
112+
The access token.
113+
"""
114+
token = EnvironmentCredential().get_token(AZURE_CREDENTIAL_SCOPE)
115+
return token
116+
117+
118+
AZURE_AUTH_FUNCTIONS: Mapping[str, AZURE_AUTH_FUNCTION_TYPE] = {
119+
"cli": get_cli_access_token,
120+
"auto": get_auto_access_token,
121+
"environment": get_environment_access_token,
122+
}
123+
124+
125+
def get_pyodbc_attrs_before(credentials: FabricCredentials) -> Dict:
126+
"""
127+
Get the pyodbc attrs before.
128+
129+
Parameters
130+
----------
131+
credentials : FabricCredentials
132+
Credentials.
133+
134+
Returns
135+
-------
136+
out : Dict
137+
The pyodbc attrs before.
138+
139+
Source
140+
------
141+
Authentication for SQL server with an access token:
142+
https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory?view=sql-server-ver15#authenticating-with-an-access-token
143+
"""
144+
global _TOKEN
145+
attrs_before: Dict
146+
MAX_REMAINING_TIME = 300
147+
148+
authentication = str(credentials.authentication).lower()
149+
if authentication in AZURE_AUTH_FUNCTIONS:
150+
time_remaining = (_TOKEN.expires_on - time.time()) if _TOKEN else MAX_REMAINING_TIME
151+
152+
if _TOKEN is None or (time_remaining < MAX_REMAINING_TIME):
153+
azure_auth_function = AZURE_AUTH_FUNCTIONS[authentication]
154+
_TOKEN = azure_auth_function(credentials)
155+
156+
token_bytes = convert_access_token_to_mswindows_byte_string(_TOKEN)
157+
sql_copt_ss_access_token = 1256 # see source in docstring
158+
attrs_before = {sql_copt_ss_access_token: token_bytes}
159+
else:
160+
attrs_before = {}
161+
162+
return attrs_before
163+
164+
165+
def bool_to_connection_string_arg(key: str, value: bool) -> str:
166+
"""
167+
Convert a boolean to a connection string argument.
168+
169+
Parameters
170+
----------
171+
key : str
172+
The key to use in the connection string.
173+
value : bool
174+
The boolean to convert.
175+
176+
Returns
177+
-------
178+
out : str
179+
The connection string argument.
180+
"""
181+
return f'{key}={"Yes" if value else "No"}'
2182

3183

4184
class SynapseConnectionManager(FabricConnectionManager):
5185
TYPE = "synapse"
6186
TOKEN = None
187+
188+
@classmethod
189+
def open(cls, connection: Connection) -> Connection:
190+
if connection.state == ConnectionState.OPEN:
191+
logger.debug("Connection is already open, skipping open.")
192+
return connection
193+
194+
credentials = cls.get_credentials(connection.credentials)
195+
196+
con_str = [f"DRIVER={{{credentials.driver}}}"]
197+
198+
if "\\" in credentials.host:
199+
# If there is a backslash \ in the host name, the host is a
200+
# SQL Server named instance. In this case then port number has to be omitted.
201+
con_str.append(f"SERVER={credentials.host}")
202+
else:
203+
con_str.append(f"SERVER={credentials.host}")
204+
205+
con_str.append(f"Database={credentials.database}")
206+
207+
assert credentials.authentication is not None
208+
209+
if "ActiveDirectory" in credentials.authentication:
210+
con_str.append(f"Authentication={credentials.authentication}")
211+
212+
if credentials.authentication == "ActiveDirectoryPassword":
213+
con_str.append(f"UID={{{credentials.UID}}}")
214+
con_str.append(f"PWD={{{credentials.PWD}}}")
215+
if credentials.authentication == "ActiveDirectoryServicePrincipal":
216+
con_str.append(f"UID={{{credentials.client_id}}}")
217+
con_str.append(f"PWD={{{credentials.client_secret}}}")
218+
elif credentials.authentication == "ActiveDirectoryInteractive":
219+
con_str.append(f"UID={{{credentials.UID}}}")
220+
221+
elif credentials.windows_login:
222+
con_str.append("trusted_connection=Yes")
223+
elif credentials.authentication == "sql":
224+
con_str.append(f"UID={{{credentials.UID}}}")
225+
con_str.append(f"PWD={{{credentials.PWD}}}")
226+
227+
# https://docs.microsoft.com/en-us/sql/relational-databases/native-client/features/using-encryption-without-validation?view=sql-server-ver15
228+
assert credentials.encrypt is not None
229+
assert credentials.trust_cert is not None
230+
231+
con_str.append(bool_to_connection_string_arg("encrypt", credentials.encrypt))
232+
con_str.append(
233+
bool_to_connection_string_arg("TrustServerCertificate", credentials.trust_cert)
234+
)
235+
236+
plugin_version = __version__.version
237+
application_name = f"dbt-{credentials.type}/{plugin_version}"
238+
con_str.append(f"APP={application_name}")
239+
240+
try:
241+
if int(credentials.retries) > 0:
242+
con_str.append(f"ConnectRetryCount={credentials.retries}")
243+
244+
except Exception as e:
245+
logger.debug(
246+
"Retry count should be integer value. Skipping retries in the connection string.",
247+
str(e),
248+
)
249+
250+
con_str_concat = ";".join(con_str)
251+
252+
index = []
253+
for i, elem in enumerate(con_str):
254+
if "pwd=" in elem.lower():
255+
index.append(i)
256+
257+
if len(index) != 0:
258+
con_str[index[0]] = "PWD=***"
259+
260+
con_str_display = ";".join(con_str)
261+
262+
retryable_exceptions = [ # https://github.com/mkleehammer/pyodbc/wiki/Exceptions
263+
pyodbc.InternalError, # not used according to docs, but defined in PEP-249
264+
pyodbc.OperationalError,
265+
]
266+
267+
if credentials.authentication.lower() in AZURE_AUTH_FUNCTIONS:
268+
# Temporary login/token errors fall into this category when using AAD
269+
retryable_exceptions.append(pyodbc.InterfaceError)
270+
271+
def connect():
272+
logger.debug(f"Using connection string: {con_str_display}")
273+
274+
attrs_before = get_pyodbc_attrs_before(credentials)
275+
handle = pyodbc.connect(
276+
con_str_concat,
277+
attrs_before=attrs_before,
278+
autocommit=True,
279+
timeout=credentials.login_timeout,
280+
)
281+
handle.timeout = credentials.query_timeout
282+
logger.debug(f"Connected to db: {credentials.database}")
283+
return handle
284+
285+
return cls.retry_connection(
286+
connection,
287+
connect=connect,
288+
logger=logger,
289+
retry_limit=credentials.retries,
290+
retryable_exceptions=retryable_exceptions,
291+
)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def run(self):
7070
long_description_content_type="text/markdown",
7171
license="MIT",
7272
author=", ".join(authors_list),
73-
url="https://github.com/dbt-msft/dbt-synapse",
73+
url="https://github.com/microsoft/dbt-synapse",
7474
packages=find_namespace_packages(include=["dbt", "dbt.*"]),
7575
include_package_data=True,
7676
install_requires=[dbt_fabric_requirement],

tests/conftest.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def pytest_addoption(parser):
10-
parser.addoption("--profile", action="store", default="user_azure", type=str)
10+
parser.addoption("--profile", action="store", default="user", type=str)
1111

1212

1313
@pytest.fixture(scope="class")
@@ -54,6 +54,7 @@ def _profile_user():
5454
"user": os.getenv("SYNAPSE_TEST_USER"),
5555
"pass": os.getenv("SYNAPSE_TEST_PASS"),
5656
"database": os.getenv("SYNAPSE_TEST_DWH_NAME"),
57+
"authentication": "sql",
5758
},
5859
}
5960

tests/functional/adapter/test_docs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class TestDocsGenerateSynapse(BaseDocsGenerate):
1919
def expected_catalog(self, project):
2020
return base_expected_catalog(
2121
project,
22-
role="dbo",
22+
role="dbttestuser",
2323
id_type="int",
2424
text_type="varchar",
2525
time_type="datetime2",
@@ -34,7 +34,7 @@ class TestDocsGenReferencesSynapse(BaseDocsGenReferences):
3434
def expected_catalog(self, project):
3535
return expected_references_catalog(
3636
project,
37-
role="dbo",
37+
role="dbttestuser",
3838
id_type="int",
3939
text_type="varchar",
4040
time_type="datetime2",

tests/functional/adapter/test_model_hooks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def check_hooks(self, state, project, host, count=1):
3030
assert ctx["target_schema"] == project.test_schema
3131
assert ctx["target_threads"] == 1
3232
assert ctx["target_type"] == "synapse"
33-
assert ctx["target_user"] == "None"
33+
# assert ctx["target_user"] == "dbttestuser"
3434
assert ctx["target_pass"] == ""
3535

3636
assert (

tests/functional/adapter/test_run_hooks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def check_hooks(self, state, project, host):
6464
assert ctx["target_schema"] == project.test_schema
6565
assert ctx["target_threads"] == 1
6666
assert ctx["target_type"] == "synapse"
67-
assert ctx["target_user"] == "None"
67+
# assert ctx["target_user"] == "None"
6868
assert ctx["target_pass"] == ""
6969

7070
assert (

0 commit comments

Comments
 (0)