Skip to content

Commit 861f009

Browse files
authored
feat(csharp): improve handling of StructArrays (#2587)
- improves the handling of structs to return objects or JsonString (defaults to JsonString to not break existing callers) - additional testing for each return type - updates to the ADO.NET wrapper to support both struct types - fixes #2586 --------- Co-authored-by: David Coe <>
1 parent 6a8e350 commit 861f009

File tree

11 files changed

+359
-114
lines changed

11 files changed

+359
-114
lines changed

csharp/src/Apache.Arrow.Adbc/Extensions/IArrowArrayExtensions.cs

+81-40
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,16 @@
2525

2626
namespace Apache.Arrow.Adbc.Extensions
2727
{
28+
public enum StructResultType
29+
{
30+
JsonString,
31+
Object
32+
}
33+
2834
public static class IArrowArrayExtensions
2935
{
3036
/// <summary>
31-
/// Helper extension to get a value from the <see cref="IArrowArray"/> at the specified index.
37+
/// Overloaded. Helper extension to get a value from the <see cref="IArrowArray"/> at the specified index.
3238
/// </summary>
3339
/// <param name="arrowArray">
3440
/// The Arrow array.
@@ -37,10 +43,30 @@ public static class IArrowArrayExtensions
3743
/// The index in the array to get the value from.
3844
/// </param>
3945
public static object? ValueAt(this IArrowArray arrowArray, int index)
46+
{
47+
return ValueAt(arrowArray, index, StructResultType.JsonString);
48+
}
49+
50+
/// <summary>
51+
/// Overloaded. Helper extension to get a value from the <see cref="IArrowArray"/> at the specified index.
52+
/// </summary>
53+
/// <param name="arrowArray">
54+
/// The Arrow array.
55+
/// </param>
56+
/// <param name="index">
57+
/// The index in the array to get the value from.
58+
/// </param>
59+
/// <param name="resultType">
60+
/// T
61+
/// </param>
62+
public static object? ValueAt(this IArrowArray arrowArray, int index, StructResultType resultType = StructResultType.JsonString)
4063
{
4164
if (arrowArray == null) throw new ArgumentNullException(nameof(arrowArray));
4265
if (index < 0) throw new ArgumentOutOfRangeException(nameof(index));
4366

67+
if (arrowArray.IsNull(index))
68+
return null;
69+
4470
switch (arrowArray.Data.DataType.TypeId)
4571
{
4672
case ArrowTypeId.Null:
@@ -127,39 +153,47 @@ public static class IArrowArrayExtensions
127153
throw new NotSupportedException($"Unsupported interval unit: {((IntervalType)arrowArray.Data.DataType).Unit}");
128154
}
129155
case ArrowTypeId.Binary:
130-
if (!arrowArray.IsNull(index))
131-
{
132-
return ((BinaryArray)arrowArray).GetBytes(index).ToArray();
133-
}
134-
else
135-
{
136-
return null;
137-
}
156+
return ((BinaryArray)arrowArray).GetBytes(index).ToArray();
138157
case ArrowTypeId.List:
139158
return ((ListArray)arrowArray).GetSlicedValues(index);
140159
case ArrowTypeId.Struct:
141-
return SerializeToJson(((StructArray)arrowArray), index);
160+
StructArray structArray = (StructArray)arrowArray;
161+
return resultType == StructResultType.JsonString ? SerializeToJson(structArray, index) : ParseStructArray(structArray, index);
142162

143-
// not covered:
144-
// -- map array
145-
// -- dictionary array
146-
// -- fixed size binary
147-
// -- union array
163+
// not covered:
164+
// -- map array
165+
// -- dictionary array
166+
// -- fixed size binary
167+
// -- union array
148168
}
149169

150170
return null;
151171
}
152172

153173
/// <summary>
154-
/// Helper extension to get a value from the <see cref="IArrowArray"/> at the specified index.
174+
/// Overloaded. Helper extension to get a value converter for the <see href="IArrowType"/>.
155175
/// </summary>
156-
/// <param name="arrowArray">
157-
/// The Arrow array.
158-
/// </param>
159-
/// <param name="index">
160-
/// The index in the array to get the value from.
176+
/// <param name="arrayType">
177+
/// The return type of an item in a StructArray.
161178
/// </param>
162179
public static Func<IArrowArray, int, object?> GetValueConverter(this IArrowType arrayType)
180+
{
181+
return GetValueConverter(arrayType, StructResultType.JsonString);
182+
}
183+
184+
/// <summary>
185+
/// Overloaded. Helper extension to get a value from the <see cref="IArrowArray"/> at the specified index.
186+
/// </summary>
187+
/// <param name="arrayType">
188+
/// The Arrow array type.
189+
/// </param>
190+
/// <param name="sourceType">
191+
/// The incoming <see cref="SourceStringType"/>.
192+
/// </param>
193+
/// <param name="resultType">
194+
/// The return type of an item in a StructArray.
195+
/// </param>
196+
public static Func<IArrowArray, int, object?> GetValueConverter(this IArrowType arrayType, StructResultType resultType)
163197
{
164198
if (arrayType == null) throw new ArgumentNullException(nameof(arrayType));
165199

@@ -198,7 +232,9 @@ public static class IArrowArrayExtensions
198232
case ArrowTypeId.Int64:
199233
return (array, index) => ((Int64Array)array).GetValue(index);
200234
case ArrowTypeId.String:
201-
return (array, index) => ((StringArray)array).GetString(index);
235+
return (array, index) => array.Data.DataType.TypeId == ArrowTypeId.Decimal256 ?
236+
((Decimal256Array)array).GetString(index) :
237+
((StringArray)array).GetString(index);
202238
#if NET6_0_OR_GREATER
203239
case ArrowTypeId.Time32:
204240
return (array, index) => ((Time32Array)array).GetTime(index);
@@ -256,7 +292,9 @@ public static class IArrowArrayExtensions
256292
case ArrowTypeId.List:
257293
return (array, index) => ((ListArray)array).GetSlicedValues(index);
258294
case ArrowTypeId.Struct:
259-
return (array, index) => SerializeToJson((StructArray)array, index);
295+
return resultType == StructResultType.JsonString ?
296+
(array, index) => SerializeToJson((StructArray)array, index) :
297+
(array, index) => ParseStructArray((StructArray)array, index);
260298

261299
// not covered:
262300
// -- map array
@@ -273,42 +311,45 @@ public static class IArrowArrayExtensions
273311
/// </summary>
274312
private static string SerializeToJson(StructArray structArray, int index)
275313
{
276-
Dictionary<String, object?>? jsonDictionary = ParseStructArray(structArray, index);
314+
Dictionary<string, object?>? obj = ParseStructArray(structArray, index);
277315

278-
return JsonSerializer.Serialize(jsonDictionary);
316+
return JsonSerializer.Serialize(obj);
279317
}
280318

281319
/// <summary>
282-
/// Converts a StructArray to a Dictionary<String, object?>.
320+
/// Converts an item in the StructArray at the index position to a Dictionary<string, object?>.
283321
/// </summary>
284-
private static Dictionary<String, object?>? ParseStructArray(StructArray structArray, int index)
322+
private static Dictionary<string, object?>? ParseStructArray(StructArray structArray, int index)
285323
{
286324
if (structArray.IsNull(index))
287325
return null;
288326

289-
Dictionary<String, object?> jsonDictionary = new Dictionary<String, object?>();
327+
Dictionary<string, object?> jsonDictionary = new Dictionary<string, object?>();
328+
290329
StructType structType = (StructType)structArray.Data.DataType;
291330
for (int i = 0; i < structArray.Data.Children.Length; i++)
292331
{
293332
string name = structType.Fields[i].Name;
294-
object? value = ValueAt(structArray.Fields[i], index);
333+
334+
// keep the results as StructArray internally
335+
object? value = ValueAt(structArray.Fields[i], index, StructResultType.Object);
295336

296337
if (value is StructArray structArray1)
297338
{
298-
List<Dictionary<string, object?>?> children = new List<Dictionary<string, object?>?>();
299-
300-
for (int j = 0; j < structArray1.Length; j++)
339+
if (structArray1.Length == 0)
301340
{
302-
children.Add(ParseStructArray(structArray1, j));
303-
}
304-
305-
if (children.Count > 0)
306-
{
307-
jsonDictionary.Add(name, children);
341+
jsonDictionary.Add(name, null);
308342
}
309343
else
310344
{
311-
jsonDictionary.Add(name, ParseStructArray(structArray1, index));
345+
List<Dictionary<string, object?>?> children = new List<Dictionary<string, object?>?>();
346+
347+
for (int j = 0; j < structArray1.Length; j++)
348+
{
349+
children.Add(ParseStructArray(structArray1, j));
350+
}
351+
352+
jsonDictionary.Add(name, children);
312353
}
313354
}
314355
else if (value is IArrowArray arrowArray)
@@ -319,7 +360,7 @@ private static string SerializeToJson(StructArray structArray, int index)
319360
{
320361
for (int j = 0; j < arrowArray.Length; j++)
321362
{
322-
values.Add(ValueAt(arrowArray, j));
363+
values.Add(ValueAt(arrowArray, j, StructResultType.Object));
323364
}
324365

325366
jsonDictionary.Add(name, values);

csharp/src/Client/AdbcDataReader.cs

+6-2
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,12 @@ internal AdbcDataReader(AdbcCommand adbcCommand, QueryResult adbcQueryResult, De
8686
this.DecimalBehavior = decimalBehavior;
8787
this.StructBehavior = structBehavior;
8888

89+
StructResultType structResultType = this.StructBehavior == StructBehavior.JsonString ? StructResultType.JsonString : StructResultType.Object;
90+
8991
this.converters = new Func<IArrowArray, int, object?>[this.schema.FieldsList.Count];
9092
for (int i = 0; i < this.converters.Length; i++)
9193
{
92-
this.converters[i] = this.schema.FieldsList[i].DataType.GetValueConverter();
94+
this.converters[i] = this.schema.FieldsList[i].DataType.GetValueConverter(structResultType);
9395
}
9496
}
9597

@@ -372,7 +374,9 @@ public ReadOnlyCollection<AdbcColumn> GetAdbcColumnSchema()
372374
}
373375
else
374376
{
375-
dbColumns.Add(new AdbcColumn(f.Name, t, f.DataType, f.IsNullable));
377+
IArrowType arrowType = SchemaConverter.GetArrowTypeBasedOnRequestedBehavior(f.DataType, this.StructBehavior);
378+
379+
dbColumns.Add(new AdbcColumn(f.Name, t, arrowType, f.IsNullable));
376380
}
377381
}
378382

csharp/src/Client/SchemaConverter.cs

+16-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717

1818
using System;
19+
using System.Collections.Generic;
1920
using System.Data;
2021
using System.Data.Common;
2122
using System.Data.SqlTypes;
@@ -60,7 +61,7 @@ public static DataTable ConvertArrowSchema(Schema schema, AdbcStatement adbcStat
6061
row[SchemaTableColumn.ColumnName] = f.Name;
6162
row[SchemaTableColumn.ColumnOrdinal] = columnOrdinal;
6263
row[SchemaTableColumn.AllowDBNull] = f.IsNullable;
63-
row[SchemaTableColumn.ProviderType] = f.DataType;
64+
row[SchemaTableColumn.ProviderType] = SchemaConverter.GetArrowTypeBasedOnRequestedBehavior(f.DataType, structBehavior);
6465
Type t = ConvertArrowType(f, decimalBehavior, structBehavior);
6566

6667
row[SchemaTableColumn.DataType] = t;
@@ -193,10 +194,7 @@ public static Type GetArrowType(Field f, DecimalBehavior decimalBehavior, Struct
193194
return typeof(string);
194195

195196
case ArrowTypeId.Struct:
196-
if (structBehavior == StructBehavior.JsonString)
197-
return typeof(string);
198-
else
199-
goto default;
197+
return structBehavior == StructBehavior.JsonString ? typeof(string) : typeof(Dictionary<string, object?>);
200198

201199
case ArrowTypeId.Timestamp:
202200
return typeof(DateTimeOffset);
@@ -271,5 +269,18 @@ public static Type GetArrowArrayType(IArrowType dataType)
271269

272270
throw new InvalidCastException($"Cannot determine the array type for {dataType.Name}");
273271
}
272+
273+
/// <summary>
274+
/// Get the IArrowType based on the input IArrowType and the desired <see cref="StructBehavior"/>.
275+
/// If it's a StructType and the desired behavior is a JsonString then this returns StringType.
276+
/// Otherwise, it returns the input IArrowType.
277+
/// </summary>
278+
/// <param name="defaultType">The default IArrowType to return.</param>
279+
/// <param name="structBehavior">Desired behavior if the IArrowType is a StructType.</param>
280+
/// <returns></returns>
281+
public static IArrowType GetArrowTypeBasedOnRequestedBehavior(IArrowType defaultType, StructBehavior structBehavior)
282+
{
283+
return defaultType.TypeId == ArrowTypeId.Struct && structBehavior == StructBehavior.JsonString ? StringType.Default : defaultType;
284+
}
274285
}
275286
}

csharp/src/Client/StructBehavior.cs

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
namespace Apache.Arrow.Adbc.Client
1919
{
20+
/// <summary>
21+
/// Controls the behavior of how StructArrays should be handled in the results.
22+
/// </summary>
2023
public enum StructBehavior
2124
{
2225
/// <summary>

csharp/src/Client/readme.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,5 @@ These properties are:
8080

8181
- __AdbcConnectionTimeout__ - This specifies the connection timeout value. The value needs to be in the form (driver.property.name, integer, unit) where the unit is one of `s` or `ms`, For example, `AdbcConnectionTimeout=(adbc.snowflake.sql.client_option.client_timeout,30,s)` would set the connection timeout to 30 seconds.
8282
- __AdbcCommandTimeout__ - This specifies the command timeout value. This follows the same pattern as `AdbcConnectionTimeout` and sets the `AdbcCommandTimeoutProperty` and `CommandTimeout` values on the `AdbcCommand` object.
83-
- __StructBehavior__ - This specifies the StructBehavior when working with Arrow Struct arrays. The valid values are `JsonString` (the default) or `Strict` (treat the struct as a native type).
83+
- __StructBehavior__ - This specifies the StructBehavior when working with Arrow Struct arrays. The valid values are `JsonString` (the default) or `Strict` (treat the struct as a native type). If using JsonString, the return ArrowType will be StringType and the result a string value. If using Strict, the return ArrowType will be StructType and the result a Dictionary<string, object?>.
8484
- __DecimalBehavior__ - This specifies the DecimalBehavior when parsing decimal values from Arrow libraries. The valid values are `UseSqlDecimal` or `OverflowDecimalAsString` where values like Decimal256 are treated as strings.

csharp/src/Drivers/BigQuery/BigQueryStatement.cs

+44-6
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,13 @@ public override QueryResult ExecuteQuery()
112112
ReadSession rrs = readClient.CreateReadSession("projects/" + results.TableReference.ProjectId, rs, maxStreamCount);
113113

114114
long totalRows = results.TotalRows == null ? -1L : (long)results.TotalRows.Value;
115-
IArrowArrayStream stream = new MultiArrowReader(TranslateSchema(results.Schema), rrs.Streams.Select(s => ReadChunk(readClient, s.Name)));
115+
116+
var readers = rrs.Streams
117+
.Select(s => ReadChunk(readClient, s.Name))
118+
.Where(chunk => chunk != null)
119+
.Cast<IArrowReader>();
120+
121+
IArrowArrayStream stream = new MultiArrowReader(TranslateSchema(results.Schema), readers);
116122

117123
return new QueryResult(totalRows, stream);
118124
}
@@ -175,8 +181,7 @@ private IArrowType TranslateType(TableFieldSchema field)
175181
case "DATE":
176182
return GetType(field, Date32Type.Default);
177183
case "RECORD" or "STRUCT":
178-
// its a json string
179-
return GetType(field, StringType.Default);
184+
return GetType(field, BuildStructType(field));
180185

181186
// treat these values as strings
182187
case "GEOGRAPHY" or "JSON":
@@ -200,6 +205,19 @@ private IArrowType TranslateType(TableFieldSchema field)
200205
}
201206
}
202207

208+
private StructType BuildStructType(TableFieldSchema field)
209+
{
210+
List<Field> arrowFields = new List<Field>();
211+
212+
foreach (TableFieldSchema subfield in field.Fields)
213+
{
214+
Field arrowField = TranslateField(subfield);
215+
arrowFields.Add(arrowField);
216+
}
217+
218+
return new StructType(arrowFields.AsReadOnly());
219+
}
220+
203221
private IArrowType GetType(TableFieldSchema field, IArrowType type)
204222
{
205223
if (field.Mode == "REPEATED")
@@ -208,7 +226,7 @@ private IArrowType GetType(TableFieldSchema field, IArrowType type)
208226
return type;
209227
}
210228

211-
static IArrowReader ReadChunk(BigQueryReadClient readClient, string streamName)
229+
static IArrowReader? ReadChunk(BigQueryReadClient readClient, string streamName)
212230
{
213231
// Ideally we wouldn't need to indirect through a stream, but the necessary APIs in Arrow
214232
// are internal. (TODO: consider changing Arrow).
@@ -217,7 +235,14 @@ static IArrowReader ReadChunk(BigQueryReadClient readClient, string streamName)
217235

218236
ReadRowsStream stream = new ReadRowsStream(enumerator);
219237

220-
return new ArrowStreamReader(stream);
238+
if (stream.HasRows)
239+
{
240+
return new ArrowStreamReader(stream);
241+
}
242+
else
243+
{
244+
return null;
245+
}
221246
}
222247

223248
private QueryOptions ValidateOptions()
@@ -349,15 +374,28 @@ sealed class ReadRowsStream : Stream
349374
ReadOnlyMemory<byte> currentBuffer;
350375
bool first;
351376
int position;
377+
bool hasRows;
352378

353379
public ReadRowsStream(IAsyncEnumerator<ReadRowsResponse> response)
354380
{
355381
if (!response.MoveNextAsync().Result) { }
356-
this.currentBuffer = response.Current.ArrowSchema.SerializedSchema.Memory;
382+
383+
if (response.Current != null)
384+
{
385+
this.currentBuffer = response.Current.ArrowSchema.SerializedSchema.Memory;
386+
this.hasRows = true;
387+
}
388+
else
389+
{
390+
this.hasRows = false;
391+
}
392+
357393
this.response = response;
358394
this.first = true;
359395
}
360396

397+
public bool HasRows => this.hasRows;
398+
361399
public override bool CanRead => true;
362400

363401
public override bool CanSeek => false;

0 commit comments

Comments
 (0)