Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(csharp): improve handling of StructArrays #2587

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 46 additions & 18 deletions csharp/src/Apache.Arrow.Adbc/Extensions/IArrowArrayExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@
using System.Collections;
using System.Collections.Generic;
using System.Data.SqlTypes;
using System.Dynamic;
using System.IO;
using System.Text.Json;
using Apache.Arrow.Types;

namespace Apache.Arrow.Adbc.Extensions
{
public enum StructResultType
{
JsonString,
Object
}

public static class IArrowArrayExtensions
{
/// <summary>
Expand All @@ -36,7 +43,7 @@ public static class IArrowArrayExtensions
/// <param name="index">
/// The index in the array to get the value from.
/// </param>
public static object? ValueAt(this IArrowArray arrowArray, int index)
public static object? ValueAt(this IArrowArray arrowArray, int index, StructResultType resultType = StructResultType.JsonString)
{
if (arrowArray == null) throw new ArgumentNullException(nameof(arrowArray));
if (index < 0) throw new ArgumentOutOfRangeException(nameof(index));
Expand Down Expand Up @@ -76,7 +83,9 @@ public static class IArrowArrayExtensions
case ArrowTypeId.Int64:
return ((Int64Array)arrowArray).GetValue(index);
case ArrowTypeId.String:
return ((StringArray)arrowArray).GetString(index);
StringArray sArray = (StringArray)arrowArray;
if (sArray.Length == 0) { return null; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we get here? Why is this not an error, and why does it impact only StringArray and not other arrays?

return sArray.GetString(index);
#if NET6_0_OR_GREATER
case ArrowTypeId.Time32:
return ((Time32Array)arrowArray).GetTime(index);
Expand Down Expand Up @@ -138,13 +147,14 @@ public static class IArrowArrayExtensions
case ArrowTypeId.List:
return ((ListArray)arrowArray).GetSlicedValues(index);
case ArrowTypeId.Struct:
return SerializeToJson(((StructArray)arrowArray), index);
StructArray structArray = (StructArray)arrowArray;
return resultType == StructResultType.JsonString ? SerializeToJson(structArray, index) : ParseStructArray(structArray, index);

// not covered:
// -- map array
// -- dictionary array
// -- fixed size binary
// -- union array
// not covered:
// -- map array
// -- dictionary array
// -- fixed size binary
// -- union array
}

return null;
Expand All @@ -159,7 +169,7 @@ public static class IArrowArrayExtensions
/// <param name="index">
/// The index in the array to get the value from.
/// </param>
public static Func<IArrowArray, int, object?> GetValueConverter(this IArrowType arrayType)
public static Func<IArrowArray, int, object?> GetValueConverter(this IArrowType arrayType, StructResultType resultType = StructResultType.JsonString)
{
if (arrayType == null) throw new ArgumentNullException(nameof(arrayType));

Expand Down Expand Up @@ -198,7 +208,23 @@ public static class IArrowArrayExtensions
case ArrowTypeId.Int64:
return (array, index) => ((Int64Array)array).GetValue(index);
case ArrowTypeId.String:
return (array, index) => ((StringArray)array).GetString(index);
return (array, index) =>
{
StringArray? sArray = array as StringArray;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives up some of the performance benefit of the approach. Instead of having to add a check to each invocation, can we have the caller tell us in advance what the expected source type is and return one of two different delegates?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldnt figure out an elegant way to have two separate delegates so I went with one and checking the DataType of the array that's passed in.


if (sArray != null)
{
return sArray.GetString(index);
}
else
{
// some callers treat the Decimal256Array values as a string
Decimal256Array? array256 = array as Decimal256Array;
return array256?.GetString(index);
}

throw new AdbcException($"Cannot get the value at {index}. A String type was requested but neither a StringArray or Decimal256Array was found.", AdbcStatusCode.InvalidData);
};
#if NET6_0_OR_GREATER
case ArrowTypeId.Time32:
return (array, index) => ((Time32Array)array).GetTime(index);
Expand Down Expand Up @@ -256,7 +282,7 @@ public static class IArrowArrayExtensions
case ArrowTypeId.List:
return (array, index) => ((ListArray)array).GetSlicedValues(index);
case ArrowTypeId.Struct:
return (array, index) => SerializeToJson((StructArray)array, index);
return (array, index) => resultType == StructResultType.JsonString ? SerializeToJson((StructArray)array, index) : ParseStructArray((StructArray)array, index) ;

// not covered:
// -- map array
Expand All @@ -273,20 +299,22 @@ public static class IArrowArrayExtensions
/// </summary>
private static string SerializeToJson(StructArray structArray, int index)
{
Dictionary<String, object?>? jsonDictionary = ParseStructArray(structArray, index);
ExpandoObject? obj = ParseStructArray(structArray, index);

return JsonSerializer.Serialize(jsonDictionary);
return JsonSerializer.Serialize(obj);
}

/// <summary>
/// Converts a StructArray to a Dictionary<String, object?>.
/// Converts an item in the StructArray at the index position to an ExpandoObject.
/// </summary>
private static Dictionary<String, object?>? ParseStructArray(StructArray structArray, int index)
private static ExpandoObject? ParseStructArray(StructArray structArray, int index)
{
if (structArray.IsNull(index))
return null;

Dictionary<String, object?> jsonDictionary = new Dictionary<String, object?>();
var expando = new ExpandoObject();
var jsonDictionary = (IDictionary<string, object?>)expando;

StructType structType = (StructType)structArray.Data.DataType;
for (int i = 0; i < structArray.Data.Children.Length; i++)
{
Expand All @@ -295,7 +323,7 @@ private static string SerializeToJson(StructArray structArray, int index)

if (value is StructArray structArray1)
{
List<Dictionary<string, object?>?> children = new List<Dictionary<string, object?>?>();
List<ExpandoObject?> children = new List<ExpandoObject?>();

for (int j = 0; j < structArray1.Length; j++)
{
Expand Down Expand Up @@ -335,7 +363,7 @@ private static string SerializeToJson(StructArray structArray, int index)
}
}

return jsonDictionary;
return expando;
}

/// <summary>
Expand Down
8 changes: 6 additions & 2 deletions csharp/src/Client/AdbcDataReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ internal AdbcDataReader(AdbcCommand adbcCommand, QueryResult adbcQueryResult, De
this.DecimalBehavior = decimalBehavior;
this.StructBehavior = structBehavior;

StructResultType structResultType = this.StructBehavior == StructBehavior.JsonString ? StructResultType.JsonString : StructResultType.Object;

this.converters = new Func<IArrowArray, int, object?>[this.schema.FieldsList.Count];
for (int i = 0; i < this.converters.Length; i++)
{
this.converters[i] = this.schema.FieldsList[i].DataType.GetValueConverter();
this.converters[i] = this.schema.FieldsList[i].DataType.GetValueConverter(structResultType);
}
}

Expand Down Expand Up @@ -372,7 +374,9 @@ public ReadOnlyCollection<AdbcColumn> GetAdbcColumnSchema()
}
else
{
dbColumns.Add(new AdbcColumn(f.Name, t, f.DataType, f.IsNullable));
IArrowType arrowType = SchemaConverter.GetArrowTypeBasedOnRequestedBehavior(f.DataType, this.StructBehavior);

dbColumns.Add(new AdbcColumn(f.Name, t, arrowType, f.IsNullable));
}
}

Expand Down
21 changes: 16 additions & 5 deletions csharp/src/Client/SchemaConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
using System.Data;
using System.Data.Common;
using System.Data.SqlTypes;
using System.Dynamic;
using Apache.Arrow.Scalars;
using Apache.Arrow.Types;

Expand Down Expand Up @@ -60,7 +61,7 @@ public static DataTable ConvertArrowSchema(Schema schema, AdbcStatement adbcStat
row[SchemaTableColumn.ColumnName] = f.Name;
row[SchemaTableColumn.ColumnOrdinal] = columnOrdinal;
row[SchemaTableColumn.AllowDBNull] = f.IsNullable;
row[SchemaTableColumn.ProviderType] = f.DataType;
row[SchemaTableColumn.ProviderType] = SchemaConverter.GetArrowTypeBasedOnRequestedBehavior(f.DataType, structBehavior);
Type t = ConvertArrowType(f, decimalBehavior, structBehavior);

row[SchemaTableColumn.DataType] = t;
Expand Down Expand Up @@ -193,10 +194,7 @@ public static Type GetArrowType(Field f, DecimalBehavior decimalBehavior, Struct
return typeof(string);

case ArrowTypeId.Struct:
if (structBehavior == StructBehavior.JsonString)
return typeof(string);
else
goto default;
return structBehavior == StructBehavior.JsonString ? typeof(string) : typeof(ExpandoObject);

case ArrowTypeId.Timestamp:
return typeof(DateTimeOffset);
Expand Down Expand Up @@ -271,5 +269,18 @@ public static Type GetArrowArrayType(IArrowType dataType)

throw new InvalidCastException($"Cannot determine the array type for {dataType.Name}");
}

/// <summary>
/// Get the IArrowType based on the input IArrowType and the desired <see cref="StructBehavior"/>.
/// If it's a StructType and the desired behavior is a JsonString then this returns StringType.
/// Otherwise, it returns the input IArrowType.
/// </summary>
/// <param name="defaultType">The default IArrowType to return.</param>
/// <param name="structBehavior">Desired behavior if the IArrowType is a StructType.</param>
/// <returns></returns>
public static IArrowType GetArrowTypeBasedOnRequestedBehavior(IArrowType defaultType, StructBehavior structBehavior)
{
return defaultType.TypeId == ArrowTypeId.Struct && structBehavior == StructBehavior.JsonString ? StringType.Default : defaultType;
}
}
}
3 changes: 3 additions & 0 deletions csharp/src/Client/StructBehavior.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

namespace Apache.Arrow.Adbc.Client
{
/// <summary>
/// Controls the behavior of how StructArrays should be handled in the results.
/// </summary>
public enum StructBehavior
{
/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Client/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@ These properties are:

- __AdbcConnectionTimeout__ - This specifies the connection timeout value. The value needs to be in the form (driver.property.name, integer, unit) where the unit is one of `s` or `ms`, For example, `AdbcConnectionTimeout=(adbc.snowflake.sql.client_option.client_timeout,30,s)` would set the connection timeout to 30 seconds.
- __AdbcCommandTimeout__ - This specifies the command timeout value. This follows the same pattern as `AdbcConnectionTimeout` and sets the `AdbcCommandTimeoutProperty` and `CommandTimeout` values on the `AdbcCommand` object.
- __StructBehavior__ - This specifies the StructBehavior when working with Arrow Struct arrays. The valid values are `JsonString` (the default) or `Strict` (treat the struct as a native type).
- __StructBehavior__ - This specifies the StructBehavior when working with Arrow Struct arrays. The valid values are `JsonString` (the default) or `Strict` (treat the struct as a native type). If using JsonString, the return ArrowType will be StringType and a string value. If using Strict, the return ArrowType will be StructType and an ExpandoObject.
- __DecimalBehavior__ - This specifies the DecimalBehavior when parsing decimal values from Arrow libraries. The valid values are `UseSqlDecimal` or `OverflowDecimalAsString` where values like Decimal256 are treated as strings.
49 changes: 43 additions & 6 deletions csharp/src/Drivers/BigQuery/BigQueryStatement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,13 @@ public override QueryResult ExecuteQuery()
ReadSession rrs = readClient.CreateReadSession("projects/" + results.TableReference.ProjectId, rs, maxStreamCount);

long totalRows = results.TotalRows == null ? -1L : (long)results.TotalRows.Value;
IArrowArrayStream stream = new MultiArrowReader(TranslateSchema(results.Schema), rrs.Streams.Select(s => ReadChunk(readClient, s.Name)));

var readers = rrs.Streams
Copy link
Contributor Author

@davidhcoe davidhcoe Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make up for the internal contract change of allowing null on ReadChunk, only pass valid readers (that aren't null) to the MultiArrowReader. If it is empty, then no errors occur.

.Select(s => ReadChunk(readClient, s.Name))
.Where(chunk => chunk != null)
.Cast<IArrowReader>();

IArrowArrayStream stream = new MultiArrowReader(TranslateSchema(results.Schema), readers);

return new QueryResult(totalRows, stream);
}
Expand Down Expand Up @@ -175,8 +181,7 @@ private IArrowType TranslateType(TableFieldSchema field)
case "DATE":
return GetType(field, Date32Type.Default);
case "RECORD" or "STRUCT":
// its a json string
return GetType(field, StringType.Default);
return GetType(field, BuildStructType(field));

// treat these values as strings
case "GEOGRAPHY" or "JSON":
Expand All @@ -200,6 +205,19 @@ private IArrowType TranslateType(TableFieldSchema field)
}
}

private StructType BuildStructType(TableFieldSchema field)
{
List<Field> arrowFields = new List<Field>();

foreach (TableFieldSchema subfield in field.Fields)
{
Field arrowField = TranslateField(subfield);
arrowFields.Add(arrowField);
}

return new StructType(arrowFields.AsReadOnly());
}

private IArrowType GetType(TableFieldSchema field, IArrowType type)
{
if (field.Mode == "REPEATED")
Expand All @@ -208,7 +226,7 @@ private IArrowType GetType(TableFieldSchema field, IArrowType type)
return type;
}

static IArrowReader ReadChunk(BigQueryReadClient readClient, string streamName)
static IArrowReader? ReadChunk(BigQueryReadClient readClient, string streamName)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the internal contract slightly because ReadChunk can now result in a null if the stream doesn't have any rows.

{
// Ideally we wouldn't need to indirect through a stream, but the necessary APIs in Arrow
// are internal. (TODO: consider changing Arrow).
Expand All @@ -217,7 +235,14 @@ static IArrowReader ReadChunk(BigQueryReadClient readClient, string streamName)

ReadRowsStream stream = new ReadRowsStream(enumerator);

return new ArrowStreamReader(stream);
if (stream.HasRows)
{
return new ArrowStreamReader(stream);
}
else
{
return null;
}
}

private QueryOptions ValidateOptions()
Expand Down Expand Up @@ -349,15 +374,27 @@ sealed class ReadRowsStream : Stream
ReadOnlyMemory<byte> currentBuffer;
bool first;
int position;
bool hasRows = true;

public ReadRowsStream(IAsyncEnumerator<ReadRowsResponse> response)
{
if (!response.MoveNextAsync().Result) { }
this.currentBuffer = response.Current.ArrowSchema.SerializedSchema.Memory;

if (response.Current != null)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A NullReferenceException occurs if there are no results from the query because response.Current is null. So, this uses an indicator of "HasRows" to dictate the behavior of it.

{
this.currentBuffer = response.Current.ArrowSchema.SerializedSchema.Memory;
}
else
{
this.hasRows = false;
}

this.response = response;
this.first = true;
}

public bool HasRows { get => this.hasRows; }

public override bool CanRead => true;

public override bool CanSeek => false;
Expand Down
Loading
Loading