diff --git a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs index 62961901ff..81d40564ea 100644 --- a/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs +++ b/src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeWrapper.cs @@ -166,7 +166,7 @@ internal static unsafe uint SniOpenSyncEx( string connString, ref IntPtr pConn, ref string spn, - byte[] instanceName, + ref string instanceName, bool fOverrideCache, bool fSync, int timeout, @@ -181,9 +181,13 @@ internal static unsafe uint SniOpenSyncEx( SQLDNSInfo cachedDnsInfo, string hostNameInCertificate) { - fixed (byte* pInstanceName = instanceName) + // Size of this buffer is as specified by netlibs. + ReadOnlySpan instanceNameBuffer = stackalloc byte[256]; + + fixed (byte* pInstanceName = instanceNameBuffer) { SniClientConsumerInfo clientConsumerInfo = new SniClientConsumerInfo(); + uint result; // initialize client ConsumerInfo part first MarshalConsumerInfo(consumerInfo, ref clientConsumerInfo.ConsumerInfo); @@ -192,7 +196,7 @@ internal static unsafe uint SniOpenSyncEx( clientConsumerInfo.HostNameInCertificate = hostNameInCertificate; clientConsumerInfo.networkLibrary = Prefix.UNKNOWN_PREFIX; clientConsumerInfo.szInstanceName = pInstanceName; - clientConsumerInfo.cchInstanceName = (uint)instanceName.Length; + clientConsumerInfo.cchInstanceName = (uint)instanceNameBuffer.Length; clientConsumerInfo.fOverrideLastConnectCache = fOverrideCache; clientConsumerInfo.fSynchronousConnection = fSync; clientConsumerInfo.timeout = timeout; @@ -231,22 +235,21 @@ internal static unsafe uint SniOpenSyncEx( { // An empty string implies we need to find the SPN so we supply a buffer for the max size var array = ArrayPool.Shared.Rent(SniMaxComposedSpnLength); - array.AsSpan().Clear(); + Span arraySpan = array.AsSpan(); + arraySpan.Clear(); try { - fixed (byte* pin_spnBuffer = array) + fixed (byte* pin_spnBuffer = arraySpan) { clientConsumerInfo.szSPN = pin_spnBuffer; clientConsumerInfo.cchSPN = (uint)SniMaxComposedSpnLength; - var result = s_nativeMethods.SniOpenSyncExWrapper(ref clientConsumerInfo, out pConn); - if (result == 0) + result = s_nativeMethods.SniOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + if (result is TdsEnums.SNI_SUCCESS) { - spn = Encoding.Unicode.GetString(array).TrimEnd('\0'); + spn = Encoding.Unicode.CreateStringFromNullTerminated(arraySpan); } - - return result; } } finally @@ -270,7 +273,7 @@ internal static unsafe uint SniOpenSyncEx( { clientConsumerInfo.szSPN = pin_spnBuffer; clientConsumerInfo.cchSPN = (uint)writer.WrittenCount; - return s_nativeMethods.SniOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + result = s_nativeMethods.SniOpenSyncExWrapper(ref clientConsumerInfo, out pConn); } } finally @@ -279,9 +282,18 @@ internal static unsafe uint SniOpenSyncEx( } } } + else + { + // Otherwise leave szSPN null (SQL Auth) + result = s_nativeMethods.SniOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + } + + if (result is TdsEnums.SNI_SUCCESS) + { + instanceName = Encoding.UTF8.CreateStringFromNullTerminated(instanceNameBuffer); + } - // Otherwise leave szSPN null (SQL Auth) - return s_nativeMethods.SniOpenSyncExWrapper(ref clientConsumerInfo, out pConn); + return result; } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs index e62541137c..a830969079 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs @@ -248,8 +248,6 @@ internal class SqlConnectionInternal : DbConnectionInternal, IDisposable /// private readonly DbConnectionPoolIdentity _identity; - private string _instanceName = string.Empty; - private SqlLoginAck _loginAck; /// @@ -559,11 +557,9 @@ internal bool IgnoreEnvChange get => RoutingInfo != null; } - // @TODO: Make auto-property - internal string InstanceName - { - get => _instanceName; - } + internal string UserInstanceName { get; private set; } = string.Empty; + + internal string InstanceName { get; set; } internal bool Is2008OrNewer { @@ -1211,7 +1207,7 @@ internal void OnEnvChange(SqlEnvChange rec) break; case TdsEnums.ENV_USERINSTANCE: - _instanceName = rec._newValue; + UserInstanceName = rec._newValue; break; case TdsEnums.ENV_ROUTING: @@ -3301,7 +3297,7 @@ private void LoginNoFailover( _currentLanguage = _originalLanguage = ConnectionOptions.CurrentLanguage; CurrentDatabase = _originalDatabase = ConnectionOptions.InitialCatalog; ServerProvidedFailoverPartner = null; - _instanceName = string.Empty; + UserInstanceName = string.Empty; routingAttempts++; @@ -3615,7 +3611,7 @@ private void LoginWithFailover( _currentLanguage = _originalLanguage = ConnectionOptions.CurrentLanguage; CurrentDatabase = _originalDatabase = connectionOptions.InitialCatalog; ServerProvidedFailoverPartner = null; - _instanceName = string.Empty; + UserInstanceName = string.Empty; AttemptOneLogin( currentServerInfo, diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Diagnostics/DiagnosticScope.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Diagnostics/DiagnosticScope.cs index 8222e662c1..bb4efd4f84 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Diagnostics/DiagnosticScope.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Diagnostics/DiagnosticScope.cs @@ -5,33 +5,44 @@ using System; using System.Runtime.CompilerServices; +#nullable enable + namespace Microsoft.Data.SqlClient.Diagnostics { + /// + /// Provides a scope for emitting diagnostic events related to SQL command and connection operations. Used to track + /// the start, completion, and error states of database operations for diagnostic listeners. + /// + /// + /// A DiagnosticScope is typically created using the method + /// and is intended to be used in a using statement to ensure proper disposal. Disposing the scope emits the + /// appropriate completion or error event based on whether an exception was set. + /// internal ref struct DiagnosticScope : IDisposable { private const int CommandOperation = 1; private const int ConnectionOpenOperation = 2; private readonly object _context1; - private readonly object _context2; + private readonly object? _context2; private readonly SqlDiagnosticListener _diagnostics; private readonly int _operation; private readonly Guid _operationId; private readonly string _operationName; - private Exception _exception; + private Exception? _exception; private DiagnosticScope( SqlDiagnosticListener diagnostics, int operation, - Guid operationsId, + Guid operationId, string operationName, object context1, - object context2) + object? context2) { _diagnostics = diagnostics; _operation = operation; - _operationId = operationsId; + _operationId = operationId; _operationName = operationName; _context1 = context1; _context2 = context2; @@ -41,7 +52,7 @@ private DiagnosticScope( public static DiagnosticScope CreateCommandScope( SqlDiagnosticListener diagnostics, SqlCommand command, - SqlTransaction transaction, + SqlTransaction? transaction, [CallerMemberName] string operationName = "") { @@ -49,19 +60,17 @@ public static DiagnosticScope CreateCommandScope( return new DiagnosticScope(diagnostics, CommandOperation, operationId, operationName, command, transaction); } - // Although ref structs do not allow for inheriting from interfaces (< C#13), but the - // compiler will know to treat this like an IDisposable (> C# 8) - public void Dispose() + public readonly void Dispose() { switch (_operation) { case CommandOperation: - if (_exception != null) + if (_exception is not null) { _diagnostics.WriteCommandError( _operationId, (SqlCommand)_context1, - (SqlTransaction)_context2, + (SqlTransaction?)_context2, _exception, _operationName); } @@ -70,13 +79,13 @@ public void Dispose() _diagnostics.WriteCommandAfter( _operationId, (SqlCommand)_context1, - (SqlTransaction)_context2, + (SqlTransaction?)_context2, _operationName); } break; case ConnectionOpenOperation: - if (_exception != null) + if (_exception is not null) { _diagnostics.WriteConnectionOpenError( _operationId, diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Diagnostics/TelemetryConstants.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Diagnostics/TelemetryConstants.cs new file mode 100644 index 0000000000..a3bb34a795 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Diagnostics/TelemetryConstants.cs @@ -0,0 +1,150 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +namespace Microsoft.Data.SqlClient.Diagnostics; + +internal static class TelemetryAttributes +{ + /// + /// Attributes prefixed with "db." and "sqlclient.db.". + /// + /// + /// + public static class Database + { + private const string StandardsPrefix = "db."; + + private const string LibrarySpecificPrefix = "sqlclient.db."; + + /// + /// This must always be . + /// + public const string SystemName = $"{StandardsPrefix}system.name"; + + /// + /// This is always in the format {instance name}|{database name}. + /// + /// + /// + /// The instance name is not included if the default instance is used. + /// + /// + /// This namespace is a logical construct, and does not necessarily reflect + /// the instance which the connection is currently connected to, or the value specified + /// in the connection string: + /// + /// + /// An AlwaysOn Availability Group between two named SQL Server instances which are accessed + /// through a listener. The namespace will not contain an instance name; it will only be aware + /// of connecting to the default instance on the listener. + /// + /// + /// Database mirroring between two SQL Server instances. The failover partner details supplied + /// by the server will contain a port number, not an instance name. The namespace will contain + /// an instance name prior to failover, and will not contain an instance name after failover. + /// + /// + /// On Windows, configuring a client-side alias for a non-default SQL Server instance. The connection + /// string will not contain an instance name, but alias resolution will discover it. The namespace + /// will contain an instance name. + /// + /// + /// An AlwaysOn Availability Group between two named SQL Server instances, accessed directly and + /// using read-only routing. The namespace will contain an instance name when connected using a + /// connection intent of ReadWrite, and will not contain an instance name when connected + /// using a connection intent of ReadOnly (since read-only routing will re-route the connection + /// to a server based on port number, and will not provide an instance name.) + /// + /// + /// + /// + /// In these more complex scenarios, clients are recommended to use the + /// and attributes to uniquely identify the server being connected to, and + /// the attribute to identify the database being used. + /// + /// + public const string Namespace = $"{StandardsPrefix}namespace"; + + public const string DatabaseName = $"{LibrarySpecificPrefix}database.name"; + + public const string OperationName = $"{StandardsPrefix}operation.name"; + + public const string StoredProcedureName = $"{StandardsPrefix}stored_procedure.name"; + + public const string ResponseStatusCode = $"{StandardsPrefix}response.status_code"; + + public const string QueryText = $"{StandardsPrefix}query.text"; + + public const string OperationBatchSize = $"{StandardsPrefix}operation.batch.size"; + } + + /// + /// Attributes prefixed with "error.". + /// + /// + /// + public static class Error + { + private const string StandardsPrefix = "error."; + + public const string Type = $"{StandardsPrefix}type"; + } + + /// + /// Attributes prefixed with "exception.". + /// + /// + /// + /// + /// + public static class Exception + { + private const string StandardsPrefix = "exception."; + + public const string Type = $"{StandardsPrefix}type"; + + public const string Message = $"{StandardsPrefix}message"; + + public const string StackTrace = $"{StandardsPrefix}stacktrace"; + } + + /// + /// Attributes prefixed with "server.". + /// + /// + /// + public static class Server + { + private const string StandardsPrefix = "server."; + + public const string Address = $"{StandardsPrefix}address"; + + public const string Port = $"{StandardsPrefix}port"; + } +} + +internal static class TelemetryAttributeValues +{ + /// + /// Values for attributes described in . + /// + /// + public static class Database + { + public const string SystemName = "microsoft.sql_server"; + + public const string ExecuteOperationName = "EXECUTE"; + } +} + +internal static class TelemetryEventNames +{ + /// + /// The name of an exception event. + /// + /// + public const string Exception = "exception"; +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ManagedSni/SniProxy.netcore.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ManagedSni/SniProxy.netcore.cs index da6885062d..f83741fb6f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ManagedSni/SniProxy.netcore.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ManagedSni/SniProxy.netcore.cs @@ -50,7 +50,7 @@ internal class SniProxy internal static SniHandle CreateConnectionHandle( string fullServerName, TimeoutTimer timeout, - out byte[] instanceName, + out string instanceName, out ResolvedServerSpn resolvedSpn, string serverSPN, bool flushCache, @@ -64,7 +64,7 @@ internal static SniHandle CreateConnectionHandle( string hostNameInCertificate, string serverCertificateFilename) { - instanceName = new byte[1]; + instanceName = null; resolvedSpn = default; bool errorWithLocalDBProcessing; @@ -100,6 +100,8 @@ internal static SniHandle CreateConnectionHandle( break; } + instanceName = details.InstanceName; + if (isIntegratedSecurity) { try diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Batch.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Batch.cs index 7b6d16a420..3f021697e2 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Batch.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Batch.cs @@ -60,6 +60,9 @@ internal SqlBatchCommand GetCurrentBatchCommand() internal int GetCurrentBatchIndex() => _batchRPCMode ? _currentlyExecutingBatch : -1; + internal int BatchSize => + _batchRPCMode ? _RPCList.Capacity : 1; + // @TODO: Indicate this is for batch RPC usage internal SqlException GetErrors(int commandIndex) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs index d9f6c33233..a6f418f39d 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs @@ -111,10 +111,7 @@ public SqlDataReader EndExecuteReader(IAsyncResult asyncResult) // between entry into Execute* API and the thread obtaining the stateObject. _pendingCancel = false; - // @TODO: Do we want to use a command scope here like nonquery and xml? or is operation id ok? - Guid operationId = s_diagnosticListener.WriteCommandBefore(this, _transaction); - Exception e = null; - + using var diagnosticScope = s_diagnosticListener.CreateCommandScope(this, _transaction); using var eventScope = SqlClientEventScope.Create($"SqlCommand.ExecuteReader | API | Object Id {ObjectID}"); // @TODO: Do we want to have a correlation trace event here like nonquery and xml? // @TODO: Basically, this doesn't follow the same pattern as nonquery, scalar, or xml. Doesn't seem right. @@ -137,7 +134,7 @@ public SqlDataReader EndExecuteReader(IAsyncResult asyncResult) // @TODO: CER Exception Handling was removed here (see GH#3581) catch (Exception ex) { - e = ex; + diagnosticScope.SetException(ex); if (ex is SqlException sqlException) { @@ -150,15 +147,6 @@ public SqlDataReader EndExecuteReader(IAsyncResult asyncResult) { SqlStatistics.StopTimer(statistics); WriteEndExecuteEvent(success, sqlExceptionNumber, isSynchronous: true); - - if (e is not null) - { - s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e); - } - else - { - s_diagnosticListener.WriteCommandAfter(operationId, this, _transaction); - } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs index 30010d0d4b..97ed1ae465 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs @@ -639,7 +639,7 @@ protected virtual DbConnectionInternal CreateConnection( { // NOTE: Retrieve here. This user instance name will be // used below to connect to the SQL Express User Instance. - instanceName = sseConnection.InstanceName; + instanceName = sseConnection.UserInstanceName; // Set future transient fault handling based on connection options sqlOwningConnection._applyTransientFaultHandling = opt != null && opt.ConnectRetryCount > 0; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index 53b89fbc90..c366a383b8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -452,7 +452,7 @@ bool withFailover trustServerCert = false; } - byte[] instanceName = null; + string instanceName = null; Debug.Assert(_connHandler != null, "SqlConnectionInternalTds handler can not be null at this point."); _connHandler.TimeoutErrorInternal.EndPhase(SqlConnectionTimeoutErrorPhase.PreLoginBegin); @@ -493,6 +493,7 @@ bool withFailover } _connHandler.pendingSQLDNSObject = null; + _connHandler.InstanceName = null; // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server _physicalStateObj.CreatePhysicalSNIHandle( @@ -572,7 +573,7 @@ bool withFailover // UNDONE - send "" for instance now, need to fix later SqlClientEventSource.Log.TryTraceEvent(" Sending prelogin handshake"); - SendPreLoginHandshake(instanceName, encrypt, integratedSecurity, serverCertificateFilename); + SendPreLoginHandshake(encrypt, integratedSecurity, serverCertificateFilename); _connHandler.TimeoutErrorInternal.EndPhase(SqlConnectionTimeoutErrorPhase.SendPreLoginHandshake); _connHandler.TimeoutErrorInternal.SetAndBeginPhase(SqlConnectionTimeoutErrorPhase.ConsumePreLoginHandshake); @@ -632,7 +633,7 @@ bool withFailover _physicalStateObj.AssignPendingDNSInfo(serverInfo.UserProtocol, FQDNforDNSCache, ref _connHandler.pendingSQLDNSObject); } - SendPreLoginHandshake(instanceName, encrypt, integratedSecurity, serverCertificateFilename); + SendPreLoginHandshake(encrypt, integratedSecurity, serverCertificateFilename); status = ConsumePreLoginHandshake( encrypt, trustServerCert, @@ -652,6 +653,7 @@ bool withFailover } SqlClientEventSource.Log.TryTraceEvent(" Prelogin handshake successful"); + _connHandler.InstanceName = instanceName; if (_authenticationProvider is { }) { _authenticationProvider.Initialize(serverInfo, _physicalStateObj, this, resolvedServerSpn.Primary, resolvedServerSpn.Secondary); @@ -776,7 +778,6 @@ internal void PutSession(TdsParserStateObject session) } private void SendPreLoginHandshake( - byte[] instanceName, SqlConnectionEncryptOption encrypt, bool integratedSecurity, string serverCertificateFilename) @@ -871,21 +872,12 @@ private void SendPreLoginHandshake( break; case (int)PreLoginOptions.INSTANCE: - int i = 0; + // Always send an empty null-terminated string + payload[payloadLength] = 0; - while (instanceName[i] != 0) - { - payload[payloadLength] = instanceName[i]; - payloadLength++; - i++; - } - - payload[payloadLength] = 0; // null terminate payloadLength++; - i++; - - offset += i; - optionDataSize = i; + offset += 1; + optionDataSize = 1; break; case (int)PreLoginOptions.THREADID: @@ -4411,7 +4403,7 @@ private TdsOperationStatus TryProcessLoginAck(TdsParserStateObject stateObj, out // Fail if SSE UserInstance and we have not received this info. if (_connHandler.ConnectionOptions.UserInstance && - string.IsNullOrEmpty(_connHandler.InstanceName)) + string.IsNullOrEmpty(_connHandler.UserInstanceName)) { stateObj.AddError(new SqlError(0, 0, TdsEnums.FATAL_ERROR_CLASS, Server, SQLMessage.UserInstanceFailure(), "", 0)); ThrowExceptionAndWarning(stateObj); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.windows.cs index ed55a76fbb..77db5516b0 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.windows.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.windows.cs @@ -147,7 +147,7 @@ internal SNIHandle( string serverName, ref string spn, int timeout, - out byte[] instanceName, + out string instanceName, bool flushCache, bool fSync, bool fParallel, @@ -161,7 +161,7 @@ internal SNIHandle( : base(IntPtr.Zero, true) { _fSync = fSync; - instanceName = new byte[256]; // Size as specified by netlibs. + instanceName = null; // Option ignoreSniOpenTimeout is no longer available //if (ignoreSniOpenTimeout) //{ @@ -178,7 +178,7 @@ internal SNIHandle( serverName, ref base.handle, ref spn, - instanceName, + ref instanceName, flushCache, fSync, timeout, @@ -194,7 +194,7 @@ internal SNIHandle( serverName, ref base.handle, ref spn, - instanceName, + ref instanceName, flushCache, fSync, timeout, diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index b7fa5c865b..baf48ff255 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -557,7 +557,7 @@ internal long TimeoutTime internal abstract void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, - out byte[] instanceName, + out string instanceName, out ResolvedServerSpn resolvedSpn, bool flushCache, bool async, diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.netcore.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.netcore.cs index c3667fda1c..52abbdcbea 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.netcore.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.netcore.cs @@ -83,7 +83,7 @@ protected override uint SniPacketGetData(PacketHandle packet, byte[] inBuff, ref internal override void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, - out byte[] instanceName, + out string instanceName, out ResolvedServerSpn resolvedSpn, bool flushCache, bool async, diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.windows.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.windows.cs index 7812860b5e..b083718b2f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.windows.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.windows.cs @@ -140,7 +140,7 @@ private ConsumerInfo CreateConsumerInfo(bool async) internal override void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, - out byte[] instanceName, + out string instanceName, out ManagedSni.ResolvedServerSpn resolvedSpn, bool flushCache, bool async, diff --git a/src/Microsoft.Data.SqlClient/src/System/Text/EncodingExtensions.cs b/src/Microsoft.Data.SqlClient/src/System/Text/EncodingExtensions.cs new file mode 100644 index 0000000000..8b8a6813cd --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/System/Text/EncodingExtensions.cs @@ -0,0 +1,131 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics; + +#nullable enable + +namespace System.Text; + +internal static class EncodingExtensions +{ + private const byte NullByte = 0x00; + private static ReadOnlySpan MultiByteNull => [NullByte, NullByte]; + + /// + /// Creates a new string from a null-terminated sequence of bytes, decoding them using the specified encoding. + /// + /// The encoding to use to decode the bytes. + /// The null-terminated sequence of bytes to decode. + /// The decoded string. + public static unsafe string CreateStringFromNullTerminated(this Encoding encoding, ReadOnlySpan nullTerminatedBytes) + { + int preNullBytes = nullTerminatedBytes.IndexOf(MultiByteNull); + + // If the sequence starts with a null terminator, avoid allocating a new zero-length string. + if (preNullBytes == 0 || nullTerminatedBytes.Length == 0) + { + return string.Empty; + } + + // IndexOf has searched for a multi-byte null terminator. This will work in most circumstances, assuming that + // every value after the end of the string is zeroed out. + if (preNullBytes == -1) + { + // If the byte sequence is [NullByte], we're in the same position as before. Return an empty string. + if (nullTerminatedBytes.Length == 1 + && nullTerminatedBytes[0] == NullByte) + { + return string.Empty; + } + // If the byte sequence is only long enough to contain an encoded string followed by a single null byte, + // adjust the null index to account for the false positive. + else if (nullTerminatedBytes.Length > 1 + && nullTerminatedBytes[nullTerminatedBytes.Length - 1] == NullByte) + { + preNullBytes = nullTerminatedBytes.Length - 1; + } + // Otherwise, there is no null terminator. Use the entire byte array. + else + { + preNullBytes = nullTerminatedBytes.Length; + } + } + // If we work with unicode encodings and strings containing nothing but ASCII characters, every other byte will + // be a null byte. In such a case, the last byte of the string will be null. This means that preNullBytes will be + // one byte too long. Adjust to account for that. + else if (encoding is UnicodeEncoding) + { + if (preNullBytes % 2 != 0) + { + // If we have a match, we already know that it'll be less than or equal to [array size - search string length]. + Debug.Assert(preNullBytes + 1 <= nullTerminatedBytes.Length); + + preNullBytes++; + } + } + + fixed (byte* pBytes = nullTerminatedBytes) + { + return encoding.GetString(pBytes, preNullBytes); + } + } + + #if NETFRAMEWORK + public static int GetByteCount(this Encoding encoding, string? s, int offset, int count) + { + if (s is null) + { + throw new ArgumentNullException(nameof(s)); + } + + ReadOnlySpan slicedString = s.AsSpan(offset, count); + + if (slicedString.Length == 0) + { + return 0; + } + + unsafe + { + fixed (char* str = slicedString) + { + return encoding.GetByteCount(str, slicedString.Length); + } + } + } + + public static byte[] GetBytes(this Encoding encoding, string? s, int index, int count) + { + if (s is null) + { + throw new ArgumentNullException(nameof(s)); + } + + ReadOnlySpan slicedString = s.AsSpan(index, count); + + if (slicedString.Length == 0) + { + return Array.Empty(); + } + + unsafe + { + fixed (char* str = slicedString) + { + int byteCount = encoding.GetByteCount(str, slicedString.Length); + byte[] bytes = new byte[byteCount]; + + fixed (byte* destArray = &bytes[0]) + { + int bytesWritten = encoding.GetBytes(str, slicedString.Length, destArray, bytes.Length); + + Debug.Assert(bytesWritten == byteCount); + return bytes; + } + } + } + } + #endif +} diff --git a/src/Microsoft.Data.SqlClient/src/System/Text/EncodingExtensions.netfx.cs b/src/Microsoft.Data.SqlClient/src/System/Text/EncodingExtensions.netfx.cs deleted file mode 100644 index baf2a275e5..0000000000 --- a/src/Microsoft.Data.SqlClient/src/System/Text/EncodingExtensions.netfx.cs +++ /dev/null @@ -1,71 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -#if NETFRAMEWORK - -using System.Diagnostics; - -#nullable enable - -namespace System.Text; - -internal static class EncodingExtensions -{ - public static int GetByteCount(this Encoding encoding, string? s, int offset, int count) - { - if (s is null) - { - throw new ArgumentNullException(nameof(s)); - } - - ReadOnlySpan slicedString = s.AsSpan(offset, count); - - if (slicedString.Length == 0) - { - return 0; - } - - unsafe - { - fixed (char* str = slicedString) - { - return encoding.GetByteCount(str, slicedString.Length); - } - } - } - - public static byte[] GetBytes(this Encoding encoding, string? s, int index, int count) - { - if (s is null) - { - throw new ArgumentNullException(nameof(s)); - } - - ReadOnlySpan slicedString = s.AsSpan(index, count); - - if (slicedString.Length == 0) - { - return Array.Empty(); - } - - unsafe - { - fixed (char* str = slicedString) - { - int byteCount = encoding.GetByteCount(str, slicedString.Length); - byte[] bytes = new byte[byteCount]; - - fixed (byte* destArray = &bytes[0]) - { - int bytesWritten = encoding.GetBytes(str, slicedString.Length, destArray, bytes.Length); - - Debug.Assert(bytesWritten == byteCount); - return bytes; - } - } - } - } -} - -#endif