diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 879e76cf..2366a75a 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -21,6 +21,7 @@ #define SQL_SS_TIMESTAMPOFFSET (-155) #define SQL_C_SS_TIMESTAMPOFFSET (0x4001) #define MAX_DIGITS_IN_NUMERIC 64 +#define SQL_SS_XML (-152) #define STRINGIFY_FOR_CASE(x) \ case x: \ @@ -2525,6 +2526,12 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p } break; } + case SQL_SS_XML: + { + LOG("Streaming XML for column {}", i); + row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false)); + break; + } case SQL_WCHAR: case SQL_WVARCHAR: case SQL_WLONGVARCHAR: { @@ -3395,6 +3402,7 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) { case SQL_LONGVARCHAR: rowSize += columnSize; break; + case SQL_SS_XML: case SQL_WCHAR: case SQL_WVARCHAR: case SQL_WLONGVARCHAR: @@ -3499,7 +3507,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || - dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) && + dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { lobColumns.push_back(i + 1); // 1-based } @@ -3621,7 +3629,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) { if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || - dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) && + dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { lobColumns.push_back(i + 1); // 1-based } diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 8faeea5a..f8d37112 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -11409,6 +11409,90 @@ def test_datetime_string_parameter_binding(cursor, db_connection): drop_table_if_exists(cursor, table_name) db_connection.commit() +SMALL_XML = "1" +LARGE_XML = "" + "".join(f"{i}" for i in range(10000)) + "" +EMPTY_XML = "" +INVALID_XML = "" # malformed + +def test_xml_basic_insert_fetch(cursor, db_connection): + """Test insert and fetch of a small XML value.""" + try: + cursor.execute("CREATE TABLE #pytest_xml_basic (id INT PRIMARY KEY IDENTITY(1,1), xml_col XML NULL);") + db_connection.commit() + + cursor.execute("INSERT INTO #pytest_xml_basic (xml_col) VALUES (?);", SMALL_XML) + db_connection.commit() + + row = cursor.execute("SELECT xml_col FROM #pytest_xml_basic;").fetchone() + assert row[0] == SMALL_XML + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_xml_basic;") + db_connection.commit() + + +def test_xml_empty_and_null(cursor, db_connection): + """Test insert and fetch of empty XML and NULL values.""" + try: + cursor.execute("CREATE TABLE #pytest_xml_empty_null (id INT PRIMARY KEY IDENTITY(1,1), xml_col XML NULL);") + db_connection.commit() + + cursor.execute("INSERT INTO #pytest_xml_empty_null (xml_col) VALUES (?);", EMPTY_XML) + cursor.execute("INSERT INTO #pytest_xml_empty_null (xml_col) VALUES (?);", None) + db_connection.commit() + + rows = [r[0] for r in cursor.execute("SELECT xml_col FROM #pytest_xml_empty_null ORDER BY id;").fetchall()] + assert rows[0] == EMPTY_XML + assert rows[1] is None + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_xml_empty_null;") + db_connection.commit() + + +def test_xml_large_insert(cursor, db_connection): + """Test insert and fetch of a large XML value to verify streaming/DAE.""" + try: + cursor.execute("CREATE TABLE #pytest_xml_large (id INT PRIMARY KEY IDENTITY(1,1), xml_col XML NULL);") + db_connection.commit() + + cursor.execute("INSERT INTO #pytest_xml_large (xml_col) VALUES (?);", LARGE_XML) + db_connection.commit() + + row = cursor.execute("SELECT xml_col FROM #pytest_xml_large;").fetchone() + assert row[0] == LARGE_XML + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_xml_large;") + db_connection.commit() + + +def test_xml_batch_insert(cursor, db_connection): + """Test batch insert (executemany) of multiple XML values.""" + try: + cursor.execute("CREATE TABLE #pytest_xml_batch (id INT PRIMARY KEY IDENTITY(1,1), xml_col XML NULL);") + db_connection.commit() + + xmls = [f"{i}" for i in range(5)] + cursor.executemany("INSERT INTO #pytest_xml_batch (xml_col) VALUES (?);", [(x,) for x in xmls]) + db_connection.commit() + + rows = [r[0] for r in cursor.execute("SELECT xml_col FROM #pytest_xml_batch ORDER BY id;").fetchall()] + assert rows == xmls + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_xml_batch;") + db_connection.commit() + + +def test_xml_malformed_input(cursor, db_connection): + """Verify driver raises error for invalid XML input.""" + try: + cursor.execute("CREATE TABLE #pytest_xml_invalid (id INT PRIMARY KEY IDENTITY(1,1), xml_col XML NULL);") + db_connection.commit() + + with pytest.raises(Exception): + cursor.execute("INSERT INTO #pytest_xml_invalid (xml_col) VALUES (?);", INVALID_XML) + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_xml_invalid;") + db_connection.commit() + def test_close(db_connection): """Test closing the cursor""" try: