-
Notifications
You must be signed in to change notification settings - Fork 3.9k
GH-46575: [C++][FlightRPC] ODBC Diagnostics Report #47763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
alinaliBQ
wants to merge
1
commit into
apache:main
Choose a base branch
from
Bit-Quill:gh-46575-odbc-diagnostics
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+345
−4
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,6 +65,20 @@ SQLRETURN SQLFreeStmt(SQLHSTMT handle, SQLUSMALLINT option) { | |
return SQL_INVALID_HANDLE; | ||
} | ||
|
||
inline bool IsValidStringFieldArgs(SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length, | ||
SQLSMALLINT* string_length_ptr, bool is_unicode) { | ||
const SQLSMALLINT char_size = is_unicode ? GetSqlWCharSize() : sizeof(char); | ||
const bool has_valid_buffer = | ||
diag_info_ptr && buffer_length >= 0 && buffer_length % char_size == 0; | ||
|
||
// regardless of capacity return false if invalid | ||
if (diag_info_ptr && !has_valid_buffer) { | ||
return false; | ||
} | ||
|
||
return has_valid_buffer || string_length_ptr; | ||
} | ||
|
||
SQLRETURN SQLGetDiagField(SQLSMALLINT handle_type, SQLHANDLE handle, | ||
SQLSMALLINT rec_number, SQLSMALLINT diag_identifier, | ||
SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length, | ||
|
@@ -76,8 +90,259 @@ SQLRETURN SQLGetDiagField(SQLSMALLINT handle_type, SQLHANDLE handle, | |
<< ", diag_info_ptr: " << diag_info_ptr | ||
<< ", buffer_length: " << buffer_length << ", string_length_ptr: " | ||
<< static_cast<const void*>(string_length_ptr); | ||
// GH-46575 TODO: Implement SQLGetDiagField | ||
return SQL_INVALID_HANDLE; | ||
// GH-46575 TODO: Add tests for SQLGetDiagField | ||
using arrow::flight::sql::odbc::Diagnostics; | ||
using ODBC::GetStringAttribute; | ||
using ODBC::ODBCConnection; | ||
using ODBC::ODBCDescriptor; | ||
using ODBC::ODBCEnvironment; | ||
using ODBC::ODBCStatement; | ||
|
||
if (!handle) { | ||
return SQL_INVALID_HANDLE; | ||
} | ||
|
||
if (!diag_info_ptr && !string_length_ptr) { | ||
return SQL_ERROR; | ||
} | ||
|
||
// If buffer length derived from null terminated string | ||
if (diag_info_ptr && buffer_length == SQL_NTS) { | ||
const wchar_t* str = reinterpret_cast<wchar_t*>(diag_info_ptr); | ||
buffer_length = wcslen(str) * arrow::flight::sql::odbc::GetSqlWCharSize(); | ||
} | ||
|
||
// Set character type to be Unicode by default | ||
const bool is_unicode = true; | ||
Diagnostics* diagnostics = nullptr; | ||
|
||
switch (handle_type) { | ||
case SQL_HANDLE_ENV: { | ||
ODBCEnvironment* environment = reinterpret_cast<ODBCEnvironment*>(handle); | ||
diagnostics = &environment->GetDiagnostics(); | ||
break; | ||
} | ||
|
||
case SQL_HANDLE_DBC: { | ||
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(handle); | ||
diagnostics = &connection->GetDiagnostics(); | ||
break; | ||
} | ||
|
||
case SQL_HANDLE_DESC: { | ||
ODBCDescriptor* descriptor = reinterpret_cast<ODBCDescriptor*>(handle); | ||
diagnostics = &descriptor->GetDiagnostics(); | ||
break; | ||
} | ||
|
||
case SQL_HANDLE_STMT: { | ||
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(handle); | ||
diagnostics = &statement->GetDiagnostics(); | ||
break; | ||
} | ||
|
||
default: | ||
return SQL_ERROR; | ||
} | ||
|
||
if (!diagnostics) { | ||
return SQL_ERROR; | ||
} | ||
|
||
// Retrieve and return if header level diagnostics | ||
switch (diag_identifier) { | ||
case SQL_DIAG_NUMBER: { | ||
if (diag_info_ptr) { | ||
*static_cast<SQLINTEGER*>(diag_info_ptr) = | ||
static_cast<SQLINTEGER>(diagnostics->GetRecordCount()); | ||
} | ||
|
||
if (string_length_ptr) { | ||
*string_length_ptr = sizeof(SQLINTEGER); | ||
} | ||
|
||
return SQL_SUCCESS; | ||
} | ||
|
||
// TODO implement return code function | ||
case SQL_DIAG_RETURNCODE: { | ||
return SQL_SUCCESS; | ||
} | ||
|
||
case SQL_DIAG_CURSOR_ROW_COUNT: { | ||
if (handle_type == SQL_HANDLE_STMT) { | ||
if (diag_info_ptr) { | ||
// Will always be 0 if only SELECT supported | ||
*static_cast<SQLLEN*>(diag_info_ptr) = 0; | ||
} | ||
|
||
if (string_length_ptr) { | ||
*string_length_ptr = sizeof(SQLLEN); | ||
} | ||
|
||
return SQL_SUCCESS; | ||
} | ||
|
||
return SQL_ERROR; | ||
} | ||
|
||
// Not supported | ||
case SQL_DIAG_DYNAMIC_FUNCTION: | ||
case SQL_DIAG_DYNAMIC_FUNCTION_CODE: { | ||
if (handle_type == SQL_HANDLE_STMT) { | ||
return SQL_SUCCESS; | ||
} | ||
|
||
return SQL_ERROR; | ||
} | ||
|
||
case SQL_DIAG_ROW_COUNT: { | ||
if (handle_type == SQL_HANDLE_STMT) { | ||
if (diag_info_ptr) { | ||
// Will always be 0 if only SELECT is supported | ||
*static_cast<SQLLEN*>(diag_info_ptr) = 0; | ||
} | ||
|
||
if (string_length_ptr) { | ||
*string_length_ptr = sizeof(SQLLEN); | ||
} | ||
|
||
return SQL_SUCCESS; | ||
} | ||
|
||
return SQL_ERROR; | ||
} | ||
} | ||
|
||
// If not a diagnostic header field then the record number must be 1 or greater | ||
if (rec_number < 1) { | ||
return SQL_ERROR; | ||
} | ||
|
||
// Retrieve record level diagnostics from specified 1 based record | ||
const uint32_t record_index = static_cast<uint32_t>(rec_number - 1); | ||
if (!diagnostics->HasRecord(record_index)) { | ||
return SQL_NO_DATA; | ||
} | ||
|
||
// Retrieve record field data | ||
switch (diag_identifier) { | ||
case SQL_DIAG_MESSAGE_TEXT: { | ||
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr, | ||
is_unicode)) { | ||
const std::string& message = diagnostics->GetMessageText(record_index); | ||
return GetStringAttribute(is_unicode, message, true, diag_info_ptr, buffer_length, | ||
string_length_ptr, *diagnostics); | ||
} | ||
|
||
return SQL_ERROR; | ||
} | ||
|
||
case SQL_DIAG_NATIVE: { | ||
if (diag_info_ptr) { | ||
*static_cast<SQLINTEGER*>(diag_info_ptr) = | ||
diagnostics->GetNativeError(record_index); | ||
} | ||
|
||
if (string_length_ptr) { | ||
*string_length_ptr = sizeof(SQLINTEGER); | ||
} | ||
|
||
return SQL_SUCCESS; | ||
} | ||
|
||
case SQL_DIAG_SERVER_NAME: { | ||
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr, | ||
is_unicode)) { | ||
switch (handle_type) { | ||
case SQL_HANDLE_DBC: { | ||
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(handle); | ||
std::string dsn = connection->GetDSN(); | ||
return GetStringAttribute(is_unicode, dsn, true, diag_info_ptr, buffer_length, | ||
string_length_ptr, *diagnostics); | ||
} | ||
|
||
case SQL_HANDLE_DESC: { | ||
ODBCDescriptor* descriptor = reinterpret_cast<ODBCDescriptor*>(handle); | ||
ODBCConnection* connection = &descriptor->GetConnection(); | ||
std::string dsn = connection->GetDSN(); | ||
return GetStringAttribute(is_unicode, dsn, true, diag_info_ptr, buffer_length, | ||
string_length_ptr, *diagnostics); | ||
break; | ||
} | ||
|
||
case SQL_HANDLE_STMT: { | ||
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(handle); | ||
ODBCConnection* connection = &statement->GetConnection(); | ||
std::string dsn = connection->GetDSN(); | ||
return GetStringAttribute(is_unicode, dsn, true, diag_info_ptr, buffer_length, | ||
string_length_ptr, *diagnostics); | ||
} | ||
|
||
default: | ||
return SQL_ERROR; | ||
} | ||
} | ||
|
||
return SQL_ERROR; | ||
} | ||
|
||
case SQL_DIAG_SQLSTATE: { | ||
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr, | ||
is_unicode)) { | ||
const std::string& state = diagnostics->GetSQLState(record_index); | ||
return GetStringAttribute(is_unicode, state, true, diag_info_ptr, buffer_length, | ||
string_length_ptr, *diagnostics); | ||
} | ||
|
||
return SQL_ERROR; | ||
} | ||
|
||
// Return valid dummy variable for unimplemented field | ||
case SQL_DIAG_COLUMN_NUMBER: { | ||
if (diag_info_ptr) { | ||
*static_cast<SQLINTEGER*>(diag_info_ptr) = SQL_NO_COLUMN_NUMBER; | ||
} | ||
|
||
if (string_length_ptr) { | ||
*string_length_ptr = sizeof(SQLINTEGER); | ||
} | ||
|
||
return SQL_SUCCESS; | ||
} | ||
|
||
// Return empty string dummy variable for unimplemented fields | ||
case SQL_DIAG_CLASS_ORIGIN: | ||
case SQL_DIAG_CONNECTION_NAME: | ||
case SQL_DIAG_SUBCLASS_ORIGIN: { | ||
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr, | ||
is_unicode)) { | ||
return GetStringAttribute(is_unicode, "", true, diag_info_ptr, buffer_length, | ||
string_length_ptr, *diagnostics); | ||
} | ||
|
||
return SQL_ERROR; | ||
} | ||
|
||
// Return valid dummy variable for unimplemented field | ||
case SQL_DIAG_ROW_NUMBER: { | ||
if (diag_info_ptr) { | ||
*static_cast<SQLLEN*>(diag_info_ptr) = SQL_NO_ROW_NUMBER; | ||
} | ||
|
||
if (string_length_ptr) { | ||
*string_length_ptr = sizeof(SQLLEN); | ||
} | ||
|
||
return SQL_SUCCESS; | ||
} | ||
|
||
default: { | ||
return SQL_ERROR; | ||
} | ||
} | ||
|
||
return SQL_ERROR; | ||
} | ||
|
||
SQLRETURN SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT rec_number, | ||
|
@@ -91,8 +356,84 @@ SQLRETURN SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT r | |
<< ", message_text: " << static_cast<const void*>(message_text) | ||
<< ", buffer_length: " << buffer_length | ||
<< ", text_length_ptr: " << static_cast<const void*>(text_length_ptr); | ||
// GH-46575 TODO: Implement SQLGetDiagRec | ||
return SQL_INVALID_HANDLE; | ||
// GH-46575 TODO: Add tests for SQLGetDiagRec | ||
using arrow::flight::sql::odbc::Diagnostics; | ||
using ODBC::GetStringAttribute; | ||
using ODBC::ODBCConnection; | ||
using ODBC::ODBCDescriptor; | ||
using ODBC::ODBCEnvironment; | ||
using ODBC::ODBCStatement; | ||
|
||
if (!handle) { | ||
return SQL_INVALID_HANDLE; | ||
} | ||
|
||
// Record number must be greater or equal to 1 | ||
if (rec_number < 1 || buffer_length < 0) { | ||
return SQL_ERROR; | ||
} | ||
|
||
// Set character type to be Unicode by default | ||
const bool is_unicode = true; | ||
Diagnostics* diagnostics = nullptr; | ||
|
||
switch (handle_type) { | ||
case SQL_HANDLE_ENV: { | ||
auto* environment = ODBCEnvironment::Of(handle); | ||
diagnostics = &environment->GetDiagnostics(); | ||
break; | ||
} | ||
|
||
case SQL_HANDLE_DBC: { | ||
auto* connection = ODBCConnection::Of(handle); | ||
diagnostics = &connection->GetDiagnostics(); | ||
break; | ||
} | ||
|
||
case SQL_HANDLE_DESC: { | ||
auto* descriptor = ODBCDescriptor::Of(handle); | ||
diagnostics = &descriptor->GetDiagnostics(); | ||
break; | ||
} | ||
|
||
case SQL_HANDLE_STMT: { | ||
auto* statement = ODBCStatement::Of(handle); | ||
diagnostics = &statement->GetDiagnostics(); | ||
break; | ||
} | ||
|
||
default: | ||
return SQL_INVALID_HANDLE; | ||
} | ||
|
||
if (!diagnostics) { | ||
return SQL_ERROR; | ||
} | ||
|
||
// Convert from ODBC 1 based record number to internal diagnostics 0 indexed storage | ||
const size_t record_index = static_cast<size_t>(rec_number - 1); | ||
if (!diagnostics->HasRecord(record_index)) { | ||
return SQL_NO_DATA; | ||
} | ||
|
||
if (sql_state) { | ||
// The length of the sql state is always 5 characters plus null | ||
SQLSMALLINT size = 6; | ||
const std::string& state = diagnostics->GetSQLState(record_index); | ||
GetStringAttribute(is_unicode, state, false, sql_state, size, &size, *diagnostics); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we not need to check the return value? |
||
} | ||
|
||
if (native_error_ptr) { | ||
*native_error_ptr = diagnostics->GetNativeError(record_index); | ||
} | ||
|
||
if (message_text || text_length_ptr) { | ||
const std::string& message = diagnostics->GetMessageText(record_index); | ||
return GetStringAttribute(is_unicode, message, false, message_text, buffer_length, | ||
text_length_ptr, *diagnostics); | ||
} | ||
|
||
return SQL_SUCCESS; | ||
} | ||
|
||
SQLRETURN SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER value_ptr, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need
diag_info_ptr &&
here?