Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(csharp): improve handling of StructArrays #2587

Merged
merged 7 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 70 additions & 29 deletions csharp/src/Apache.Arrow.Adbc/Extensions/IArrowArrayExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@

namespace Apache.Arrow.Adbc.Extensions
{
public enum StructResultType
{
JsonString,
Object
}

public static class IArrowArrayExtensions
{
/// <summary>
/// Helper extension to get a value from the <see cref="IArrowArray"/> at the specified index.
/// Overloaded. Helper extension to get a value from the <see cref="IArrowArray"/> at the specified index.
/// </summary>
/// <param name="arrowArray">
/// The Arrow array.
Expand All @@ -37,10 +43,30 @@ public static class IArrowArrayExtensions
/// The index in the array to get the value from.
/// </param>
public static object? ValueAt(this IArrowArray arrowArray, int index)
{
return ValueAt(arrowArray, index, StructResultType.JsonString);
}

/// <summary>
/// Overloaded. Helper extension to get a value from the <see cref="IArrowArray"/> at the specified index.
/// </summary>
/// <param name="arrowArray">
/// The Arrow array.
/// </param>
/// <param name="index">
/// The index in the array to get the value from.
/// </param>
/// <param name="resultType">
/// T
/// </param>
public static object? ValueAt(this IArrowArray arrowArray, int index, StructResultType resultType = StructResultType.JsonString)
{
if (arrowArray == null) throw new ArgumentNullException(nameof(arrowArray));
if (index < 0) throw new ArgumentOutOfRangeException(nameof(index));

if (arrowArray.IsNull(index))
return null;

switch (arrowArray.Data.DataType.TypeId)
{
case ArrowTypeId.Null:
Expand Down Expand Up @@ -76,7 +102,9 @@ public static class IArrowArrayExtensions
case ArrowTypeId.Int64:
return ((Int64Array)arrowArray).GetValue(index);
case ArrowTypeId.String:
return ((StringArray)arrowArray).GetString(index);
StringArray sArray = (StringArray)arrowArray;
if (sArray.Length == 0) { return null; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we get here? Why is this not an error, and why does it impact only StringArray and not other arrays?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still curious about this as it looks strictly wrong. Is there a call stack which shows how we'd get here?

return sArray.GetString(index);
#if NET6_0_OR_GREATER
case ArrowTypeId.Time32:
return ((Time32Array)arrowArray).GetTime(index);
Expand Down Expand Up @@ -127,39 +155,47 @@ public static class IArrowArrayExtensions
throw new NotSupportedException($"Unsupported interval unit: {((IntervalType)arrowArray.Data.DataType).Unit}");
}
case ArrowTypeId.Binary:
if (!arrowArray.IsNull(index))
{
return ((BinaryArray)arrowArray).GetBytes(index).ToArray();
}
else
{
return null;
}
return ((BinaryArray)arrowArray).GetBytes(index).ToArray();
case ArrowTypeId.List:
return ((ListArray)arrowArray).GetSlicedValues(index);
case ArrowTypeId.Struct:
return SerializeToJson(((StructArray)arrowArray), index);
StructArray structArray = (StructArray)arrowArray;
return resultType == StructResultType.JsonString ? SerializeToJson(structArray, index) : ParseStructArray(structArray, index);

// not covered:
// -- map array
// -- dictionary array
// -- fixed size binary
// -- union array
// not covered:
// -- map array
// -- dictionary array
// -- fixed size binary
// -- union array
}

return null;
}

/// <summary>
/// Helper extension to get a value from the <see cref="IArrowArray"/> at the specified index.
/// Overloaded. Helper extension to get a value converter for the <see href="IArrowType"/>.
/// </summary>
/// <param name="arrowArray">
/// The Arrow array.
/// </param>
/// <param name="index">
/// The index in the array to get the value from.
/// <param name="arrayType">
/// The return type of an item in a StructArray.
/// </param>
public static Func<IArrowArray, int, object?> GetValueConverter(this IArrowType arrayType)
{
return GetValueConverter(arrayType, StructResultType.JsonString);
}

/// <summary>
/// Overloaded. Helper extension to get a value from the <see cref="IArrowArray"/> at the specified index.
/// </summary>
/// <param name="arrayType">
/// The Arrow array type.
/// </param>
/// <param name="sourceType">
/// The incoming <see cref="SourceStringType"/>.
/// </param>
/// <param name="resultType">
/// The return type of an item in a StructArray.
/// </param>
public static Func<IArrowArray, int, object?> GetValueConverter(this IArrowType arrayType, StructResultType resultType)
{
if (arrayType == null) throw new ArgumentNullException(nameof(arrayType));

Expand Down Expand Up @@ -198,7 +234,9 @@ public static class IArrowArrayExtensions
case ArrowTypeId.Int64:
return (array, index) => ((Int64Array)array).GetValue(index);
case ArrowTypeId.String:
return (array, index) => ((StringArray)array).GetString(index);
return (array, index) => array.Data.DataType.TypeId == ArrowTypeId.Decimal256 ?
((Decimal256Array)array).GetString(index) :
((StringArray)array).GetString(index);
#if NET6_0_OR_GREATER
case ArrowTypeId.Time32:
return (array, index) => ((Time32Array)array).GetTime(index);
Expand Down Expand Up @@ -256,7 +294,9 @@ public static class IArrowArrayExtensions
case ArrowTypeId.List:
return (array, index) => ((ListArray)array).GetSlicedValues(index);
case ArrowTypeId.Struct:
return (array, index) => SerializeToJson((StructArray)array, index);
return resultType == StructResultType.JsonString ?
(array, index) => SerializeToJson((StructArray)array, index) :
(array, index) => ParseStructArray((StructArray)array, index);

// not covered:
// -- map array
Expand All @@ -273,20 +313,21 @@ public static class IArrowArrayExtensions
/// </summary>
private static string SerializeToJson(StructArray structArray, int index)
{
Dictionary<String, object?>? jsonDictionary = ParseStructArray(structArray, index);
Dictionary<string, object?>? obj = ParseStructArray(structArray, index);

return JsonSerializer.Serialize(jsonDictionary);
return JsonSerializer.Serialize(obj);
}

/// <summary>
/// Converts a StructArray to a Dictionary<String, object?>.
/// Converts an item in the StructArray at the index position to a Dictionary<string, object?>.
/// </summary>
private static Dictionary<String, object?>? ParseStructArray(StructArray structArray, int index)
private static Dictionary<string, object?>? ParseStructArray(StructArray structArray, int index)
{
if (structArray.IsNull(index))
return null;

Dictionary<String, object?> jsonDictionary = new Dictionary<String, object?>();
Dictionary<string, object?> jsonDictionary = new Dictionary<string, object?>();

StructType structType = (StructType)structArray.Data.DataType;
for (int i = 0; i < structArray.Data.Children.Length; i++)
{
Expand Down
8 changes: 6 additions & 2 deletions csharp/src/Client/AdbcDataReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ internal AdbcDataReader(AdbcCommand adbcCommand, QueryResult adbcQueryResult, De
this.DecimalBehavior = decimalBehavior;
this.StructBehavior = structBehavior;

StructResultType structResultType = this.StructBehavior == StructBehavior.JsonString ? StructResultType.JsonString : StructResultType.Object;

this.converters = new Func<IArrowArray, int, object?>[this.schema.FieldsList.Count];
for (int i = 0; i < this.converters.Length; i++)
{
this.converters[i] = this.schema.FieldsList[i].DataType.GetValueConverter();
this.converters[i] = this.schema.FieldsList[i].DataType.GetValueConverter(structResultType);
}
}

Expand Down Expand Up @@ -372,7 +374,9 @@ public ReadOnlyCollection<AdbcColumn> GetAdbcColumnSchema()
}
else
{
dbColumns.Add(new AdbcColumn(f.Name, t, f.DataType, f.IsNullable));
IArrowType arrowType = SchemaConverter.GetArrowTypeBasedOnRequestedBehavior(f.DataType, this.StructBehavior);

dbColumns.Add(new AdbcColumn(f.Name, t, arrowType, f.IsNullable));
}
}

Expand Down
21 changes: 16 additions & 5 deletions csharp/src/Client/SchemaConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Data.SqlTypes;
Expand Down Expand Up @@ -60,7 +61,7 @@ public static DataTable ConvertArrowSchema(Schema schema, AdbcStatement adbcStat
row[SchemaTableColumn.ColumnName] = f.Name;
row[SchemaTableColumn.ColumnOrdinal] = columnOrdinal;
row[SchemaTableColumn.AllowDBNull] = f.IsNullable;
row[SchemaTableColumn.ProviderType] = f.DataType;
row[SchemaTableColumn.ProviderType] = SchemaConverter.GetArrowTypeBasedOnRequestedBehavior(f.DataType, structBehavior);
Type t = ConvertArrowType(f, decimalBehavior, structBehavior);

row[SchemaTableColumn.DataType] = t;
Expand Down Expand Up @@ -193,10 +194,7 @@ public static Type GetArrowType(Field f, DecimalBehavior decimalBehavior, Struct
return typeof(string);

case ArrowTypeId.Struct:
if (structBehavior == StructBehavior.JsonString)
return typeof(string);
else
goto default;
return structBehavior == StructBehavior.JsonString ? typeof(string) : typeof(Dictionary<string, object?>);

case ArrowTypeId.Timestamp:
return typeof(DateTimeOffset);
Expand Down Expand Up @@ -271,5 +269,18 @@ public static Type GetArrowArrayType(IArrowType dataType)

throw new InvalidCastException($"Cannot determine the array type for {dataType.Name}");
}

/// <summary>
/// Get the IArrowType based on the input IArrowType and the desired <see cref="StructBehavior"/>.
/// If it's a StructType and the desired behavior is a JsonString then this returns StringType.
/// Otherwise, it returns the input IArrowType.
/// </summary>
/// <param name="defaultType">The default IArrowType to return.</param>
/// <param name="structBehavior">Desired behavior if the IArrowType is a StructType.</param>
/// <returns></returns>
public static IArrowType GetArrowTypeBasedOnRequestedBehavior(IArrowType defaultType, StructBehavior structBehavior)
{
return defaultType.TypeId == ArrowTypeId.Struct && structBehavior == StructBehavior.JsonString ? StringType.Default : defaultType;
}
}
}
3 changes: 3 additions & 0 deletions csharp/src/Client/StructBehavior.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

namespace Apache.Arrow.Adbc.Client
{
/// <summary>
/// Controls the behavior of how StructArrays should be handled in the results.
/// </summary>
public enum StructBehavior
{
/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Client/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@ These properties are:

- __AdbcConnectionTimeout__ - This specifies the connection timeout value. The value needs to be in the form (driver.property.name, integer, unit) where the unit is one of `s` or `ms`, For example, `AdbcConnectionTimeout=(adbc.snowflake.sql.client_option.client_timeout,30,s)` would set the connection timeout to 30 seconds.
- __AdbcCommandTimeout__ - This specifies the command timeout value. This follows the same pattern as `AdbcConnectionTimeout` and sets the `AdbcCommandTimeoutProperty` and `CommandTimeout` values on the `AdbcCommand` object.
- __StructBehavior__ - This specifies the StructBehavior when working with Arrow Struct arrays. The valid values are `JsonString` (the default) or `Strict` (treat the struct as a native type).
- __StructBehavior__ - This specifies the StructBehavior when working with Arrow Struct arrays. The valid values are `JsonString` (the default) or `Strict` (treat the struct as a native type). If using JsonString, the return ArrowType will be StringType and the result a string value. If using Strict, the return ArrowType will be StructType and the result a Dictionary<string, object?>.
- __DecimalBehavior__ - This specifies the DecimalBehavior when parsing decimal values from Arrow libraries. The valid values are `UseSqlDecimal` or `OverflowDecimalAsString` where values like Decimal256 are treated as strings.
50 changes: 44 additions & 6 deletions csharp/src/Drivers/BigQuery/BigQueryStatement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,13 @@ public override QueryResult ExecuteQuery()
ReadSession rrs = readClient.CreateReadSession("projects/" + results.TableReference.ProjectId, rs, maxStreamCount);

long totalRows = results.TotalRows == null ? -1L : (long)results.TotalRows.Value;
IArrowArrayStream stream = new MultiArrowReader(TranslateSchema(results.Schema), rrs.Streams.Select(s => ReadChunk(readClient, s.Name)));

var readers = rrs.Streams
.Select(s => ReadChunk(readClient, s.Name))
.Where(chunk => chunk != null)
.Cast<IArrowReader>();

IArrowArrayStream stream = new MultiArrowReader(TranslateSchema(results.Schema), readers);

return new QueryResult(totalRows, stream);
}
Expand Down Expand Up @@ -175,8 +181,7 @@ private IArrowType TranslateType(TableFieldSchema field)
case "DATE":
return GetType(field, Date32Type.Default);
case "RECORD" or "STRUCT":
// its a json string
return GetType(field, StringType.Default);
return GetType(field, BuildStructType(field));

// treat these values as strings
case "GEOGRAPHY" or "JSON":
Expand All @@ -200,6 +205,19 @@ private IArrowType TranslateType(TableFieldSchema field)
}
}

private StructType BuildStructType(TableFieldSchema field)
{
List<Field> arrowFields = new List<Field>();

foreach (TableFieldSchema subfield in field.Fields)
{
Field arrowField = TranslateField(subfield);
arrowFields.Add(arrowField);
}

return new StructType(arrowFields.AsReadOnly());
}

private IArrowType GetType(TableFieldSchema field, IArrowType type)
{
if (field.Mode == "REPEATED")
Expand All @@ -208,7 +226,7 @@ private IArrowType GetType(TableFieldSchema field, IArrowType type)
return type;
}

static IArrowReader ReadChunk(BigQueryReadClient readClient, string streamName)
static IArrowReader? ReadChunk(BigQueryReadClient readClient, string streamName)
{
// Ideally we wouldn't need to indirect through a stream, but the necessary APIs in Arrow
// are internal. (TODO: consider changing Arrow).
Expand All @@ -217,7 +235,14 @@ static IArrowReader ReadChunk(BigQueryReadClient readClient, string streamName)

ReadRowsStream stream = new ReadRowsStream(enumerator);

return new ArrowStreamReader(stream);
if (stream.HasRows)
{
return new ArrowStreamReader(stream);
}
else
{
return null;
}
}

private QueryOptions ValidateOptions()
Expand Down Expand Up @@ -349,15 +374,28 @@ sealed class ReadRowsStream : Stream
ReadOnlyMemory<byte> currentBuffer;
bool first;
int position;
bool hasRows;

public ReadRowsStream(IAsyncEnumerator<ReadRowsResponse> response)
{
if (!response.MoveNextAsync().Result) { }
this.currentBuffer = response.Current.ArrowSchema.SerializedSchema.Memory;

if (response.Current != null)
{
this.currentBuffer = response.Current.ArrowSchema.SerializedSchema.Memory;
this.hasRows = true;
}
else
{
this.hasRows = false;
}

this.response = response;
this.first = true;
}

public bool HasRows => this.hasRows;

public override bool CanRead => true;

public override bool CanSeek => false;
Expand Down
Loading
Loading