diff --git a/csharp/src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj b/csharp/src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj index c46342cdce..96245e648a 100644 --- a/csharp/src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj +++ b/csharp/src/Apache.Arrow.Adbc/Apache.Arrow.Adbc.csproj @@ -7,6 +7,7 @@ + diff --git a/csharp/src/Apache.Arrow.Adbc/Tracing/ActivityExtensions.cs b/csharp/src/Apache.Arrow.Adbc/Tracing/ActivityExtensions.cs new file mode 100644 index 0000000000..24bab84cb5 --- /dev/null +++ b/csharp/src/Apache.Arrow.Adbc/Tracing/ActivityExtensions.cs @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Collections.Generic; +using System.Diagnostics; + +namespace Apache.Arrow.Adbc.Tracing +{ + public static class ActivityExtensions + { + /// + /// Add a new object to the list. + /// + /// The activity to add the event to. + /// The name of the event. + /// The optional list of tags to attach to the event. + /// for convenient chaining. + public static Activity AddEvent(this Activity activity, string eventName, IReadOnlyList>? tags = default) + { + ActivityTagsCollection? tagsCollection = tags == null ? null : new ActivityTagsCollection(tags); + return activity.AddEvent(new ActivityEvent(eventName, tags: tagsCollection)); + } + + /// + /// Add a new to the list. + /// + /// The activity to add the event to. + /// The traceParent id for the associated . + /// The optional list of tags to attach to the event. + /// for convenient chaining. + public static Activity AddLink(this Activity activity, string traceParent, IReadOnlyList>? tags = default) + { + ActivityTagsCollection? tagsCollection = tags == null ? null : new ActivityTagsCollection(tags); + return activity.AddLink(new ActivityLink(ActivityContext.Parse(traceParent, null), tags: tagsCollection)); + } + } +} diff --git a/csharp/src/Apache.Arrow.Adbc/Tracing/ActivityTrace.cs b/csharp/src/Apache.Arrow.Adbc/Tracing/ActivityTrace.cs new file mode 100644 index 0000000000..a9193c11e8 --- /dev/null +++ b/csharp/src/Apache.Arrow.Adbc/Tracing/ActivityTrace.cs @@ -0,0 +1,351 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; + +namespace Apache.Arrow.Adbc.Tracing +{ + /// + /// Provides a base implementation for a tracing source. If drivers want to enable tracing, + /// they need to add a trace listener (e.g., ). + /// + public class ActivityTrace + { + private const string ProductVersionDefault = "1.0.0"; + private static readonly string s_assemblyVersion = GetProductVersion(); + private bool _disposedValue; + + /// + /// Constructs a new object. If is set, it provides the + /// activity source name, otherwise the current assembly name is used as the acctivity source name. + /// + /// + public ActivityTrace(string? activitySourceName = default, string? traceParent = default) + { + activitySourceName ??= GetType().Assembly.GetName().Name!; + if (string.IsNullOrWhiteSpace(activitySourceName)) + { + throw new ArgumentNullException(nameof(activitySourceName)); + } + + // This is required to be disposed + ActivitySource = new(activitySourceName, s_assemblyVersion); + TraceParent = traceParent; + } + + /// + /// Gets the . + /// + public ActivitySource ActivitySource { get; } + + /// + /// Gets the name of the + /// + public string ActivitySourceName => ActivitySource.Name; + + /// + /// Invokes the delegate within the context of a new started . + /// + /// The delegate to call within the context of a newly started + /// The name of the method for the activity. + /// Returns a new object if there is any listener to the Activity, returns null otherwise + /// + /// Creates and starts a new object if there is any listener for the ActivitySource. + /// Passes the Activity to the delegate and invokes the delegate. If there are no exceptions thrown by the delegate the + /// Activity status is set to . If an exception is thrown by the delegate, the Activity + /// status is set to and an Activity is added to the actitity + /// and finally the exception is rethrown. + /// + public void TraceActivity(Action call, [CallerMemberName] string? activityName = default, string? traceParent = default) + { + using Activity? activity = StartActivityInternal(activityName, ActivitySource, traceParent ?? TraceParent); + try + { + call.Invoke(activity); + if (activity?.Status == ActivityStatusCode.Unset) activity?.SetStatus(ActivityStatusCode.Ok); + } + catch (Exception ex) + { + TraceException(ex, activity); + throw; + } + } + + /// + /// Invokes the delegate within the context of a new started . + /// + /// The return type for the delegate. + /// The delegate to call within the context of a newly started + /// The name of the method for the activity. + /// The result of the call to the delegate. + /// + /// Creates and starts a new object if there is any listener for the ActivitySource. + /// Passes the Activity to the delegate and invokes the delegate. If there are no exceptions thrown by the delegate the + /// Activity status is set to and the result is returned. + /// If an exception is thrown by the delegate, the Activity status is set to + /// and an Event is added to the actitity and finally the exception is rethrown. + /// + public T TraceActivity(Func call, [CallerMemberName] string? activityName = default, string? traceParent = default) + { + using Activity? activity = StartActivityInternal(activityName, ActivitySource, traceParent ?? TraceParent); + try + { + T? result = call.Invoke(activity); + if (activity?.Status == ActivityStatusCode.Unset) activity?.SetStatus(ActivityStatusCode.Ok); + return result; + } + catch (Exception ex) + { + TraceException(ex, activity); + throw; + } + } + + /// + /// Invokes the delegate within the context of a new started . + /// + /// The delegate to call within the context of a newly started + /// The name of the method for the activity. + /// + /// + /// Creates and starts a new object if there is any listener for the ActivitySource. + /// Passes the Activity to the delegate and invokes the delegate. If there are no exceptions thrown by the delegate the + /// Activity status is set to and the result is returned. + /// If an exception is thrown by the delegate, the Activity status is set to + /// and an Event is added to the actitity and finally the exception is rethrown. + /// + public async Task TraceActivityAsync(Func call, [CallerMemberName] string? activityName = default, string? traceParent = default) + { + using Activity? activity = StartActivityInternal(activityName, ActivitySource, traceParent ?? TraceParent); + try + { + await call.Invoke(activity); + if (activity?.Status == ActivityStatusCode.Unset) activity?.SetStatus(ActivityStatusCode.Ok); + } + catch (Exception ex) + { + TraceException(ex, activity); + throw; + } + } + + /// + /// Invokes the delegate within the context of a new started . + /// + /// The return type for the delegate. + /// The delegate to call within the context of a newly started + /// The name of the method for the activity. + /// The result of the call to the delegate. + /// + /// Creates and starts a new object if there is any listener for the ActivitySource. + /// Passes the Activity to the delegate and invokes the delegate. If there are no exceptions thrown by the delegate the + /// Activity status is set to and the result is returned. + /// If an exception is thrown by the delegate, the Activity status is set to + /// and an Event is added to the actitity and finally the exception is rethrown. + /// + public async Task TraceActivityAsync(Func> call, [CallerMemberName] string? activityName = default, string? traceParent = default) + { + using Activity? activity = StartActivityInternal(activityName, ActivitySource, traceParent ?? TraceParent); + try + { + T? result = await call.Invoke(activity); + if (activity?.Status == ActivityStatusCode.Unset) activity?.SetStatus(ActivityStatusCode.Ok); + return result; + } + catch (Exception ex) + { + TraceException(ex, activity); + throw; + } + } + + /// + /// Invokes the delegate within the context of a new started . + /// + /// The to start the on. + /// The delegate to call within the context of a newly started + /// The name of the method for the activity. + /// + /// + /// Creates and starts a new object if there is any listener for the ActivitySource. + /// Passes the Activity to the delegate and invokes the delegate. If there are no exceptions thrown by the delegate the + /// Activity status is set to and the result is returned. + /// If an exception is thrown by the delegate, the Activity status is set to + /// and an Event is added to the actitity and finally the exception is rethrown. + /// + public static async Task TraceActivityAsync(ActivitySource activitySource, Func call, [CallerMemberName] string? activityName = default, string? traceParent = default) + { + using Activity? activity = StartActivityInternal(activityName, activitySource, traceParent); + try + { + await call.Invoke(activity); + if (activity?.Status == ActivityStatusCode.Unset) activity?.SetStatus(ActivityStatusCode.Ok); + } + catch (Exception ex) + { + TraceException(ex, activity); + throw; + } + } + + /// + /// Invokes the delegate within the context of a new started . + /// + /// The return type for the delegate. + /// The to start the on. + /// The delegate to call within the context of a newly started + /// The name of the method for the activity. + /// The result of the call to the delegate. + /// + /// Creates and starts a new object if there is any listener for the ActivitySource. + /// Passes the Activity to the delegate and invokes the delegate. If there are no exceptions thrown by the delegate the + /// Activity status is set to and the result is returned. + /// If an exception is thrown by the delegate, the Activity status is set to + /// and an Event is added to the actitity and finally the exception is rethrown. + /// + public static async Task TraceActivityAsync(ActivitySource activitySource, Func> call, [CallerMemberName] string? activityName = default, string? traceParent = default) + { + using Activity? activity = StartActivityInternal(activityName, activitySource, traceParent); + try + { + T? result = await call.Invoke(activity); + if (activity?.Status == ActivityStatusCode.Unset) activity?.SetStatus(ActivityStatusCode.Ok); + return result; + } + catch (Exception ex) + { + TraceException(ex, activity); + throw; + } + } + + /// + /// Writes the exception to the trace by adding an exception event to the current activity (span). + /// + /// The exception to trace. + /// The current activity where the exception is caught. + /// + /// An indicator that should be set to true if the exception event is recorded + /// at a point where it is known that the exception is escaping the scope of the span/activity. + /// For example, escaped should be true if the exception is caught and re-thrown. + /// However, escaped should be set to false if execution continues in the current scope. + /// + public static void TraceException(Exception exception, Activity? activity) => + WriteTraceException(exception, activity); + + /// + /// Starts an on the if there is + /// and active listener on the source. + /// + /// The name of the method for the activity. + /// If there is an active listener on the source, an is returned, null otherwise. + public Activity? StartActivity([CallerMemberName] string? activityName = default, string? traceParent = default) + { + return StartActivityInternal(activityName, ActivitySource, traceParent ?? TraceParent); + } + + /// + /// Gets or sets the trace parent context. + /// + public string? TraceParent { get; set; } + + /// + /// Gets the product version from the file version of the current assembly. + /// + /// + private static string GetProductVersion() + { + FileVersionInfo fileVersionInfo = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location); + return fileVersionInfo.ProductVersion ?? ProductVersionDefault; + } + + /// + /// Disposes managed and unmanaged objects. If overridden, ensure to call this base method. + /// + /// An indicator of whether this method is being called from the Dispose method. + protected virtual void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + ActivitySource.Dispose(); + } + _disposedValue = true; + } + } + + /// + public virtual void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + private static void WriteTraceException(Exception exception, Activity? activity) + { + activity?.AddException(exception); + activity?.SetStatus(ActivityStatusCode.Error); + } + + private static Activity? StartActivityInternal(string? activityName, ActivitySource activitySource, string? traceParent = default) + { + string fullActivityName = GetActivityName(activityName); + return StartActivity(activitySource, fullActivityName, traceParent); + } + + private static string GetActivityName(string? activityName) + { + string tracingBaseName = string.Empty; + if (!string.IsNullOrWhiteSpace(activityName)) + { + StackTrace stackTrace = new(); + StackFrame? frame = stackTrace.GetFrames().FirstOrDefault(f => f.GetMethod()?.Name == activityName); + tracingBaseName = frame?.GetMethod()?.DeclaringType?.FullName ?? string.Empty; + if (tracingBaseName != string.Empty) + { + tracingBaseName += "."; + } + } + else + { + activityName = "[unknown-member]"; + } + string fullActivityName = tracingBaseName + activityName; + return fullActivityName; + } + + /// + /// Creates and starts a new object if there is any listener to the Activity, returns null otherwise. + /// + /// The from which to start the activity. + /// The name of the method for the activity + /// Returns a new object if there is any listener to the Activity, returns null otherwise + private static Activity? StartActivity(ActivitySource activitySource, string activityName, string? traceParent = default) + { + return traceParent != null && ActivityContext.TryParse(traceParent, null, isRemote: true, out ActivityContext parentContext) + ? (activitySource.StartActivity(activityName, ActivityKind.Client, parentContext)) + : (activitySource.StartActivity(activityName, ActivityKind.Client)); + } + } +} diff --git a/csharp/src/Apache.Arrow.Adbc/Tracing/SerializableActivity.cs b/csharp/src/Apache.Arrow.Adbc/Tracing/SerializableActivity.cs new file mode 100644 index 0000000000..fec50e2c36 --- /dev/null +++ b/csharp/src/Apache.Arrow.Adbc/Tracing/SerializableActivity.cs @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text.Json.Serialization; + +namespace Apache.Arrow.Adbc.Tracing +{ + /// + /// Simplified version of that excludes some properties, etc. + /// + internal class SerializableActivity + { + [JsonConstructor] + public SerializableActivity() { } + + internal SerializableActivity( + ActivityStatusCode status, + string? statusDescription, + bool hasRemoteParent, + ActivityKind kind, + string operationName, + TimeSpan duration, + DateTime startTimeUtc, + string? id, + string? parentId, + string? rootId, + string? traceStateString, + ActivitySpanId spanId, + ActivityTraceId traceId, + bool recorded, + bool isAllDataRequested, + ActivityTraceFlags activityTraceFlags, + ActivitySpanId parentSpanId, + ActivityIdFormat idFormat, + IReadOnlyList> tagObjects, + IReadOnlyList events, + IReadOnlyList links, + IReadOnlyList> baggage) + { + Status = statusDescription ?? status.ToString(); + HasRemoteParent = hasRemoteParent; + Kind = kind.ToString(); + OperationName = operationName; + Duration = duration; + StartTimeUtc = startTimeUtc; + Id = id; + ParentId = parentId; + RootId = rootId; + TraceStateString = traceStateString; + SpanId = spanId.ToHexString(); + TraceId = traceId.ToHexString(); + Recorded = recorded; + IsAllDataRequested = isAllDataRequested; + ActivityTraceFlags = activityTraceFlags.ToString(); + ParentSpanId = parentSpanId.ToHexString(); + IdFormat = idFormat.ToString(); + TagObjects = tagObjects; + Events = events; + Links = links; + Baggage = baggage; + } + + internal SerializableActivity(Activity activity) : this( + activity.Status, + activity.StatusDescription, + activity.HasRemoteParent, + activity.Kind, + activity.OperationName, + activity.Duration, + activity.StartTimeUtc, + activity.Id, + activity.ParentId, + activity.RootId, + activity.TraceStateString, + activity.SpanId, + activity.TraceId, + activity.Recorded, + activity.IsAllDataRequested, + activity.ActivityTraceFlags, + activity.ParentSpanId, + activity.IdFormat, + activity.TagObjects.ToArray(), + activity.Events.Select(e => (SerializableActivityEvent)e).ToArray(), + activity.Links.Select(l => (SerializableActivityLink)l).ToArray(), + activity.Baggage.ToArray()) + { } + + public string? Status { get; set; } + public bool HasRemoteParent { get; set; } + public string? Kind { get; set; } + public string OperationName { get; set; } = ""; + public TimeSpan Duration { get; set; } + public DateTime StartTimeUtc { get; set; } + public string? Id { get; set; } + public string? ParentId { get; set; } + public string? RootId { get; set; } + + public string? TraceStateString { get; set; } + public string? SpanId { get; set; } + public string? TraceId { get; set; } + public bool Recorded { get; set; } + public bool IsAllDataRequested { get; set; } + public string? ActivityTraceFlags { get; set; } + public string? ParentSpanId { get; set; } + public string? IdFormat { get; set; } + + public IReadOnlyList> TagObjects { get; set; } = []; + public IReadOnlyList Events { get; set; } = []; + public IReadOnlyList Links { get; set; } = []; + public IReadOnlyList> Baggage { get; set; } = []; + } + + internal class SerializableActivityEvent + { + /// + /// Gets the name. + /// + public string? Name { get; set; } + + /// + /// Gets the timestamp. + /// + public DateTimeOffset Timestamp { get; set; } + + public IReadOnlyList> Tags { get; set; } = []; + + public static implicit operator SerializableActivityEvent(System.Diagnostics.ActivityEvent source) + { + return new SerializableActivityEvent() + { + Name = source.Name, + Timestamp = source.Timestamp, + Tags = source.Tags.ToArray(), + }; + } + } + + internal class SerializableActivityLink + { + public SerializableActivityContext? Context { get; set; } + + public IReadOnlyList>? Tags { get; set; } = []; + + public static implicit operator SerializableActivityLink(System.Diagnostics.ActivityLink source) + { + return new SerializableActivityLink() + { + Context = source.Context, + Tags = source.Tags?.ToArray(), + }; + } + } + + internal class SerializableActivityContext + { + public string? SpanId { get; set; } + public string? TraceId { get; set; } + public string? TraceState { get; set; } + public ActivityTraceFlags? TraceFlags { get; set; } + public bool IsRemote { get; set; } + + public static implicit operator SerializableActivityContext(System.Diagnostics.ActivityContext source) + { + return new SerializableActivityContext() + { + SpanId = source.SpanId.ToHexString(), + TraceId = source.TraceId.ToHexString(), + TraceState = source.TraceState, + TraceFlags = source.TraceFlags, + IsRemote = source.IsRemote, + }; + } + } +} diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs index 7b09fb68a8..0bc74aa808 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs @@ -26,6 +26,7 @@ using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Thrift; using Apache.Arrow.Adbc.Extensions; +using Apache.Arrow.Adbc.Tracing; using Apache.Arrow.Ipc; using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; @@ -83,7 +84,7 @@ internal struct ColumnsMetadataColumnNames public string TableName { get; internal set; } public string ColumnName { get; internal set; } public string DataType { get; internal set; } - public string TypeName { get; internal set; } + public string TypeName { get; internal set; } public string Nullable { get; internal set; } public string ColumnDef { get; internal set; } public string OrdinalPosition { get; internal set; } @@ -268,9 +269,11 @@ internal enum ColumnTypeId TIMESTAMP_WITH_TIMEZONE = 2014, } - internal HiveServer2Connection(IReadOnlyDictionary properties) + internal HiveServer2Connection(IReadOnlyDictionary properties, ActivityTrace trace) { Properties = properties; + Trace = trace; + // Note: "LazyThreadSafetyMode.PublicationOnly" is thread-safe initialization where // the first successful thread sets the value. If an exception is thrown, initialization // will retry until it successfully returns a value without an exception. @@ -287,6 +290,9 @@ internal HiveServer2Connection(IReadOnlyDictionary properties) } } + // internal for testing + protected internal ActivityTrace Trace { get; } + internal TCLIService.Client Client { get { return _client ?? throw new InvalidOperationException("connection not open"); } @@ -302,40 +308,45 @@ internal TCLIService.Client Client internal async Task OpenAsync() { - CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(ConnectTimeoutMilliseconds, ApacheUtility.TimeUnit.Milliseconds); - try + await Trace.TraceActivityAsync(async (activity) => { - TTransport transport = CreateTransport(); - TProtocol protocol = await CreateProtocolAsync(transport, cancellationToken); - _transport = protocol.Transport; - _client = new TCLIService.Client(protocol); - TOpenSessionReq request = CreateSessionRequest(); + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(ConnectTimeoutMilliseconds, ApacheUtility.TimeUnit.Milliseconds); + activity?.AddTag("db.client.connection.timeout.ms", ConnectTimeoutMilliseconds); + try + { + TTransport transport = CreateTransport(); + TProtocol protocol = await CreateProtocolAsync(transport, cancellationToken); + _transport = protocol.Transport; + _client = new TCLIService.Client(protocol); + TOpenSessionReq request = CreateSessionRequest(); - TOpenSessionResp? session = await Client.OpenSession(request, cancellationToken); + TOpenSessionResp? session = await Client.OpenSession(request, cancellationToken); + + // Explicitly check the session status + if (session == null) + { + throw new HiveServer2Exception("Unable to open session. Unknown error."); + } + else if (session.Status.StatusCode != TStatusCode.SUCCESS_STATUS) + { + throw new HiveServer2Exception(session.Status.ErrorMessage) + .SetNativeError(session.Status.ErrorCode) + .SetSqlState(session.Status.SqlState); + } - // Explicitly check the session status - if (session == null) + activity?.AddEvent("session.start", [new("session.id", new Guid(session.SessionHandle.SessionId.Guid).ToString())]); + SessionHandle = session.SessionHandle; + } + catch (Exception ex) when (ExceptionHelper.IsOperationCanceledOrCancellationRequested(ex, cancellationToken)) { - throw new HiveServer2Exception("Unable to open session. Unknown error."); + throw new TimeoutException("The operation timed out while attempting to open a session. Please try increasing connect timeout.", ex); } - else if (session.Status.StatusCode != TStatusCode.SUCCESS_STATUS) + catch (Exception ex) when (ex is not HiveServer2Exception) { - throw new HiveServer2Exception(session.Status.ErrorMessage) - .SetNativeError(session.Status.ErrorCode) - .SetSqlState(session.Status.SqlState); + // Handle other exceptions if necessary + throw new HiveServer2Exception($"An unexpected error occurred while opening the session. '{ex.Message}'", ex); } - - SessionHandle = session.SessionHandle; - } - catch (Exception ex) when (ExceptionHelper.IsOperationCanceledOrCancellationRequested(ex, cancellationToken)) - { - throw new TimeoutException("The operation timed out while attempting to open a session. Please try increasing connect timeout.", ex); - } - catch (Exception ex) when (ex is not HiveServer2Exception) - { - // Handle other exceptions if necessary - throw new HiveServer2Exception($"An unexpected error occurred while opening the session. '{ex.Message}'", ex); - } + }); } internal TSessionHandle? SessionHandle { get; private set; } @@ -356,230 +367,268 @@ internal async Task OpenAsync() internal abstract IArrowArrayStream NewReader(T statement, Schema schema) where T : HiveServer2Statement; - public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList? tableTypes, string? columnNamePattern) + public override void SetOption(string key, string value) { - Dictionary>> catalogMap = new Dictionary>>(); - CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); - try + switch (key.ToLowerInvariant()) { - if (GetObjectsPatternsRequireLowerCase) - { - catalogPattern = catalogPattern?.ToLower(); - dbSchemaPattern = dbSchemaPattern?.ToLower(); - tableNamePattern = tableNamePattern?.ToLower(); - columnNamePattern = columnNamePattern?.ToLower(); - } - if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Catalogs) - { - TGetCatalogsReq getCatalogsReq = new TGetCatalogsReq(SessionHandle); - if (AreResultsAvailableDirectly()) - { - SetDirectResults(getCatalogsReq); - } - - TGetCatalogsResp getCatalogsResp = Client.GetCatalogs(getCatalogsReq, cancellationToken).Result; + // Since this API only allows non-null values, we'll treat empty string as null to allow the TraceParent to be unset. + case HiveServer2Parameters.TraceParent: + Trace.TraceParent = !string.IsNullOrWhiteSpace(value) + ? ActivityContext.TryParse(value, null, out _) + ? value + : throw new ArgumentOutOfRangeException(nameof(value), $"Invalid trace_parent '{value}'.") + : null; + break; + + case HiveServer2Parameters.HostName: + case HiveServer2Parameters.Port: + case HiveServer2Parameters.Path: + case HiveServer2Parameters.AuthType: + case HiveServer2Parameters.TransportType: + case HiveServer2Parameters.DataTypeConv: + case HiveServer2Parameters.TLSOptions: + case HiveServer2Parameters.ConnectTimeoutMilliseconds: + throw new InvalidOperationException($"Options '{key}' cannot be set once connection is created."); - if (getCatalogsResp.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(getCatalogsResp.Status.ErrorMessage); - } - var catalogsMetadata = GetResultSetMetadataAsync(getCatalogsResp, cancellationToken).Result; - IReadOnlyDictionary columnMap = GetColumnIndexMap(catalogsMetadata.Schema.Columns); + default: + throw new ArgumentOutOfRangeException(nameof(key), $"Unsupported or unknown option '{key}'."); + } + } - string catalogRegexp = PatternToRegEx(catalogPattern); - TRowSet rowSet = GetRowSetAsync(getCatalogsResp, cancellationToken).Result; - IReadOnlyList list = rowSet.Columns[columnMap[TableCat]].StringVal.Values; - for (int i = 0; i < list.Count; i++) + public override IArrowArrayStream GetObjects(GetObjectsDepth depth, string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern, IReadOnlyList? tableTypes, string? columnNamePattern) + { + return Trace.TraceActivity((activity) => + { + Dictionary>> catalogMap = new Dictionary>>(); + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + try { - string col = list[i]; - string catalog = col; - - if (Regex.IsMatch(catalog, catalogRegexp, RegexOptions.IgnoreCase)) + if (GetObjectsPatternsRequireLowerCase) { - catalogMap.Add(catalog, new Dictionary>()); + catalogPattern = catalogPattern?.ToLower(); + dbSchemaPattern = dbSchemaPattern?.ToLower(); + tableNamePattern = tableNamePattern?.ToLower(); + columnNamePattern = columnNamePattern?.ToLower(); + } + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Catalogs) + { + TGetCatalogsReq getCatalogsReq = new TGetCatalogsReq(SessionHandle); + if (AreResultsAvailableDirectly()) + { + SetDirectResults(getCatalogsReq); + } + + activity?.AddEvent("db.operation.name.start", [new("db.operation.name", nameof(Client.GetCatalogs))]); + TGetCatalogsResp getCatalogsResp = Client.GetCatalogs(getCatalogsReq, cancellationToken).Result; + var eventActivity = activity?.AddEvent("db.operation.name.end", + [ + new("db.operation.name", nameof(Client.GetCatalogs)), + new("db.response.status_code", getCatalogsResp.Status.StatusCode.ToString()) + ]); + + if (getCatalogsResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getCatalogsResp.Status.ErrorMessage); + } + var catalogsMetadata = GetResultSetMetadataAsync(getCatalogsResp, cancellationToken).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(catalogsMetadata.Schema.Columns); + + string catalogRegexp = PatternToRegEx(catalogPattern); + TRowSet rowSet = GetRowSetAsync(getCatalogsResp, cancellationToken).Result; + IReadOnlyList list = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + for (int i = 0; i < list.Count; i++) + { + string col = list[i]; + string catalog = col; + + if (Regex.IsMatch(catalog, catalogRegexp, RegexOptions.IgnoreCase)) + { + catalogMap.Add(catalog, new Dictionary>()); + } + } + // Handle the case where server does not support 'catalog' in the namespace. + if (list.Count == 0 && string.IsNullOrEmpty(catalogPattern)) + { + catalogMap.Add(string.Empty, []); + } + eventActivity?.AddTag("db.response.returned_rows", catalogMap.Count); } - } - // Handle the case where server does not support 'catalog' in the namespace. - if (list.Count == 0 && string.IsNullOrEmpty(catalogPattern)) - { - catalogMap.Add(string.Empty, []); - } - } - - if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.DbSchemas) - { - TGetSchemasReq getSchemasReq = new TGetSchemasReq(SessionHandle); - getSchemasReq.CatalogName = catalogPattern; - getSchemasReq.SchemaName = dbSchemaPattern; - if (AreResultsAvailableDirectly()) - { - SetDirectResults(getSchemasReq); - } - - TGetSchemasResp getSchemasResp = Client.GetSchemas(getSchemasReq, cancellationToken).Result; - if (getSchemasResp.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(getSchemasResp.Status.ErrorMessage); - } - - TGetResultSetMetadataResp schemaMetadata = GetResultSetMetadataAsync(getSchemasResp, cancellationToken).Result; - IReadOnlyDictionary columnMap = GetColumnIndexMap(schemaMetadata.Schema.Columns); - TRowSet rowSet = GetRowSetAsync(getSchemasResp, cancellationToken).Result; - - IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCatalog]].StringVal.Values; - IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; - - for (int i = 0; i < catalogList.Count; i++) - { - string catalog = catalogList[i]; - string schemaDb = schemaList[i]; - // It seems Spark sometimes returns empty string for catalog on some schema (temporary tables). - catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new Dictionary()); - } - } - - if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Tables) - { - TGetTablesReq getTablesReq = new TGetTablesReq(SessionHandle); - getTablesReq.CatalogName = catalogPattern; - getTablesReq.SchemaName = dbSchemaPattern; - getTablesReq.TableName = tableNamePattern; - if (AreResultsAvailableDirectly()) - { - SetDirectResults(getTablesReq); - } - - TGetTablesResp getTablesResp = Client.GetTables(getTablesReq, cancellationToken).Result; - if (getTablesResp.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(getTablesResp.Status.ErrorMessage); - } - TGetResultSetMetadataResp tableMetadata = GetResultSetMetadataAsync(getTablesResp, cancellationToken).Result; - IReadOnlyDictionary columnMap = GetColumnIndexMap(tableMetadata.Schema.Columns); - TRowSet rowSet = GetRowSetAsync(getTablesResp, cancellationToken).Result; + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.DbSchemas) + { + TGetSchemasReq getSchemasReq = new TGetSchemasReq(SessionHandle); + getSchemasReq.CatalogName = catalogPattern; + getSchemasReq.SchemaName = dbSchemaPattern; + if (AreResultsAvailableDirectly()) + { + SetDirectResults(getSchemasReq); + } + + TGetSchemasResp getSchemasResp = Client.GetSchemas(getSchemasReq, cancellationToken).Result; + if (getSchemasResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getSchemasResp.Status.ErrorMessage); + } + + TGetResultSetMetadataResp schemaMetadata = GetResultSetMetadataAsync(getSchemasResp, cancellationToken).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(schemaMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getSchemasResp, cancellationToken).Result; + + IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCatalog]].StringVal.Values; + IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + // It seems Spark sometimes returns empty string for catalog on some schema (temporary tables). + catalogMap.GetValueOrDefault(catalog)?.Add(schemaDb, new Dictionary()); + } + } - IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; - IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; - IReadOnlyList tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; - IReadOnlyList tableTypeList = rowSet.Columns[columnMap[TableType]].StringVal.Values; + if (depth == GetObjectsDepth.All || depth >= GetObjectsDepth.Tables) + { + TGetTablesReq getTablesReq = new TGetTablesReq(SessionHandle); + getTablesReq.CatalogName = catalogPattern; + getTablesReq.SchemaName = dbSchemaPattern; + getTablesReq.TableName = tableNamePattern; + if (AreResultsAvailableDirectly()) + { + SetDirectResults(getTablesReq); + } + + TGetTablesResp getTablesResp = Client.GetTables(getTablesReq, cancellationToken).Result; + if (getTablesResp.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(getTablesResp.Status.ErrorMessage); + } + + TGetResultSetMetadataResp tableMetadata = GetResultSetMetadataAsync(getTablesResp, cancellationToken).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(tableMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(getTablesResp, cancellationToken).Result; + + IReadOnlyList catalogList = rowSet.Columns[columnMap[TableCat]].StringVal.Values; + IReadOnlyList schemaList = rowSet.Columns[columnMap[TableSchem]].StringVal.Values; + IReadOnlyList tableList = rowSet.Columns[columnMap[TableName]].StringVal.Values; + IReadOnlyList tableTypeList = rowSet.Columns[columnMap[TableType]].StringVal.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + string catalog = catalogList[i]; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string tableType = tableTypeList[i]; + TableInfo tableInfo = new(tableType); + catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName, tableInfo); + } + } - for (int i = 0; i < catalogList.Count; i++) - { - string catalog = catalogList[i]; - string schemaDb = schemaList[i]; - string tableName = tableList[i]; - string tableType = tableTypeList[i]; - TableInfo tableInfo = new(tableType); - catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.Add(tableName, tableInfo); - } - } + if (depth == GetObjectsDepth.All) + { + TGetColumnsReq columnsReq = new TGetColumnsReq(SessionHandle); + columnsReq.CatalogName = catalogPattern; + columnsReq.SchemaName = dbSchemaPattern; + columnsReq.TableName = tableNamePattern; + if (AreResultsAvailableDirectly()) + { + SetDirectResults(columnsReq); + } + + if (!string.IsNullOrEmpty(columnNamePattern)) + columnsReq.ColumnName = columnNamePattern; + + var columnsResponse = Client.GetColumns(columnsReq, cancellationToken).Result; + if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) + { + throw new Exception(columnsResponse.Status.ErrorMessage); + } + + TGetResultSetMetadataResp columnsMetadata = GetResultSetMetadataAsync(columnsResponse, cancellationToken).Result; + IReadOnlyDictionary columnMap = GetColumnIndexMap(columnsMetadata.Schema.Columns); + TRowSet rowSet = GetRowSetAsync(columnsResponse, cancellationToken).Result; + + ColumnsMetadataColumnNames columnNames = GetColumnsMetadataColumnNames(); + IReadOnlyList catalogList = rowSet.Columns[columnMap[columnNames.TableCatalog]].StringVal.Values; + IReadOnlyList schemaList = rowSet.Columns[columnMap[columnNames.TableSchema]].StringVal.Values; + IReadOnlyList tableList = rowSet.Columns[columnMap[columnNames.TableName]].StringVal.Values; + IReadOnlyList columnNameList = rowSet.Columns[columnMap[columnNames.ColumnName]].StringVal.Values; + ReadOnlySpan columnTypeList = rowSet.Columns[columnMap[columnNames.DataType]].I32Val.Values.Values; + IReadOnlyList typeNameList = rowSet.Columns[columnMap[columnNames.TypeName]].StringVal.Values; + ReadOnlySpan nullableList = rowSet.Columns[columnMap[columnNames.Nullable]].I32Val.Values.Values; + IReadOnlyList columnDefaultList = rowSet.Columns[columnMap[columnNames.ColumnDef]].StringVal.Values; + ReadOnlySpan ordinalPosList = rowSet.Columns[columnMap[columnNames.OrdinalPosition]].I32Val.Values.Values; + IReadOnlyList isNullableList = rowSet.Columns[columnMap[columnNames.IsNullable]].StringVal.Values; + IReadOnlyList isAutoIncrementList = rowSet.Columns[columnMap[columnNames.IsAutoIncrement]].StringVal.Values; + ReadOnlySpan columnSizeList = rowSet.Columns[columnMap[columnNames.ColumnSize]].I32Val.Values.Values; + ReadOnlySpan decimalDigitsList = rowSet.Columns[columnMap[columnNames.DecimalDigits]].I32Val.Values.Values; + + for (int i = 0; i < catalogList.Count; i++) + { + // For systems that don't support 'catalog' in the namespace + string catalog = catalogList[i] ?? string.Empty; + string schemaDb = schemaList[i]; + string tableName = tableList[i]; + string columnName = columnNameList[i]; + short colType = (short)columnTypeList[i]; + string typeName = typeNameList[i]; + short nullable = (short)nullableList[i]; + string? isAutoIncrementString = isAutoIncrementList[i]; + bool isAutoIncrement = (!string.IsNullOrEmpty(isAutoIncrementString) && (isAutoIncrementString.Equals("YES", StringComparison.InvariantCultureIgnoreCase) || isAutoIncrementString.Equals("TRUE", StringComparison.InvariantCultureIgnoreCase))); + string isNullable = isNullableList[i] ?? "YES"; + string columnDefault = columnDefaultList[i] ?? ""; + // Spark/Databricks reports ordinal index zero-indexed, instead of one-indexed + int ordinalPos = ordinalPosList[i] + PositionRequiredOffset; + int columnSize = columnSizeList[i]; + int decimalDigits = decimalDigitsList[i]; + TableInfo? tableInfo = catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName); + tableInfo?.ColumnName.Add(columnName); + tableInfo?.ColType.Add(colType); + tableInfo?.Nullable.Add(nullable); + tableInfo?.IsAutoIncrement.Add(isAutoIncrement); + tableInfo?.IsNullable.Add(isNullable); + tableInfo?.ColumnDefault.Add(columnDefault); + tableInfo?.OrdinalPosition.Add(ordinalPos); + SetPrecisionScaleAndTypeName(colType, typeName, tableInfo, columnSize, decimalDigits); + } + } - if (depth == GetObjectsDepth.All) - { - TGetColumnsReq columnsReq = new TGetColumnsReq(SessionHandle); - columnsReq.CatalogName = catalogPattern; - columnsReq.SchemaName = dbSchemaPattern; - columnsReq.TableName = tableNamePattern; - if (AreResultsAvailableDirectly()) - { - SetDirectResults(columnsReq); - } + StringArray.Builder catalogNameBuilder = new StringArray.Builder(); + List catalogDbSchemasValues = new List(); - if (!string.IsNullOrEmpty(columnNamePattern)) - columnsReq.ColumnName = columnNamePattern; + foreach (KeyValuePair>> catalogEntry in catalogMap) + { + catalogNameBuilder.Append(catalogEntry.Key); + + if (depth == GetObjectsDepth.Catalogs) + { + catalogDbSchemasValues.Add(null); + } + else + { + catalogDbSchemasValues.Add(GetDbSchemas( + depth, catalogEntry.Value)); + } + } - var columnsResponse = Client.GetColumns(columnsReq, cancellationToken).Result; - if (columnsResponse.Status.StatusCode == TStatusCode.ERROR_STATUS) - { - throw new Exception(columnsResponse.Status.ErrorMessage); - } + Schema schema = StandardSchemas.GetObjectsSchema; + IReadOnlyList dataArrays = schema.Validate( + new List + { + catalogNameBuilder.Build(), + catalogDbSchemasValues.BuildListArrayForType(new StructType(StandardSchemas.DbSchemaSchema)), + }); - TGetResultSetMetadataResp columnsMetadata = GetResultSetMetadataAsync(columnsResponse, cancellationToken).Result; - IReadOnlyDictionary columnMap = GetColumnIndexMap(columnsMetadata.Schema.Columns); - TRowSet rowSet = GetRowSetAsync(columnsResponse, cancellationToken).Result; - - ColumnsMetadataColumnNames columnNames = GetColumnsMetadataColumnNames(); - IReadOnlyList catalogList = rowSet.Columns[columnMap[columnNames.TableCatalog]].StringVal.Values; - IReadOnlyList schemaList = rowSet.Columns[columnMap[columnNames.TableSchema]].StringVal.Values; - IReadOnlyList tableList = rowSet.Columns[columnMap[columnNames.TableName]].StringVal.Values; - IReadOnlyList columnNameList = rowSet.Columns[columnMap[columnNames.ColumnName]].StringVal.Values; - ReadOnlySpan columnTypeList = rowSet.Columns[columnMap[columnNames.DataType]].I32Val.Values.Values; - IReadOnlyList typeNameList = rowSet.Columns[columnMap[columnNames.TypeName]].StringVal.Values; - ReadOnlySpan nullableList = rowSet.Columns[columnMap[columnNames.Nullable]].I32Val.Values.Values; - IReadOnlyList columnDefaultList = rowSet.Columns[columnMap[columnNames.ColumnDef]].StringVal.Values; - ReadOnlySpan ordinalPosList = rowSet.Columns[columnMap[columnNames.OrdinalPosition]].I32Val.Values.Values; - IReadOnlyList isNullableList = rowSet.Columns[columnMap[columnNames.IsNullable]].StringVal.Values; - IReadOnlyList isAutoIncrementList = rowSet.Columns[columnMap[columnNames.IsAutoIncrement]].StringVal.Values; - ReadOnlySpan columnSizeList = rowSet.Columns[columnMap[columnNames.ColumnSize]].I32Val.Values.Values; - ReadOnlySpan decimalDigitsList = rowSet.Columns[columnMap[columnNames.DecimalDigits]].I32Val.Values.Values; - - for (int i = 0; i < catalogList.Count; i++) - { - // For systems that don't support 'catalog' in the namespace - string catalog = catalogList[i] ?? string.Empty; - string schemaDb = schemaList[i]; - string tableName = tableList[i]; - string columnName = columnNameList[i]; - short colType = (short)columnTypeList[i]; - string typeName = typeNameList[i]; - short nullable = (short)nullableList[i]; - string? isAutoIncrementString = isAutoIncrementList[i]; - bool isAutoIncrement = (!string.IsNullOrEmpty(isAutoIncrementString) && (isAutoIncrementString.Equals("YES", StringComparison.InvariantCultureIgnoreCase) || isAutoIncrementString.Equals("TRUE", StringComparison.InvariantCultureIgnoreCase))); - string isNullable = isNullableList[i] ?? "YES"; - string columnDefault = columnDefaultList[i] ?? ""; - // Spark/Databricks reports ordinal index zero-indexed, instead of one-indexed - int ordinalPos = ordinalPosList[i] + PositionRequiredOffset; - int columnSize = columnSizeList[i]; - int decimalDigits = decimalDigitsList[i]; - TableInfo? tableInfo = catalogMap.GetValueOrDefault(catalog)?.GetValueOrDefault(schemaDb)?.GetValueOrDefault(tableName); - tableInfo?.ColumnName.Add(columnName); - tableInfo?.ColType.Add(colType); - tableInfo?.Nullable.Add(nullable); - tableInfo?.IsAutoIncrement.Add(isAutoIncrement); - tableInfo?.IsNullable.Add(isNullable); - tableInfo?.ColumnDefault.Add(columnDefault); - tableInfo?.OrdinalPosition.Add(ordinalPos); - SetPrecisionScaleAndTypeName(colType, typeName, tableInfo, columnSize, decimalDigits); + return new HiveInfoArrowStream(schema, dataArrays); } - } - - StringArray.Builder catalogNameBuilder = new StringArray.Builder(); - List catalogDbSchemasValues = new List(); - - foreach (KeyValuePair>> catalogEntry in catalogMap) - { - catalogNameBuilder.Append(catalogEntry.Key); - - if (depth == GetObjectsDepth.Catalogs) + catch (Exception ex) when (ExceptionHelper.IsOperationCanceledOrCancellationRequested(ex, cancellationToken)) { - catalogDbSchemasValues.Add(null); + throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); } - else + catch (Exception ex) when (ex is not HiveServer2Exception) { - catalogDbSchemasValues.Add(GetDbSchemas( - depth, catalogEntry.Value)); + throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); } - } - - Schema schema = StandardSchemas.GetObjectsSchema; - IReadOnlyList dataArrays = schema.Validate( - new List - { - catalogNameBuilder.Build(), - catalogDbSchemasValues.BuildListArrayForType(new StructType(StandardSchemas.DbSchemaSchema)), - }); - - return new HiveInfoArrowStream(schema, dataArrays); - } - catch (Exception ex) when (ExceptionHelper.IsOperationCanceledOrCancellationRequested(ex, cancellationToken)) - { - throw new TimeoutException("The metadata query execution timed out. Consider increasing the query timeout value.", ex); - } - catch (Exception ex) when (ex is not HiveServer2Exception) - { - throw new HiveServer2Exception($"An unexpected error occurred while running metadata query. '{ex.Message}'", ex); - } + }); } public override IArrowArrayStream GetTableTypes() @@ -683,16 +732,20 @@ private string GetInfoTypeStringValue(TGetInfoType infoType) public override void Dispose() { - if (_client != null) + Trace.TraceActivity((activity) => { - CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); - TCloseSessionReq r6 = new(SessionHandle); - _client.CloseSession(r6, cancellationToken).Wait(); - _transport?.Close(); - _client.Dispose(); - _transport = null; - _client = null; - } + if (_client != null) + { + CancellationToken cancellationToken = ApacheUtility.GetCancellationToken(QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds); + TCloseSessionReq r6 = new(SessionHandle); + _client.CloseSession(r6, cancellationToken).Wait(); + activity?.AddEvent("session.end"); + _transport?.Close(); + _client.Dispose(); + _transport = null; + _client = null; + } + }); } internal static async Task GetResultSetMetadataAsync(TOperationHandle operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken = default) @@ -798,7 +851,7 @@ protected IReadOnlyDictionary GetColumnIndexMap(List c protected abstract string ProductVersion { get; } - protected abstract bool GetObjectsPatternsRequireLowerCase { get; } + protected abstract bool GetObjectsPatternsRequireLowerCase { get; } protected abstract bool IsColumnSizeValidForDecimal { get; } diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2ConnectionFactory.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2ConnectionFactory.cs index 1666ff264e..2589b3301d 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2ConnectionFactory.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2ConnectionFactory.cs @@ -17,12 +17,13 @@ using System; using System.Collections.Generic; +using Apache.Arrow.Adbc.Tracing; namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 { internal class HiveServer2ConnectionFactory { - public static HiveServer2Connection NewConnection(IReadOnlyDictionary properties) + public static HiveServer2Connection NewConnection(IReadOnlyDictionary properties, ActivityTrace trace) { if (!properties.TryGetValue(HiveServer2Parameters.TransportType, out string? type)) { @@ -33,7 +34,7 @@ public static HiveServer2Connection NewConnection(IReadOnlyDictionary properties; - internal HiveServer2Database(IReadOnlyDictionary properties) + internal HiveServer2Database(IReadOnlyDictionary properties, ActivityTrace trace) { this.properties = properties; + this.Trace = trace; } public override AdbcConnection Connect(IReadOnlyDictionary? options) @@ -38,9 +41,11 @@ public override AdbcConnection Connect(IReadOnlyDictionary? opti : options .Concat(properties.Where(x => !options.Keys.Contains(x.Key, StringComparer.OrdinalIgnoreCase))) .ToDictionary(kvp => kvp.Key, kvp => kvp.Value); - HiveServer2Connection connection = HiveServer2ConnectionFactory.NewConnection(mergedProperties); + HiveServer2Connection connection = HiveServer2ConnectionFactory.NewConnection(mergedProperties, Trace); connection.OpenAsync().Wait(); return connection; } + + protected ActivityTrace Trace { get; } } } diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Driver.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Driver.cs index 984fe6bd50..5678d45524 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Driver.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Driver.cs @@ -16,14 +16,28 @@ */ using System.Collections.Generic; +using Apache.Arrow.Adbc.Tracing; namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 { public class HiveServer2Driver : AdbcDriver { + public HiveServer2Driver(string? activitySourceName = default, string? traceParent = default) + { + Trace = new ActivityTrace(activitySourceName, traceParent); + } + public override AdbcDatabase Open(IReadOnlyDictionary parameters) { - return new HiveServer2Database(parameters); + return new HiveServer2Database(parameters, Trace); + } + + protected ActivityTrace Trace { get; } + + public override void Dispose() + { + Trace.Dispose(); + base.Dispose(); } } } diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs index 6c22ed9cc9..35cbebd5a5 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs @@ -25,6 +25,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Tracing; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; using Thrift; @@ -46,7 +47,7 @@ internal class HiveServer2HttpConnection : HiveServer2Connection protected override string ProductVersion => _productVersion.Value; - public HiveServer2HttpConnection(IReadOnlyDictionary properties) : base(properties) + public HiveServer2HttpConnection(IReadOnlyDictionary properties, ActivityTrace trace) : base(properties, trace) { ValidateProperties(); _productVersion = new Lazy(() => GetProductVersion(), LazyThreadSafetyMode.PublicationOnly); @@ -142,13 +143,14 @@ private void ValidateOptions() public override AdbcStatement CreateStatement() { - return new HiveServer2Statement(this); + return new HiveServer2Statement(this, Trace); } internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader( statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion, + trace: Trace, enableBatchSizeStopCondition: false); protected override TTransport CreateTransport() diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs index b9d7f2c2a1..a6af413dc4 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Parameters.cs @@ -27,6 +27,7 @@ public static class HiveServer2Parameters public const string DataTypeConv = "adbc.hive.data_type_conv"; public const string TLSOptions = "adbc.hive.tls_options"; public const string ConnectTimeoutMilliseconds = "adbc.hive.connect_timeout_ms"; + public const string TraceParent = "adbc.hive.trace_parent"; } public static class HiveServer2AuthTypeConstants diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs index c571947b28..fb2d79857a 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs @@ -19,13 +19,17 @@ using System.Buffers.Text; using System.Collections.Generic; using System.Data.SqlTypes; +using System.Diagnostics; using System.Globalization; +using System.Linq; using System.Threading; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Tracing; using Apache.Arrow.Ipc; using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; using Thrift.Transport; +using static Apache.Hive.Service.Rpc.Thrift.TCLIService; namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 { @@ -76,50 +80,65 @@ public HiveServer2Reader( HiveServer2Statement statement, Schema schema, DataTypeConversion dataTypeConversion, + ActivityTrace trace, bool enableBatchSizeStopCondition = true) { _statement = statement; Schema = schema; _dataTypeConversion = dataTypeConversion; + Trace = trace; _enableBatchSizeStopCondition = enableBatchSizeStopCondition; } + protected ActivityTrace Trace { get; } + public Schema Schema { get; } public async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) { - // All records have been exhausted - if (_statement == null) + return await Trace.TraceActivityAsync(async (activity) => { - return null; - } + // All records have been exhausted + if (_statement == null) + { + return null; + } - try - { - // Await the fetch response - TFetchResultsResp response = await FetchNext(_statement, cancellationToken); + try + { + // Await the fetch response + activity?.AddEvent("db.operation.name.start", [new("db.operation.name", nameof(Client.FetchResults))]); + TFetchResultsResp response = await FetchNext(_statement, cancellationToken); + + int columnCount = GetColumnCount(response); + int rowCount = GetRowCount(response, columnCount); + activity?.AddEvent("db.operation.name.end", + [ + new("db.operation.name", nameof(Client.FetchResults)), + new("db.response.status_code", response.Status.StatusCode.ToString()), + new("db.client.response.returned_rows", rowCount), + ]); + + if ((_enableBatchSizeStopCondition && _statement.BatchSize > 0 && rowCount < _statement.BatchSize) || rowCount == 0) + { + // This is the last batch + _statement = null; + } - int columnCount = GetColumnCount(response); - int rowCount = GetRowCount(response, columnCount); - if ((_enableBatchSizeStopCondition && _statement.BatchSize > 0 && rowCount < _statement.BatchSize) || rowCount == 0) + // Build the current batch, if any data exists + return rowCount > 0 ? CreateBatch(response, columnCount, rowCount) : null; + } + catch (Exception ex) + when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || + (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) { - // This is the last batch - _statement = null; + throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); } - - // Build the current batch, if any data exists - return rowCount > 0 ? CreateBatch(response, columnCount, rowCount) : null; - } - catch (Exception ex) - when (ApacheUtility.ContainsException(ex, out OperationCanceledException? _) || - (ApacheUtility.ContainsException(ex, out TTransportException? _) && cancellationToken.IsCancellationRequested)) - { - throw new TimeoutException("The query execution timed out. Consider increasing the query timeout value.", ex); - } - catch (Exception ex) when (ex is not HiveServer2Exception) - { - throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ex.Message}'", ex); - } + catch (Exception ex) when (ex is not HiveServer2Exception) + { + throw new HiveServer2Exception($"An unexpected error occurred while fetching results. '{ex.Message}'", ex); + } + }); } private RecordBatch CreateBatch(TFetchResultsResp response, int columnCount, int rowCount) diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index c08f997caa..d1db710916 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Tracing; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; using Thrift.Transport; @@ -27,12 +28,15 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2 { internal class HiveServer2Statement : AdbcStatement { - internal HiveServer2Statement(HiveServer2Connection connection) + protected internal HiveServer2Statement(HiveServer2Connection connection, ActivityTrace trace) { Connection = connection; + Trace = trace; ValidateOptions(connection.Properties); } + protected ActivityTrace Trace { get; } + protected virtual void SetStatementProperties(TExecuteStatementReq statement) { statement.QueryTimeout = QueryTimeoutSeconds; diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs index 238302435f..f48e0d22cb 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs @@ -21,6 +21,7 @@ using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Tracing; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; @@ -42,8 +43,8 @@ internal abstract class ImpalaConnection : HiveServer2Connection private const int DefaultHttpTransportPort = 28000; */ - internal ImpalaConnection(IReadOnlyDictionary properties) - : base(properties) + internal ImpalaConnection(IReadOnlyDictionary properties, ActivityTrace trace) + : base(properties, trace) { ValidateProperties(); _productVersion = new Lazy(() => GetProductVersion(), LazyThreadSafetyMode.PublicationOnly); @@ -62,7 +63,7 @@ private void ValidateProperties() public override AdbcStatement CreateStatement() { - return new ImpalaStatement(this); + return new ImpalaStatement(this, Trace); } protected override Task GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken cancellationToken = default) => diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaConnectionFactory.cs b/csharp/src/Drivers/Apache/Impala/ImpalaConnectionFactory.cs index 5a03231238..9ea165f0bd 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaConnectionFactory.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaConnectionFactory.cs @@ -17,12 +17,13 @@ using System; using System.Collections.Generic; +using Apache.Arrow.Adbc.Tracing; namespace Apache.Arrow.Adbc.Drivers.Apache.Impala { internal class ImpalaConnectionFactory { - public static ImpalaConnection NewConnection(IReadOnlyDictionary properties) + public static ImpalaConnection NewConnection(IReadOnlyDictionary properties, ActivityTrace trace) { if (!properties.TryGetValue(ImpalaParameters.Type, out string? type) && string.IsNullOrEmpty(type)) { @@ -34,8 +35,8 @@ public static ImpalaConnection NewConnection(IReadOnlyDictionary } return serverTypeValue switch { - ImpalaServerType.Http => new ImpalaHttpConnection(properties), - ImpalaServerType.Standard => new ImpalaStandardConnection(properties), + ImpalaServerType.Http => new ImpalaHttpConnection(properties, trace), + ImpalaServerType.Standard => new ImpalaStandardConnection(properties, trace), _ => throw new ArgumentOutOfRangeException(nameof(properties), $"Unsupported or unknown value '{type}' given for property '{ImpalaParameters.Type}'. Supported types: {ServerTypeParser.SupportedList}"), }; } diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs b/csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs index 9c36994592..6065b74795 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaDatabase.cs @@ -18,6 +18,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Apache.Arrow.Adbc.Tracing; namespace Apache.Arrow.Adbc.Drivers.Apache.Impala { @@ -25,11 +26,14 @@ internal class ImpalaDatabase : AdbcDatabase { readonly IReadOnlyDictionary properties; - internal ImpalaDatabase(IReadOnlyDictionary properties) + internal ImpalaDatabase(IReadOnlyDictionary properties, ActivityTrace trace) { this.properties = properties; + this.Trace = trace; } + protected ActivityTrace Trace { get; } + public override AdbcConnection Connect(IReadOnlyDictionary? options) { IReadOnlyDictionary mergedProperties = options == null @@ -37,7 +41,7 @@ public override AdbcConnection Connect(IReadOnlyDictionary? opti : options .Concat(properties.Where(x => !options.Keys.Contains(x.Key, StringComparer.OrdinalIgnoreCase))) .ToDictionary(kvp => kvp.Key, kvp => kvp.Value); - ImpalaConnection connection = ImpalaConnectionFactory.NewConnection(mergedProperties); + ImpalaConnection connection = ImpalaConnectionFactory.NewConnection(mergedProperties, Trace); connection.OpenAsync().Wait(); return connection; } diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaDriver.cs b/csharp/src/Drivers/Apache/Impala/ImpalaDriver.cs index 69674f0f2d..7bc0ed5b60 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaDriver.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaDriver.cs @@ -16,14 +16,28 @@ */ using System.Collections.Generic; +using Apache.Arrow.Adbc.Tracing; namespace Apache.Arrow.Adbc.Drivers.Apache.Impala { public class ImpalaDriver : AdbcDriver { + public ImpalaDriver(string? activitySourceName = default, string? traceParent = default) + { + Trace = new ActivityTrace(activitySourceName, traceParent); + } + public override AdbcDatabase Open(IReadOnlyDictionary parameters) { - return new ImpalaDatabase(parameters); + return new ImpalaDatabase(parameters, Trace); + } + + protected ActivityTrace Trace { get; } + + public override void Dispose() + { + Trace.Dispose(); + base.Dispose(); } } } diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs b/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs index f32cad0113..984cc271a5 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs @@ -27,6 +27,7 @@ using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Tracing; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; using Thrift; @@ -39,7 +40,8 @@ internal class ImpalaHttpConnection : ImpalaConnection { private const string BasicAuthenticationScheme = "Basic"; - public ImpalaHttpConnection(IReadOnlyDictionary properties) : base(properties) + public ImpalaHttpConnection(IReadOnlyDictionary properties, ActivityTrace trace) + : base(properties, trace) { } @@ -124,7 +126,7 @@ protected override void ValidateOptions() } } - internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion); + internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion, Trace); protected override TTransport CreateTransport() { diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs b/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs index 1c8cb78fdf..5f78e16a42 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs @@ -22,6 +22,7 @@ using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Tracing; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; using Thrift.Protocol; @@ -31,7 +32,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala { internal class ImpalaStandardConnection : ImpalaConnection { - public ImpalaStandardConnection(IReadOnlyDictionary properties) : base(properties) + public ImpalaStandardConnection(IReadOnlyDictionary properties, ActivityTrace trace) + : base(properties, trace) { } @@ -149,7 +151,7 @@ protected override TOpenSessionReq CreateSessionRequest() return request; } - internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion); + internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion, Trace); internal override ImpalaServerType ServerType => ImpalaServerType.Standard; diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs b/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs index 840fbe297b..f88c3d2bd8 100644 --- a/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs +++ b/csharp/src/Drivers/Apache/Impala/ImpalaStatement.cs @@ -16,13 +16,14 @@ */ using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Tracing; namespace Apache.Arrow.Adbc.Drivers.Apache.Impala { internal class ImpalaStatement : HiveServer2Statement { - internal ImpalaStatement(ImpalaConnection connection) - : base(connection) + internal ImpalaStatement(ImpalaConnection connection, ActivityTrace trace) + : base(connection, trace) { } diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs index b9b40dfd1e..926796f0d3 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs @@ -22,6 +22,7 @@ using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; using Apache.Arrow.Adbc.Extensions; +using Apache.Arrow.Adbc.Tracing; using Apache.Arrow.Ipc; using Apache.Arrow.Types; using Apache.Hive.Service.Rpc.Thrift; @@ -45,8 +46,8 @@ internal abstract class SparkConnection : HiveServer2Connection { "spark.thriftserver.arrowBasedRowSet.timestampAsString", "false" } }; - internal SparkConnection(IReadOnlyDictionary properties) - : base(properties) + internal SparkConnection(IReadOnlyDictionary properties, ActivityTrace trace) + : base(properties, trace) { ValidateProperties(); _productVersion = new Lazy(() => GetProductVersion(), LazyThreadSafetyMode.PublicationOnly); @@ -63,7 +64,7 @@ private void ValidateProperties() public override AdbcStatement CreateStatement() { - return new SparkStatement(this); + return new SparkStatement(this, Trace); } protected internal override int PositionRequiredOffset => 1; diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs b/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs index 4feaf4183b..2cb5fade09 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkConnectionFactory.cs @@ -17,12 +17,13 @@ using System; using System.Collections.Generic; +using Apache.Arrow.Adbc.Tracing; namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { internal class SparkConnectionFactory { - public static SparkConnection NewConnection(IReadOnlyDictionary properties) + public static SparkConnection NewConnection(IReadOnlyDictionary properties, ActivityTrace trace) { if (!properties.TryGetValue(SparkParameters.Type, out string? type) && string.IsNullOrEmpty(type)) { @@ -35,8 +36,8 @@ public static SparkConnection NewConnection(IReadOnlyDictionary return serverTypeValue switch { - SparkServerType.Databricks => new SparkDatabricksConnection(properties), - SparkServerType.Http => new SparkHttpConnection(properties), + SparkServerType.Databricks => new SparkDatabricksConnection(properties, trace), + SparkServerType.Http => new SparkHttpConnection(properties, trace), // TODO: Re-enable when properly supported //SparkServerType.Standard => new SparkStandardConnection(properties), _ => throw new ArgumentOutOfRangeException(nameof(properties), $"Unsupported or unknown value '{type}' given for property '{SparkParameters.Type}'. Supported types: {ServerTypeParser.SupportedList}"), diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabase.cs b/csharp/src/Drivers/Apache/Spark/SparkDatabase.cs index ff12cdb8e9..348f5b9c28 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkDatabase.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkDatabase.cs @@ -18,6 +18,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Apache.Arrow.Adbc.Tracing; namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { @@ -25,11 +26,14 @@ internal class SparkDatabase : AdbcDatabase { readonly IReadOnlyDictionary properties; - internal SparkDatabase(IReadOnlyDictionary properties) + internal SparkDatabase(IReadOnlyDictionary properties, ActivityTrace trace) { + Trace = trace; this.properties = properties; } + protected ActivityTrace Trace { get; } + public override AdbcConnection Connect(IReadOnlyDictionary? options) { // connection options takes precedence over database properties for the same option @@ -38,7 +42,7 @@ public override AdbcConnection Connect(IReadOnlyDictionary? opti : options .Concat(properties.Where(x => !options.Keys.Contains(x.Key, StringComparer.OrdinalIgnoreCase))) .ToDictionary(kvp => kvp.Key, kvp => kvp.Value); - SparkConnection connection = SparkConnectionFactory.NewConnection(mergedProperties); // new SparkConnection(mergedProperties); + SparkConnection connection = SparkConnectionFactory.NewConnection(mergedProperties, Trace); connection.OpenAsync().Wait(); return connection; } diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs index d51ef42b9b..d078b53617 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs @@ -18,6 +18,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Tracing; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; @@ -25,11 +26,12 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { internal class SparkDatabricksConnection : SparkHttpConnection { - public SparkDatabricksConnection(IReadOnlyDictionary properties) : base(properties) + public SparkDatabricksConnection(IReadOnlyDictionary properties, ActivityTrace trace) + : base(properties, trace) { } - internal override IArrowArrayStream NewReader(T statement, Schema schema) => new SparkDatabricksReader(statement, schema); + internal override IArrowArrayStream NewReader(T statement, Schema schema) => new SparkDatabricksReader(statement, schema, Trace); internal override SchemaParser SchemaParser => new SparkDatabricksSchemaParser(); diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs index 059ab1690b..251dc6b77d 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs @@ -16,11 +16,14 @@ */ using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Tracing; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; +using static Apache.Hive.Service.Rpc.Thrift.TCLIService; namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { @@ -32,51 +35,64 @@ internal sealed class SparkDatabricksReader : IArrowArrayStream int index; IArrowReader? reader; - public SparkDatabricksReader(HiveServer2Statement statement, Schema schema) + public SparkDatabricksReader(HiveServer2Statement statement, Schema schema, ActivityTrace trace) { this.statement = statement; this.schema = schema; + this.Trace = trace; } + private ActivityTrace Trace { get; } + public Schema Schema { get { return schema; } } public async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) { - while (true) + return await Trace.TraceActivityAsync(async (activity) => { - if (this.reader != null) + while (true) { - RecordBatch? next = await this.reader.ReadNextRecordBatchAsync(cancellationToken); - if (next != null) + if (this.reader != null) { - return next; + RecordBatch? next = await this.reader.ReadNextRecordBatchAsync(cancellationToken); + if (next != null) + { + return next; + } + this.reader = null; } - this.reader = null; - } - if (this.batches != null && this.index < this.batches.Count) - { - this.reader = new ArrowStreamReader(new ChunkStream(this.schema, this.batches[this.index++].Batch)); - continue; - } + if (this.batches != null && this.index < this.batches.Count) + { + this.reader = new ArrowStreamReader(new ChunkStream(this.schema, this.batches[this.index++].Batch)); + continue; + } - this.batches = null; - this.index = 0; + this.batches = null; + this.index = 0; - if (this.statement == null) - { - return null; - } + if (this.statement == null) + { + return null; + } - TFetchResultsReq request = new TFetchResultsReq(this.statement.OperationHandle, TFetchOrientation.FETCH_NEXT, this.statement.BatchSize); - TFetchResultsResp response = await this.statement.Connection.Client!.FetchResults(request, cancellationToken); - this.batches = response.Results.ArrowBatches; + activity?.AddEvent("db.operation.name.start", [new("db.operation.name", nameof(Client.FetchResults))]); + TFetchResultsReq request = new TFetchResultsReq(this.statement.OperationHandle, TFetchOrientation.FETCH_NEXT, this.statement.BatchSize); + TFetchResultsResp response = await this.statement.Connection.Client!.FetchResults(request, cancellationToken); + activity?.AddEvent("db.operation.name.end", + [ + new("db.operation.name", nameof(Client.FetchResults)), + new("db.response.status_code", response.Status.StatusCode.ToString()), + new("db.client.response.returned_rows", response.Results.ArrowBatches.Sum(b => b.RowCount)) + ]); + this.batches = response.Results.ArrowBatches; - if (!response.HasMoreRows) - { - this.statement = null; + if (!response.HasMoreRows) + { + this.statement = null; + } } - } + }); } public void Dispose() diff --git a/csharp/src/Drivers/Apache/Spark/SparkDriver.cs b/csharp/src/Drivers/Apache/Spark/SparkDriver.cs index 359f654349..f6da7d92ba 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkDriver.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkDriver.cs @@ -16,14 +16,29 @@ */ using System.Collections.Generic; +using System.Diagnostics; +using Apache.Arrow.Adbc.Tracing; namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { public class SparkDriver : AdbcDriver { + public SparkDriver(string? activitySourceName = default, string? traceParent = default) + { + Trace = new ActivityTrace(activitySourceName, traceParent); + } + public override AdbcDatabase Open(IReadOnlyDictionary parameters) { - return new SparkDatabase(parameters); + return new SparkDatabase(parameters, Trace); + } + + protected ActivityTrace Trace { get; } + + public override void Dispose() + { + Trace.Dispose(); + base.Dispose(); } } } diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs index 75abb1196b..fbac2d901a 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs @@ -26,6 +26,7 @@ using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Tracing; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; using Thrift; @@ -40,7 +41,8 @@ internal class SparkHttpConnection : SparkConnection private const string BasicAuthenticationScheme = "Basic"; private const string BearerAuthenticationScheme = "Bearer"; - public SparkHttpConnection(IReadOnlyDictionary properties) : base(properties) + public SparkHttpConnection(IReadOnlyDictionary properties, ActivityTrace trace) + : base(properties, trace) { } @@ -132,45 +134,52 @@ protected override void ValidateOptions() } } - internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion); + internal override IArrowArrayStream NewReader(T statement, Schema schema) => new HiveServer2Reader(statement, schema, dataTypeConversion: statement.Connection.DataTypeConversion, Trace); protected override TTransport CreateTransport() { - // Assumption: parameters have already been validated. - Properties.TryGetValue(SparkParameters.HostName, out string? hostName); - Properties.TryGetValue(SparkParameters.Path, out string? path); - Properties.TryGetValue(SparkParameters.Port, out string? port); - Properties.TryGetValue(SparkParameters.AuthType, out string? authType); - if (!SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue)) + return Trace.TraceActivity((activity) => { - throw new ArgumentOutOfRangeException(SparkParameters.AuthType, authType, $"Unsupported {SparkParameters.AuthType} value."); - } - Properties.TryGetValue(SparkParameters.Token, out string? token); - Properties.TryGetValue(AdbcOptions.Username, out string? username); - Properties.TryGetValue(AdbcOptions.Password, out string? password); - Properties.TryGetValue(AdbcOptions.Uri, out string? uri); + // Assumption: parameters have already been validated. + Properties.TryGetValue(SparkParameters.HostName, out string? hostName); + Properties.TryGetValue(SparkParameters.Path, out string? path); + Properties.TryGetValue(SparkParameters.Port, out string? port); + Properties.TryGetValue(SparkParameters.AuthType, out string? authType); + if (!SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue)) + { + throw new ArgumentOutOfRangeException(SparkParameters.AuthType, authType, $"Unsupported {SparkParameters.AuthType} value."); + } + Properties.TryGetValue(SparkParameters.Token, out string? token); + Properties.TryGetValue(AdbcOptions.Username, out string? username); + Properties.TryGetValue(AdbcOptions.Password, out string? password); + Properties.TryGetValue(AdbcOptions.Uri, out string? uri); - Uri baseAddress = GetBaseAddress(uri, hostName, path, port, SparkParameters.HostName); - AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, token, username, password); + Uri baseAddress = GetBaseAddress(uri, hostName, path, port, SparkParameters.HostName); + AuthenticationHeaderValue? authenticationHeaderValue = GetAuthenticationHeaderValue(authTypeValue, token, username, password); - HttpClientHandler httpClientHandler = NewHttpClientHandler(); - HttpClient httpClient = new(httpClientHandler); - httpClient.BaseAddress = baseAddress; - httpClient.DefaultRequestHeaders.Authorization = authenticationHeaderValue; - httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(s_userAgent); - httpClient.DefaultRequestHeaders.AcceptEncoding.Clear(); - httpClient.DefaultRequestHeaders.AcceptEncoding.Add(new StringWithQualityHeaderValue("identity")); - httpClient.DefaultRequestHeaders.ExpectContinue = false; + HttpClientHandler httpClientHandler = NewHttpClientHandler(); + HttpClient httpClient = new(httpClientHandler); + httpClient.BaseAddress = baseAddress; + httpClient.DefaultRequestHeaders.Authorization = authenticationHeaderValue; + httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(s_userAgent); + httpClient.DefaultRequestHeaders.AcceptEncoding.Clear(); + httpClient.DefaultRequestHeaders.AcceptEncoding.Add(new StringWithQualityHeaderValue("identity")); + httpClient.DefaultRequestHeaders.ExpectContinue = false; - TConfiguration config = new(); - ThriftHttpTransport transport = new(httpClient, config) - { - // This value can only be set before the first call/request. So if a new value for query timeout - // is set, we won't be able to update the value. Setting to ~infinite and relying on cancellation token - // to ensure cancelled correctly. - ConnectTimeout = int.MaxValue, - }; - return transport; + TConfiguration config = new(); + ThriftHttpTransport transport = new(httpClient, config) + { + // This value can only be set before the first call/request. So if a new value for query timeout + // is set, we won't be able to update the value. Setting to ~infinite and relying on cancellation token + // to ensure cancelled correctly. + ConnectTimeout = int.MaxValue, + }; + activity?.AddTag("network.protocol.name", baseAddress.Scheme) + .AddTag("server.address", baseAddress.Host) + .AddTag("server.port", baseAddress.Port) + .AddTag("server.authentication.type", authType); + return transport; + }); } private HttpClientHandler NewHttpClientHandler() diff --git a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs index 2c28ea8e13..37e9707b3d 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkStandardConnection.cs @@ -20,6 +20,7 @@ using System.Net; using System.Threading; using System.Threading.Tasks; +using Apache.Arrow.Adbc.Tracing; using Apache.Hive.Service.Rpc.Thrift; using Thrift.Protocol; using Thrift.Transport; @@ -28,7 +29,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { internal class SparkStandardConnection : SparkHttpConnection { - public SparkStandardConnection(IReadOnlyDictionary properties) : base(properties) + public SparkStandardConnection(IReadOnlyDictionary properties, ActivityTrace trace) + : base(properties, trace) { } diff --git a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs index 5decbddb02..c1a290d13a 100644 --- a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs +++ b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs @@ -17,14 +17,15 @@ using System.Collections.Generic; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Tracing; using Apache.Hive.Service.Rpc.Thrift; namespace Apache.Arrow.Adbc.Drivers.Apache.Spark { internal class SparkStatement : HiveServer2Statement { - internal SparkStatement(SparkConnection connection) - : base(connection) + internal SparkStatement(SparkConnection connection, ActivityTrace trace) + : base(connection, trace) { } diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Apache.Arrow.Adbc.Tests.csproj b/csharp/test/Apache.Arrow.Adbc.Tests/Apache.Arrow.Adbc.Tests.csproj index e0637dc8bb..6a7ab2f726 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/Apache.Arrow.Adbc.Tests.csproj +++ b/csharp/test/Apache.Arrow.Adbc.Tests/Apache.Arrow.Adbc.Tests.csproj @@ -1,4 +1,4 @@ - + net8.0;net472 @@ -11,6 +11,8 @@ + + @@ -18,6 +20,7 @@ runtime; build; native; contentfiles; analyzers + diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Tracing/TracingTests.cs b/csharp/test/Apache.Arrow.Adbc.Tests/Tracing/TracingTests.cs new file mode 100644 index 0000000000..cddb569694 --- /dev/null +++ b/csharp/test/Apache.Arrow.Adbc.Tests/Tracing/TracingTests.cs @@ -0,0 +1,356 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Text; +using System.Text.Json; +using Apache.Arrow.Adbc.Tracing; +using OpenTelemetry; +using OpenTelemetry.Trace; +using Xunit; +using Xunit.Abstractions; + +namespace Apache.Arrow.Adbc.Tests.Tracing +{ + public class TracingTests(ITestOutputHelper? outputHelper) : IDisposable + { + private readonly ITestOutputHelper? _outputHelper = outputHelper; + private bool _disposed; + + [Fact] + internal void CanStartActivity() + { + string activitySourceName = NewName(); + using MemoryStream stream = new(); + using TracerProvider provider = Sdk.CreateTracerProviderBuilder() + .AddSource(activitySourceName) + .AddTestMemoryExporter(stream) + .Build(); + + var testClass = new TraceInheritor(activitySourceName); + testClass.MethodWithNoInstrumentation(); + Assert.Equal(0, stream.Length); + + testClass.MethodWithActivity(); + Assert.True(stream.Length > 0); + long currLength = stream.Length; + + testClass.MethodWithNoInstrumentation(); + Assert.Equal(currLength, stream.Length); + + stream.Seek(0, SeekOrigin.Begin); + StreamReader reader = new(stream); + + int lineCount = 0; + string? text = reader.ReadLine(); + while (text != null) + { + lineCount++; + SerializableActivity? activity = JsonSerializer.Deserialize(text); + Assert.NotNull(activity); + Assert.Contains(nameof(TraceInheritor.MethodWithActivity), activity.OperationName); + Assert.DoesNotContain(nameof(TraceInheritor.MethodWithNoInstrumentation), activity.OperationName); + text = reader.ReadLine(); + } + Assert.Equal(1, lineCount); + } + + [Fact] + internal void CanAddEvent() + { + string activitySourceName = NewName(); + using MemoryStream stream = new(); + using TracerProvider provider = Sdk.CreateTracerProviderBuilder() + .AddSource(activitySourceName) + .AddTestMemoryExporter(stream) + .Build(); + + var testClass = new TraceInheritor(activitySourceName); + testClass.MethodWithNoInstrumentation(); + Assert.Equal(0, stream.Length); + + string eventName = NewName(); + testClass.MethodWithEvent(eventName); + Assert.True(stream.Length > 0); + long currLength = stream.Length; + + testClass.MethodWithNoInstrumentation(); + Assert.Equal(currLength, stream.Length); + + stream.Seek(0, SeekOrigin.Begin); + StreamReader reader = new(stream); + + int lineCount = 0; + string? text = reader.ReadLine(); + while (text != null) + { + lineCount++; + Assert.Contains(nameof(TraceInheritor.MethodWithEvent), text); + Assert.DoesNotContain(nameof(TraceInheritor.MethodWithNoInstrumentation), text); + Assert.Contains(eventName, text); + text = reader.ReadLine(); + } + Assert.Equal(1, lineCount); + } + + [Fact] + internal void CanSerializeDeserializeActivity() + { + string activitySourceName = NewName(); + using MemoryStream stream = new(); + using TracerProvider provider = Sdk.CreateTracerProviderBuilder() + .AddSource(activitySourceName) + .AddTestMemoryExporter(stream) + .Build(); + + var testClass = new TraceInheritor(activitySourceName); + string activityName = NewName(); + string eventName = NewName(); + const string rootId = "3236da27af79882bd317c4d1c3776982"; + string traceParent = $"00-{rootId}-a3cc9bd52ccd58e6-01"; + IReadOnlyList> tags = + [ + new (NewName(), NewName()), + new (NewName(), NewName()), + ]; + testClass.MethodWithAllProperties(activityName, eventName, tags, traceParent); + stream.Seek(0, SeekOrigin.Begin); + StreamReader reader = new(stream); + + int lineCount = 0; + string? text = reader.ReadLine(); + while (text != null) + { + lineCount++; + SerializableActivity? activity = JsonSerializer.Deserialize(text); + Assert.NotNull(activity); + string activityJson = JsonSerializer.Serialize(activity); + Assert.Equal(rootId, activity.TraceId); + Assert.Equal(rootId, activity.RootId); + Assert.Contains(rootId, activity.ParentId); + Assert.True(activity.HasRemoteParent); + Assert.Equal(text, activityJson); + + text = reader.ReadLine(); + } + Assert.Equal(1, lineCount); + } + + [Fact] + internal void CanAddActivityWithDepth() + { + string activitySourceName = NewName(); + using MemoryStream stream = new(); + using TracerProvider provider = Sdk.CreateTracerProviderBuilder() + .AddSource(activitySourceName) + .AddTestMemoryExporter(stream) + .Build(); + + var testClass = new TraceInheritor(activitySourceName); + const int recurseCount = 5; + testClass.MethodWithActivityRecursive(nameof(TraceInheritor.MethodWithActivityRecursive), recurseCount); + + stream.Seek(0, SeekOrigin.Begin); + StreamReader reader = new(stream); + + int lineCount = 0; + string? text = reader.ReadLine(); + while (text != null) + { + if (string.IsNullOrWhiteSpace(text)) continue; + lineCount++; + Assert.Contains(nameof(TraceInheritor.MethodWithActivityRecursive), text); + Assert.DoesNotContain(nameof(TraceInheritor.MethodWithNoInstrumentation), text); + SerializableActivity? activity = JsonSerializer.Deserialize(text); + Assert.Contains(nameof(TraceInheritor.MethodWithActivityRecursive), activity?.OperationName); + Assert.NotNull(activity); + text = reader.ReadLine(); + } + Assert.Equal(recurseCount, lineCount); + } + + [Fact] + internal void CanAddTraceParent() + { + string activitySourceName = NewName(); + using MemoryStream stream = new(); + stream.SetLength(0); + using TracerProvider provider1 = Sdk.CreateTracerProviderBuilder() + .AddSource(activitySourceName) + .AddTestMemoryExporter(stream) + .Build(); + + var testClass = new TraceInheritor(activitySourceName); + testClass.MethodWithNoInstrumentation(); + Assert.Equal(0, stream.Length); + + const string eventNameWithParent = "eventNameWithParent"; + const string eventNameWithoutParent = "eventNameWithoutParent"; + testClass.MethodWithActivity(eventNameWithoutParent); + Assert.True(stream.Length > 0); + + const string traceParent = "00-3236da27af79882bd317c4d1c3776982-a3cc9bd52ccd58e6-01"; + + testClass.SetTraceParent(traceParent); + const int withParentCountExpected = 10; + for (int i = 0; i < withParentCountExpected; i++) + { + testClass.MethodWithActivity(eventNameWithParent); + } + testClass.SetTraceParent(null); + + testClass.MethodWithActivity(eventNameWithoutParent); + Assert.True(stream.Length > 0); + + stream.Seek(0, SeekOrigin.Begin); + StreamReader reader = new(stream); + + int lineCount = 0; + int withParentCount = 0; + int withoutParentCount = 0; + string? text = reader.ReadLine(); + while (text != null) + { + lineCount++; + SerializableActivity? clientActivity = JsonSerializer.Deserialize(text); + Assert.NotNull(clientActivity); + if (clientActivity.OperationName.Contains(eventNameWithoutParent)) + { + withoutParentCount++; + Assert.Null(clientActivity.ParentId); + } + else if (clientActivity.OperationName.Contains(eventNameWithParent)) + { + withParentCount++; + Assert.Equal(traceParent, clientActivity.ParentId); + } + text = reader.ReadLine(); + } + Assert.Equal(2, withoutParentCount); + Assert.Equal(withParentCountExpected, withParentCount); + } + + internal static string NewName() => Guid.NewGuid().ToString().Replace("-", "").ToLower(); + + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + } + _disposed = true; + } + } + + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + private class TraceInheritor : ActivityTrace + { + internal TraceInheritor(string? activitySourceName = default) : base(activitySourceName) { } + + internal void MethodWithNoInstrumentation() + { + + } + + internal void MethodWithActivity() + { + TraceActivity(_ => { }); + } + + internal void MethodWithActivity(string activityName) + { + TraceActivity(activity => { }, activityName: activityName); + } + + internal void MethodWithActivityRecursive(string activityName, int recurseCount) + { + TraceActivity(_ => + { + recurseCount--; + if (recurseCount > 0) + { + MethodWithActivityRecursive(activityName, recurseCount); + } + }, activityName: activityName + recurseCount.ToString()); + } + + internal void MethodWithEvent(string eventName) + { + TraceActivity((activity) => activity?.AddEvent(eventName)); + } + + internal void MethodWithAllProperties( + string activityName, + string eventName, + IReadOnlyList> tags, + string traceParent) + { + TraceActivity(activity => + { + foreach (KeyValuePair tag in tags) + { + activity?.AddTag(tag.Key, tag.Value) + .AddBaggage(tag.Key, tag.Value?.ToString()); + } + activity?.AddEvent(eventName, tags) + .AddLink(traceParent, tags); + }, activityName: activityName, traceParent: traceParent); + } + + internal void SetTraceParent(string? traceParent) + { + TraceParent = traceParent; + } + } + + internal class MemoryStreamExporter(MemoryStream stream) : BaseExporter + { + private readonly MemoryStream _stream = stream; + + public override ExportResult Export(in Batch batch) + { + byte[] newLine = Encoding.UTF8.GetBytes(Environment.NewLine); + foreach (Activity activity in batch) + { + var sa = new SerializableActivity(activity); + byte[] jsonString = JsonSerializer.SerializeToUtf8Bytes(sa); + _stream.Write(jsonString, 0, jsonString.Length); + _stream.Write(newLine, 0, newLine.Length); + } + return ExportResult.Success; + } + } + } + + public static class AdbcMemoryTestExporterExtensions + { + public static TracerProviderBuilder AddTestMemoryExporter(this TracerProviderBuilder builder, MemoryStream stream) + { + return builder.AddProcessor(sp => new SimpleActivityExportProcessor(new TracingTests.MemoryStreamExporter(stream))); + } + } +} diff --git a/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs b/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs index 417885e0b9..004a7fca2e 100644 --- a/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs +++ b/csharp/test/Drivers/Apache/Spark/SparkConnectionTest.cs @@ -133,6 +133,39 @@ internal void MetadataTimeoutTest(MetadataWithExceptions metadataWithException) } } + [SkippableTheory] + [InlineData(HiveServer2Parameters.TraceParent, null)] + [InlineData(HiveServer2Parameters.TraceParent, "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01")] + [InlineData(HiveServer2Parameters.TraceParent, "0af7651916cd43dd8448eb211c80319c", typeof(ArgumentOutOfRangeException))] + [InlineData(HiveServer2Parameters.TraceParent, "invalid-traceparent", typeof(ArgumentOutOfRangeException))] + [InlineData(HiveServer2Parameters.AuthType, null, typeof(InvalidOperationException))] + [InlineData(HiveServer2Parameters.ConnectTimeoutMilliseconds, null, typeof(InvalidOperationException))] + [InlineData(HiveServer2Parameters.DataTypeConv, null, typeof(InvalidOperationException))] + [InlineData(HiveServer2Parameters.HostName, null, typeof(InvalidOperationException))] + [InlineData(HiveServer2Parameters.Path, null, typeof(InvalidOperationException))] + [InlineData(HiveServer2Parameters.Port, null, typeof(InvalidOperationException))] + [InlineData(HiveServer2Parameters.TLSOptions, null, typeof(InvalidOperationException))] + [InlineData(HiveServer2Parameters.TransportType, null, typeof(InvalidOperationException))] + [InlineData("invalid.option", null, typeof(ArgumentOutOfRangeException))] + internal void SetOptionTest(string key, string? value, Type? exceptionType = default) + { + HiveServer2Connection connection = (HiveServer2Connection)TestEnvironment.Connection; + if (exceptionType == null) + { + connection.SetOption(key, value!); + switch (key.ToLower()) + { + case HiveServer2Parameters.TraceParent: + Assert.Equal(value, connection.Trace.TraceParent); + break; + } + } + else + { + OutputHelper?.WriteLine(Assert.Throws(exceptionType, () => connection.SetOption(key, value!)).Message); + } + } + /// /// Data type used for metadata timeout tests. ///