diff --git a/mssql_python/__init__.py b/mssql_python/__init__.py index 875150c5..e69c41e6 100644 --- a/mssql_python/__init__.py +++ b/mssql_python/__init__.py @@ -147,6 +147,28 @@ def getDecimalSeparator(): SQL_WCHAR = ConstantsDDBC.SQL_WCHAR.value SQL_WMETADATA = -99 +# Export connection attribute constants for set_attr() +# Only include driver-level attributes that the SQL Server ODBC driver can handle directly + +# Core driver-level attributes +SQL_ATTR_ACCESS_MODE = ConstantsDDBC.SQL_ATTR_ACCESS_MODE.value +SQL_ATTR_CONNECTION_TIMEOUT = ConstantsDDBC.SQL_ATTR_CONNECTION_TIMEOUT.value +SQL_ATTR_CURRENT_CATALOG = ConstantsDDBC.SQL_ATTR_CURRENT_CATALOG.value +SQL_ATTR_LOGIN_TIMEOUT = ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value +SQL_ATTR_PACKET_SIZE = ConstantsDDBC.SQL_ATTR_PACKET_SIZE.value +SQL_ATTR_TXN_ISOLATION = ConstantsDDBC.SQL_ATTR_TXN_ISOLATION.value + +# Transaction Isolation Level Constants +SQL_TXN_READ_UNCOMMITTED = ConstantsDDBC.SQL_TXN_READ_UNCOMMITTED.value +SQL_TXN_READ_COMMITTED = ConstantsDDBC.SQL_TXN_READ_COMMITTED.value +SQL_TXN_REPEATABLE_READ = ConstantsDDBC.SQL_TXN_REPEATABLE_READ.value +SQL_TXN_SERIALIZABLE = ConstantsDDBC.SQL_TXN_SERIALIZABLE.value + +# Access Mode Constants +SQL_MODE_READ_WRITE = ConstantsDDBC.SQL_MODE_READ_WRITE.value +SQL_MODE_READ_ONLY = ConstantsDDBC.SQL_MODE_READ_ONLY.value + + from .pooling import PoolingManager def pooling(max_size=100, idle_timeout=600, enabled=True): # """ diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 832d2aac..48ed44f1 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -16,7 +16,7 @@ from typing import Any import threading from mssql_python.cursor import Cursor -from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, sanitize_user_input, log +from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, sanitize_user_input, log, validate_attribute_value from mssql_python import ddbc_bindings from mssql_python.pooling import PoolingManager from mssql_python.exceptions import InterfaceError, ProgrammingError @@ -109,6 +109,7 @@ class Connection: setencoding(encoding=None, ctype=None) -> None: setdecoding(sqltype, encoding=None, ctype=None) -> None: getdecoding(sqltype) -> dict: + set_attr(attribute, value) -> None: """ # DB-API 2.0 Exception attributes @@ -129,10 +130,16 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef Initialize the connection object with the specified connection string and parameters. Args: - - connection_str (str): The connection string to connect to. - - autocommit (bool): If True, causes a commit to be performed after each SQL statement. + connection_str (str): The connection string to connect to. + autocommit (bool): If True, causes a commit to be performed after each SQL statement. + attrs_before (dict, optional): Dictionary of connection attributes to set before + connection establishment. Keys are SQL_ATTR_* constants, + and values are their corresponding settings. + Use this for attributes that must be set before connecting, + such as SQL_ATTR_LOGIN_TIMEOUT, SQL_ATTR_ODBC_CURSORS, + and SQL_ATTR_PACKET_SIZE. + timeout (int): Login timeout in seconds. 0 means no timeout. **kwargs: Additional key/value pairs for the connection string. - Not including below properties since we are driver doesn't support this: Returns: None @@ -143,6 +150,12 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef This method sets up the initial state for the connection object, preparing it for further operations such as connecting to the database, executing queries, etc. + + Example: + >>> # Setting login timeout using attrs_before + >>> import mssql_python as ms + >>> conn = ms.connect("Server=myserver;Database=mydb", + ... attrs_before={ms.SQL_ATTR_LOGIN_TIMEOUT: 30}) """ self.connection_str = self._construct_connection_string( connection_str, **kwargs @@ -546,6 +559,71 @@ def getdecoding(self, sqltype): ) return self._decoding_settings[sqltype].copy() + + def set_attr(self, attribute, value): + """ + Set a connection attribute. + + This method sets a connection attribute using SQLSetConnectAttr. + It provides pyodbc-compatible functionality for configuring connection + behavior such as autocommit mode, transaction isolation level, and + connection timeouts. + + Args: + attribute (int): The connection attribute to set. Should be one of the + SQL_ATTR_* constants (e.g., SQL_ATTR_AUTOCOMMIT, + SQL_ATTR_TXN_ISOLATION). + value: The value to set for the attribute. Can be an integer, string, + bytes, or bytearray depending on the attribute type. + + Raises: + InterfaceError: If the connection is closed or attribute is invalid. + ProgrammingError: If the value type or range is invalid. + ProgrammingError: If the attribute cannot be set after connection. + + Example: + >>> conn.set_attr(SQL_ATTR_TXN_ISOLATION, SQL_TXN_READ_COMMITTED) + + Note: + Some attributes (like SQL_ATTR_LOGIN_TIMEOUT, SQL_ATTR_ODBC_CURSORS, and + SQL_ATTR_PACKET_SIZE) can only be set before connection establishment and + must be provided in the attrs_before parameter when creating the connection. + Attempting to set these attributes after connection will raise a ProgrammingError. + """ + if self._closed: + raise InterfaceError("Cannot set attribute on closed connection", "Connection is closed") + + # Use the integrated validation helper function with connection state + is_valid, error_message, sanitized_attr, sanitized_val = validate_attribute_value( + attribute, value, is_connected=True + ) + + if not is_valid: + # Use the already sanitized values for logging + log('warning', f"Invalid attribute or value: {sanitized_attr}={sanitized_val}, {error_message}") + raise ProgrammingError( + driver_error=f"Invalid attribute or value: {error_message}", + ddbc_error=error_message + ) + + # Log with sanitized values + log('debug', f"Setting connection attribute: {sanitized_attr}={sanitized_val}") + + try: + # Call the underlying C++ method + self._conn.set_attr(attribute, value) + log('info', f"Connection attribute {sanitized_attr} set successfully") + + except Exception as e: + error_msg = f"Failed to set connection attribute {sanitized_attr}: {str(e)}" + log('error', error_msg) + + # Determine appropriate exception type based on error content + error_str = str(e).lower() + if 'invalid' in error_str or 'unsupported' in error_str or 'cast' in error_str: + raise InterfaceError(error_msg, str(e)) from e + else: + raise ProgrammingError(error_msg, str(e)) from e @property def searchescape(self): diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 05df3e14..20ff3bf9 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -20,20 +20,14 @@ class ConstantsDDBC(Enum): SQL_STILL_EXECUTING = 2 SQL_NTS = -3 SQL_DRIVER_NOPROMPT = 0 - SQL_ATTR_ASYNC_DBC_EVENT = 119 SQL_IS_INTEGER = -6 - SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE = 117 SQL_OV_DDBC3_80 = 380 - SQL_ATTR_DDBC_VERSION = 200 - SQL_ATTR_ASYNC_ENABLE = 4 - SQL_ATTR_ASYNC_STMT_EVENT = 29 SQL_ERROR = -1 SQL_INVALID_HANDLE = -2 SQL_NULL_HANDLE = 0 SQL_OV_DDBC3 = 3 SQL_COMMIT = 0 SQL_ROLLBACK = 1 - SQL_ATTR_AUTOCOMMIT = 102 SQL_SMALLINT = 5 SQL_CHAR = 1 SQL_WCHAR = -8 @@ -94,20 +88,15 @@ class ConstantsDDBC(Enum): SQL_DESC_TYPE = 2 SQL_DESC_LENGTH = 3 SQL_DESC_NAME = 4 - SQL_ATTR_ROW_ARRAY_SIZE = 27 - SQL_ATTR_ROWS_FETCHED_PTR = 26 - SQL_ATTR_ROW_STATUS_PTR = 25 SQL_ROW_SUCCESS = 0 SQL_ROW_SUCCESS_WITH_INFO = 1 SQL_ROW_NOROW = 100 - SQL_ATTR_CURSOR_TYPE = 6 SQL_CURSOR_FORWARD_ONLY = 0 SQL_CURSOR_STATIC = 3 SQL_CURSOR_KEYSET_DRIVEN = 2 SQL_CURSOR_DYNAMIC = 3 SQL_NULL_DATA = -1 SQL_C_DEFAULT = 99 - SQL_ATTR_ROW_BIND_TYPE = 5 SQL_BIND_BY_COLUMN = 0 SQL_PARAM_INPUT = 1 SQL_PARAM_OUTPUT = 2 @@ -115,7 +104,6 @@ class ConstantsDDBC(Enum): SQL_C_WCHAR = -8 SQL_NULLABLE = 1 SQL_MAX_NUMERIC_LEN = 16 - SQL_ATTR_QUERY_TIMEOUT = 2 SQL_FETCH_NEXT = 1 SQL_FETCH_FIRST = 2 @@ -136,6 +124,60 @@ class ConstantsDDBC(Enum): SQL_QUICK = 0 SQL_ENSURE = 1 + # Connection Attribute Constants for set_attr() + SQL_ATTR_ACCESS_MODE = 101 + SQL_ATTR_AUTOCOMMIT = 102 + SQL_ATTR_CURSOR_TYPE = 6 + SQL_ATTR_ROW_BIND_TYPE = 5 + SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE = 117 + SQL_ATTR_ROW_ARRAY_SIZE = 27 + SQL_ATTR_ASYNC_DBC_EVENT = 119 + SQL_ATTR_DDBC_VERSION = 200 + SQL_ATTR_ASYNC_STMT_EVENT = 29 + SQL_ATTR_ROWS_FETCHED_PTR = 26 + SQL_ATTR_ROW_STATUS_PTR = 25 + SQL_ATTR_CONNECTION_TIMEOUT = 113 + SQL_ATTR_CURRENT_CATALOG = 109 + SQL_ATTR_LOGIN_TIMEOUT = 103 + SQL_ATTR_ODBC_CURSORS = 110 + SQL_ATTR_PACKET_SIZE = 112 + SQL_ATTR_QUIET_MODE = 111 + SQL_ATTR_TXN_ISOLATION = 108 + SQL_ATTR_TRACE = 104 + SQL_ATTR_TRACEFILE = 105 + SQL_ATTR_TRANSLATE_LIB = 106 + SQL_ATTR_TRANSLATE_OPTION = 107 + SQL_ATTR_CONNECTION_POOLING = 201 + SQL_ATTR_CP_MATCH = 202 + SQL_ATTR_ASYNC_ENABLE = 4 + SQL_ATTR_ENLIST_IN_DTC = 1207 + SQL_ATTR_ENLIST_IN_XA = 1208 + SQL_ATTR_CONNECTION_DEAD = 1209 + SQL_ATTR_SERVER_NAME = 13 + SQL_ATTR_RESET_CONNECTION = 116 + + # Transaction Isolation Level Constants + SQL_TXN_READ_UNCOMMITTED = 1 + SQL_TXN_READ_COMMITTED = 2 + SQL_TXN_REPEATABLE_READ = 4 + SQL_TXN_SERIALIZABLE = 8 + + # Access Mode Constants + SQL_MODE_READ_WRITE = 0 + SQL_MODE_READ_ONLY = 1 + + # Connection Dead Constants + SQL_CD_TRUE = 1 + SQL_CD_FALSE = 0 + + # ODBC Cursors Constants + SQL_CUR_USE_IF_NEEDED = 0 + SQL_CUR_USE_ODBC = 1 + SQL_CUR_USE_DRIVER = 2 + + # Reset Connection Constants + SQL_RESET_CONNECTION_YES = 1 + class GetInfoConstants(Enum): """ These constants are used with various methods like getinfo(). @@ -324,4 +366,54 @@ def get_numeric_types(cls) -> set: ConstantsDDBC.SQL_SMALLINT.value, ConstantsDDBC.SQL_INTEGER.value, ConstantsDDBC.SQL_BIGINT.value, ConstantsDDBC.SQL_REAL.value, ConstantsDDBC.SQL_FLOAT.value, ConstantsDDBC.SQL_DOUBLE.value - } \ No newline at end of file + } + +class AttributeSetTime(Enum): + """ + Defines when connection attributes can be set in relation to connection establishment. + + This enum is used to validate if a specific connection attribute can be set before + connection, after connection, or at either time. + """ + BEFORE_ONLY = 1 # Must be set before connection is established + AFTER_ONLY = 2 # Can only be set after connection is established + EITHER = 3 # Can be set either before or after connection + +# Dictionary mapping attributes to their valid set times +ATTRIBUTE_SET_TIMING = { + # Must be set before connection + ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value: AttributeSetTime.BEFORE_ONLY, + ConstantsDDBC.SQL_ATTR_ODBC_CURSORS.value: AttributeSetTime.BEFORE_ONLY, + ConstantsDDBC.SQL_ATTR_PACKET_SIZE.value: AttributeSetTime.BEFORE_ONLY, + + # Can only be set after connection + ConstantsDDBC.SQL_ATTR_CONNECTION_DEAD.value: AttributeSetTime.AFTER_ONLY, + ConstantsDDBC.SQL_ATTR_ENLIST_IN_DTC.value: AttributeSetTime.AFTER_ONLY, + ConstantsDDBC.SQL_ATTR_TRANSLATE_LIB.value: AttributeSetTime.AFTER_ONLY, + ConstantsDDBC.SQL_ATTR_TRANSLATE_OPTION.value: AttributeSetTime.AFTER_ONLY, + + # Can be set either before or after connection + ConstantsDDBC.SQL_ATTR_ACCESS_MODE.value: AttributeSetTime.EITHER, + ConstantsDDBC.SQL_ATTR_ASYNC_DBC_EVENT.value: AttributeSetTime.EITHER, + ConstantsDDBC.SQL_ATTR_ASYNC_DBC_FUNCTIONS_ENABLE.value: AttributeSetTime.EITHER, + ConstantsDDBC.SQL_ATTR_ASYNC_ENABLE.value: AttributeSetTime.EITHER, + ConstantsDDBC.SQL_ATTR_AUTOCOMMIT.value: AttributeSetTime.EITHER, + ConstantsDDBC.SQL_ATTR_CONNECTION_TIMEOUT.value: AttributeSetTime.EITHER, + ConstantsDDBC.SQL_ATTR_CURRENT_CATALOG.value: AttributeSetTime.EITHER, + ConstantsDDBC.SQL_ATTR_QUIET_MODE.value: AttributeSetTime.EITHER, + ConstantsDDBC.SQL_ATTR_TRACE.value: AttributeSetTime.EITHER, + ConstantsDDBC.SQL_ATTR_TRACEFILE.value: AttributeSetTime.EITHER, + ConstantsDDBC.SQL_ATTR_TXN_ISOLATION.value: AttributeSetTime.EITHER, +} + +def get_attribute_set_timing(attribute): + """ + Get when an attribute can be set (before connection, after, or either). + + Args: + attribute (int): The connection attribute (SQL_ATTR_*) + + Returns: + AttributeSetTime: When the attribute can be set + """ + return ATTRIBUTE_SET_TIMING.get(attribute, AttributeSetTime.AFTER_ONLY) \ No newline at end of file diff --git a/mssql_python/helpers.py b/mssql_python/helpers.py index 2ac3c669..82a6ca65 100644 --- a/mssql_python/helpers.py +++ b/mssql_python/helpers.py @@ -7,8 +7,8 @@ from mssql_python import ddbc_bindings from mssql_python.exceptions import raise_exception from mssql_python.logging_config import get_logger -import platform -from pathlib import Path +import re +from mssql_python.constants import ConstantsDDBC from mssql_python.ddbc_bindings import normalize_architecture logger = get_logger() @@ -155,6 +155,96 @@ def sanitize_user_input(user_input: str, max_length: int = 50) -> str: # Return placeholder if nothing remains after sanitization return sanitized if sanitized else "" +def validate_attribute_value(attribute, value, is_connected=True, sanitize_logs=True, max_log_length=50): + """ + Validates attribute and value pairs for connection attributes. + + Performs basic type checking and validation of ODBC connection attributes. + + Args: + attribute (int): The connection attribute to validate (SQL_ATTR_*) + value: The value to set for the attribute (int, str, bytes, or bytearray) + is_connected (bool): Whether the connection is already established + sanitize_logs (bool): Whether to include sanitized versions for logging + max_log_length (int): Maximum length of sanitized output for logging + + Returns: + tuple: (is_valid, error_message, sanitized_attribute, sanitized_value) + """ + # Sanitize a value for logging + def _sanitize_for_logging(input_val, max_length=max_log_length): + if not isinstance(input_val, str): + try: + input_val = str(input_val) + except: + return "" + + # Allow alphanumeric, dash, underscore, and dot + sanitized = re.sub(r'[^\w\-\.]', '', input_val) + + # Limit length + if len(sanitized) > max_length: + sanitized = sanitized[:max_length] + "..." + + return sanitized if sanitized else "" + + # Create sanitized versions for logging + sanitized_attr = _sanitize_for_logging(attribute) if sanitize_logs else str(attribute) + sanitized_val = _sanitize_for_logging(value) if sanitize_logs else str(value) + + # Basic attribute validation - must be an integer + if not isinstance(attribute, int): + return False, f"Attribute must be an integer, got {type(attribute).__name__}", sanitized_attr, sanitized_val + + # Define driver-level attributes that are supported + SUPPORTED_ATTRIBUTES = [ + ConstantsDDBC.SQL_ATTR_ACCESS_MODE.value, + ConstantsDDBC.SQL_ATTR_CONNECTION_TIMEOUT.value, + ConstantsDDBC.SQL_ATTR_CURRENT_CATALOG.value, + ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value, + ConstantsDDBC.SQL_ATTR_PACKET_SIZE.value, + ConstantsDDBC.SQL_ATTR_TXN_ISOLATION.value + ] + + # Check if attribute is supported + if attribute not in SUPPORTED_ATTRIBUTES: + return False, f"Unsupported attribute: {attribute}", sanitized_attr, sanitized_val + + # Check timing constraints for these specific attributes + BEFORE_ONLY_ATTRIBUTES = [ + ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value, + ConstantsDDBC.SQL_ATTR_PACKET_SIZE.value + ] + + # Check if attribute can be set at the current connection state + if is_connected and attribute in BEFORE_ONLY_ATTRIBUTES: + return False, (f"Attribute {attribute} must be set before connection establishment. " + "Use the attrs_before parameter when creating the connection."), sanitized_attr, sanitized_val + + # Basic value type validation + if isinstance(value, int): + # For integer values, check if negative (login timeout can be -1 for default) + if value < 0 and attribute != ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value: + return False, f"Integer value cannot be negative: {value}", sanitized_attr, sanitized_val + + elif isinstance(value, str): + # Basic string length check + MAX_STRING_SIZE = 8192 # 8KB maximum + if len(value) > MAX_STRING_SIZE: + return False, f"String value too large: {len(value)} bytes (max {MAX_STRING_SIZE})", sanitized_attr, sanitized_val + + elif isinstance(value, (bytes, bytearray)): + # Basic binary length check + MAX_BINARY_SIZE = 32768 # 32KB maximum + if len(value) > MAX_BINARY_SIZE: + return False, f"Binary value too large: {len(value)} bytes (max {MAX_BINARY_SIZE})", sanitized_attr, sanitized_val + + else: + # Reject unsupported value types + return False, f"Unsupported attribute value type: {type(value).__name__}", sanitized_attr, sanitized_val + + # All basic validations passed + return True, None, sanitized_attr, sanitized_val def log(level: str, message: str, *args) -> None: """ diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 03178c4f..b38645d8 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -7,6 +7,8 @@ #include "connection.h" #include "connection_pool.h" #include +#include +#include #include #define SQL_COPT_SS_ACCESS_TOKEN 1256 // Custom attribute ID for access token @@ -168,35 +170,118 @@ SqlHandlePtr Connection::allocStatementHandle() { return std::make_shared(static_cast(SQL_HANDLE_STMT), stmt); } - SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) { LOG("Setting SQL attribute"); - SQLPOINTER ptr = nullptr; - SQLINTEGER length = 0; + //SQLPOINTER ptr = nullptr; + //SQLINTEGER length = 0; static std::string buffer; // to hold sensitive data temporarily - + if (py::isinstance(value)) { - int intValue = value.cast(); - ptr = reinterpret_cast(static_cast(intValue)); - length = SQL_IS_INTEGER; - } else if (py::isinstance(value) || py::isinstance(value)) { - buffer = value.cast(); // stack buffer - ptr = buffer.data(); - length = static_cast(buffer.size()); - } else { - LOG("Unsupported attribute value type"); - return SQL_ERROR; - } - - SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), attribute, ptr, length); - if (!SQL_SUCCEEDED(ret)) { - LOG("Failed to set attribute"); + // Get the integer value + long long longValue = value.cast(); + + SQLRETURN ret = SQLSetConnectAttr_ptr( + _dbcHandle->get(), + attribute, + (SQLPOINTER)(SQLULEN)longValue, + SQL_IS_INTEGER); + + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to set attribute"); + } + else { + LOG("Set attribute successfully"); + } + return ret; + } + else if (py::isinstance(value)) { + try { + static std::vector wstr_buffers; // Keep buffers alive + std::string utf8_str = value.cast(); + + // Convert to wide string + std::wstring wstr = Utf8ToWString(utf8_str); + if (wstr.empty() && !utf8_str.empty()) { + LOG("Failed to convert string value to wide string"); + return SQL_ERROR; + } + + // Limit static buffer growth for memory safety + constexpr size_t MAX_BUFFER_COUNT = 100; + if (wstr_buffers.size() >= MAX_BUFFER_COUNT) { + // Remove oldest 50% of entries when limit reached + wstr_buffers.erase(wstr_buffers.begin(), wstr_buffers.begin() + (MAX_BUFFER_COUNT / 2)); + } + + wstr_buffers.push_back(wstr); + + SQLPOINTER ptr; + SQLINTEGER length; + +#if defined(__APPLE__) || defined(__linux__) + // For macOS/Linux, convert wstring to SQLWCHAR buffer + std::vector sqlwcharBuffer = WStringToSQLWCHAR(wstr); + if (sqlwcharBuffer.empty() && !wstr.empty()) { + LOG("Failed to convert wide string to SQLWCHAR buffer"); + return SQL_ERROR; + } + + ptr = sqlwcharBuffer.data(); + length = static_cast(sqlwcharBuffer.size() * sizeof(SQLWCHAR)); +#else + // On Windows, wchar_t and SQLWCHAR are the same size + ptr = const_cast(wstr_buffers.back().c_str()); + length = static_cast(wstr.length() * sizeof(SQLWCHAR)); +#endif + + SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), attribute, ptr, length); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to set string attribute"); + } + else { + LOG("Set string attribute successfully"); + } + return ret; + } + catch (const std::exception& e) { + LOG("Exception during string attribute setting: " + std::string(e.what())); + return SQL_ERROR; + } } + else if (py::isinstance(value) || py::isinstance(value)) { + try { + static std::vector buffers; + std::string binary_data = value.cast(); + + // Limit static buffer growth + constexpr size_t MAX_BUFFER_COUNT = 100; + if (buffers.size() >= MAX_BUFFER_COUNT) { + // Remove oldest 50% of entries when limit reached + buffers.erase(buffers.begin(), buffers.begin() + (MAX_BUFFER_COUNT / 2)); + } + + buffers.emplace_back(std::move(binary_data)); + SQLPOINTER ptr = const_cast(buffers.back().c_str()); + SQLINTEGER length = static_cast(buffers.back().size()); + + SQLRETURN ret = SQLSetConnectAttr_ptr(_dbcHandle->get(), attribute, ptr, length); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to set attribute with binary data"); + } + else { + LOG("Set attribute successfully with binary data"); + } + return ret; + } + catch (const std::exception& e) { + LOG("Exception during binary attribute setting: " + std::string(e.what())); + return SQL_ERROR; + } + } else { - LOG("Set attribute successfully"); + LOG("Unsupported attribute value type"); + return SQL_ERROR; } - - return ret; } void Connection::applyAttrsBefore(const py::dict& attrs) { @@ -208,11 +293,12 @@ void Connection::applyAttrsBefore(const py::dict& attrs) { continue; } - if (key == SQL_COPT_SS_ACCESS_TOKEN) { - SQLRETURN ret = setAttribute(key, py::reinterpret_borrow(item.second)); - if (!SQL_SUCCEEDED(ret)) { - ThrowStdException("Failed to set access token before connect"); - } + // Apply all supported attributes + SQLRETURN ret = setAttribute(key, py::reinterpret_borrow(item.second)); + if (!SQL_SUCCEEDED(ret)) { + std::string attrName = std::to_string(key); + std::string errorMsg = "Failed to set attribute " + attrName + " before connect"; + ThrowStdException(errorMsg); } } } @@ -380,3 +466,33 @@ py::object ConnectionHandle::getInfo(SQLUSMALLINT infoType) const { } return _conn->getInfo(infoType); } + +void ConnectionHandle::setAttr(int attribute, py::object value) { + if (!_conn) { + ThrowStdException("Connection not established"); + } + + // Use existing setAttribute with better error handling + SQLRETURN ret = _conn->setAttribute(static_cast(attribute), value); + if (!SQL_SUCCEEDED(ret)) { + // Get detailed error information from ODBC + try { + ErrorInfo errorInfo = SQLCheckError_Wrap(SQL_HANDLE_DBC, _conn->getDbcHandle(), ret); + + std::string errorMsg = "Failed to set connection attribute " + std::to_string(attribute); + if (!errorInfo.ddbcErrorMsg.empty()) { + // Convert wstring to string for concatenation + std::string ddbcErrorStr = WideToUTF8(errorInfo.ddbcErrorMsg); + errorMsg += ": " + ddbcErrorStr; + } + + LOG("Connection setAttribute failed: {}", errorMsg); + ThrowStdException(errorMsg); + } catch (...) { + // Fallback to generic error if detailed error retrieval fails + std::string errorMsg = "Failed to set connection attribute " + std::to_string(attribute); + LOG("Connection setAttribute failed: {}", errorMsg); + ThrowStdException(errorMsg); + } + } +} diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index 66dd5895..68c2a216 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -45,10 +45,14 @@ class Connection { // Get information about the driver and data source py::object getInfo(SQLUSMALLINT infoType) const; + SQLRETURN setAttribute(SQLINTEGER attribute, py::object value); + + // Add getter for DBC handle for error reporting + const SqlHandlePtr& getDbcHandle() const { return _dbcHandle; } + private: void allocateDbcHandle(); void checkError(SQLRETURN ret) const; - SQLRETURN setAttribute(SQLINTEGER attribute, py::object value); void applyAttrsBefore(const py::dict& attrs_before); std::wstring _connStr; @@ -69,6 +73,7 @@ class ConnectionHandle { void setAutocommit(bool enabled); bool getAutocommit() const; SqlHandlePtr allocStatementHandle(); + void setAttr(int attribute, py::object value); // Get information about the driver and data source py::object getInfo(SQLUSMALLINT infoType) const; diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 879e76cf..7fef5c7f 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -3813,6 +3813,7 @@ PYBIND11_MODULE(ddbc_bindings, m) { .def("rollback", &ConnectionHandle::rollback, "Rollback the current transaction") .def("set_autocommit", &ConnectionHandle::setAutocommit) .def("get_autocommit", &ConnectionHandle::getAutocommit) + .def("set_attr", &ConnectionHandle::setAttr, py::arg("attribute"), py::arg("value"), "Set connection attribute") .def("alloc_statement_handle", &ConnectionHandle::allocStatementHandle) .def("get_info", &ConnectionHandle::getInfo, py::arg("info_type")); m.def("enable_pooling", &enable_pooling, "Enable global connection pooling"); diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 6f843c97..6ed9f0d4 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -19,7 +19,9 @@ - test_context_manager_connection_closes: Test that context manager closes the connection. """ +from mssql_python.exceptions import InterfaceError, ProgrammingError, DatabaseError import mssql_python +import sys import pytest import time from mssql_python import connect, Connection, pooling, SQL_CHAR, SQL_WCHAR @@ -5040,4 +5042,1514 @@ def test_connection_searchescape_consistency(db_connection): assert new_escape == escape1, "Searchescape should be consistent across connections" new_conn.close() except Exception as e: - print(f"Note: New connection comparison failed: {e}") \ No newline at end of file + print(f"Note: New connection comparison failed: {e}") +def test_setencoding_default_settings(db_connection): + """Test that default encoding settings are correct.""" + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-16le', "Default encoding should be utf-16le" + assert settings['ctype'] == -8, "Default ctype should be SQL_WCHAR (-8)" + +def test_setencoding_basic_functionality(db_connection): + """Test basic setencoding functionality.""" + # Test setting UTF-8 encoding + db_connection.setencoding(encoding='utf-8') + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-8', "Encoding should be set to utf-8" + assert settings['ctype'] == 1, "ctype should default to SQL_CHAR (1) for utf-8" + + # Test setting UTF-16LE with explicit ctype + db_connection.setencoding(encoding='utf-16le', ctype=-8) + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-16le', "Encoding should be set to utf-16le" + assert settings['ctype'] == -8, "ctype should be SQL_WCHAR (-8)" + +def test_setencoding_automatic_ctype_detection(db_connection): + """Test automatic ctype detection based on encoding.""" + # UTF-16 variants should default to SQL_WCHAR + utf16_encodings = ['utf-16', 'utf-16le', 'utf-16be'] + for encoding in utf16_encodings: + db_connection.setencoding(encoding=encoding) + settings = db_connection.getencoding() + assert settings['ctype'] == -8, f"{encoding} should default to SQL_WCHAR (-8)" + + # Other encodings should default to SQL_CHAR + other_encodings = ['utf-8', 'latin-1', 'ascii'] + for encoding in other_encodings: + db_connection.setencoding(encoding=encoding) + settings = db_connection.getencoding() + assert settings['ctype'] == 1, f"{encoding} should default to SQL_CHAR (1)" + +def test_setencoding_explicit_ctype_override(db_connection): + """Test that explicit ctype parameter overrides automatic detection.""" + # Set UTF-8 with SQL_WCHAR (override default) + db_connection.setencoding(encoding='utf-8', ctype=-8) + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-8', "Encoding should be utf-8" + assert settings['ctype'] == -8, "ctype should be SQL_WCHAR (-8) when explicitly set" + + # Set UTF-16LE with SQL_CHAR (override default) + db_connection.setencoding(encoding='utf-16le', ctype=1) + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-16le', "Encoding should be utf-16le" + assert settings['ctype'] == 1, "ctype should be SQL_CHAR (1) when explicitly set" + +def test_setencoding_none_parameters(db_connection): + """Test setencoding with None parameters.""" + # Test with encoding=None (should use default) + db_connection.setencoding(encoding=None) + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-16le', "encoding=None should use default utf-16le" + assert settings['ctype'] == -8, "ctype should be SQL_WCHAR for utf-16le" + + # Test with both None (should use defaults) + db_connection.setencoding(encoding=None, ctype=None) + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-16le', "encoding=None should use default utf-16le" + assert settings['ctype'] == -8, "ctype=None should use default SQL_WCHAR" + +def test_setencoding_invalid_encoding(db_connection): + """Test setencoding with invalid encoding.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setencoding(encoding='invalid-encoding-name') + + assert "Unsupported encoding" in str(exc_info.value), "Should raise ProgrammingError for invalid encoding" + assert "invalid-encoding-name" in str(exc_info.value), "Error message should include the invalid encoding name" + +def test_setencoding_invalid_ctype(db_connection): + """Test setencoding with invalid ctype.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setencoding(encoding='utf-8', ctype=999) + + assert "Invalid ctype" in str(exc_info.value), "Should raise ProgrammingError for invalid ctype" + assert "999" in str(exc_info.value), "Error message should include the invalid ctype value" + +def test_setencoding_closed_connection(conn_str): + """Test setencoding on closed connection.""" + + temp_conn = connect(conn_str) + temp_conn.close() + + with pytest.raises(InterfaceError) as exc_info: + temp_conn.setencoding(encoding='utf-8') + + assert "Connection is closed" in str(exc_info.value), "Should raise InterfaceError for closed connection" + +def test_setencoding_constants_access(): + """Test that SQL_CHAR and SQL_WCHAR constants are accessible.""" + + + # Test constants exist and have correct values + assert hasattr(mssql_python, 'SQL_CHAR'), "SQL_CHAR constant should be available" + assert hasattr(mssql_python, 'SQL_WCHAR'), "SQL_WCHAR constant should be available" + assert mssql_python.SQL_CHAR == 1, "SQL_CHAR should have value 1" + assert mssql_python.SQL_WCHAR == -8, "SQL_WCHAR should have value -8" + +def test_setencoding_with_constants(db_connection): + """Test setencoding using module constants.""" + + + # Test with SQL_CHAR constant + db_connection.setencoding(encoding='utf-8', ctype=mssql_python.SQL_CHAR) + settings = db_connection.getencoding() + assert settings['ctype'] == mssql_python.SQL_CHAR, "Should accept SQL_CHAR constant" + + # Test with SQL_WCHAR constant + db_connection.setencoding(encoding='utf-16le', ctype=mssql_python.SQL_WCHAR) + settings = db_connection.getencoding() + assert settings['ctype'] == mssql_python.SQL_WCHAR, "Should accept SQL_WCHAR constant" + +def test_setencoding_common_encodings(db_connection): + """Test setencoding with various common encodings.""" + common_encodings = [ + 'utf-8', + 'utf-16le', + 'utf-16be', + 'utf-16', + 'latin-1', + 'ascii', + 'cp1252' + ] + + for encoding in common_encodings: + try: + db_connection.setencoding(encoding=encoding) + settings = db_connection.getencoding() + assert settings['encoding'] == encoding, f"Failed to set encoding {encoding}" + except Exception as e: + pytest.fail(f"Failed to set valid encoding {encoding}: {e}") + +def test_setencoding_persistence_across_cursors(db_connection): + """Test that encoding settings persist across cursor operations.""" + # Set custom encoding + db_connection.setencoding(encoding='utf-8', ctype=1) + + # Create cursors and verify encoding persists + cursor1 = db_connection.cursor() + settings1 = db_connection.getencoding() + + cursor2 = db_connection.cursor() + settings2 = db_connection.getencoding() + + assert settings1 == settings2, "Encoding settings should persist across cursor creation" + assert settings1['encoding'] == 'utf-8', "Encoding should remain utf-8" + assert settings1['ctype'] == 1, "ctype should remain SQL_CHAR" + + cursor1.close() + cursor2.close() + +@pytest.mark.skip("Skipping Unicode data tests till we have support for Unicode") +def test_setencoding_with_unicode_data(db_connection): + """Test setencoding with actual Unicode data operations.""" + # Test UTF-8 encoding with Unicode data + db_connection.setencoding(encoding='utf-8') + cursor = db_connection.cursor() + + try: + # Create test table + cursor.execute("CREATE TABLE #test_encoding_unicode (text_col NVARCHAR(100))") + + # Test various Unicode strings + test_strings = [ + "Hello, World!", + "Hello, 世界!", # Chinese + "Привет, мир!", # Russian + "مرحبا بالعالم", # Arabic + "🌍🌎🌏", # Emoji + ] + + for test_string in test_strings: + # Insert data + cursor.execute("INSERT INTO #test_encoding_unicode (text_col) VALUES (?)", test_string) + + # Retrieve and verify + cursor.execute("SELECT text_col FROM #test_encoding_unicode WHERE text_col = ?", test_string) + result = cursor.fetchone() + + assert result is not None, f"Failed to retrieve Unicode string: {test_string}" + assert result[0] == test_string, f"Unicode string mismatch: expected {test_string}, got {result[0]}" + + # Clear for next test + cursor.execute("DELETE FROM #test_encoding_unicode") + + except Exception as e: + pytest.fail(f"Unicode data test failed with UTF-8 encoding: {e}") + finally: + try: + cursor.execute("DROP TABLE #test_encoding_unicode") + except: + pass + cursor.close() + +def test_setencoding_before_and_after_operations(db_connection): + """Test that setencoding works both before and after database operations.""" + cursor = db_connection.cursor() + + try: + # Initial encoding setting + db_connection.setencoding(encoding='utf-16le') + + # Perform database operation + cursor.execute("SELECT 'Initial test' as message") + result1 = cursor.fetchone() + assert result1[0] == 'Initial test', "Initial operation failed" + + # Change encoding after operation + db_connection.setencoding(encoding='utf-8') + settings = db_connection.getencoding() + assert settings['encoding'] == 'utf-8', "Failed to change encoding after operation" + + # Perform another operation with new encoding + cursor.execute("SELECT 'Changed encoding test' as message") + result2 = cursor.fetchone() + assert result2[0] == 'Changed encoding test', "Operation after encoding change failed" + + except Exception as e: + pytest.fail(f"Encoding change test failed: {e}") + finally: + cursor.close() + +def test_getencoding_default(conn_str): + """Test getencoding returns default settings""" + conn = connect(conn_str) + try: + encoding_info = conn.getencoding() + assert isinstance(encoding_info, dict) + assert 'encoding' in encoding_info + assert 'ctype' in encoding_info + # Default should be utf-16le with SQL_WCHAR + assert encoding_info['encoding'] == 'utf-16le' + assert encoding_info['ctype'] == SQL_WCHAR + finally: + conn.close() + +def test_getencoding_returns_copy(conn_str): + """Test getencoding returns a copy (not reference)""" + conn = connect(conn_str) + try: + encoding_info1 = conn.getencoding() + encoding_info2 = conn.getencoding() + + # Should be equal but not the same object + assert encoding_info1 == encoding_info2 + assert encoding_info1 is not encoding_info2 + + # Modifying one shouldn't affect the other + encoding_info1['encoding'] = 'modified' + assert encoding_info2['encoding'] != 'modified' + finally: + conn.close() + +def test_getencoding_closed_connection(conn_str): + """Test getencoding on closed connection raises InterfaceError""" + conn = connect(conn_str) + conn.close() + + with pytest.raises(InterfaceError, match="Connection is closed"): + conn.getencoding() + +def test_setencoding_getencoding_consistency(conn_str): + """Test that setencoding and getencoding work consistently together""" + conn = connect(conn_str) + try: + test_cases = [ + ('utf-8', SQL_CHAR), + ('utf-16le', SQL_WCHAR), + ('latin-1', SQL_CHAR), + ('ascii', SQL_CHAR), + ] + + for encoding, expected_ctype in test_cases: + conn.setencoding(encoding) + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == encoding.lower() + assert encoding_info['ctype'] == expected_ctype + finally: + conn.close() + +def test_setencoding_default_encoding(conn_str): + """Test setencoding with default UTF-16LE encoding""" + conn = connect(conn_str) + try: + conn.setencoding() + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-16le' + assert encoding_info['ctype'] == SQL_WCHAR + finally: + conn.close() + +def test_setencoding_utf8(conn_str): + """Test setencoding with UTF-8 encoding""" + conn = connect(conn_str) + try: + conn.setencoding('utf-8') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-8' + assert encoding_info['ctype'] == SQL_CHAR + finally: + conn.close() + +def test_setencoding_latin1(conn_str): + """Test setencoding with latin-1 encoding""" + conn = connect(conn_str) + try: + conn.setencoding('latin-1') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'latin-1' + assert encoding_info['ctype'] == SQL_CHAR + finally: + conn.close() + +def test_setencoding_with_explicit_ctype_sql_char(conn_str): + """Test setencoding with explicit SQL_CHAR ctype""" + conn = connect(conn_str) + try: + conn.setencoding('utf-8', SQL_CHAR) + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-8' + assert encoding_info['ctype'] == SQL_CHAR + finally: + conn.close() + +def test_setencoding_with_explicit_ctype_sql_wchar(conn_str): + """Test setencoding with explicit SQL_WCHAR ctype""" + conn = connect(conn_str) + try: + conn.setencoding('utf-16le', SQL_WCHAR) + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-16le' + assert encoding_info['ctype'] == SQL_WCHAR + finally: + conn.close() + +def test_setencoding_invalid_ctype_error(conn_str): + """Test setencoding with invalid ctype raises ProgrammingError""" + + conn = connect(conn_str) + try: + with pytest.raises(ProgrammingError, match="Invalid ctype"): + conn.setencoding('utf-8', 999) + finally: + conn.close() + +def test_setencoding_case_insensitive_encoding(conn_str): + """Test setencoding with case variations""" + conn = connect(conn_str) + try: + # Test various case formats + conn.setencoding('UTF-8') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-8' # Should be normalized + + conn.setencoding('Utf-16LE') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-16le' # Should be normalized + finally: + conn.close() + +def test_setencoding_none_encoding_default(conn_str): + """Test setencoding with None encoding uses default""" + conn = connect(conn_str) + try: + conn.setencoding(None) + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-16le' + assert encoding_info['ctype'] == SQL_WCHAR + finally: + conn.close() + +def test_setencoding_override_previous(conn_str): + """Test setencoding overrides previous settings""" + conn = connect(conn_str) + try: + # Set initial encoding + conn.setencoding('utf-8') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-8' + assert encoding_info['ctype'] == SQL_CHAR + + # Override with different encoding + conn.setencoding('utf-16le') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'utf-16le' + assert encoding_info['ctype'] == SQL_WCHAR + finally: + conn.close() + +def test_setencoding_ascii(conn_str): + """Test setencoding with ASCII encoding""" + conn = connect(conn_str) + try: + conn.setencoding('ascii') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'ascii' + assert encoding_info['ctype'] == SQL_CHAR + finally: + conn.close() + +def test_setencoding_cp1252(conn_str): + """Test setencoding with Windows-1252 encoding""" + conn = connect(conn_str) + try: + conn.setencoding('cp1252') + encoding_info = conn.getencoding() + assert encoding_info['encoding'] == 'cp1252' + assert encoding_info['ctype'] == SQL_CHAR + finally: + conn.close() + +def test_setdecoding_default_settings(db_connection): + """Test that default decoding settings are correct for all SQL types.""" + + # Check SQL_CHAR defaults + sql_char_settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert sql_char_settings['encoding'] == 'utf-8', "Default SQL_CHAR encoding should be utf-8" + assert sql_char_settings['ctype'] == mssql_python.SQL_CHAR, "Default SQL_CHAR ctype should be SQL_CHAR" + + # Check SQL_WCHAR defaults + sql_wchar_settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert sql_wchar_settings['encoding'] == 'utf-16le', "Default SQL_WCHAR encoding should be utf-16le" + assert sql_wchar_settings['ctype'] == mssql_python.SQL_WCHAR, "Default SQL_WCHAR ctype should be SQL_WCHAR" + + # Check SQL_WMETADATA defaults + sql_wmetadata_settings = db_connection.getdecoding(mssql_python.SQL_WMETADATA) + assert sql_wmetadata_settings['encoding'] == 'utf-16le', "Default SQL_WMETADATA encoding should be utf-16le" + assert sql_wmetadata_settings['ctype'] == mssql_python.SQL_WCHAR, "Default SQL_WMETADATA ctype should be SQL_WCHAR" + +def test_setdecoding_basic_functionality(db_connection): + """Test basic setdecoding functionality for different SQL types.""" + + # Test setting SQL_CHAR decoding + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='latin-1') + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'latin-1', "SQL_CHAR encoding should be set to latin-1" + assert settings['ctype'] == mssql_python.SQL_CHAR, "SQL_CHAR ctype should default to SQL_CHAR for latin-1" + + # Test setting SQL_WCHAR decoding + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16be') + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == 'utf-16be', "SQL_WCHAR encoding should be set to utf-16be" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "SQL_WCHAR ctype should default to SQL_WCHAR for utf-16be" + + # Test setting SQL_WMETADATA decoding + db_connection.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16le') + settings = db_connection.getdecoding(mssql_python.SQL_WMETADATA) + assert settings['encoding'] == 'utf-16le', "SQL_WMETADATA encoding should be set to utf-16le" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "SQL_WMETADATA ctype should default to SQL_WCHAR" + +def test_setdecoding_automatic_ctype_detection(db_connection): + """Test automatic ctype detection based on encoding for different SQL types.""" + + # UTF-16 variants should default to SQL_WCHAR + utf16_encodings = ['utf-16', 'utf-16le', 'utf-16be'] + for encoding in utf16_encodings: + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding=encoding) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['ctype'] == mssql_python.SQL_WCHAR, f"SQL_CHAR with {encoding} should auto-detect SQL_WCHAR ctype" + + # Other encodings should default to SQL_CHAR + other_encodings = ['utf-8', 'latin-1', 'ascii', 'cp1252'] + for encoding in other_encodings: + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding=encoding) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['ctype'] == mssql_python.SQL_CHAR, f"SQL_WCHAR with {encoding} should auto-detect SQL_CHAR ctype" + +def test_setdecoding_explicit_ctype_override(db_connection): + """Test that explicit ctype parameter overrides automatic detection.""" + + # Set SQL_CHAR with UTF-8 encoding but explicit SQL_WCHAR ctype + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8', ctype=mssql_python.SQL_WCHAR) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "Encoding should be utf-8" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "ctype should be SQL_WCHAR when explicitly set" + + # Set SQL_WCHAR with UTF-16LE encoding but explicit SQL_CHAR ctype + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le', ctype=mssql_python.SQL_CHAR) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == 'utf-16le', "Encoding should be utf-16le" + assert settings['ctype'] == mssql_python.SQL_CHAR, "ctype should be SQL_CHAR when explicitly set" + +def test_setdecoding_none_parameters(db_connection): + """Test setdecoding with None parameters uses appropriate defaults.""" + + # Test SQL_CHAR with encoding=None (should use utf-8 default) + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding=None) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "SQL_CHAR with encoding=None should use utf-8 default" + assert settings['ctype'] == mssql_python.SQL_CHAR, "ctype should be SQL_CHAR for utf-8" + + # Test SQL_WCHAR with encoding=None (should use utf-16le default) + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding=None) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == 'utf-16le', "SQL_WCHAR with encoding=None should use utf-16le default" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "ctype should be SQL_WCHAR for utf-16le" + + # Test with both parameters None + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding=None, ctype=None) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "SQL_CHAR with both None should use utf-8 default" + assert settings['ctype'] == mssql_python.SQL_CHAR, "ctype should default to SQL_CHAR" + +def test_setdecoding_invalid_sqltype(db_connection): + """Test setdecoding with invalid sqltype raises ProgrammingError.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setdecoding(999, encoding='utf-8') + + assert "Invalid sqltype" in str(exc_info.value), "Should raise ProgrammingError for invalid sqltype" + assert "999" in str(exc_info.value), "Error message should include the invalid sqltype value" + +def test_setdecoding_invalid_encoding(db_connection): + """Test setdecoding with invalid encoding raises ProgrammingError.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='invalid-encoding-name') + + assert "Unsupported encoding" in str(exc_info.value), "Should raise ProgrammingError for invalid encoding" + assert "invalid-encoding-name" in str(exc_info.value), "Error message should include the invalid encoding name" + +def test_setdecoding_invalid_ctype(db_connection): + """Test setdecoding with invalid ctype raises ProgrammingError.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8', ctype=999) + + assert "Invalid ctype" in str(exc_info.value), "Should raise ProgrammingError for invalid ctype" + assert "999" in str(exc_info.value), "Error message should include the invalid ctype value" + +def test_setdecoding_closed_connection(conn_str): + """Test setdecoding on closed connection raises InterfaceError.""" + + temp_conn = connect(conn_str) + temp_conn.close() + + with pytest.raises(InterfaceError) as exc_info: + temp_conn.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + + assert "Connection is closed" in str(exc_info.value), "Should raise InterfaceError for closed connection" + +def test_setdecoding_constants_access(): + """Test that SQL constants are accessible.""" + + # Test constants exist and have correct values + assert hasattr(mssql_python, 'SQL_CHAR'), "SQL_CHAR constant should be available" + assert hasattr(mssql_python, 'SQL_WCHAR'), "SQL_WCHAR constant should be available" + assert hasattr(mssql_python, 'SQL_WMETADATA'), "SQL_WMETADATA constant should be available" + + assert mssql_python.SQL_CHAR == 1, "SQL_CHAR should have value 1" + assert mssql_python.SQL_WCHAR == -8, "SQL_WCHAR should have value -8" + assert mssql_python.SQL_WMETADATA == -99, "SQL_WMETADATA should have value -99" + +def test_setdecoding_with_constants(db_connection): + """Test setdecoding using module constants.""" + + # Test with SQL_CHAR constant + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8', ctype=mssql_python.SQL_CHAR) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['ctype'] == mssql_python.SQL_CHAR, "Should accept SQL_CHAR constant" + + # Test with SQL_WCHAR constant + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le', ctype=mssql_python.SQL_WCHAR) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['ctype'] == mssql_python.SQL_WCHAR, "Should accept SQL_WCHAR constant" + + # Test with SQL_WMETADATA constant + db_connection.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16be') + settings = db_connection.getdecoding(mssql_python.SQL_WMETADATA) + assert settings['encoding'] == 'utf-16be', "Should accept SQL_WMETADATA constant" + +def test_setdecoding_common_encodings(db_connection): + """Test setdecoding with various common encodings.""" + + common_encodings = [ + 'utf-8', + 'utf-16le', + 'utf-16be', + 'utf-16', + 'latin-1', + 'ascii', + 'cp1252' + ] + + for encoding in common_encodings: + try: + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding=encoding) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == encoding, f"Failed to set SQL_CHAR decoding to {encoding}" + + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding=encoding) + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == encoding, f"Failed to set SQL_WCHAR decoding to {encoding}" + except Exception as e: + pytest.fail(f"Failed to set valid encoding {encoding}: {e}") + +def test_setdecoding_case_insensitive_encoding(db_connection): + """Test setdecoding with case variations normalizes encoding.""" + + # Test various case formats + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='UTF-8') + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "Encoding should be normalized to lowercase" + + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='Utf-16LE') + settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + assert settings['encoding'] == 'utf-16le', "Encoding should be normalized to lowercase" + +def test_setdecoding_independent_sql_types(db_connection): + """Test that decoding settings for different SQL types are independent.""" + + # Set different encodings for each SQL type + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le') + db_connection.setdecoding(mssql_python.SQL_WMETADATA, encoding='utf-16be') + + # Verify each maintains its own settings + sql_char_settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + sql_wchar_settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) + sql_wmetadata_settings = db_connection.getdecoding(mssql_python.SQL_WMETADATA) + + assert sql_char_settings['encoding'] == 'utf-8', "SQL_CHAR should maintain utf-8" + assert sql_wchar_settings['encoding'] == 'utf-16le', "SQL_WCHAR should maintain utf-16le" + assert sql_wmetadata_settings['encoding'] == 'utf-16be', "SQL_WMETADATA should maintain utf-16be" + +def test_setdecoding_override_previous(db_connection): + """Test setdecoding overrides previous settings for the same SQL type.""" + + # Set initial decoding + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'utf-8', "Initial encoding should be utf-8" + assert settings['ctype'] == mssql_python.SQL_CHAR, "Initial ctype should be SQL_CHAR" + + # Override with different settings + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='latin-1', ctype=mssql_python.SQL_WCHAR) + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'latin-1', "Encoding should be overridden to latin-1" + assert settings['ctype'] == mssql_python.SQL_WCHAR, "ctype should be overridden to SQL_WCHAR" + +def test_getdecoding_invalid_sqltype(db_connection): + """Test getdecoding with invalid sqltype raises ProgrammingError.""" + + with pytest.raises(ProgrammingError) as exc_info: + db_connection.getdecoding(999) + + assert "Invalid sqltype" in str(exc_info.value), "Should raise ProgrammingError for invalid sqltype" + assert "999" in str(exc_info.value), "Error message should include the invalid sqltype value" + +def test_getdecoding_closed_connection(conn_str): + """Test getdecoding on closed connection raises InterfaceError.""" + + temp_conn = connect(conn_str) + temp_conn.close() + + with pytest.raises(InterfaceError) as exc_info: + temp_conn.getdecoding(mssql_python.SQL_CHAR) + + assert "Connection is closed" in str(exc_info.value), "Should raise InterfaceError for closed connection" + +def test_getdecoding_returns_copy(db_connection): + """Test getdecoding returns a copy (not reference).""" + + # Set custom decoding + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + + # Get settings twice + settings1 = db_connection.getdecoding(mssql_python.SQL_CHAR) + settings2 = db_connection.getdecoding(mssql_python.SQL_CHAR) + + # Should be equal but not the same object + assert settings1 == settings2, "Settings should be equal" + assert settings1 is not settings2, "Settings should be different objects" + + # Modifying one shouldn't affect the other + settings1['encoding'] = 'modified' + assert settings2['encoding'] != 'modified', "Modification should not affect other copy" + +def test_setdecoding_getdecoding_consistency(db_connection): + """Test that setdecoding and getdecoding work consistently together.""" + + test_cases = [ + (mssql_python.SQL_CHAR, 'utf-8', mssql_python.SQL_CHAR), + (mssql_python.SQL_CHAR, 'utf-16le', mssql_python.SQL_WCHAR), + (mssql_python.SQL_WCHAR, 'latin-1', mssql_python.SQL_CHAR), + (mssql_python.SQL_WCHAR, 'utf-16be', mssql_python.SQL_WCHAR), + (mssql_python.SQL_WMETADATA, 'utf-16le', mssql_python.SQL_WCHAR), + ] + + for sqltype, encoding, expected_ctype in test_cases: + db_connection.setdecoding(sqltype, encoding=encoding) + settings = db_connection.getdecoding(sqltype) + assert settings['encoding'] == encoding.lower(), f"Encoding should be {encoding.lower()}" + assert settings['ctype'] == expected_ctype, f"ctype should be {expected_ctype}" + +def test_setdecoding_persistence_across_cursors(db_connection): + """Test that decoding settings persist across cursor operations.""" + + # Set custom decoding settings + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='latin-1', ctype=mssql_python.SQL_CHAR) + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16be', ctype=mssql_python.SQL_WCHAR) + + # Create cursors and verify settings persist + cursor1 = db_connection.cursor() + char_settings1 = db_connection.getdecoding(mssql_python.SQL_CHAR) + wchar_settings1 = db_connection.getdecoding(mssql_python.SQL_WCHAR) + + cursor2 = db_connection.cursor() + char_settings2 = db_connection.getdecoding(mssql_python.SQL_CHAR) + wchar_settings2 = db_connection.getdecoding(mssql_python.SQL_WCHAR) + + # Settings should persist across cursor creation + assert char_settings1 == char_settings2, "SQL_CHAR settings should persist across cursors" + assert wchar_settings1 == wchar_settings2, "SQL_WCHAR settings should persist across cursors" + + assert char_settings1['encoding'] == 'latin-1', "SQL_CHAR encoding should remain latin-1" + assert wchar_settings1['encoding'] == 'utf-16be', "SQL_WCHAR encoding should remain utf-16be" + + cursor1.close() + cursor2.close() + +def test_setdecoding_before_and_after_operations(db_connection): + """Test that setdecoding works both before and after database operations.""" + cursor = db_connection.cursor() + + try: + # Initial decoding setting + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + + # Perform database operation + cursor.execute("SELECT 'Initial test' as message") + result1 = cursor.fetchone() + assert result1[0] == 'Initial test', "Initial operation failed" + + # Change decoding after operation + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='latin-1') + settings = db_connection.getdecoding(mssql_python.SQL_CHAR) + assert settings['encoding'] == 'latin-1', "Failed to change decoding after operation" + + # Perform another operation with new decoding + cursor.execute("SELECT 'Changed decoding test' as message") + result2 = cursor.fetchone() + assert result2[0] == 'Changed decoding test', "Operation after decoding change failed" + + except Exception as e: + pytest.fail(f"Decoding change test failed: {e}") + finally: + cursor.close() + +def test_setdecoding_all_sql_types_independently(conn_str): + """Test setdecoding with all SQL types on a fresh connection.""" + + conn = connect(conn_str) + try: + # Test each SQL type with different configurations + test_configs = [ + (mssql_python.SQL_CHAR, 'ascii', mssql_python.SQL_CHAR), + (mssql_python.SQL_WCHAR, 'utf-16le', mssql_python.SQL_WCHAR), + (mssql_python.SQL_WMETADATA, 'utf-16be', mssql_python.SQL_WCHAR), + ] + + for sqltype, encoding, ctype in test_configs: + conn.setdecoding(sqltype, encoding=encoding, ctype=ctype) + settings = conn.getdecoding(sqltype) + assert settings['encoding'] == encoding, f"Failed to set encoding for sqltype {sqltype}" + assert settings['ctype'] == ctype, f"Failed to set ctype for sqltype {sqltype}" + + finally: + conn.close() + +def test_setdecoding_security_logging(db_connection): + """Test that setdecoding logs invalid attempts safely.""" + + # These should raise exceptions but not crash due to logging + test_cases = [ + (999, 'utf-8', None), # Invalid sqltype + (mssql_python.SQL_CHAR, 'invalid-encoding', None), # Invalid encoding + (mssql_python.SQL_CHAR, 'utf-8', 999), # Invalid ctype + ] + + for sqltype, encoding, ctype in test_cases: + with pytest.raises(ProgrammingError): + db_connection.setdecoding(sqltype, encoding=encoding, ctype=ctype) + +@pytest.mark.skip("Skipping Unicode data tests till we have support for Unicode") +def test_setdecoding_with_unicode_data(db_connection): + """Test setdecoding with actual Unicode data operations.""" + + # Test different decoding configurations with Unicode data + db_connection.setdecoding(mssql_python.SQL_CHAR, encoding='utf-8') + db_connection.setdecoding(mssql_python.SQL_WCHAR, encoding='utf-16le') + + cursor = db_connection.cursor() + + try: + # Create test table with both CHAR and NCHAR columns + cursor.execute(""" + CREATE TABLE #test_decoding_unicode ( + char_col VARCHAR(100), + nchar_col NVARCHAR(100) + ) + """) + + # Test various Unicode strings + test_strings = [ + "Hello, World!", + "Hello, 世界!", # Chinese + "Привет, мир!", # Russian + "مرحبا بالعالم", # Arabic + ] + + for test_string in test_strings: + # Insert data + cursor.execute( + "INSERT INTO #test_decoding_unicode (char_col, nchar_col) VALUES (?, ?)", + test_string, test_string + ) + + # Retrieve and verify + cursor.execute("SELECT char_col, nchar_col FROM #test_decoding_unicode WHERE char_col = ?", test_string) + result = cursor.fetchone() + + assert result is not None, f"Failed to retrieve Unicode string: {test_string}" + assert result[0] == test_string, f"CHAR column mismatch: expected {test_string}, got {result[0]}" + assert result[1] == test_string, f"NCHAR column mismatch: expected {test_string}, got {result[1]}" + + # Clear for next test + cursor.execute("DELETE FROM #test_decoding_unicode") + + except Exception as e: + pytest.fail(f"Unicode data test failed with custom decoding: {e}") + finally: + try: + cursor.execute("DROP TABLE #test_decoding_unicode") + except: + pass + cursor.close() + +# ==================== SET_ATTR TEST CASES ==================== + +def test_set_attr_constants_access(): + """Test that only relevant connection attribute constants are accessible. + + This test distinguishes between driver-independent (ODBC standard) and + driver-manager–dependent (may not be supported everywhere) constants. + Only ODBC-standard, cross-platform constants should be public API. + """ + # ODBC-standard, driver-independent constants (should be public) + odbc_attr_constants = [ + 'SQL_ATTR_ACCESS_MODE', 'SQL_ATTR_CONNECTION_TIMEOUT', + 'SQL_ATTR_CURRENT_CATALOG', 'SQL_ATTR_LOGIN_TIMEOUT', + 'SQL_ATTR_PACKET_SIZE', 'SQL_ATTR_TXN_ISOLATION', + ] + odbc_value_constants = [ + 'SQL_TXN_READ_UNCOMMITTED', 'SQL_TXN_READ_COMMITTED', + 'SQL_TXN_REPEATABLE_READ', 'SQL_TXN_SERIALIZABLE', + 'SQL_MODE_READ_WRITE', 'SQL_MODE_READ_ONLY', + ] + + # Driver-manager–dependent or rarely supported constants (should NOT be public API) + dm_attr_constants = [ + 'SQL_ATTR_QUIET_MODE', 'SQL_ATTR_TRACE', 'SQL_ATTR_TRACEFILE', + 'SQL_ATTR_TRANSLATE_LIB', 'SQL_ATTR_TRANSLATE_OPTION', + 'SQL_ATTR_CONNECTION_POOLING', 'SQL_ATTR_CP_MATCH', + 'SQL_ATTR_ASYNC_ENABLE', 'SQL_ATTR_CONNECTION_DEAD', + 'SQL_ATTR_SERVER_NAME', 'SQL_ATTR_RESET_CONNECTION', + 'SQL_ATTR_ODBC_CURSORS', 'SQL_CUR_USE_IF_NEEDED', 'SQL_CUR_USE_ODBC', + 'SQL_CUR_USE_DRIVER' + ] + dm_value_constants = [ + 'SQL_CD_TRUE', 'SQL_CD_FALSE', 'SQL_RESET_CONNECTION_YES' + ] + + # Check ODBC-standard constants are present and int + for const_name in odbc_attr_constants + odbc_value_constants: + assert hasattr(mssql_python, const_name), f"{const_name} should be available (ODBC standard)" + const_value = getattr(mssql_python, const_name) + assert isinstance(const_value, int), f"{const_name} should be an integer" + + # Check driver-manager–dependent constants are NOT present + for const_name in dm_attr_constants + dm_value_constants: + assert not hasattr(mssql_python, const_name), f"{const_name} should NOT be public API" + +def test_set_attr_basic_functionality(db_connection): + """Test basic set_attr functionality with ODBC-standard attributes.""" + try: + db_connection.set_attr(mssql_python.SQL_ATTR_CONNECTION_TIMEOUT, 30) + except Exception as e: + if "not supported" not in str(e).lower(): + pytest.fail(f"Unexpected error setting connection timeout: {e}") + +def test_set_attr_transaction_isolation(db_connection): + """Test setting transaction isolation level (ODBC-standard).""" + isolation_levels = [ + mssql_python.SQL_TXN_READ_UNCOMMITTED, + mssql_python.SQL_TXN_READ_COMMITTED, + mssql_python.SQL_TXN_REPEATABLE_READ, + mssql_python.SQL_TXN_SERIALIZABLE + ] + for level in isolation_levels: + try: + db_connection.set_attr(mssql_python.SQL_ATTR_TXN_ISOLATION, level) + break + except Exception as e: + error_str = str(e).lower() + if not any(phrase in error_str for phrase in ["not supported", "failed to set", "invalid", "error"]): + pytest.fail(f"Unexpected error setting isolation level {level}: {e}") + +def test_set_attr_invalid_attr_id_type(db_connection): + """Test set_attr with invalid attr_id type raises ProgrammingError.""" + from mssql_python.exceptions import ProgrammingError + invalid_attr_ids = ["string", 3.14, None, [], {}] + for invalid_attr_id in invalid_attr_ids: + with pytest.raises(ProgrammingError) as exc_info: + db_connection.set_attr(invalid_attr_id, 1) + + assert "Attribute must be an integer" in str(exc_info.value), \ + f"Should raise ProgrammingError for invalid attr_id type: {type(invalid_attr_id)}" + +def test_set_attr_invalid_value_type(db_connection): + """Test set_attr with invalid value type raises ProgrammingError.""" + from mssql_python.exceptions import ProgrammingError + + invalid_values = [3.14, None, [], {}] + + for invalid_value in invalid_values: + with pytest.raises(ProgrammingError) as exc_info: + db_connection.set_attr(mssql_python.SQL_ATTR_CONNECTION_TIMEOUT, invalid_value) + + assert "Unsupported attribute value type" in str(exc_info.value), \ + f"Should raise ProgrammingError for invalid value type: {type(invalid_value)}" + +def test_set_attr_value_out_of_range(db_connection): + """Test set_attr with value out of SQLULEN range raises ProgrammingError.""" + from mssql_python.exceptions import ProgrammingError + + out_of_range_values = [-1, -100] + + for invalid_value in out_of_range_values: + with pytest.raises(ProgrammingError) as exc_info: + db_connection.set_attr(mssql_python.SQL_ATTR_CONNECTION_TIMEOUT, invalid_value) + + assert "Integer value cannot be negative" in str(exc_info.value), \ + f"Should raise ProgrammingError for out of range value: {invalid_value}" + +def test_set_attr_closed_connection(conn_str): + """Test set_attr on closed connection raises InterfaceError.""" + from mssql_python.exceptions import InterfaceError + + + temp_conn = connect(conn_str) + temp_conn.close() + + with pytest.raises(InterfaceError) as exc_info: + temp_conn.set_attr(mssql_python.SQL_ATTR_CONNECTION_TIMEOUT, 30) + + assert "Connection is closed" in str(exc_info.value), \ + "Should raise InterfaceError for closed connection" + +def test_set_attr_invalid_attribute_id(db_connection): + """Test set_attr with invalid/unsupported attribute ID.""" + from mssql_python.exceptions import ProgrammingError, DatabaseError + + # Use a clearly invalid attribute ID + invalid_attr_id = 999999 + + try: + db_connection.set_attr(invalid_attr_id, 1) + # If no exception, some drivers might silently ignore invalid attributes + pytest.skip("Driver silently accepts invalid attribute IDs") + except (ProgrammingError, DatabaseError) as e: + # Expected behavior - driver should reject invalid attribute + assert "attribute" in str(e).lower() or "invalid" in str(e).lower() or "not supported" in str(e).lower() + except Exception as e: + pytest.fail(f"Unexpected exception type for invalid attribute: {type(e).__name__}: {e}") + +def test_set_attr_valid_range_values(db_connection): + """Test set_attr with valid range of values.""" + + + # Test boundary values for SQLUINTEGER + valid_values = [0, 1, 100, 1000, 65535, 4294967295] + + for value in valid_values: + try: + # Use connection timeout as it's commonly supported + db_connection.set_attr(mssql_python.SQL_ATTR_CONNECTION_TIMEOUT, value) + # If we get here, the value was accepted + except Exception as e: + # Some values might not be valid for specific attributes + if "invalid" not in str(e).lower() and "not supported" not in str(e).lower(): + pytest.fail(f"Unexpected error for valid value {value}: {e}") + +def test_set_attr_multiple_attributes(db_connection): + """Test setting multiple attributes in sequence.""" + + + # Test setting multiple safe attributes + attribute_value_pairs = [ + (mssql_python.SQL_ATTR_CONNECTION_TIMEOUT, 60), + (mssql_python.SQL_ATTR_LOGIN_TIMEOUT, 30), + (mssql_python.SQL_ATTR_PACKET_SIZE, 4096), + ] + + successful_sets = 0 + for attr_id, value in attribute_value_pairs: + try: + db_connection.set_attr(attr_id, value) + successful_sets += 1 + except Exception as e: + # Some attributes might not be supported by all drivers + # Accept "not supported", "failed to set", or other driver errors + error_str = str(e).lower() + if not any(phrase in error_str for phrase in ["not supported", "failed to set", "invalid", "error"]): + pytest.fail(f"Unexpected error setting attribute {attr_id} to {value}: {e}") + + # At least one attribute setting should succeed on most drivers + if successful_sets == 0: + pytest.skip("No connection attributes supported by this driver configuration") + +def test_set_attr_with_constants(db_connection): + """Test set_attr using exported module constants.""" + + + # Test using the exported constants + test_cases = [ + (mssql_python.SQL_ATTR_TXN_ISOLATION, mssql_python.SQL_TXN_READ_COMMITTED), + (mssql_python.SQL_ATTR_ACCESS_MODE, mssql_python.SQL_MODE_READ_WRITE), + ] + + for attr_id, value in test_cases: + try: + db_connection.set_attr(attr_id, value) + # Success - the constants worked correctly + except Exception as e: + # Some attributes/values might not be supported + # Accept "not supported", "failed to set", "invalid", or other driver errors + error_str = str(e).lower() + if not any(phrase in error_str for phrase in ["not supported", "failed to set", "invalid", "error"]): + pytest.fail(f"Unexpected error using constants {attr_id}, {value}: {e}") + +def test_set_attr_persistence_across_operations(db_connection): + """Test that set_attr changes persist across database operations.""" + + + cursor = db_connection.cursor() + try: + # Set an attribute before operations + db_connection.set_attr(mssql_python.SQL_ATTR_CONNECTION_TIMEOUT, 45) + + # Perform database operation + cursor.execute("SELECT 1 as test_value") + result = cursor.fetchone() + assert result[0] == 1, "Database operation should succeed" + + # Set attribute after operation + db_connection.set_attr(mssql_python.SQL_ATTR_CONNECTION_TIMEOUT, 60) + + # Another operation + cursor.execute("SELECT 2 as test_value") + result = cursor.fetchone() + assert result[0] == 2, "Database operation after set_attr should succeed" + + except Exception as e: + if "not supported" not in str(e).lower(): + pytest.fail(f"Error in set_attr persistence test: {e}") + finally: + cursor.close() + +def test_set_attr_security_logging(db_connection): + """Test that set_attr logs invalid attempts safely.""" + from mssql_python.exceptions import ProgrammingError + + # These should raise exceptions but not crash due to logging + test_cases = [ + ("invalid_attr", 1), # Invalid attr_id type + (123, "invalid_value"), # Invalid value type + (123, -1), # Out of range value + ] + + for attr_id, value in test_cases: + with pytest.raises(ProgrammingError): + db_connection.set_attr(attr_id, value) + +def test_set_attr_edge_cases(db_connection): + """Test set_attr with edge case values.""" + + + # Test with boundary values + edge_cases = [ + (mssql_python.SQL_ATTR_CONNECTION_TIMEOUT, 0), # Minimum value + (mssql_python.SQL_ATTR_CONNECTION_TIMEOUT, 4294967295), # Maximum SQLUINTEGER + ] + + for attr_id, value in edge_cases: + try: + db_connection.set_attr(attr_id, value) + # Success with edge case value + except Exception as e: + # Some edge values might not be valid for specific attributes + if "out of range" in str(e).lower(): + pytest.fail(f"Edge case value {value} should be in valid range") + elif "not supported" not in str(e).lower() and "invalid" not in str(e).lower(): + pytest.fail(f"Unexpected error for edge case {attr_id}, {value}: {e}") + +def test_set_attr_txn_isolation_effect(db_connection): + """Test that setting transaction isolation level actually affects transactions.""" + import os + conn_str = os.getenv('DB_CONNECTION_STRING') + + # Create a temporary table for the test + cursor = db_connection.cursor() + try: + drop_table_if_exists(cursor, "##test_isolation") + cursor.execute("CREATE TABLE ##test_isolation (id INT, value VARCHAR(50))") + cursor.execute("INSERT INTO ##test_isolation VALUES (1, 'original')") + db_connection.commit() + + # First set transaction isolation level to SERIALIZABLE (most strict) + try: + db_connection.set_attr(mssql_python.SQL_ATTR_TXN_ISOLATION, mssql_python.SQL_TXN_SERIALIZABLE) + + # Create two separate connections for the test + conn1 = connect(conn_str) + conn2 = connect(conn_str) + + # Start transaction in first connection + cursor1 = conn1.cursor() + cursor1.execute("BEGIN TRANSACTION") + cursor1.execute("UPDATE ##test_isolation SET value = 'updated' WHERE id = 1") + + # Try to read from second connection - should be blocked or timeout + cursor2 = conn2.cursor() + cursor2.execute("SET LOCK_TIMEOUT 5000") # 5 second timeout + + with pytest.raises((DatabaseError, Exception)) as exc_info: + cursor2.execute("SELECT * FROM ##test_isolation WHERE id = 1") + + # Clean up + cursor1.execute("ROLLBACK") + cursor1.close() + conn1.close() + cursor2.close() + conn2.close() + + # Now set READ UNCOMMITTED (least strict) + db_connection.set_attr(mssql_python.SQL_ATTR_TXN_ISOLATION, mssql_python.SQL_TXN_READ_UNCOMMITTED) + + # Create two new connections + conn1 = connect(conn_str) + conn2 = connect(conn_str) + conn2.set_attr(mssql_python.SQL_ATTR_TXN_ISOLATION, mssql_python.SQL_TXN_READ_UNCOMMITTED) + + # Start transaction in first connection + cursor1 = conn1.cursor() + cursor1.execute("BEGIN TRANSACTION") + cursor1.execute("UPDATE ##test_isolation SET value = 'dirty read' WHERE id = 1") + + # Try to read from second connection - should succeed with READ UNCOMMITTED + cursor2 = conn2.cursor() + cursor2.execute("SET LOCK_TIMEOUT 5000") + cursor2.execute("SELECT value FROM ##test_isolation WHERE id = 1") + result = cursor2.fetchone()[0] + + # Should see uncommitted "dirty read" value + assert result == 'dirty read', "READ UNCOMMITTED should allow dirty reads" + + # Clean up + cursor1.execute("ROLLBACK") + cursor1.close() + conn1.close() + cursor2.close() + conn2.close() + + except Exception as e: + if "not supported" not in str(e).lower(): + pytest.fail(f"Unexpected error in transaction isolation test: {e}") + else: + pytest.skip("Transaction isolation level changes not supported by driver") + + finally: + # Clean up + try: + cursor.execute("DROP TABLE ##test_isolation") + except: + pass + cursor.close() + +def test_set_attr_connection_timeout_effect(db_connection): + """Test that setting connection timeout actually affects query timeout.""" + + cursor = db_connection.cursor() + try: + # Set a short timeout (3 seconds) + try: + # Try to set the connection timeout + db_connection.set_attr(mssql_python.SQL_ATTR_CONNECTION_TIMEOUT, 3) + + # Check if the timeout setting worked by running an actual query + # WAITFOR DELAY is a reliable way to test timeout + start_time = time.time() + try: + cursor.execute("WAITFOR DELAY '00:00:05'") # 5-second delay + # If we get here, the timeout didn't work, but we won't fail the test + # since not all drivers support this feature + end_time = time.time() + elapsed = end_time - start_time + if elapsed >= 4.5: + pytest.skip("Connection timeout attribute not effective with this driver") + except Exception as exc: + # If we got an exception, check if it's a timeout-related exception + error_msg = str(exc).lower() + if "timeout" in error_msg or "timed out" in error_msg or "canceled" in error_msg: + # This is the expected behavior if timeout works + assert True + else: + # It's some other error, not a timeout + pytest.skip(f"Connection timeout test encountered non-timeout error: {exc}") + + except Exception as e: + if "not supported" not in str(e).lower(): + pytest.fail(f"Unexpected error in connection timeout test: {e}") + else: + pytest.skip("Connection timeout not supported by driver") + + finally: + cursor.close() + +def test_set_attr_login_timeout_effect(conn_str): + """Test that setting login timeout affects connection time to invalid server.""" + + # Testing with a non-existent server to trigger a timeout + conn_parts = conn_str.split(';') + new_parts = [] + for part in conn_parts: + if part.startswith("Server=") or part.startswith("server="): + # Use an invalid server address that will timeout + new_parts.append("Server=invalidserver.example.com") + else: + new_parts.append(part) + + # Add explicit login timeout directly in the connection string + new_parts.append("Connect Timeout=5") + + invalid_conn_str = ';'.join(new_parts) + + # Test with a short timeout + start_time = time.time() + try: + # Create a new connection with login timeout in the connection string + conn = connect(invalid_conn_str) # Don't use the login_timeout parameter + conn.close() + pytest.fail("Connection to invalid server should have failed") + except Exception as e: + end_time = time.time() + elapsed = end_time - start_time + + # Be more lenient with the timeout verification - up to 20 seconds + # Network conditions and driver behavior can vary + if elapsed > 30: + pytest.skip(f"Login timeout test took too long ({elapsed:.1f}s) but this may be environment-dependent") + + # We expected an exception, so this is successful + assert True + +def test_set_attr_packet_size_effect(conn_str): + """Test that setting packet size affects network packet size.""" + + # Some drivers don't support changing packet size after connection + # Try with explicit packet size in connection string for the first size + packet_size = 4096 + try: + # Add packet size to connection string + if ";" in conn_str: + modified_conn_str = conn_str + f";Packet Size={packet_size}" + else: + modified_conn_str = conn_str + f" Packet Size={packet_size}" + + conn = connect(modified_conn_str) + + # Execute a query that returns a large result set to test packet size + cursor = conn.cursor() + + # Create a temp table with a large string column + drop_table_if_exists(cursor, "##test_packet_size") + cursor.execute("CREATE TABLE ##test_packet_size (id INT, large_data NVARCHAR(MAX))") + + # Insert a very large string + large_string = "X" * (packet_size // 2) # Unicode chars take 2 bytes each + cursor.execute("INSERT INTO ##test_packet_size VALUES (?, ?)", (1, large_string)) + conn.commit() + + # Fetch the large string + cursor.execute("SELECT large_data FROM ##test_packet_size WHERE id = 1") + result = cursor.fetchone()[0] + + assert result == large_string, "Data should be retrieved correctly" + + # Clean up + cursor.execute("DROP TABLE ##test_packet_size") + conn.commit() + cursor.close() + conn.close() + + except Exception as e: + if ("not supported" not in str(e).lower() and + "attribute" not in str(e).lower()): + pytest.fail(f"Unexpected error in packet size test: {e}") + else: + pytest.skip(f"Packet size setting not supported: {e}") + +def test_set_attr_current_catalog_effect(db_connection, conn_str): + """Test that setting the current catalog/database actually changes the context.""" + # This only works if we have multiple databases available + cursor = db_connection.cursor() + try: + # Get current database name + cursor.execute("SELECT DB_NAME()") + original_db = cursor.fetchone()[0] + + # Get list of other databases + cursor.execute("SELECT name FROM sys.databases WHERE database_id > 4 AND name != DB_NAME()") + rows = cursor.fetchall() + if not rows: + pytest.skip("No other user databases available for testing") + + other_db = rows[0][0] + + # Try to switch database using set_attr + try: + db_connection.set_attr(mssql_python.SQL_ATTR_CURRENT_CATALOG, other_db) + + # Verify we're now in the other database + cursor.execute("SELECT DB_NAME()") + new_db = cursor.fetchone()[0] + + assert new_db == other_db, f"Database should have changed to {other_db} but is {new_db}" + + # Switch back + db_connection.set_attr(mssql_python.SQL_ATTR_CURRENT_CATALOG, original_db) + + # Verify we're back in the original database + cursor.execute("SELECT DB_NAME()") + current_db = cursor.fetchone()[0] + + assert current_db == original_db, f"Database should have changed back to {original_db} but is {current_db}" + + except Exception as e: + if "not supported" not in str(e).lower(): + pytest.fail(f"Unexpected error in current catalog test: {e}") + else: + pytest.skip("Current catalog changes not supported by driver") + + finally: + cursor.close() + +# ==================== TEST ATTRS_BEFORE AND SET_ATTR TIMING ==================== + +def test_attrs_before_login_timeout(conn_str): + """Test setting login timeout before connection via attrs_before.""" + # Test with a reasonable timeout value + timeout_value = 30 + conn = connect(conn_str, attrs_before={ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value: timeout_value}) + + # Verify connection was successful + cursor = conn.cursor() + cursor.execute("SELECT 1") + result = cursor.fetchall() + assert result[0][0] == 1 + conn.close() + + +def test_attrs_before_packet_size(conn_str): + """Test setting packet size before connection via attrs_before.""" + # Use a valid packet size value + packet_size = 8192 # 8KB packet size + conn = connect(conn_str, attrs_before={ConstantsDDBC.SQL_ATTR_PACKET_SIZE.value: packet_size}) + + # Verify connection was successful + cursor = conn.cursor() + cursor.execute("SELECT 1") + result = cursor.fetchall() + assert result[0][0] == 1 + conn.close() + + +def test_attrs_before_multiple_attributes(conn_str): + """Test setting multiple attributes before connection via attrs_before.""" + attrs = { + ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value: 30, + ConstantsDDBC.SQL_ATTR_PACKET_SIZE.value: 8192, + ConstantsDDBC.SQL_ATTR_ACCESS_MODE.value: ConstantsDDBC.SQL_MODE_READ_WRITE.value, + ConstantsDDBC.SQL_ATTR_TXN_ISOLATION.value: ConstantsDDBC.SQL_TXN_READ_COMMITTED.value + } + + conn = connect(conn_str, attrs_before=attrs) + + # Verify connection was successful + cursor = conn.cursor() + cursor.execute("SELECT 1") + result = cursor.fetchall() + assert result[0][0] == 1 + conn.close() + + +def test_set_attr_access_mode_after_connect(db_connection): + """Test setting access mode after connection via set_attr.""" + # Set access mode to read-write (default, but explicitly set it) + db_connection.set_attr(ConstantsDDBC.SQL_ATTR_ACCESS_MODE.value, ConstantsDDBC.SQL_MODE_READ_WRITE.value) + + # Verify we can still execute writes + cursor = db_connection.cursor() + drop_table_if_exists(cursor, "#test_access_mode") + cursor.execute("CREATE TABLE #test_access_mode (id INT)") + cursor.execute("INSERT INTO #test_access_mode VALUES (1)") + cursor.execute("SELECT * FROM #test_access_mode") + result = cursor.fetchall() + assert result[0][0] == 1 + +def test_set_attr_current_catalog_after_connect(db_connection): + """Test setting current catalog after connection via set_attr.""" + # Get current database name + cursor = db_connection.cursor() + cursor.execute("SELECT DB_NAME()") + original_db = cursor.fetchone()[0] + + # Try to set current catalog to master + db_connection.set_attr(ConstantsDDBC.SQL_ATTR_CURRENT_CATALOG.value, "master") + + # Verify the change + cursor.execute("SELECT DB_NAME()") + new_db = cursor.fetchone()[0] + assert new_db.lower() == "master" + + # Set it back to the original + db_connection.set_attr(ConstantsDDBC.SQL_ATTR_CURRENT_CATALOG.value, original_db) + + +def test_set_attr_connection_timeout_after_connect(db_connection): + """Test setting connection timeout after connection via set_attr.""" + # Set connection timeout to a reasonable value + db_connection.set_attr(ConstantsDDBC.SQL_ATTR_CONNECTION_TIMEOUT.value, 60) + + # Verify we can still execute queries + cursor = db_connection.cursor() + cursor.execute("SELECT 1") + result = cursor.fetchall() + assert result[0][0] == 1 + + +def test_set_attr_before_only_attributes_error(db_connection): + """Test that setting before-only attributes after connection raises error.""" + # Try to set login timeout after connection + with pytest.raises(ProgrammingError) as excinfo: + db_connection.set_attr(ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value, 30) + + assert "must be set before connection establishment" in str(excinfo.value) + + # Try to set packet size after connection + with pytest.raises(ProgrammingError) as excinfo: + db_connection.set_attr(ConstantsDDBC.SQL_ATTR_PACKET_SIZE.value, 8192) + + assert "must be set before connection establishment" in str(excinfo.value) + + +def test_attrs_before_after_only_attributes(conn_str): + """Test that setting after-only attributes before connection is ignored.""" + # Try to set connection dead before connection (should be ignored) + conn = connect(conn_str, attrs_before={ConstantsDDBC.SQL_ATTR_CONNECTION_DEAD.value: 0}) + + # Verify connection was successful + cursor = conn.cursor() + cursor.execute("SELECT 1") + result = cursor.fetchall() + assert result[0][0] == 1 + conn.close() + + +def test_attrs_before_connection_types(conn_str): + """Test attrs_before with different data types for attribute values.""" + attrs = { + # Integer attribute + ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value: 30, + # String attribute (catalog name) + ConstantsDDBC.SQL_ATTR_CURRENT_CATALOG.value: "testdb" + } + + conn = connect(conn_str, attrs_before=attrs) + + # Verify connection was successful and current catalog was set + cursor = conn.cursor() + cursor.execute("SELECT DB_NAME()") + result = cursor.fetchone()[0] + assert result.lower() == "testdb" + conn.close() + +def test_set_attr_unsupported_attribute(db_connection): + """Test that setting an unsupported attribute raises an error.""" + # Choose an attribute not in the supported list + unsupported_attr = 999999 # A made-up attribute ID + + with pytest.raises(ProgrammingError) as excinfo: + db_connection.set_attr(unsupported_attr, 1) + + assert "Unsupported attribute" in str(excinfo.value) \ No newline at end of file