Skip to content

feat(csharp/src/Drivers/Apache): enhance GetColumns with BASE_TYPE_NAME column #2695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ protected static Uri GetBaseAddress(string? uri, string? hostName, string? path,
return baseAddress;
}

protected IReadOnlyDictionary<string, int> GetColumnIndexMap(List<TColumnDesc> columns) => columns
internal IReadOnlyDictionary<string, int> GetColumnIndexMap(List<TColumnDesc> columns) => columns
.Select(t => new { Index = t.Position - ColumnMapIndexOffset, t.ColumnName })
.ToDictionary(t => t.ColumnName, t => t.Index);

Expand Down Expand Up @@ -1242,12 +1242,7 @@ private static StructArray GetColumnSchema(TableInfo tableInfo)
nullBitmapBuffer.Build());
}

protected abstract void SetPrecisionScaleAndTypeName(
short colType,
string typeName,
TableInfo? tableInfo,
int columnSize,
int decimalDigits);
internal abstract void SetPrecisionScaleAndTypeName(short columnType, string typeName, TableInfo? tableInfo, int columnSize, int decimalDigits);

public override Schema GetTableSchema(string? catalog, string? dbSchema, string? tableName)
{
Expand Down Expand Up @@ -1364,7 +1359,7 @@ private static IArrowType GetArrowType(int columnTypeId, string typeName, bool i
}
}

protected async Task<TRowSet> FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default)
internal async Task<TRowSet> FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default)
{
await PollForResponseAsync(operationHandle, Client, PollTimeMillisecondsDefault, cancellationToken);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ protected override TOpenSessionReq CreateSessionRequest()
return req;
}

protected override void SetPrecisionScaleAndTypeName(
internal override void SetPrecisionScaleAndTypeName(
short colType,
string typeName,
TableInfo? tableInfo,
Expand Down
128 changes: 127 additions & 1 deletion csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
using Apache.Arrow.Types;
using Apache.Hive.Service.Rpc.Thrift;
using Thrift.Transport;

Expand Down Expand Up @@ -406,7 +407,39 @@ private async Task<QueryResult> GetColumnsAsync(CancellationToken cancellationTo
cancellationToken);
OperationHandle = resp.OperationHandle;

return await GetQueryResult(resp.DirectResults, cancellationToken);
// Common variables declared upfront
TGetResultSetMetadataResp metadata;
Schema schema;
TRowSet rowSet;

// For GetColumns, we need to enhance the result with BASE_TYPE_NAME
if (Connection.AreResultsAvailableDirectly() && resp.DirectResults?.ResultSet?.Results != null)
{
// Get data from direct results
metadata = resp.DirectResults.ResultSetMetadata;
schema = Connection.SchemaParser.GetArrowSchema(metadata.Schema, Connection.DataTypeConversion);
rowSet = resp.DirectResults.ResultSet.Results;
}
else
{
// Poll and fetch results
await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken);

// Get metadata
metadata = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken);
schema = Connection.SchemaParser.GetArrowSchema(metadata.Schema, Connection.DataTypeConversion);

// Fetch the results
rowSet = await Connection.FetchResultsAsync(OperationHandle!, BatchSize, cancellationToken);
}

// Common processing for both paths
int columnCount = HiveServer2Reader.GetColumnCount(rowSet);
int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount);
IReadOnlyList<IArrowArray> data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion);

// Return the enhanced result with added BASE_TYPE_NAME column
return EnhanceGetColumnsResult(schema, data, rowCount, metadata, rowSet);
}

private async Task<Schema> GetResultSetSchemaAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default)
Expand All @@ -426,12 +459,105 @@ private async Task<QueryResult> GetQueryResult(TSparkDirectResults? directResult
int columnCount = HiveServer2Reader.GetColumnCount(rowSet);
int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount);
IReadOnlyList<IArrowArray> data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion);

return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(schema, data));
}

await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken);
schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken);

return new QueryResult(-1, Connection.NewReader(this, schema));
}

protected internal QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList<IArrowArray> originalData,
int rowCount, TGetResultSetMetadataResp metadata, TRowSet rowSet)
{
// Create a column map using Connection's GetColumnIndexMap method
var columnMap = Connection.GetColumnIndexMap(metadata.Schema.Columns);

// Get column indices - we know these columns always exist
int typeNameIndex = columnMap["TYPE_NAME"];
int dataTypeIndex = columnMap["DATA_TYPE"];
int columnSizeIndex = columnMap["COLUMN_SIZE"];
int decimalDigitsIndex = columnMap["DECIMAL_DIGITS"];

// Extract the existing arrays
StringArray typeNames = (StringArray)originalData[typeNameIndex];
Int32Array originalColumnSizes = (Int32Array)originalData[columnSizeIndex];
Int32Array originalDecimalDigits = (Int32Array)originalData[decimalDigitsIndex];

// Create enhanced schema with BASE_TYPE_NAME column
var enhancedFields = originalSchema.FieldsList.ToList();
enhancedFields.Add(new Field("BASE_TYPE_NAME", StringType.Default, true));
Schema enhancedSchema = new Schema(enhancedFields, originalSchema.Metadata);

// Pre-allocate arrays to store our values
int length = typeNames.Length;
List<string> baseTypeNames = new List<string>(length);
List<int> columnSizeValues = new List<int>(length);
List<int> decimalDigitsValues = new List<int>(length);

// Process each row
for (int i = 0; i < length; i++)
{
string? typeName = typeNames.GetString(i);
short colType = (short)rowSet.Columns[dataTypeIndex].I32Val.Values.Values[i];
int columnSize = originalColumnSizes.GetValue(i).GetValueOrDefault();
int decimalDigits = originalDecimalDigits.GetValue(i).GetValueOrDefault();

// Create a TableInfo for this row
var tableInfo = new HiveServer2Connection.TableInfo(string.Empty);

// Process all types through SetPrecisionScaleAndTypeName
Connection.SetPrecisionScaleAndTypeName(colType, typeName ?? string.Empty, tableInfo, columnSize, decimalDigits);

// Get base type name
string baseTypeName;
if (tableInfo.BaseTypeName.Count > 0)
{
string? baseTypeNameValue = tableInfo.BaseTypeName[0];
baseTypeName = baseTypeNameValue ?? string.Empty;
}
else
{
baseTypeName = typeName ?? string.Empty;
}
baseTypeNames.Add(baseTypeName);

// Get precision/scale values
if (tableInfo.Precision.Count > 0)
{
int? precisionValue = tableInfo.Precision[0];
columnSizeValues.Add(precisionValue.GetValueOrDefault(columnSize));
}
else
{
columnSizeValues.Add(columnSize);
}

if (tableInfo.Scale.Count > 0)
{
int? scaleValue = tableInfo.Scale[0];
decimalDigitsValues.Add(scaleValue.GetValueOrDefault(decimalDigits));
}
else
{
decimalDigitsValues.Add(decimalDigits);
}
}

// Create the Arrow arrays directly from our data arrays
StringArray baseTypeNameArray = new StringArray.Builder().AppendRange(baseTypeNames).Build();
Int32Array columnSizeArray = new Int32Array.Builder().AppendRange(columnSizeValues).Build();
Int32Array decimalDigitsArray = new Int32Array.Builder().AppendRange(decimalDigitsValues).Build();

// Create enhanced data with modified columns
var enhancedData = new List<IArrowArray>(originalData);
enhancedData[columnSizeIndex] = columnSizeArray;
enhancedData[decimalDigitsIndex] = decimalDigitsArray;
enhancedData.Add(baseTypeNameArray);

return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(enhancedSchema, enhancedData));
}
}
}
6 changes: 5 additions & 1 deletion csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ internal abstract class SqlTypeNameParser<T> : ISqlTypeNameParser where T : SqlT

// Note: the INTERVAL sql type does not have an associated column type id.
private static readonly HashSet<ISqlTypeNameParser> s_parsers = new HashSet<ISqlTypeNameParser>(s_parserMap.Values
.Concat([SqlIntervalTypeParser.Default, SqlSimpleTypeParser.Default("VOID")]));
.Concat([
SqlIntervalTypeParser.Default,
SqlSimpleTypeParser.Default("VOID"),
SqlSimpleTypeParser.Default("VARIANT"),
]));

/// <summary>
/// Gets the base SQL type name without decoration or sub clauses
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp response, Cancel
protected internal override Task<TRowSet> GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken cancellationToken = default) =>
FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken);

protected override void SetPrecisionScaleAndTypeName(
internal override void SetPrecisionScaleAndTypeName(
short colType,
string typeName,
TableInfo? tableInfo,
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Drivers/Apache/Spark/SparkConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public override AdbcStatement CreateStatement()

protected internal override int PositionRequiredOffset => 1;

protected override void SetPrecisionScaleAndTypeName(
internal override void SetPrecisionScaleAndTypeName(
short colType,
string typeName,
TableInfo? tableInfo,
Expand Down
Loading
Loading