diff --git a/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs b/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs index f4f4441e24..2c88fcc732 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkAuthType.cs @@ -23,6 +23,7 @@ internal enum SparkAuthType UsernameOnly, Basic, Token, + OAuth, Empty = int.MaxValue, } @@ -48,6 +49,9 @@ internal static bool TryParse(string? authType, out SparkAuthType authTypeValue) case SparkAuthTypeConstants.Token: authTypeValue = SparkAuthType.Token; return true; + case SparkAuthTypeConstants.OAuth: + authTypeValue = SparkAuthType.OAuth; + return true; default: authTypeValue = default; return false; diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs index 75abb1196b..3edb7614e0 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs @@ -51,6 +51,7 @@ protected override void ValidateAuthentication() Properties.TryGetValue(AdbcOptions.Username, out string? username); Properties.TryGetValue(AdbcOptions.Password, out string? password); Properties.TryGetValue(SparkParameters.AuthType, out string? authType); + Properties.TryGetValue(SparkParameters.AccessToken, out string? access_token); if (!SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue)) { throw new ArgumentOutOfRangeException(SparkParameters.AuthType, authType, $"Unsupported {SparkParameters.AuthType} value."); @@ -83,6 +84,13 @@ protected override void ValidateAuthentication() $"Parameters must include valid authentiation settings. Please provide either '{SparkParameters.Token}'; or '{AdbcOptions.Username}' and '{AdbcOptions.Password}'.", nameof(Properties)); break; + + case SparkAuthType.OAuth: + if (string.IsNullOrWhiteSpace(access_token)) + throw new ArgumentException( + $"Parameter '{SparkParameters.AuthType}' is set to '{SparkAuthTypeConstants.OAuth}' but parameter '{SparkParameters.AccessToken}' is not set. Please provide a value for '{SparkParameters.AccessToken}'.", + nameof(Properties)); + break; default: throw new ArgumentOutOfRangeException(SparkParameters.AuthType, authType, $"Unsupported {SparkParameters.AuthType} value."); } @@ -146,12 +154,13 @@ protected override TTransport CreateTransport() throw new ArgumentOutOfRangeException(SparkParameters.AuthType, authType, $"Unsupported {SparkParameters.AuthType} value."); } Properties.TryGetValue(SparkParameters.Token, out string? token); + Properties.TryGetValue(SparkParameters.AccessToken, out string? access_token); Properties.TryGetValue(AdbcOptions.Username, out string? username); Properties.TryGetValue(AdbcOptions.Password, out string? password); Properties.TryGetValue(AdbcOptions.Uri, out string? uri); Uri baseAddress = GetBaseAddress(uri, hostName, path, port, SparkParameters.HostName); - AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, token, username, password); + AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, token, username, password, access_token); HttpClientHandler httpClientHandler = NewHttpClientHandler(); HttpClient httpClient = new(httpClientHandler); @@ -191,7 +200,7 @@ private HttpClientHandler NewHttpClientHandler() return httpClientHandler; } - private static AuthenticationHeaderValue? GetAuthenticationHeaderValue(SparkAuthType authType, string? token, string? username, string? password) + private static AuthenticationHeaderValue? GetAuthenticationHeaderValue(SparkAuthType authType, string? token, string? username, string? password, string? access_token) { if (!string.IsNullOrEmpty(token) && (authType == SparkAuthType.Empty || authType == SparkAuthType.Token)) { @@ -205,6 +214,10 @@ private HttpClientHandler NewHttpClientHandler() { return new AuthenticationHeaderValue(BasicAuthenticationScheme, Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:"))); } + else if (!string.IsNullOrEmpty(access_token) && authType == SparkAuthType.OAuth) + { + return new AuthenticationHeaderValue(BearerAuthenticationScheme, access_token); + } else if (authType == SparkAuthType.None) { return null; diff --git a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs index 6cb96dd5f1..d3bd9be6d5 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs @@ -26,6 +26,9 @@ public static class SparkParameters public const string Port = "adbc.spark.port"; public const string Path = "adbc.spark.path"; public const string Token = "adbc.spark.token"; + + // access_token is required when authType is oauth + public const string AccessToken = "adbc.spark.access_token"; public const string AuthType = "adbc.spark.auth_type"; public const string Type = "adbc.spark.type"; public const string DataTypeConv = "adbc.spark.data_type_conv"; @@ -39,6 +42,7 @@ public static class SparkAuthTypeConstants public const string UsernameOnly = "username_only"; public const string Basic = "basic"; public const string Token = "token"; + public const string OAuth = "oauth"; } public static class SparkServerTypeConstants diff --git a/csharp/test/Drivers/Apache/Spark/DriverTests.cs b/csharp/test/Drivers/Apache/Spark/DriverTests.cs index 03699c7a70..ec49f38a0d 100644 --- a/csharp/test/Drivers/Apache/Spark/DriverTests.cs +++ b/csharp/test/Drivers/Apache/Spark/DriverTests.cs @@ -84,19 +84,24 @@ public override void CanDetectInvalidAuthentication() Dictionary parameters = GetDriverParameters(TestConfiguration); bool hasToken = parameters.TryGetValue(SparkParameters.Token, out var token) && !string.IsNullOrEmpty(token); + bool hasAccessToken = parameters.TryGetValue(SparkParameters.Token, out var access_token) && !string.IsNullOrEmpty(access_token); bool hasUsername = parameters.TryGetValue(AdbcOptions.Username, out var username) && !string.IsNullOrEmpty(username); bool hasPassword = parameters.TryGetValue(AdbcOptions.Password, out var password) && !string.IsNullOrEmpty(password); if (hasToken) { parameters[SparkParameters.Token] = "invalid-token"; } + else if (hasAccessToken) + { + parameters[SparkParameters.AccessToken] = "invalid-access-token"; + } else if (hasUsername && hasPassword) { parameters[AdbcOptions.Password] = "invalid-password"; } else { - Assert.Fail($"Unexpected configuration. Must provide '{SparkParameters.Token}' or '{AdbcOptions.Username}' and '{AdbcOptions.Password}'."); + Assert.Fail($"Unexpected configuration. Must provide '{SparkParameters.Token}' or '{SparkParameters.AccessToken}' or '{AdbcOptions.Username}' and '{AdbcOptions.Password}'."); } AdbcDatabase database = driver.Open(parameters); diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestConfiguration.cs b/csharp/test/Drivers/Apache/Spark/SparkTestConfiguration.cs index 5ada5abeb3..abf4400bc1 100644 --- a/csharp/test/Drivers/Apache/Spark/SparkTestConfiguration.cs +++ b/csharp/test/Drivers/Apache/Spark/SparkTestConfiguration.cs @@ -25,5 +25,8 @@ public class SparkTestConfiguration : ApacheTestConfiguration [JsonPropertyName("token"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] public string Token { get; set; } = string.Empty; + [JsonPropertyName("access_token"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string AccessToken { get; set; } = string.Empty; + } } diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs index ca0fb1fe71..9c9124d6d4 100644 --- a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs +++ b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs @@ -77,6 +77,10 @@ public override Dictionary GetDriverParameters(SparkTestConfigur { parameters.Add(SparkParameters.Token, testConfiguration.Token!); } + if (!string.IsNullOrEmpty(testConfiguration.AccessToken)) + { + parameters.Add(SparkParameters.AccessToken, testConfiguration.AccessToken); + } if (!string.IsNullOrEmpty(testConfiguration.Username)) { parameters.Add(AdbcOptions.Username, testConfiguration.Username!);