Skip to content

Commit 028d22f

Browse files
feat(csharp/src/Drivers/Apache/Spark): Add Lz4 compression support to arrow batch reader (#2669)
1. Create a new file [Lz4Utilities.cs](main...eric-wang-1990:arrow-adbc:add_lz4_to_arrowbatch#diff-7275463b6b9fcc3cf3b954580e3ded8b0b7237a90e0f4aea33eb11e613f3db39) to abstract common Lz4 decompress util functions for both cloud fetch and arrow batch. 2. Add support for decompress arrow batch with Lz4 3. Rename adbc.spark.cloudfetch.lz4.enabled to adbc.spark.lz4Compression.enabled since it is not specific to cloudfetch 4. Add test to test both cloudFetch and arrowBatch in StatementTests.
1 parent aa1053a commit 028d22f

File tree

7 files changed

+198
-47
lines changed

7 files changed

+198
-47
lines changed

Diff for: csharp/src/Drivers/Apache/Spark/CloudFetch/SparkCloudFetchReader.cs

+35-42
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
using System.Net.Http;
2323
using System.Threading;
2424
using System.Threading.Tasks;
25+
using Apache.Arrow.Adbc.Drivers.Apache;
2526
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
27+
using Apache.Arrow.Adbc.Drivers.Apache.Spark;
2628
using Apache.Arrow.Ipc;
2729
using Apache.Hive.Service.Rpc.Thrift;
2830
using K4os.Compression.LZ4.Streams;
@@ -161,59 +163,52 @@ private HttpClient HttpClient
161163
var link = this.resultLinks[this.linkIndex++];
162164
byte[]? fileData = null;
163165

164-
// Retry logic for downloading files
165-
for (int retry = 0; retry < this.maxRetries; retry++)
166+
try
166167
{
167-
try
168+
// Try to download with retry logic
169+
for (int retry = 0; retry < this.maxRetries; retry++)
168170
{
169-
fileData = await DownloadFileAsync(link.FileLink, cancellationToken);
170-
break; // Success, exit retry loop
171+
try
172+
{
173+
fileData = await DownloadFileAsync(link.FileLink, cancellationToken);
174+
break; // Success, exit retry loop
175+
}
176+
catch (Exception) when (retry < this.maxRetries - 1)
177+
{
178+
// Only delay and retry if we haven't reached max retries
179+
await Task.Delay(this.retryDelayMs * (retry + 1), cancellationToken);
180+
}
171181
}
172-
catch (Exception ex) when (retry < this.maxRetries - 1)
182+
183+
// If download still failed after all retries
184+
if (fileData == null)
173185
{
174-
// Log the error and retry
175-
Debug.WriteLine($"Error downloading file (attempt {retry + 1}/{this.maxRetries}): {ex.Message}");
176-
await Task.Delay(this.retryDelayMs * (retry + 1), cancellationToken);
186+
throw new AdbcException($"Failed to download CloudFetch data from {link.FileLink} after {this.maxRetries} attempts");
177187
}
178-
}
179188

180-
// Process the downloaded file data
181-
MemoryStream dataStream;
189+
ReadOnlyMemory<byte> dataToUse = new ReadOnlyMemory<byte>(fileData);
182190

183-
// If the data is LZ4 compressed, decompress it
184-
if (this.isLz4Compressed)
185-
{
186-
try
187-
{
188-
dataStream = new MemoryStream();
189-
using (var inputStream = new MemoryStream(fileData!))
190-
using (var decompressor = LZ4Stream.Decode(inputStream))
191-
{
192-
await decompressor.CopyToAsync(dataStream);
193-
}
194-
dataStream.Position = 0;
195-
}
196-
catch (Exception ex)
191+
// If the data is LZ4 compressed, decompress it
192+
if (this.isLz4Compressed)
197193
{
198-
Debug.WriteLine($"Error decompressing data: {ex.Message}");
199-
continue; // Skip this link and try the next one
194+
dataToUse = Lz4Utilities.DecompressLz4(fileData);
200195
}
201-
}
202-
else
203-
{
204-
dataStream = new MemoryStream(fileData!);
205-
}
206196

207-
try
208-
{
209-
this.currentReader = new ArrowStreamReader(dataStream);
197+
// Use ChunkStream which supports ReadOnlyMemory<byte> directly
198+
this.currentReader = new ArrowStreamReader(new ChunkStream(this.schema, dataToUse));
210199
continue;
211200
}
212201
catch (Exception ex)
213202
{
214-
Debug.WriteLine($"Error creating Arrow reader: {ex.Message}");
215-
dataStream.Dispose();
216-
continue; // Skip this link and try the next one
203+
// Create concise error message based on exception type
204+
string errorPrefix = $"CloudFetch link {this.linkIndex-1}:";
205+
string errorMessage = ex switch
206+
{
207+
_ when ex.GetType().Name.Contains("LZ4") => $"{errorPrefix} LZ4 decompression failed - Data may be corrupted",
208+
HttpRequestException or TaskCanceledException => $"{errorPrefix} Download failed - {ex.Message}",
209+
_ => $"{errorPrefix} Processing failed - {ex.Message}" // Default case for any other exception
210+
};
211+
throw new AdbcException(errorMessage, ex);
217212
}
218213
}
219214

@@ -242,9 +237,7 @@ private HttpClient HttpClient
242237
}
243238
catch (Exception ex)
244239
{
245-
Debug.WriteLine($"Error fetching results from server: {ex.Message}");
246-
this.statement = null; // Mark as done due to error
247-
return null;
240+
throw new AdbcException($"Server request failed - {ex.Message}", ex);
248241
}
249242

250243
// Check if we have URL-based results

Diff for: csharp/src/Drivers/Apache/Spark/Lz4Utilities.cs

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
using System;
19+
using System.Buffers;
20+
using System.IO;
21+
using K4os.Compression.LZ4.Streams;
22+
23+
namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
24+
{
25+
/// <summary>
26+
/// Utility class for LZ4 compression/decompression operations.
27+
/// </summary>
28+
internal static class Lz4Utilities
29+
{
30+
/// <summary>
31+
/// Decompresses LZ4 compressed data into memory.
32+
/// </summary>
33+
/// <param name="compressedData">The compressed data bytes.</param>
34+
/// <returns>A ReadOnlyMemory containing the decompressed data.</returns>
35+
/// <exception cref="AdbcException">Thrown when decompression fails.</exception>
36+
public static ReadOnlyMemory<byte> DecompressLz4(byte[] compressedData)
37+
{
38+
try
39+
{
40+
var outputStream = new MemoryStream();
41+
using (var inputStream = new MemoryStream(compressedData))
42+
using (var decompressor = LZ4Stream.Decode(inputStream))
43+
{
44+
decompressor.CopyTo(outputStream);
45+
}
46+
// Get the underlying buffer and its valid length without copying
47+
return new ReadOnlyMemory<byte>(outputStream.GetBuffer(), 0, (int)outputStream.Length);
48+
// Note: We're not disposing the outputStream here because we're returning its buffer.
49+
// The memory will be reclaimed when the ReadOnlyMemory is no longer referenced.
50+
}
51+
catch (Exception ex)
52+
{
53+
throw new AdbcException($"Failed to decompress LZ4 data: {ex.Message}", ex);
54+
}
55+
}
56+
}
57+
}

Diff for: csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ internal override IArrowArrayStream NewReader<T>(T statement, Schema schema, TGe
5656
}
5757
else
5858
{
59-
return new SparkDatabricksReader(statement, schema);
59+
return new SparkDatabricksReader(statement, schema, isLz4Compressed);
6060
}
6161
}
6262

Diff for: csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs

+50-1
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
* limitations under the License.
1616
*/
1717

18+
using System;
1819
using System.Collections.Generic;
20+
using System.IO;
1921
using System.Threading;
2022
using System.Threading.Tasks;
2123
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
2224
using Apache.Arrow.Ipc;
2325
using Apache.Hive.Service.Rpc.Thrift;
26+
using K4os.Compression.LZ4.Streams;
2427

2528
namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
2629
{
@@ -31,11 +34,18 @@ internal sealed class SparkDatabricksReader : IArrowArrayStream
3134
List<TSparkArrowBatch>? batches;
3235
int index;
3336
IArrowReader? reader;
37+
bool isLz4Compressed;
3438

3539
public SparkDatabricksReader(HiveServer2Statement statement, Schema schema)
40+
: this(statement, schema, false)
41+
{
42+
}
43+
44+
public SparkDatabricksReader(HiveServer2Statement statement, Schema schema, bool isLz4Compressed)
3645
{
3746
this.statement = statement;
3847
this.schema = schema;
48+
this.isLz4Compressed = isLz4Compressed;
3949
}
4050

4151
public Schema Schema { get { return schema; } }
@@ -56,7 +66,7 @@ public SparkDatabricksReader(HiveServer2Statement statement, Schema schema)
5666

5767
if (this.batches != null && this.index < this.batches.Count)
5868
{
59-
this.reader = new ArrowStreamReader(new ChunkStream(this.schema, this.batches[this.index++].Batch));
69+
ProcessFetchedBatches();
6070
continue;
6171
}
6272

@@ -70,6 +80,8 @@ public SparkDatabricksReader(HiveServer2Statement statement, Schema schema)
7080

7181
TFetchResultsReq request = new TFetchResultsReq(this.statement.OperationHandle!, TFetchOrientation.FETCH_NEXT, this.statement.BatchSize);
7282
TFetchResultsResp response = await this.statement.Connection.Client!.FetchResults(request, cancellationToken);
83+
84+
// Make sure we get the arrowBatches
7385
this.batches = response.Results.ArrowBatches;
7486

7587
if (!response.HasMoreRows)
@@ -79,6 +91,43 @@ public SparkDatabricksReader(HiveServer2Statement statement, Schema schema)
7991
}
8092
}
8193

94+
private void ProcessFetchedBatches()
95+
{
96+
var batch = this.batches![this.index];
97+
98+
// Ensure batch data exists
99+
if (batch.Batch == null || batch.Batch.Length == 0)
100+
{
101+
this.index++;
102+
return;
103+
}
104+
105+
try
106+
{
107+
ReadOnlyMemory<byte> dataToUse = new ReadOnlyMemory<byte>(batch.Batch);
108+
109+
// If LZ4 compression is enabled, decompress the data
110+
if (isLz4Compressed)
111+
{
112+
dataToUse = Lz4Utilities.DecompressLz4(batch.Batch);
113+
}
114+
115+
// Always use ChunkStream which ensures proper schema handling
116+
this.reader = new ArrowStreamReader(new ChunkStream(this.schema, dataToUse));
117+
}
118+
catch (Exception ex)
119+
{
120+
// Create concise error message based on exception type
121+
string errorMessage = ex switch
122+
{
123+
_ when ex.GetType().Name.Contains("LZ4") => $"Batch {this.index}: LZ4 decompression failed - Data may be corrupted",
124+
_ => $"Batch {this.index}: Processing failed - {ex.Message}" // Default case for any other exception
125+
};
126+
throw new AdbcException(errorMessage, ex);
127+
}
128+
this.index++;
129+
}
130+
82131
public void Dispose()
83132
{
84133
}

Diff for: csharp/src/Drivers/Apache/Spark/SparkStatement.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,10 @@ internal void SetMaxBytesPerFile(long maxBytesPerFile)
155155
/// </summary>
156156
public sealed class Options : ApacheParameters
157157
{
158+
// Lz4 compression option
159+
public const string CanDecompressLz4 = "adbc.spark.lz4_compression.enabled";
158160
// CloudFetch options
159161
public const string UseCloudFetch = "adbc.spark.cloudfetch.enabled";
160-
public const string CanDecompressLz4 = "adbc.spark.cloudfetch.lz4.enabled";
161162
public const string MaxBytesPerFile = "adbc.spark.cloudfetch.max_bytes_per_file";
162163
}
163164
}

Diff for: csharp/src/Drivers/Apache/Thrift/ChunkStream.cs

+8-2
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,17 @@ namespace Apache.Arrow.Adbc.Drivers.Apache
2626
internal class ChunkStream : Stream
2727
{
2828
ReadOnlyMemory<byte> currentBuffer;
29-
byte[] data;
29+
ReadOnlyMemory<byte> data;
3030
bool first;
3131
int position;
3232

3333
public ChunkStream(Schema schema, byte[] data)
34+
: this(schema, new ReadOnlyMemory<byte>(data))
35+
{
36+
// Call the other constructor to avoid duplication
37+
}
38+
39+
public ChunkStream(Schema schema, ReadOnlyMemory<byte> data)
3440
{
3541
MemoryStream buffer = new MemoryStream();
3642
ArrowStreamWriter writer = new ArrowStreamWriter(buffer, schema, leaveOpen: true);
@@ -70,7 +76,7 @@ public override int Read(byte[] buffer, int offset, int count)
7076
{
7177
return 0;
7278
}
73-
this.currentBuffer = new ReadOnlyMemory<byte>(this.data);
79+
this.currentBuffer = this.data;
7480
this.position = 0;
7581
remaining = this.currentBuffer.Length - this.position;
7682
}

Diff for: csharp/test/Drivers/Apache/Spark/StatementTests.cs

+45
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
using System;
1919
using System.Collections.Generic;
20+
using System.Linq;
2021
using System.Threading.Tasks;
2122
using Apache.Arrow.Adbc.Drivers.Apache.Spark;
2223
using Apache.Arrow.Adbc.Tests.Drivers.Apache.Common;
@@ -39,6 +40,50 @@ internal override void StatementTimeoutTest(StatementWithExceptions statementWit
3940
base.StatementTimeoutTest(statementWithExceptions);
4041
}
4142

43+
[SkippableTheory]
44+
[InlineData(true, "CloudFetch enabled")]
45+
[InlineData(false, "CloudFetch disabled")]
46+
public async Task LZ4DecompressionCapabilityTest(bool useCloudFetch, string configName)
47+
{
48+
OutputHelper?.WriteLine($"Testing with LZ4 decompression capability enabled ({configName})");
49+
50+
// Create a connection using the test configuration
51+
using AdbcConnection connection = NewConnection();
52+
using var statement = connection.CreateStatement();
53+
54+
// Set options for LZ4 decompression (enabled by default) and CloudFetch as specified
55+
statement.SetOption(SparkStatement.Options.UseCloudFetch, useCloudFetch.ToString().ToLower());
56+
OutputHelper?.WriteLine($"CloudFetch is {(useCloudFetch ? "enabled" : "disabled")}");
57+
OutputHelper?.WriteLine("LZ4 decompression capability is enabled by default");
58+
59+
// Execute a query that should return data
60+
statement.SqlQuery = "SELECT id, CAST(id AS STRING) as id_string, id * 2 as id_doubled FROM RANGE(100)";
61+
QueryResult result = statement.ExecuteQuery();
62+
63+
// Verify we have a valid stream
64+
Assert.NotNull(result.Stream);
65+
66+
// Read all batches
67+
int totalRows = 0;
68+
int batchCount = 0;
69+
70+
while (result.Stream != null)
71+
{
72+
using var batch = await result.Stream.ReadNextRecordBatchAsync();
73+
if (batch == null)
74+
break;
75+
76+
batchCount++;
77+
totalRows += batch.Length;
78+
OutputHelper?.WriteLine($"Batch {batchCount}: Read {batch.Length} rows");
79+
}
80+
81+
// Verify we got all rows
82+
Assert.Equal(100, totalRows);
83+
OutputHelper?.WriteLine($"Successfully read {totalRows} rows in {batchCount} batches with {configName}");
84+
OutputHelper?.WriteLine("NOTE: Whether actual LZ4 compression was used is determined by the server");
85+
}
86+
4287
internal class LongRunningStatementTimeoutTestData : ShortRunningStatementTimeoutTestData
4388
{
4489
public LongRunningStatementTimeoutTestData()

0 commit comments

Comments
 (0)