Skip to content

feat(csharp/src/Drivers/Apache): enhance GetColumns with BASE_TYPE_NAME column #2695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ protected static Uri GetBaseAddress(string? uri, string? hostName, string? path,
return baseAddress;
}

protected IReadOnlyDictionary<string, int> GetColumnIndexMap(List<TColumnDesc> columns) => columns
internal IReadOnlyDictionary<string, int> GetColumnIndexMap(List<TColumnDesc> columns) => columns
.Select(t => new { Index = t.Position - ColumnMapIndexOffset, t.ColumnName })
.ToDictionary(t => t.ColumnName, t => t.Index);

Expand Down Expand Up @@ -1242,12 +1242,7 @@ private static StructArray GetColumnSchema(TableInfo tableInfo)
nullBitmapBuffer.Build());
}

protected abstract void SetPrecisionScaleAndTypeName(
short colType,
string typeName,
TableInfo? tableInfo,
int columnSize,
int decimalDigits);
public abstract void SetPrecisionScaleAndTypeName(short columnType, string typeName, TableInfo? tableInfo, int columnSize, int decimalDigits);

public override Schema GetTableSchema(string? catalog, string? dbSchema, string? tableName)
{
Expand Down Expand Up @@ -1364,7 +1359,7 @@ private static IArrowType GetArrowType(int columnTypeId, string typeName, bool i
}
}

protected async Task<TRowSet> FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default)
public async Task<TRowSet> FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default)
{
await PollForResponseAsync(operationHandle, Client, PollTimeMillisecondsDefault, cancellationToken);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ protected override TOpenSessionReq CreateSessionRequest()
return req;
}

protected override void SetPrecisionScaleAndTypeName(
public override void SetPrecisionScaleAndTypeName(
short colType,
string typeName,
TableInfo? tableInfo,
Expand Down
117 changes: 117 additions & 0 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
using Apache.Arrow.Types;
using Apache.Hive.Service.Rpc.Thrift;
using Thrift.Transport;

Expand Down Expand Up @@ -426,12 +427,128 @@ private async Task<QueryResult> GetQueryResult(TSparkDirectResults? directResult
int columnCount = HiveServer2Reader.GetColumnCount(rowSet);
int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount);
IReadOnlyList<IArrowArray> data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion);

// Enhance column schema results if this is a GetColumns query
if (SqlQuery?.ToLowerInvariant() == GetColumnsCommandName)
{
return EnhanceGetColumnsResult(schema, data, rowCount, resultSetMetadata, rowSet);
}

return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(schema, data));
}

await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken);
schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken);

// For GetColumns operation, we need to fetch the results and enhance them
if (SqlQuery?.ToLowerInvariant() == GetColumnsCommandName)
{
// Fetch the results manually to enhance them
TRowSet rowSet = await Connection.FetchResultsAsync(OperationHandle!, BatchSize, cancellationToken);
int columnCount = HiveServer2Reader.GetColumnCount(rowSet);
int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount);

// Get metadata again to ensure we have the latest
TGetResultSetMetadataResp metadata = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken);

// Get the arrays from the row set
IReadOnlyList<IArrowArray> data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion);

return EnhanceGetColumnsResult(schema, data, rowCount, metadata, rowSet);
}

return new QueryResult(-1, Connection.NewReader(this, schema));
}

private QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList<IArrowArray> originalData,
int rowCount, TGetResultSetMetadataResp metadata, TRowSet rowSet)
{
// Create a column map using Connection's GetColumnIndexMap method
var columnMap = Connection.GetColumnIndexMap(metadata.Schema.Columns);

// Get column indices - we know these columns always exist
int typeNameIndex = columnMap["TYPE_NAME"];
int dataTypeIndex = columnMap["DATA_TYPE"];
int columnSizeIndex = columnMap["COLUMN_SIZE"];
int decimalDigitsIndex = columnMap["DECIMAL_DIGITS"];

// Extract the existing arrays
StringArray typeNames = (StringArray)originalData[typeNameIndex];
Int32Array originalColumnSizes = (Int32Array)originalData[columnSizeIndex];
Int32Array originalDecimalDigits = (Int32Array)originalData[decimalDigitsIndex];

// Create enhanced schema with BASE_TYPE_NAME column
var enhancedFields = originalSchema.FieldsList.ToList();
enhancedFields.Add(new Field("BASE_TYPE_NAME", StringType.Default, true));
Schema enhancedSchema = new Schema(enhancedFields, originalSchema.Metadata);

// Pre-allocate arrays to store our values
int length = typeNames.Length;
List<string> baseTypeNames = new List<string>(length);
List<int> columnSizeValues = new List<int>(length);
List<int> decimalDigitsValues = new List<int>(length);

// Process each row
for (int i = 0; i < length; i++)
{
string? typeName = typeNames.GetString(i);
short colType = (short)rowSet.Columns[dataTypeIndex].I32Val.Values.Values[i];
int columnSize = originalColumnSizes.GetValue(i).GetValueOrDefault();
int decimalDigits = originalDecimalDigits.GetValue(i).GetValueOrDefault();

// Create a TableInfo for this row
var tableInfo = new HiveServer2Connection.TableInfo(string.Empty);

// Process all types through SetPrecisionScaleAndTypeName
Connection.SetPrecisionScaleAndTypeName(colType, typeName ?? string.Empty, tableInfo, columnSize, decimalDigits);

// Get base type name
string baseTypeName;
if (tableInfo.BaseTypeName.Count > 0)
{
string? baseTypeNameValue = tableInfo.BaseTypeName[0];
baseTypeName = baseTypeNameValue ?? string.Empty;
}
else
{
baseTypeName = typeName ?? string.Empty;
}
baseTypeNames.Add(baseTypeName);

// Get precision/scale values
if (tableInfo.Precision.Count > 0)
{
int? precisionValue = tableInfo.Precision[0];
columnSizeValues.Add(precisionValue.GetValueOrDefault(columnSize));
}
else
{
columnSizeValues.Add(columnSize);
}

if (tableInfo.Scale.Count > 0)
{
int? scaleValue = tableInfo.Scale[0];
decimalDigitsValues.Add(scaleValue.GetValueOrDefault(decimalDigits));
}
else
{
decimalDigitsValues.Add(decimalDigits);
}
}

// Create the Arrow arrays directly from our data arrays
StringArray baseTypeNameArray = new StringArray.Builder().AppendRange(baseTypeNames).Build();
Int32Array columnSizeArray = new Int32Array.Builder().AppendRange(columnSizeValues).Build();
Int32Array decimalDigitsArray = new Int32Array.Builder().AppendRange(decimalDigitsValues).Build();

// Create enhanced data with modified columns
var enhancedData = new List<IArrowArray>(originalData);
enhancedData[columnSizeIndex] = columnSizeArray;
enhancedData[decimalDigitsIndex] = decimalDigitsArray;
enhancedData.Add(baseTypeNameArray);

return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(enhancedSchema, enhancedData));
}
}
}
6 changes: 5 additions & 1 deletion csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ internal abstract class SqlTypeNameParser<T> : ISqlTypeNameParser where T : SqlT

// Note: the INTERVAL sql type does not have an associated column type id.
private static readonly HashSet<ISqlTypeNameParser> s_parsers = new HashSet<ISqlTypeNameParser>(s_parserMap.Values
.Concat([SqlIntervalTypeParser.Default, SqlSimpleTypeParser.Default("VOID")]));
.Concat([
SqlIntervalTypeParser.Default,
SqlSimpleTypeParser.Default("VOID"),
SqlSimpleTypeParser.Default("VARIANT"),
]));

/// <summary>
/// Gets the base SQL type name without decoration or sub clauses
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp response, Cancel
protected internal override Task<TRowSet> GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken cancellationToken = default) =>
FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken);

protected override void SetPrecisionScaleAndTypeName(
public override void SetPrecisionScaleAndTypeName(
short colType,
string typeName,
TableInfo? tableInfo,
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Drivers/Apache/Spark/SparkConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public override AdbcStatement CreateStatement()

protected internal override int PositionRequiredOffset => 1;

protected override void SetPrecisionScaleAndTypeName(
public override void SetPrecisionScaleAndTypeName(
short colType,
string typeName,
TableInfo? tableInfo,
Expand Down
Loading
Loading