Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
349 changes: 345 additions & 4 deletions cpp/src/arrow/flight/sql/odbc/odbc_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ SQLRETURN SQLFreeStmt(SQLHSTMT handle, SQLUSMALLINT option) {
return SQL_INVALID_HANDLE;
}

inline bool IsValidStringFieldArgs(SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length,
SQLSMALLINT* string_length_ptr, bool is_unicode) {
const SQLSMALLINT char_size = is_unicode ? GetSqlWCharSize() : sizeof(char);
const bool has_valid_buffer =
diag_info_ptr && buffer_length >= 0 && buffer_length % char_size == 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need diag_info_ptr && here?

Suggested change
diag_info_ptr && buffer_length >= 0 && buffer_length % char_size == 0;
buffer_length >= 0 && buffer_length % char_size == 0;


// regardless of capacity return false if invalid
if (diag_info_ptr && !has_valid_buffer) {
return false;
}

return has_valid_buffer || string_length_ptr;
}

SQLRETURN SQLGetDiagField(SQLSMALLINT handle_type, SQLHANDLE handle,
SQLSMALLINT rec_number, SQLSMALLINT diag_identifier,
SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length,
Expand All @@ -76,8 +90,259 @@ SQLRETURN SQLGetDiagField(SQLSMALLINT handle_type, SQLHANDLE handle,
<< ", diag_info_ptr: " << diag_info_ptr
<< ", buffer_length: " << buffer_length << ", string_length_ptr: "
<< static_cast<const void*>(string_length_ptr);
// GH-46575 TODO: Implement SQLGetDiagField
return SQL_INVALID_HANDLE;
// GH-46575 TODO: Add tests for SQLGetDiagField
using arrow::flight::sql::odbc::Diagnostics;
using ODBC::GetStringAttribute;
using ODBC::ODBCConnection;
using ODBC::ODBCDescriptor;
using ODBC::ODBCEnvironment;
using ODBC::ODBCStatement;

if (!handle) {
return SQL_INVALID_HANDLE;
}

if (!diag_info_ptr && !string_length_ptr) {
return SQL_ERROR;
}

// If buffer length derived from null terminated string
if (diag_info_ptr && buffer_length == SQL_NTS) {
const wchar_t* str = reinterpret_cast<wchar_t*>(diag_info_ptr);
buffer_length = wcslen(str) * arrow::flight::sql::odbc::GetSqlWCharSize();
}

// Set character type to be Unicode by default
const bool is_unicode = true;
Diagnostics* diagnostics = nullptr;

switch (handle_type) {
case SQL_HANDLE_ENV: {
ODBCEnvironment* environment = reinterpret_cast<ODBCEnvironment*>(handle);
diagnostics = &environment->GetDiagnostics();
break;
}

case SQL_HANDLE_DBC: {
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(handle);
diagnostics = &connection->GetDiagnostics();
break;
}

case SQL_HANDLE_DESC: {
ODBCDescriptor* descriptor = reinterpret_cast<ODBCDescriptor*>(handle);
diagnostics = &descriptor->GetDiagnostics();
break;
}

case SQL_HANDLE_STMT: {
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(handle);
diagnostics = &statement->GetDiagnostics();
break;
}

default:
return SQL_ERROR;
}

if (!diagnostics) {
return SQL_ERROR;
}

// Retrieve and return if header level diagnostics
switch (diag_identifier) {
case SQL_DIAG_NUMBER: {
if (diag_info_ptr) {
*static_cast<SQLINTEGER*>(diag_info_ptr) =
static_cast<SQLINTEGER>(diagnostics->GetRecordCount());
}

if (string_length_ptr) {
*string_length_ptr = sizeof(SQLINTEGER);
}

return SQL_SUCCESS;
}

// TODO implement return code function
case SQL_DIAG_RETURNCODE: {
return SQL_SUCCESS;
}

case SQL_DIAG_CURSOR_ROW_COUNT: {
if (handle_type == SQL_HANDLE_STMT) {
if (diag_info_ptr) {
// Will always be 0 if only SELECT supported
*static_cast<SQLLEN*>(diag_info_ptr) = 0;
}

if (string_length_ptr) {
*string_length_ptr = sizeof(SQLLEN);
}

return SQL_SUCCESS;
}

return SQL_ERROR;
}

// Not supported
case SQL_DIAG_DYNAMIC_FUNCTION:
case SQL_DIAG_DYNAMIC_FUNCTION_CODE: {
if (handle_type == SQL_HANDLE_STMT) {
return SQL_SUCCESS;
}

return SQL_ERROR;
}

case SQL_DIAG_ROW_COUNT: {
if (handle_type == SQL_HANDLE_STMT) {
if (diag_info_ptr) {
// Will always be 0 if only SELECT is supported
*static_cast<SQLLEN*>(diag_info_ptr) = 0;
}

if (string_length_ptr) {
*string_length_ptr = sizeof(SQLLEN);
}

return SQL_SUCCESS;
}

return SQL_ERROR;
}
}

// If not a diagnostic header field then the record number must be 1 or greater
if (rec_number < 1) {
return SQL_ERROR;
}

// Retrieve record level diagnostics from specified 1 based record
const uint32_t record_index = static_cast<uint32_t>(rec_number - 1);
if (!diagnostics->HasRecord(record_index)) {
return SQL_NO_DATA;
}

// Retrieve record field data
switch (diag_identifier) {
case SQL_DIAG_MESSAGE_TEXT: {
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr,
is_unicode)) {
const std::string& message = diagnostics->GetMessageText(record_index);
return GetStringAttribute(is_unicode, message, true, diag_info_ptr, buffer_length,
string_length_ptr, *diagnostics);
}

return SQL_ERROR;
}

case SQL_DIAG_NATIVE: {
if (diag_info_ptr) {
*static_cast<SQLINTEGER*>(diag_info_ptr) =
diagnostics->GetNativeError(record_index);
}

if (string_length_ptr) {
*string_length_ptr = sizeof(SQLINTEGER);
}

return SQL_SUCCESS;
}

case SQL_DIAG_SERVER_NAME: {
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr,
is_unicode)) {
switch (handle_type) {
case SQL_HANDLE_DBC: {
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(handle);
std::string dsn = connection->GetDSN();
return GetStringAttribute(is_unicode, dsn, true, diag_info_ptr, buffer_length,
string_length_ptr, *diagnostics);
}

case SQL_HANDLE_DESC: {
ODBCDescriptor* descriptor = reinterpret_cast<ODBCDescriptor*>(handle);
ODBCConnection* connection = &descriptor->GetConnection();
std::string dsn = connection->GetDSN();
return GetStringAttribute(is_unicode, dsn, true, diag_info_ptr, buffer_length,
string_length_ptr, *diagnostics);
break;
}

case SQL_HANDLE_STMT: {
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(handle);
ODBCConnection* connection = &statement->GetConnection();
std::string dsn = connection->GetDSN();
return GetStringAttribute(is_unicode, dsn, true, diag_info_ptr, buffer_length,
string_length_ptr, *diagnostics);
}

default:
return SQL_ERROR;
}
}

return SQL_ERROR;
}

case SQL_DIAG_SQLSTATE: {
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr,
is_unicode)) {
const std::string& state = diagnostics->GetSQLState(record_index);
return GetStringAttribute(is_unicode, state, true, diag_info_ptr, buffer_length,
string_length_ptr, *diagnostics);
}

return SQL_ERROR;
}

// Return valid dummy variable for unimplemented field
case SQL_DIAG_COLUMN_NUMBER: {
if (diag_info_ptr) {
*static_cast<SQLINTEGER*>(diag_info_ptr) = SQL_NO_COLUMN_NUMBER;
}

if (string_length_ptr) {
*string_length_ptr = sizeof(SQLINTEGER);
}

return SQL_SUCCESS;
}

// Return empty string dummy variable for unimplemented fields
case SQL_DIAG_CLASS_ORIGIN:
case SQL_DIAG_CONNECTION_NAME:
case SQL_DIAG_SUBCLASS_ORIGIN: {
if (IsValidStringFieldArgs(diag_info_ptr, buffer_length, string_length_ptr,
is_unicode)) {
return GetStringAttribute(is_unicode, "", true, diag_info_ptr, buffer_length,
string_length_ptr, *diagnostics);
}

return SQL_ERROR;
}

// Return valid dummy variable for unimplemented field
case SQL_DIAG_ROW_NUMBER: {
if (diag_info_ptr) {
*static_cast<SQLLEN*>(diag_info_ptr) = SQL_NO_ROW_NUMBER;
}

if (string_length_ptr) {
*string_length_ptr = sizeof(SQLLEN);
}

return SQL_SUCCESS;
}

default: {
return SQL_ERROR;
}
}

return SQL_ERROR;
}

SQLRETURN SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT rec_number,
Expand All @@ -91,8 +356,84 @@ SQLRETURN SQLGetDiagRec(SQLSMALLINT handle_type, SQLHANDLE handle, SQLSMALLINT r
<< ", message_text: " << static_cast<const void*>(message_text)
<< ", buffer_length: " << buffer_length
<< ", text_length_ptr: " << static_cast<const void*>(text_length_ptr);
// GH-46575 TODO: Implement SQLGetDiagRec
return SQL_INVALID_HANDLE;
// GH-46575 TODO: Add tests for SQLGetDiagRec
using arrow::flight::sql::odbc::Diagnostics;
using ODBC::GetStringAttribute;
using ODBC::ODBCConnection;
using ODBC::ODBCDescriptor;
using ODBC::ODBCEnvironment;
using ODBC::ODBCStatement;

if (!handle) {
return SQL_INVALID_HANDLE;
}

// Record number must be greater or equal to 1
if (rec_number < 1 || buffer_length < 0) {
return SQL_ERROR;
}

// Set character type to be Unicode by default
const bool is_unicode = true;
Diagnostics* diagnostics = nullptr;

switch (handle_type) {
case SQL_HANDLE_ENV: {
auto* environment = ODBCEnvironment::Of(handle);
diagnostics = &environment->GetDiagnostics();
break;
}

case SQL_HANDLE_DBC: {
auto* connection = ODBCConnection::Of(handle);
diagnostics = &connection->GetDiagnostics();
break;
}

case SQL_HANDLE_DESC: {
auto* descriptor = ODBCDescriptor::Of(handle);
diagnostics = &descriptor->GetDiagnostics();
break;
}

case SQL_HANDLE_STMT: {
auto* statement = ODBCStatement::Of(handle);
diagnostics = &statement->GetDiagnostics();
break;
}

default:
return SQL_INVALID_HANDLE;
}

if (!diagnostics) {
return SQL_ERROR;
}

// Convert from ODBC 1 based record number to internal diagnostics 0 indexed storage
const size_t record_index = static_cast<size_t>(rec_number - 1);
if (!diagnostics->HasRecord(record_index)) {
return SQL_NO_DATA;
}

if (sql_state) {
// The length of the sql state is always 5 characters plus null
SQLSMALLINT size = 6;
const std::string& state = diagnostics->GetSQLState(record_index);
GetStringAttribute(is_unicode, state, false, sql_state, size, &size, *diagnostics);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not need to check the return value?

}

if (native_error_ptr) {
*native_error_ptr = diagnostics->GetNativeError(record_index);
}

if (message_text || text_length_ptr) {
const std::string& message = diagnostics->GetMessageText(record_index);
return GetStringAttribute(is_unicode, message, false, message_text, buffer_length,
text_length_ptr, *diagnostics);
}

return SQL_SUCCESS;
}

SQLRETURN SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER value_ptr,
Expand Down
Loading