diff --git a/csharp/src/Drivers/Databricks/DatabricksConnection.cs b/csharp/src/Drivers/Databricks/DatabricksConnection.cs index a12f9bde96..45a496b64e 100644 --- a/csharp/src/Drivers/Databricks/DatabricksConnection.cs +++ b/csharp/src/Drivers/Databricks/DatabricksConnection.cs @@ -33,12 +33,72 @@ internal class DatabricksConnection : SparkHttpConnection { private bool _applySSPWithQueries = false; + // CloudFetch configuration + private const long DefaultMaxBytesPerFile = 20 * 1024 * 1024; // 20MB + + private bool _useCloudFetch = true; + private bool _canDecompressLz4 = true; + private long _maxBytesPerFile = DefaultMaxBytesPerFile; + public DatabricksConnection(IReadOnlyDictionary properties) : base(properties) { - if (Properties.TryGetValue(DatabricksParameters.ApplySSPWithQueries, out string? applySSPWithQueriesStr) && - bool.TryParse(applySSPWithQueriesStr, out bool applySSPWithQueriesValue)) + ValidateProperties(); + } + + private void ValidateProperties() + { + if (Properties.TryGetValue(DatabricksParameters.ApplySSPWithQueries, out string? applySSPWithQueriesStr)) + { + if (bool.TryParse(applySSPWithQueriesStr, out bool applySSPWithQueriesValue)) + { + _applySSPWithQueries = applySSPWithQueriesValue; + } + else + { + throw new ArgumentException($"Parameter '{DatabricksParameters.ApplySSPWithQueries}' value '{applySSPWithQueriesStr}' could not be parsed. Valid values are 'true' and 'false'."); + } + } + + // Parse CloudFetch options from connection properties + if (Properties.TryGetValue(DatabricksParameters.UseCloudFetch, out string? useCloudFetchStr)) + { + if (bool.TryParse(useCloudFetchStr, out bool useCloudFetchValue)) + { + _useCloudFetch = useCloudFetchValue; + } + else + { + throw new ArgumentException($"Parameter '{DatabricksParameters.UseCloudFetch}' value '{useCloudFetchStr}' could not be parsed. Valid values are 'true' and 'false'."); + } + } + + if (Properties.TryGetValue(DatabricksParameters.CanDecompressLz4, out string? canDecompressLz4Str)) + { + if (bool.TryParse(canDecompressLz4Str, out bool canDecompressLz4Value)) + { + _canDecompressLz4 = canDecompressLz4Value; + } + else + { + throw new ArgumentException($"Parameter '{DatabricksParameters.CanDecompressLz4}' value '{canDecompressLz4Str}' could not be parsed. Valid values are 'true' and 'false'."); + } + } + + if (Properties.TryGetValue(DatabricksParameters.MaxBytesPerFile, out string? maxBytesPerFileStr)) { - _applySSPWithQueries = applySSPWithQueriesValue; + if (!long.TryParse(maxBytesPerFileStr, out long maxBytesPerFileValue)) + { + throw new ArgumentException($"Parameter '{DatabricksParameters.MaxBytesPerFile}' value '{maxBytesPerFileStr}' could not be parsed. Valid values are positive integers."); + } + + if (maxBytesPerFileValue <= 0) + { + throw new ArgumentOutOfRangeException( + nameof(Properties), + maxBytesPerFileValue, + $"Parameter '{DatabricksParameters.MaxBytesPerFile}' value must be a positive integer."); + } + _maxBytesPerFile = maxBytesPerFileValue; } } @@ -47,6 +107,21 @@ public DatabricksConnection(IReadOnlyDictionary properties) : ba /// internal bool ApplySSPWithQueries => _applySSPWithQueries; + /// + /// Gets whether CloudFetch is enabled. + /// + internal bool UseCloudFetch => _useCloudFetch; + + /// + /// Gets whether LZ4 decompression is enabled. + /// + internal bool CanDecompressLz4 => _canDecompressLz4; + + /// + /// Gets the maximum bytes per file for CloudFetch. + /// + internal long MaxBytesPerFile => _maxBytesPerFile; + internal override IArrowArrayStream NewReader(T statement, Schema schema, TGetResultSetMetadataResp? metadataResp = null) { // Get result format from metadata response if available diff --git a/csharp/src/Drivers/Databricks/DatabricksParameters.cs b/csharp/src/Drivers/Databricks/DatabricksParameters.cs index 9d06a9479e..1c963b4d7b 100644 --- a/csharp/src/Drivers/Databricks/DatabricksParameters.cs +++ b/csharp/src/Drivers/Databricks/DatabricksParameters.cs @@ -25,6 +25,24 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks public class DatabricksParameters : SparkParameters { // CloudFetch configuration parameters + /// + /// Whether to use CloudFetch for retrieving results. + /// Default value is true if not specified. + /// + public const string UseCloudFetch = "adbc.databricks.cloudfetch.enabled"; + + /// + /// Whether the client can decompress LZ4 compressed results. + /// Default value is true if not specified. + /// + public const string CanDecompressLz4 = "adbc.databricks.cloudfetch.lz4.enabled"; + + /// + /// Maximum bytes per file for CloudFetch. + /// Default value is 20MB if not specified. + /// + public const string MaxBytesPerFile = "adbc.databricks.cloudfetch.max_bytes_per_file"; + /// /// Maximum number of retry attempts for CloudFetch downloads. /// Default value is 3 if not specified. diff --git a/csharp/src/Drivers/Databricks/DatabricksStatement.cs b/csharp/src/Drivers/Databricks/DatabricksStatement.cs index 62f576eb98..f2c9e643ad 100644 --- a/csharp/src/Drivers/Databricks/DatabricksStatement.cs +++ b/csharp/src/Drivers/Databricks/DatabricksStatement.cs @@ -27,18 +27,17 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks /// internal class DatabricksStatement : SparkStatement { - // Default maximum bytes per file for CloudFetch - private const long DefaultMaxBytesPerFile = 20 * 1024 * 1024; // 20MB - - // CloudFetch configuration - private bool useCloudFetch = true; - private bool canDecompressLz4 = true; - private long maxBytesPerFile = DefaultMaxBytesPerFile; + private bool useCloudFetch; + private bool canDecompressLz4; + private long maxBytesPerFile; public DatabricksStatement(DatabricksConnection connection) : base(connection) { - + // Inherit CloudFetch settings from connection + useCloudFetch = connection.UseCloudFetch; + canDecompressLz4 = connection.CanDecompressLz4; + maxBytesPerFile = connection.MaxBytesPerFile; } protected override void SetStatementProperties(TExecuteStatementReq statement) @@ -55,7 +54,7 @@ public override void SetOption(string key, string value) { switch (key) { - case Options.UseCloudFetch: + case DatabricksParameters.UseCloudFetch: if (bool.TryParse(value, out bool useCloudFetchValue)) { this.useCloudFetch = useCloudFetchValue; @@ -65,7 +64,7 @@ public override void SetOption(string key, string value) throw new ArgumentException($"Invalid value for {key}: {value}. Expected a boolean value."); } break; - case Options.CanDecompressLz4: + case DatabricksParameters.CanDecompressLz4: if (bool.TryParse(value, out bool canDecompressLz4Value)) { this.canDecompressLz4 = canDecompressLz4Value; @@ -75,7 +74,7 @@ public override void SetOption(string key, string value) throw new ArgumentException($"Invalid value for {key}: {value}. Expected a boolean value."); } break; - case Options.MaxBytesPerFile: + case DatabricksParameters.MaxBytesPerFile: if (long.TryParse(value, out long maxBytesPerFileValue)) { this.maxBytesPerFile = maxBytesPerFileValue; @@ -132,16 +131,5 @@ internal void SetMaxBytesPerFile(long maxBytesPerFile) { this.maxBytesPerFile = maxBytesPerFile; } - - /// - /// Provides the constant string key values to the method. - /// - public sealed class Options : ApacheParameters - { - // CloudFetch options - public const string UseCloudFetch = "adbc.databricks.cloudfetch.enabled"; - public const string CanDecompressLz4 = "adbc.databricks.cloudfetch.lz4.enabled"; - public const string MaxBytesPerFile = "adbc.databricks.cloudfetch.max_bytes_per_file"; - } } } diff --git a/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs b/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs index e57a55ec29..0d9bbfa90b 100644 --- a/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs +++ b/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs @@ -50,13 +50,25 @@ public async Task TestRealDatabricksCloudFetchLargeResultSet() await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM main.tpcds_sf10_delta.catalog_sales LIMIT 1000000", 1000000); } - private async Task TestRealDatabricksCloudFetchLargeQuery(string query, int rowCount) + [Fact] + public async Task TestRealDatabricksNoCloudFetchSmallResultSet() + { + await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM range(1000)", 1000, false); + } + + [Fact] + public async Task TestRealDatabricksNoCloudFetchLargeResultSet() + { + await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM main.tpcds_sf10_delta.catalog_sales LIMIT 1000000", 1000000, false); + } + + private async Task TestRealDatabricksCloudFetchLargeQuery(string query, int rowCount, bool useCloudFetch = true) { // Create a statement with CloudFetch enabled var statement = Connection.CreateStatement(); - statement.SetOption(DatabricksStatement.Options.UseCloudFetch, "true"); - statement.SetOption(DatabricksStatement.Options.CanDecompressLz4, "true"); - statement.SetOption(DatabricksStatement.Options.MaxBytesPerFile, "10485760"); // 10MB + statement.SetOption(DatabricksParameters.UseCloudFetch, useCloudFetch.ToString()); + statement.SetOption(DatabricksParameters.CanDecompressLz4, "true"); + statement.SetOption(DatabricksParameters.MaxBytesPerFile, "10485760"); // 10MB // Execute a query that generates a large result set using range function statement.SqlQuery = query; diff --git a/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs b/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs index a5667de537..3462f15058 100644 --- a/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs +++ b/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs @@ -22,6 +22,7 @@ using Apache.Arrow.Adbc.Drivers.Apache; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Apache.Arrow.Adbc.Drivers.Databricks; using Thrift.Transport; using Xunit; using Xunit.Abstractions; @@ -295,6 +296,10 @@ public InvalidConnectionParametersTestData() Add(new([], typeof(ArgumentException))); Add(new(new() { [SparkParameters.Type] = " " }, typeof(ArgumentException))); Add(new(new() { [SparkParameters.Type] = "xxx" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [DatabricksParameters.UseCloudFetch] = "notabool" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [DatabricksParameters.CanDecompressLz4] = "notabool"}, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [DatabricksParameters.MaxBytesPerFile] = "notanumber" }, typeof(ArgumentException))); + Add(new(new() { [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [DatabricksParameters.MaxBytesPerFile] = "-100" }, typeof(ArgumentOutOfRangeException))); Add(new(new() { /*[SparkParameters.Type] = SparkServerTypeConstants.Databricks,*/ [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = "-1" }, typeof(ArgumentOutOfRangeException))); Add(new(new() { /*[SparkParameters.Type] = SparkServerTypeConstants.Databricks,*/ [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = IPEndPoint.MinPort.ToString(CultureInfo.InvariantCulture) }, typeof(ArgumentOutOfRangeException))); Add(new(new() { /*[SparkParameters.Type] = SparkServerTypeConstants.Databricks,*/ [SparkParameters.HostName] = "valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port] = (IPEndPoint.MaxPort + 1).ToString(CultureInfo.InvariantCulture) }, typeof(ArgumentOutOfRangeException))); diff --git a/csharp/test/Drivers/Databricks/StatementTests.cs b/csharp/test/Drivers/Databricks/StatementTests.cs index 8f06c63c8d..16bc67bd45 100644 --- a/csharp/test/Drivers/Databricks/StatementTests.cs +++ b/csharp/test/Drivers/Databricks/StatementTests.cs @@ -44,7 +44,7 @@ public async Task LZ4DecompressionCapabilityTest(bool useCloudFetch, string conf using var statement = connection.CreateStatement(); // Set options for LZ4 decompression (enabled by default) and CloudFetch as specified - statement.SetOption(DatabricksStatement.Options.UseCloudFetch, useCloudFetch.ToString().ToLower()); + statement.SetOption(DatabricksParameters.UseCloudFetch, useCloudFetch.ToString().ToLower()); OutputHelper?.WriteLine($"CloudFetch is {(useCloudFetch ? "enabled" : "disabled")}"); OutputHelper?.WriteLine("LZ4 decompression capability is enabled by default");