Skip to content

feat(csharp/src/Drivers/Apache): enhance GetColumns with BASE_TYPE_NAME column #2695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ protected static Uri GetBaseAddress(string? uri, string? hostName, string? path,
return baseAddress;
}

protected IReadOnlyDictionary<string, int> GetColumnIndexMap(List<TColumnDesc> columns) => columns
internal IReadOnlyDictionary<string, int> GetColumnIndexMap(List<TColumnDesc> columns) => columns
.Select(t => new { Index = t.Position - ColumnMapIndexOffset, t.ColumnName })
.ToDictionary(t => t.ColumnName, t => t.Index);

Expand Down Expand Up @@ -1242,12 +1242,7 @@ private static StructArray GetColumnSchema(TableInfo tableInfo)
nullBitmapBuffer.Build());
}

protected abstract void SetPrecisionScaleAndTypeName(
short colType,
string typeName,
TableInfo? tableInfo,
int columnSize,
int decimalDigits);
internal abstract void SetPrecisionScaleAndTypeName(short columnType, string typeName, TableInfo? tableInfo, int columnSize, int decimalDigits);

public override Schema GetTableSchema(string? catalog, string? dbSchema, string? tableName)
{
Expand Down Expand Up @@ -1364,7 +1359,7 @@ private static IArrowType GetArrowType(int columnTypeId, string typeName, bool i
}
}

protected async Task<TRowSet> FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default)
internal async Task<TRowSet> FetchResultsAsync(TOperationHandle operationHandle, long batchSize = BatchSizeDefault, CancellationToken cancellationToken = default)
{
await PollForResponseAsync(operationHandle, Client, PollTimeMillisecondsDefault, cancellationToken);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ protected override TOpenSessionReq CreateSessionRequest()
return req;
}

protected override void SetPrecisionScaleAndTypeName(
internal override void SetPrecisionScaleAndTypeName(
short colType,
string typeName,
TableInfo? tableInfo,
Expand Down
125 changes: 124 additions & 1 deletion csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
using Apache.Arrow.Types;
using Apache.Hive.Service.Rpc.Thrift;
using Thrift.Transport;

Expand Down Expand Up @@ -406,7 +407,36 @@ private async Task<QueryResult> GetColumnsAsync(CancellationToken cancellationTo
cancellationToken);
OperationHandle = resp.OperationHandle;

return await GetQueryResult(resp.DirectResults, cancellationToken);
// For GetColumns, we need to enhance the result with BASE_TYPE_NAME
if (Connection.AreResultsAvailableDirectly() && resp.DirectResults?.ResultSet?.Results != null)
{
TGetResultSetMetadataResp resultSetMetadata = resp.DirectResults.ResultSetMetadata;
Schema schema = Connection.SchemaParser.GetArrowSchema(resultSetMetadata.Schema, Connection.DataTypeConversion);
TRowSet rowSet = resp.DirectResults.ResultSet.Results;
int columnCount = HiveServer2Reader.GetColumnCount(rowSet);
int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount);
IReadOnlyList<IArrowArray> data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion);

return EnhanceGetColumnsResult(schema, data, rowCount, resultSetMetadata, rowSet);
}
else
{
await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken);
Schema schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken);

// Fetch the results manually to enhance them
TRowSet rowSet = await Connection.FetchResultsAsync(OperationHandle!, BatchSize, cancellationToken);
int columnCount = HiveServer2Reader.GetColumnCount(rowSet);
int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount);

// Get metadata again to ensure we have the latest
TGetResultSetMetadataResp metadata = await HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!, Connection.Client, cancellationToken);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a duplicate call with the call to GetResultSetSchemaAsync. Is there any way to avoid the extra call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this, changed


// Get the arrays from the row set
IReadOnlyList<IArrowArray> data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion);

return EnhanceGetColumnsResult(schema, data, rowCount, metadata, rowSet);
}
}

private async Task<Schema> GetResultSetSchemaAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default)
Expand All @@ -426,12 +456,105 @@ private async Task<QueryResult> GetQueryResult(TSparkDirectResults? directResult
int columnCount = HiveServer2Reader.GetColumnCount(rowSet);
int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount);
IReadOnlyList<IArrowArray> data = HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema, Connection.DataTypeConversion);

return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(schema, data));
}

await HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client, PollTimeMilliseconds, cancellationToken);
schema = await GetResultSetSchemaAsync(OperationHandle!, Connection.Client, cancellationToken);

return new QueryResult(-1, Connection.NewReader(this, schema));
}

protected internal QueryResult EnhanceGetColumnsResult(Schema originalSchema, IReadOnlyList<IArrowArray> originalData,
int rowCount, TGetResultSetMetadataResp metadata, TRowSet rowSet)
{
// Create a column map using Connection's GetColumnIndexMap method
var columnMap = Connection.GetColumnIndexMap(metadata.Schema.Columns);

// Get column indices - we know these columns always exist
int typeNameIndex = columnMap["TYPE_NAME"];
int dataTypeIndex = columnMap["DATA_TYPE"];
int columnSizeIndex = columnMap["COLUMN_SIZE"];
int decimalDigitsIndex = columnMap["DECIMAL_DIGITS"];

// Extract the existing arrays
StringArray typeNames = (StringArray)originalData[typeNameIndex];
Int32Array originalColumnSizes = (Int32Array)originalData[columnSizeIndex];
Int32Array originalDecimalDigits = (Int32Array)originalData[decimalDigitsIndex];

// Create enhanced schema with BASE_TYPE_NAME column
var enhancedFields = originalSchema.FieldsList.ToList();
enhancedFields.Add(new Field("BASE_TYPE_NAME", StringType.Default, true));
Schema enhancedSchema = new Schema(enhancedFields, originalSchema.Metadata);

// Pre-allocate arrays to store our values
int length = typeNames.Length;
List<string> baseTypeNames = new List<string>(length);
List<int> columnSizeValues = new List<int>(length);
List<int> decimalDigitsValues = new List<int>(length);

// Process each row
for (int i = 0; i < length; i++)
{
string? typeName = typeNames.GetString(i);
short colType = (short)rowSet.Columns[dataTypeIndex].I32Val.Values.Values[i];
int columnSize = originalColumnSizes.GetValue(i).GetValueOrDefault();
int decimalDigits = originalDecimalDigits.GetValue(i).GetValueOrDefault();

// Create a TableInfo for this row
var tableInfo = new HiveServer2Connection.TableInfo(string.Empty);

// Process all types through SetPrecisionScaleAndTypeName
Connection.SetPrecisionScaleAndTypeName(colType, typeName ?? string.Empty, tableInfo, columnSize, decimalDigits);

// Get base type name
string baseTypeName;
if (tableInfo.BaseTypeName.Count > 0)
{
string? baseTypeNameValue = tableInfo.BaseTypeName[0];
baseTypeName = baseTypeNameValue ?? string.Empty;
}
else
{
baseTypeName = typeName ?? string.Empty;
}
baseTypeNames.Add(baseTypeName);

// Get precision/scale values
if (tableInfo.Precision.Count > 0)
{
int? precisionValue = tableInfo.Precision[0];
columnSizeValues.Add(precisionValue.GetValueOrDefault(columnSize));
}
else
{
columnSizeValues.Add(columnSize);
}

if (tableInfo.Scale.Count > 0)
{
int? scaleValue = tableInfo.Scale[0];
decimalDigitsValues.Add(scaleValue.GetValueOrDefault(decimalDigits));
}
else
{
decimalDigitsValues.Add(decimalDigits);
}
}

// Create the Arrow arrays directly from our data arrays
StringArray baseTypeNameArray = new StringArray.Builder().AppendRange(baseTypeNames).Build();
Int32Array columnSizeArray = new Int32Array.Builder().AppendRange(columnSizeValues).Build();
Int32Array decimalDigitsArray = new Int32Array.Builder().AppendRange(decimalDigitsValues).Build();

// Create enhanced data with modified columns
var enhancedData = new List<IArrowArray>(originalData);
enhancedData[columnSizeIndex] = columnSizeArray;
enhancedData[decimalDigitsIndex] = decimalDigitsArray;
enhancedData.Add(baseTypeNameArray);

return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(enhancedSchema, enhancedData));
}
}
}
6 changes: 5 additions & 1 deletion csharp/src/Drivers/Apache/Hive2/SqlTypeNameParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ internal abstract class SqlTypeNameParser<T> : ISqlTypeNameParser where T : SqlT

// Note: the INTERVAL sql type does not have an associated column type id.
private static readonly HashSet<ISqlTypeNameParser> s_parsers = new HashSet<ISqlTypeNameParser>(s_parserMap.Values
.Concat([SqlIntervalTypeParser.Default, SqlSimpleTypeParser.Default("VOID")]));
.Concat([
SqlIntervalTypeParser.Default,
SqlSimpleTypeParser.Default("VOID"),
SqlSimpleTypeParser.Default("VARIANT"),
]));

/// <summary>
/// Gets the base SQL type name without decoration or sub clauses
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp response, Cancel
protected internal override Task<TRowSet> GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken cancellationToken = default) =>
FetchResultsAsync(response.OperationHandle, cancellationToken: cancellationToken);

protected override void SetPrecisionScaleAndTypeName(
internal override void SetPrecisionScaleAndTypeName(
short colType,
string typeName,
TableInfo? tableInfo,
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Drivers/Apache/Spark/SparkConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public override AdbcStatement CreateStatement()

protected internal override int PositionRequiredOffset => 1;

protected override void SetPrecisionScaleAndTypeName(
internal override void SetPrecisionScaleAndTypeName(
short colType,
string typeName,
TableInfo? tableInfo,
Expand Down
Loading
Loading