Skip to content

Commit 9fffd1e

Browse files
feat: Enhance MSSQL connection configuration with ODBC options and improve bulk copy handling
1 parent e2e7a06 commit 9fffd1e

File tree

2 files changed

+98
-3
lines changed

2 files changed

+98
-3
lines changed

sqlmesh/core/config/connection.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,16 @@ class MSSQLConnectionConfig(ConnectionConfig):
14221422
autocommit: t.Optional[bool] = False
14231423
tds_version: t.Optional[str] = None
14241424

1425+
# Driver options
1426+
driver: t.Literal["pymssql", "pyodbc"] = "pymssql"
1427+
# PyODBC specific options
1428+
driver_name: t.Optional[str] = None # e.g. "ODBC Driver 18 for SQL Server"
1429+
trust_server_certificate: t.Optional[bool] = None
1430+
encrypt: t.Optional[bool] = None
1431+
# Dictionary of arbitrary ODBC connection properties
1432+
# See: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute
1433+
odbc_properties: t.Optional[t.Dict[str, t.Any]] = None
1434+
14251435
concurrent_tasks: int = 4
14261436
register_comments: bool = True
14271437
pre_ping: bool = True
@@ -1432,7 +1442,7 @@ class MSSQLConnectionConfig(ConnectionConfig):
14321442

14331443
@property
14341444
def _connection_kwargs_keys(self) -> t.Set[str]:
1435-
return {
1445+
base_keys = {
14361446
"host",
14371447
"user",
14381448
"password",
@@ -1447,15 +1457,96 @@ def _connection_kwargs_keys(self) -> t.Set[str]:
14471457
"tds_version",
14481458
}
14491459

1460+
if self.driver == "pyodbc":
1461+
base_keys.update(
1462+
{
1463+
"driver_name",
1464+
"trust_server_certificate",
1465+
"encrypt",
1466+
"odbc_properties",
1467+
}
1468+
)
1469+
# Remove pymssql-specific parameters
1470+
base_keys.discard("tds_version")
1471+
base_keys.discard("conn_properties")
1472+
1473+
return base_keys
1474+
14501475
@property
14511476
def _engine_adapter(self) -> t.Type[EngineAdapter]:
14521477
return engine_adapter.MSSQLEngineAdapter
14531478

14541479
@property
14551480
def _connection_factory(self) -> t.Callable:
1456-
import pymssql
1481+
if self.driver == "pymssql":
1482+
import pymssql
1483+
1484+
return pymssql.connect
1485+
1486+
import pyodbc
1487+
1488+
def connect(**kwargs: t.Any) -> t.Callable:
1489+
# Extract parameters for connection string
1490+
host = kwargs.pop("host")
1491+
port = kwargs.pop("port", 1433)
1492+
database = kwargs.pop("database", "")
1493+
user = kwargs.pop("user", None)
1494+
password = kwargs.pop("password", None)
1495+
driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server")
1496+
trust_server_certificate = kwargs.pop("trust_server_certificate", False)
1497+
encrypt = kwargs.pop("encrypt", True)
1498+
login_timeout = kwargs.pop("login_timeout", 60)
1499+
1500+
# Build connection string
1501+
conn_str_parts = [
1502+
f"DRIVER={{{driver_name}}}",
1503+
f"SERVER={host},{port}",
1504+
]
1505+
1506+
if database:
1507+
conn_str_parts.append(f"DATABASE={database}")
1508+
1509+
# Add security options
1510+
conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}")
1511+
if trust_server_certificate:
1512+
conn_str_parts.append("TrustServerCertificate=YES")
1513+
1514+
conn_str_parts.append(f"Connection Timeout={login_timeout}")
1515+
1516+
# Standard SQL Server authentication
1517+
if user:
1518+
conn_str_parts.append(f"UID={user}")
1519+
if password:
1520+
conn_str_parts.append(f"PWD={password}")
1521+
1522+
# Add any additional ODBC properties from the odbc_properties dictionary
1523+
if self.odbc_properties:
1524+
for key, value in self.odbc_properties.items():
1525+
# Skip properties that we've already set above
1526+
if key.lower() in (
1527+
"driver",
1528+
"server",
1529+
"database",
1530+
"uid",
1531+
"pwd",
1532+
"encrypt",
1533+
"trustservercertificate",
1534+
"connection timeout",
1535+
):
1536+
continue
14571537

1458-
return pymssql.connect
1538+
# Handle boolean values properly
1539+
if isinstance(value, bool):
1540+
conn_str_parts.append(f"{key}={'YES' if value else 'NO'}")
1541+
else:
1542+
conn_str_parts.append(f"{key}={value}")
1543+
1544+
# Create the connection string
1545+
conn_str = ";".join(conn_str_parts)
1546+
1547+
return pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False))
1548+
1549+
return connect
14591550

14601551
@property
14611552
def _extra_engine_config(self) -> t.Dict[str, t.Any]:

sqlmesh/core/engine_adapter/mssql.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ def _df_to_source_queries(
219219
assert isinstance(df, pd.DataFrame)
220220
temp_table = self._get_temp_table(target_table or "pandas")
221221

222+
# Return the superclass implementation if the connection pool doesn't support bulk_copy
223+
if not hasattr(self._connection_pool.get(), "bulk_copy"):
224+
return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table)
225+
222226
def query_factory() -> Query:
223227
# It is possible for the factory to be called multiple times and if so then the temp table will already
224228
# be created so we skip creating again. This means we are assuming the first call is the same result

0 commit comments

Comments
 (0)