diff --git a/src/Driver/Client/Websocket/WSClient.cs b/src/Driver/Client/Websocket/WSClient.cs index 993df0c..9c01f01 100644 --- a/src/Driver/Client/Websocket/WSClient.cs +++ b/src/Driver/Client/Websocket/WSClient.cs @@ -1,6 +1,5 @@ using System; using System.Diagnostics; -using System.Net.WebSockets; using TDengine.Driver.Impl.WebSocketMethods; namespace TDengine.Driver.Client.Websocket @@ -10,6 +9,8 @@ public class WSClient : ITDengineClient private Connection _connection; private readonly TimeZoneInfo _tz; private readonly ConnectionStringBuilder _builder; + private readonly object _reconnectLock = new object(); + public WSClient(ConnectionStringBuilder builder) { @@ -66,42 +67,47 @@ private void Reconnect() { if (!_builder.AutoReconnect) return; - - Connection connection = null; - for (int i = 0; i < _builder.ReconnectRetryCount; i++) + lock (_reconnectLock) { - try - { - // sleep - System.Threading.Thread.Sleep(_builder.ReconnectIntervalMs); - connection = new Connection(GetUrl(_builder), _builder.Username, _builder.Password, - _builder.Database, _builder.ConnTimeout, _builder.ReadTimeout, _builder.WriteTimeout, - _builder.EnableCompression); - connection.Connect(); - break; - } - catch (Exception) + if (_connection != null && _connection.IsAvailable()) // connection is available, no need to reconnect + return; + + Connection connection = null; + for (int i = 0; i < _builder.ReconnectRetryCount; i++) { - if (connection != null) + try { - connection.Close(); - connection = null; + // sleep + System.Threading.Thread.Sleep(_builder.ReconnectIntervalMs); + connection = new Connection(GetUrl(_builder), _builder.Username, _builder.Password, + _builder.Database, _builder.ConnTimeout, _builder.ReadTimeout, _builder.WriteTimeout, + _builder.EnableCompression); + connection.Connect(); + break; + } + catch (Exception) + { + if (connection != null) + { + connection.Close(); + connection = null; + } } } - } - if (connection == null) - { - throw new TDengineError((int)TDengineError.InternalErrorCode.WS_RECONNECT_FAILED, - "websocket connection reconnect failed"); - } + if (connection == null) + { + throw new TDengineError((int)TDengineError.InternalErrorCode.WS_RECONNECT_FAILED, + "websocket connection reconnect failed"); + } - if (_connection != null) - { - _connection.Close(); - } + if (_connection != null) + { + _connection.Close(); + } - _connection = connection; + _connection = connection; + } } public IStmt StmtInit() diff --git a/src/Driver/Client/Websocket/WSRows.cs b/src/Driver/Client/Websocket/WSRows.cs index dee8942..3c95bb6 100644 --- a/src/Driver/Client/Websocket/WSRows.cs +++ b/src/Driver/Client/Websocket/WSRows.cs @@ -97,16 +97,18 @@ private List ParseMetas(WSStmtUseResultResp result) public void Dispose() { - if (_freed) - { - return; - } + if (_freed) return; _freed = true; - if (_connection != null && _connection.IsAvailable()) + if (_connection == null || !_connection.IsAvailable()) return; + try { _connection.FreeResult(_resultId); } + catch (Exception) + { + // ignored + } } public long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) diff --git a/src/Driver/Client/Websocket/WSStmt.cs b/src/Driver/Client/Websocket/WSStmt.cs index 6552c16..af3df65 100644 --- a/src/Driver/Client/Websocket/WSStmt.cs +++ b/src/Driver/Client/Websocket/WSStmt.cs @@ -5,11 +5,11 @@ namespace TDengine.Driver.Client.Websocket { public class WSStmt : IStmt { - private ulong _stmt; + private readonly ulong _stmt; private readonly TimeZoneInfo _tz; - private Connection _connection; - private bool closed; - private long lastAffected; + private readonly Connection _connection; + private bool _closed; + private long _lastAffected; private bool _isInsert; public WSStmt(ulong stmt, TimeZoneInfo tz, Connection connection) @@ -22,13 +22,18 @@ public WSStmt(ulong stmt, TimeZoneInfo tz, Connection connection) public void Dispose() { - if (closed) + if (_closed) return; + + _closed = true; + if (_connection == null || !_connection.IsAvailable()) return; + try { - return; + _connection.StmtClose(_stmt); + } + catch (Exception) + { + // ignored } - - _connection.StmtClose(_stmt); - closed = true; } public void Prepare(string query) @@ -219,12 +224,12 @@ public void AddBatch() public void Exec() { var resp = _connection.StmtExec(_stmt); - lastAffected = resp.Affected; + _lastAffected = resp.Affected; } public long Affected() { - return lastAffected; + return _lastAffected; } public IRows Result() @@ -233,6 +238,7 @@ public IRows Result() { return new WSRows((int)Affected()); } + var resp = _connection.StmtUseResult(_stmt); return new WSRows(resp, _connection, _tz); } diff --git a/src/Driver/Impl/WebSocketMethods/BaseConnection.cs b/src/Driver/Impl/WebSocketMethods/BaseConnection.cs index 68492f4..53545b6 100644 --- a/src/Driver/Impl/WebSocketMethods/BaseConnection.cs +++ b/src/Driver/Impl/WebSocketMethods/BaseConnection.cs @@ -1,5 +1,8 @@ using System; +using System.Collections.Concurrent; +using System.Collections.Generic; using System.IO; +using System.Linq; using System.Net.WebSockets; using System.Text; using System.Threading; @@ -9,6 +12,20 @@ namespace TDengine.Driver.Impl.WebSocketMethods { + internal class WsMessage + { + internal readonly byte[] Message; + internal readonly WebSocketMessageType MessageType; + internal readonly Exception Exception; + + internal WsMessage(byte[] message, WebSocketMessageType messageType, Exception exception) + { + Message = message; + MessageType = messageType; + Exception = exception; + } + } + public class BaseConnection { private readonly ClientWebSocket _client; @@ -16,11 +33,34 @@ public class BaseConnection private readonly TimeSpan _readTimeout; private readonly TimeSpan _writeTimeout; - private ulong _reqId; private readonly TimeSpan _defaultConnTimeout = TimeSpan.FromMinutes(1); private readonly TimeSpan _defaultReadTimeout = TimeSpan.FromMinutes(5); private readonly TimeSpan _defaultWriteTimeout = TimeSpan.FromSeconds(10); + private readonly ConcurrentDictionary> _pendingRequests = + new ConcurrentDictionary>(); + + private readonly SemaphoreSlim _sendSemaphore = new SemaphoreSlim(1, 1); + + private bool _exit = false; + private readonly ReaderWriterLockSlim _exitLock = new ReaderWriterLockSlim(); + + private bool IsExit + { + get + { + _exitLock.EnterReadLock(); + try + { + return _exit; + } + finally + { + _exitLock.ExitReadLock(); + } + } + } + protected BaseConnection(string addr, TimeSpan connectTimeout = default, TimeSpan readTimeout = default, TimeSpan writeTimeout = default, bool enableCompression = false) { @@ -67,12 +107,13 @@ protected BaseConnection(string addr, TimeSpan connectTimeout = default, throw new TDengineError((int)TDengineError.InternalErrorCode.WS_CONNEC_FAILED, $"connect to {addr} fail"); } + + Task.Run(async () => { await ReceiveLoop().ConfigureAwait(false); }); } - protected ulong _GetReqId() + protected static ulong _GetReqId() { - _reqId += 1; - return _reqId; + return (ulong)ReqId.GetReqId(); } @@ -88,7 +129,7 @@ protected static void WriteUInt64ToBytes(byte[] byteArray, ulong value, int offs byteArray[offset + 7] = (byte)(value >> 56); } - protected static void WriteUInt32ToBytes(byte[] byteArray, UInt32 value, int offset) + protected static void WriteUInt32ToBytes(byte[] byteArray, uint value, int offset) { byteArray[offset + 0] = (byte)value; byteArray[offset + 1] = (byte)(value >> 8); @@ -96,30 +137,124 @@ protected static void WriteUInt32ToBytes(byte[] byteArray, UInt32 value, int off byteArray[offset + 3] = (byte)(value >> 24); } - protected static void WriteUInt16ToBytes(byte[] byteArray, UInt16 value, int offset) + protected static void WriteUInt16ToBytes(byte[] byteArray, ushort value, int offset) { byteArray[offset + 0] = (byte)value; byteArray[offset + 1] = (byte)(value >> 8); } - protected byte[] SendBinaryBackBytes(byte[] request) + + protected byte[] SendBinaryBackBytes(byte[] request, ulong reqId) + { + var task = Task.Run(async () => await AsyncSendBinaryBackByte(request, reqId).ConfigureAwait(false)); + WaitAndThrowOriginalException(task); + return task.Result; + } + + private static void WaitAndThrowOriginalException(Task task) { - SendBinary(request); - var respBytes = Receive(out var messageType); + try + { + task.Wait(); + } + catch (AggregateException ex) + { + var firstException = ex.Flatten().InnerExceptions.First(); + throw firstException; + } + } + + private async Task AsyncSendBinaryBackByte(byte[] request, ulong reqId) + { + var tcs = AddTask(reqId); + // send request + try + { + await AsyncSendBinary(request).ConfigureAwait(false); + } + catch (Exception) + { + _pendingRequests.TryRemove(reqId, out _); + throw; + } + + await WaitForResponseWithTimeout(reqId, tcs).ConfigureAwait(false); + + // get response + var responseMessage = await tcs.Task.ConfigureAwait(false); + if (responseMessage.Exception != null) throw responseMessage.Exception; + + var respBytes = responseMessage.Message; + var messageType = responseMessage.MessageType; if (messageType == WebSocketMessageType.Binary) { return respBytes; } - var resp = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(respBytes)); + WSBaseResp resp; + try + { + resp = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(respBytes)); + } + catch (Exception e) + { + throw new TDengineError((int)TDengineError.InternalErrorCode.WS_UNEXPECTED_MESSAGE, + "receive unexpected message", e.Message); + } + throw new TDengineError(resp.Code, resp.Message, request, Encoding.UTF8.GetString(respBytes)); } + private async Task WaitForResponseWithTimeout(ulong reqId, TaskCompletionSource tcs) + { + using (var cts = new CancellationTokenSource()) + { + // wait for timeout + var timeoutTask = Task.Delay(_readTimeout, cts.Token); + // wait for response + var completedTask = await Task.WhenAny(tcs.Task, timeoutTask).ConfigureAwait(false); + // timeout + if (completedTask == timeoutTask) + { + if (_pendingRequests.TryRemove(reqId, out var removedTcs)) removedTcs.TrySetCanceled(); + throw new TimeoutException($"Request timed out. reqId: 0x{reqId:x}"); + } + + cts.Cancel(); + } + } + + protected T SendBinaryBackJson(byte[] request, ulong reqId) where T : IWSBaseResp + { + var task = Task.Run(async () => await AsyncSendBinaryBackJson(request, reqId).ConfigureAwait(false)); + WaitAndThrowOriginalException(task); + return task.Result; + } - protected T SendBinaryBackJson(byte[] request) where T : IWSBaseResp + private async Task AsyncSendBinaryBackJson(byte[] request, ulong reqId) where T : IWSBaseResp { - SendBinary(request); - var respBytes = Receive(out var messageType); + var tcs = AddTask(reqId); + // send request + try + { + await AsyncSendBinary(request).ConfigureAwait(false); + } + catch (Exception) + { + _pendingRequests.TryRemove(reqId, out _); + throw; + } + + await WaitForResponseWithTimeout(reqId, tcs).ConfigureAwait(false); + // get response + var responseMessage = await tcs.Task.ConfigureAwait(false); + if (responseMessage.Exception != null) + { + throw responseMessage.Exception; + } + + var respBytes = responseMessage.Message; + var messageType = responseMessage.MessageType; if (messageType != WebSocketMessageType.Text) { throw new TDengineError((int)TDengineError.InternalErrorCode.WS_UNEXPECTED_MESSAGE, @@ -131,18 +266,58 @@ protected T SendBinaryBackJson(byte[] request) where T : IWSBaseResp throw new TDengineError(resp.Code, resp.Message); } - protected T2 SendJsonBackJson(string action, T1 req) where T2 : IWSBaseResp + protected T2 SendJsonBackJson(string action, T1 req, ulong reqId) where T2 : IWSBaseResp + { + var task = Task.Run(async () => + await AsyncSendJsonBackJson(action, req, reqId).ConfigureAwait(false)); + WaitAndThrowOriginalException(task); + return task.Result; + } + + private async Task AsyncSendJsonBackJson(string action, T1 req, ulong reqId) where T2 : IWSBaseResp { - var reqStr = SendJson(action, req); - var respBytes = Receive(out var messageType); + var tcs = AddTask(reqId); + // send request + string reqStr; + try + { + reqStr = await AsyncSendJson(action, req).ConfigureAwait(false); + } + catch (Exception) + { + _pendingRequests.TryRemove(reqId, out _); + throw; + } + + await WaitForResponseWithTimeout(reqId, tcs).ConfigureAwait(false); + + // get response + var responseMessage = await tcs.Task.ConfigureAwait(false); + if (responseMessage.Exception != null) + { + throw responseMessage.Exception; + } + + var respBytes = responseMessage.Message; + var messageType = responseMessage.MessageType; if (messageType != WebSocketMessageType.Text) { throw new TDengineError((int)TDengineError.InternalErrorCode.WS_UNEXPECTED_MESSAGE, "receive unexpected binary message", respBytes, reqStr); } - var resp = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(respBytes)); - // Console.WriteLine(Encoding.UTF8.GetString(respBytes)); + T2 resp; + try + { + resp = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(respBytes)); + } + catch (Exception e) + { + throw new TDengineError((int)TDengineError.InternalErrorCode.WS_UNEXPECTED_MESSAGE, + $"receive unexpected message: {e}", + "req:" + reqStr + ";resp:" + Encoding.UTF8.GetString(respBytes)); + } + if (resp.Action != action) { throw new TDengineError((int)TDengineError.InternalErrorCode.WS_UNEXPECTED_MESSAGE, @@ -154,28 +329,89 @@ protected T2 SendJsonBackJson(string action, T1 req) where T2 : IWSBaseR throw new TDengineError(resp.Code, resp.Message); } - protected byte[] SendJsonBackBytes(string action, T req) + + protected byte[] SendJsonBackBytes(string action, T req, ulong reqId) + { + var task = Task.Run(async () => await AsyncSendJsonBackBytes(action, req, reqId).ConfigureAwait(false)); + WaitAndThrowOriginalException(task); + return task.Result; + } + + private async Task AsyncSendJsonBackBytes(string action, T req, ulong reqId) { - var reqStr = SendJson(action, req); - var respBytes = Receive(out var messageType); + var tcs = AddTask(reqId); + // send request + string reqStr; + try + { + reqStr = await AsyncSendJson(action, req).ConfigureAwait(false); + } + catch (Exception) + { + _pendingRequests.TryRemove(reqId, out _); + throw; + } + + await WaitForResponseWithTimeout(reqId, tcs).ConfigureAwait(false); + + // get response + var responseMessage = await tcs.Task.ConfigureAwait(false); + if (responseMessage.Exception != null) + { + throw responseMessage.Exception; + } + + var respBytes = responseMessage.Message; + var messageType = responseMessage.MessageType; if (messageType == WebSocketMessageType.Binary) { return respBytes; } - var resp = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(respBytes)); - throw new TDengineError(resp.Code, resp.Message, reqStr); + WSBaseResp resp; + try + { + resp = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(respBytes)); + } + catch (Exception) + { + throw new TDengineError((int)TDengineError.InternalErrorCode.WS_UNEXPECTED_MESSAGE, + "receive unexpected message", "req:" + reqStr + ";resp:" + Encoding.UTF8.GetString(respBytes)); + } + + throw new TDengineError(resp.Code, resp.Message, Encoding.UTF8.GetString(respBytes)); } - protected string SendJson(string action, T req) + protected string SendJson(string action, T req, ulong reqId) { - var request = JsonConvert.SerializeObject(new WSActionReq + var task = Task.Run(async () => await AsyncSendJson(action, req).ConfigureAwait(false)); + WaitAndThrowOriginalException(task); + return task.Result; + } + + private TaskCompletionSource AddTask(ulong reqId) + { + _exitLock.EnterReadLock(); + try { - Action = action, - Args = req - }); - SendText(request); - return request; + if (_exit) + { + throw new TDengineError((int)TDengineError.InternalErrorCode.WS_CONNECTION_CLOSED, + "websocket connection is closed"); + } + + var tcs = new TaskCompletionSource(); + if (!_pendingRequests.TryAdd(reqId, tcs)) + { + throw new InvalidOperationException($"Request with reqId '0x{reqId:x}' already exists."); + } + + return tcs; + } + finally + { + _exitLock.ExitReadLock(); + } } private async Task SendAsync(ArraySegment data, WebSocketMessageType messageType) @@ -201,70 +437,162 @@ private async Task SendAsync(ArraySegment data, WebSocketMessageType messa } } - private void SendText(string request) - { - var data = new ArraySegment(Encoding.UTF8.GetBytes(request)); - Task.Run(async () => { await SendAsync(data, WebSocketMessageType.Text).ConfigureAwait(true); }).Wait(); - } - - private void SendBinary(byte[] request) + private async Task AsyncSendJson(string action, T req) { - var data = new ArraySegment(request); - Task.Run(async () => { await SendAsync(data, WebSocketMessageType.Binary).ConfigureAwait(true); }).Wait(); + var request = JsonConvert.SerializeObject(new WSActionReq + { + Action = action, + Args = req + }); + await AsyncSendText(request).ConfigureAwait(false); + return request; } - private byte[] Receive(out WebSocketMessageType messageType) + private async Task AsyncSendText(string request) { - var task = Task.Run(async () => await ReceiveAsync().ConfigureAwait(true)); - task.Wait(); - messageType = task.Result.Item2; - return task.Result.Item1; + await _sendSemaphore.WaitAsync().ConfigureAwait(false); + try + { + var data = new ArraySegment(Encoding.UTF8.GetBytes(request)); + await SendAsync(data, WebSocketMessageType.Text).ConfigureAwait(false); + } + finally + { + _sendSemaphore.Release(); + } } - private async Task> ReceiveAsync() + private async Task AsyncSendBinary(byte[] request) { - if (!IsAvailable()) + await _sendSemaphore.WaitAsync().ConfigureAwait(false); + try { - throw new TDengineError((int)TDengineError.InternalErrorCode.WS_CONNECTION_CLOSED, - "websocket connection is closed"); + var data = new ArraySegment(request); + await SendAsync(data, WebSocketMessageType.Binary).ConfigureAwait(false); } + finally + { + _sendSemaphore.Release(); + } + } - using (var cts = new CancellationTokenSource()) + private async Task ReceiveLoop() + { + Exception exception = null; + try { - cts.CancelAfter(_readTimeout); - using (MemoryStream memoryStream = new MemoryStream()) + var buffer = new byte[1024 * 8]; + while (_client.State == WebSocketState.Open) { - int bufferSize = 1024 * 4; - byte[] buffer = new byte[bufferSize]; WebSocketReceiveResult result; - - do + using (MemoryStream memoryStream = new MemoryStream()) { - result = await _client.ReceiveAsync(new ArraySegment(buffer), cts.Token) - .ConfigureAwait(false); - - if (result.MessageType == WebSocketMessageType.Close) + do { - await _client - .CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None) + result = await _client.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None) .ConfigureAwait(false); - throw new TDengineError((int)TDengineError.InternalErrorCode.WS_RECEIVE_CLOSE_FRAME, - "receive websocket close frame"); + if (result.MessageType == WebSocketMessageType.Close) + { + throw new TDengineError((int)TDengineError.InternalErrorCode.WS_RECEIVE_CLOSE_FRAME, + "receive websocket close frame"); + } + + memoryStream.Write(buffer, 0, result.Count); + } while (!result.EndOfMessage); + + var bs = memoryStream.ToArray(); + TaskCompletionSource tcs; + switch (result.MessageType) + { + case WebSocketMessageType.Binary: + if (bs.Length < 16) + { + throw new TDengineError((int)TDengineError.InternalErrorCode.WS_UNEXPECTED_MESSAGE, + $"binary message length is less than 16, length:{bs.Length}"); + } + + var flag = BitConverter.ToUInt64(bs, 0); + var reqId = BitConverter.ToUInt64(bs, 8); + // new query response + if (flag == 0xffffffffffffffff) + { + reqId = BitConverter.ToUInt64(bs, 26); + } + + if (_pendingRequests.TryRemove(reqId, out tcs)) + { + tcs.TrySetResult(new WsMessage(bs, result.MessageType, null)); + } + + break; + + case WebSocketMessageType.Text: + WSBaseResp resp; + try + { + resp = JsonConvert.DeserializeObject( + Encoding.UTF8.GetString(bs)); + } + catch (Exception e) + { + throw new TDengineError((int)TDengineError.InternalErrorCode.WS_UNEXPECTED_MESSAGE, + "receive unexpected message", e.Message); + } + + if (_pendingRequests.TryRemove(resp.ReqId, out tcs)) + { + tcs.TrySetResult(new WsMessage(bs, result.MessageType, null)); + } + + break; + default: + throw new TDengineError((int)TDengineError.InternalErrorCode.WS_UNEXPECTED_MESSAGE, + "receive unexpected message type"); } - - memoryStream.Write(buffer, 0, result.Count); - } while (!result.EndOfMessage); - - return Tuple.Create(memoryStream.ToArray(), result.MessageType); + } } } + catch (Exception e) + { + exception = e; + } + finally + { + DoClose(exception); + } } - public void Close() + private void DoClose(Exception e = null) { + _exitLock.EnterWriteLock(); try { - _client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None).Wait(); + if (_exit) return; + _exit = true; + foreach (var kvp in _pendingRequests) + { + if (e != null) + { + kvp.Value.TrySetResult(new WsMessage(null, WebSocketMessageType.Close, e)); + } + else + { + kvp.Value.TrySetCanceled(); + } + } + + _pendingRequests.Clear(); + } + finally + { + _exitLock.ExitWriteLock(); + } + + try + { + Task.Run(() => + _client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None) + .ConfigureAwait(false)).Wait(); } catch (Exception) { @@ -272,8 +600,14 @@ public void Close() } } + public void Close() + { + DoClose(); + } + public bool IsAvailable(Exception e = null) { + if (IsExit) return false; if (_client.State != WebSocketState.Open) return false; @@ -284,14 +618,26 @@ public bool IsAvailable(Exception e = null) case WebSocketException _: return false; case AggregateException ae: - return !(ae.InnerException is WebSocketException); + if (ae.InnerException is WebSocketException) return false; + if (ae.InnerException is TDengineError tInnerException) + { + return TDengineErrorIsConnectionAvailable(tInnerException); + } + + return true; case TDengineError te: - return te.Code != (int)TDengineError.InternalErrorCode.WS_CONNECTION_CLOSED && - te.Code != (int)TDengineError.InternalErrorCode.WS_RECEIVE_CLOSE_FRAME && - te.Code != (int)TDengineError.InternalErrorCode.WS_WRITE_TIMEOUT; + return TDengineErrorIsConnectionAvailable(te); default: return true; } } + + private bool TDengineErrorIsConnectionAvailable(TDengineError te) + { + return te.Code != (int)TDengineError.InternalErrorCode.WS_CONNECTION_CLOSED && + te.Code != (int)TDengineError.InternalErrorCode.WS_RECEIVE_CLOSE_FRAME && + te.Code != (int)TDengineError.InternalErrorCode.WS_UNEXPECTED_MESSAGE && + te.Code != (int)TDengineError.InternalErrorCode.WS_RECONNECT_FAILED; + } } } \ No newline at end of file diff --git a/src/Driver/Impl/WebSocketMethods/Connection.cs b/src/Driver/Impl/WebSocketMethods/Connection.cs index b99e50d..1be0f94 100644 --- a/src/Driver/Impl/WebSocketMethods/Connection.cs +++ b/src/Driver/Impl/WebSocketMethods/Connection.cs @@ -21,18 +21,19 @@ public Connection(string addr, string user, string password, string db, TimeSpan public void Connect() { + var reqId = _GetReqId(); SendJsonBackJson(WSAction.Conn, new WSConnReq { - ReqId = _GetReqId(), + ReqId = reqId, User = _user, Password = _password, Db = _db - }); + },reqId); } - public WSQueryResp BinaryQuery(string sql, ulong reqid = default) + public WSQueryResp BinaryQuery(string sql, ulong reqid = 0) { - if (reqid == default) + if (reqid == 0) { reqid = _GetReqId(); } @@ -52,7 +53,7 @@ public WSQueryResp BinaryQuery(string sql, ulong reqid = default) WriteUInt32ToBytes(req, (uint)src.Length, 26); Buffer.BlockCopy(src, 0, req, 30, src.Length); - return SendBinaryBackJson(req); + return SendBinaryBackJson(req,reqid); } public byte[] FetchRawBlockBinary(ulong resultId) @@ -62,20 +63,22 @@ public byte[] FetchRawBlockBinary(ulong resultId) //p0+16 uint64 action //p0+24 uint16 version var req = new byte[32]; - WriteUInt64ToBytes(req, _GetReqId(), 0); + var reqId = _GetReqId(); + WriteUInt64ToBytes(req, reqId, 0); WriteUInt64ToBytes(req, resultId, 8); WriteUInt64ToBytes(req, WSActionBinary.FetchRawBlockMessage, 16); WriteUInt64ToBytes(req, 1, 24); - return SendBinaryBackBytes(req); + return SendBinaryBackBytes(req,reqId); } public void FreeResult(ulong resultId) { + var reqId = _GetReqId(); SendJson(WSAction.FreeResult, new WSFreeResultReq { - ReqId = _GetReqId(), + ReqId = reqId, ResultId = resultId - }); + },reqId); } } } \ No newline at end of file diff --git a/src/Driver/Impl/WebSocketMethods/Protocol/WSBaseResp.cs b/src/Driver/Impl/WebSocketMethods/Protocol/WSBaseResp.cs index 7cbe345..b4ea120 100644 --- a/src/Driver/Impl/WebSocketMethods/Protocol/WSBaseResp.cs +++ b/src/Driver/Impl/WebSocketMethods/Protocol/WSBaseResp.cs @@ -14,4 +14,17 @@ public interface IWSBaseResp [JsonProperty("timing")] long Timing { get; set; } } + + public class WSBaseResp : IWSBaseResp + { + [JsonProperty("code")] public int Code { get; set; } + + [JsonProperty("message")] public string Message { get; set; } + + [JsonProperty("action")] public string Action { get; set; } + + [JsonProperty("req_id")] public ulong ReqId { get; set; } + + [JsonProperty("timing")] public long Timing { get; set; } + } } \ No newline at end of file diff --git a/src/Driver/Impl/WebSocketMethods/Schemaless.cs b/src/Driver/Impl/WebSocketMethods/Schemaless.cs index 8f298fe..152e325 100644 --- a/src/Driver/Impl/WebSocketMethods/Schemaless.cs +++ b/src/Driver/Impl/WebSocketMethods/Schemaless.cs @@ -8,14 +8,15 @@ public WSSchemalessResp SchemalessInsert(string lines, TDengineSchemalessProtoco TDengineSchemalessPrecision precision, int ttl, long reqId) { + var uReqId = (ulong)reqId; return SendJsonBackJson(WSAction.SchemalessWrite, new WSSchemalessReq { - ReqId = (ulong)reqId, + ReqId = uReqId, Protocol = (int)protocol, Precision = TDengineConstant.SchemalessPrecisionString(precision), TTL = ttl, Data = lines, - }); + },uReqId); } } } \ No newline at end of file diff --git a/src/Driver/Impl/WebSocketMethods/Statement.cs b/src/Driver/Impl/WebSocketMethods/Statement.cs index a373d18..fa0f437 100644 --- a/src/Driver/Impl/WebSocketMethods/Statement.cs +++ b/src/Driver/Impl/WebSocketMethods/Statement.cs @@ -11,27 +11,29 @@ public WSStmtInitResp StmtInit(ulong reqId) return SendJsonBackJson(WSAction.STMTInit, new WSStmtInitReq { ReqId = reqId, - }); + },reqId); } public WSStmtPrepareResp StmtPrepare(ulong stmtId,string sql) { + var reqId = _GetReqId(); return SendJsonBackJson(WSAction.STMTPrepare, new WSStmtPrepareReq { - ReqId = _GetReqId(), + ReqId = reqId, StmtId = stmtId, SQL = sql - }); + },reqId); } public WSStmtSetTableNameResp StmtSetTableName(ulong stmtId,string tablename) { + var reqId = _GetReqId(); return SendJsonBackJson(WSAction.STMTSetTableName, new WSStmtSetTableNameReq { - ReqId = _GetReqId(), + ReqId = reqId, StmtId = stmtId, Name = tablename, - }); + },reqId); } public WSStmtSetTagsResp StmtSetTags(ulong stmtId,TaosFieldE[] fields, object[] tags) @@ -45,7 +47,6 @@ public WSStmtSetTagsResp StmtSetTags(ulong stmtId,TaosFieldE[] fields, object[] { if (tags[i] == null) { - var a = new object[1]{123}; Array newArray = Array.CreateInstance(TDengineConstant.ScanNullableType(fields[i].type), 1); newArray.SetValue(null, 0); param[i] = newArray; @@ -60,11 +61,12 @@ public WSStmtSetTagsResp StmtSetTags(ulong stmtId,TaosFieldE[] fields, object[] var bytes = BlockWriter.Serialize(1, fields, param); var req = new byte[24 +bytes.Length]; - WriteUInt64ToBytes(req, _GetReqId(),0); + var reqId = _GetReqId(); + WriteUInt64ToBytes(req, reqId,0); WriteUInt64ToBytes(req,stmtId,8); WriteUInt64ToBytes(req,WSActionBinary.SetTagsMessage,16); Buffer.BlockCopy(bytes, 0, req, 24, bytes.Length); - return SendBinaryBackJson(req); + return SendBinaryBackJson(req,reqId); } public WSStmtBindResp StmtBind(ulong stmtId,TaosFieldE[] fields, object[] row) @@ -92,11 +94,12 @@ public WSStmtBindResp StmtBind(ulong stmtId,TaosFieldE[] fields, object[] row) var bytes = BlockWriter.Serialize(1, fields, param); var req = new byte[24 +bytes.Length]; - WriteUInt64ToBytes(req, _GetReqId(),0); + var reqId = _GetReqId(); + WriteUInt64ToBytes(req, reqId,0); WriteUInt64ToBytes(req,stmtId,8); WriteUInt64ToBytes(req,WSActionBinary.BindMessage,16); Buffer.BlockCopy(bytes, 0, req, 24, bytes.Length); - return SendBinaryBackJson(req); + return SendBinaryBackJson(req,reqId); } public WSStmtBindResp StmtBind(ulong stmtId,TaosFieldE[] fields, params Array[] param) { @@ -107,64 +110,71 @@ public WSStmtBindResp StmtBind(ulong stmtId,TaosFieldE[] fields, params Array[] var bytes = BlockWriter.Serialize(param[0].Length, fields, param); var req = new byte[24 +bytes.Length]; - WriteUInt64ToBytes(req, _GetReqId(),0); + var reqId = _GetReqId(); + WriteUInt64ToBytes(req, reqId,0); WriteUInt64ToBytes(req,stmtId,8); WriteUInt64ToBytes(req,WSActionBinary.BindMessage,16); Buffer.BlockCopy(bytes, 0, req, 24, bytes.Length); - return SendBinaryBackJson(req); + return SendBinaryBackJson(req,reqId); } public WSStmtAddBatchResp StmtAddBatch(ulong stmtId) { + var reqId = _GetReqId(); return SendJsonBackJson(WSAction.STMTAddBatch, new WSStmtAddBatchReq { - ReqId = _GetReqId(), + ReqId = reqId, StmtId = stmtId - }); + },reqId); } public WSStmtExecResp StmtExec(ulong stmtId) { + var reqId = _GetReqId(); return SendJsonBackJson(WSAction.STMTExec, new WSStmtExecReq { - ReqId = _GetReqId(), + ReqId =reqId, StmtId = stmtId - }); + },reqId); } public WSStmtGetColFieldsResp StmtGetColFields(ulong stmtId) { + var reqId = _GetReqId(); return SendJsonBackJson(WSAction.STMTGetColFields, new WSStmtGetColFieldsReq { - ReqId = _GetReqId(), + ReqId =reqId , StmtId = stmtId - }); + },reqId); } public WSStmtGetTagFieldsResp StmtGetTagFields(ulong stmtId) { + var reqId = _GetReqId(); return SendJsonBackJson(WSAction.STMTGetTagFields, new WSStmtGetTagFieldsReq { - ReqId = _GetReqId(), + ReqId = reqId, StmtId = stmtId - }); + },reqId); } public WSStmtUseResultResp StmtUseResult(ulong stmtId) { + var reqId = _GetReqId(); return SendJsonBackJson(WSAction.STMTUseResult, new WSStmtUseResultReq { - ReqId = _GetReqId(), + ReqId = reqId, StmtId = stmtId - }); + },reqId); } public void StmtClose(ulong stmtId) { + var reqId = _GetReqId(); SendJson(WSAction.STMTClose, new WSStmtCloseReq { - ReqId = _GetReqId(), + ReqId = reqId, StmtId = stmtId - }); + },reqId); } } diff --git a/src/Driver/Impl/WebSocketMethods/TMQ.cs b/src/Driver/Impl/WebSocketMethods/TMQ.cs index e3bb159..a349999 100644 --- a/src/Driver/Impl/WebSocketMethods/TMQ.cs +++ b/src/Driver/Impl/WebSocketMethods/TMQ.cs @@ -66,7 +66,7 @@ public WSTMQSubscribeResp Subscribe(ulong reqId, List topics, TMQOptions WithTableName = options.MsgWithTableName, SessionTimeoutMs = options.SessionTimeoutMs, MaxPollIntervalMs = options.MaxPollIntervalMs - }); + },reqId); } public WSTMQPollResp Poll(long blockingTime) @@ -80,7 +80,7 @@ public WSTMQPollResp Poll(ulong reqId, long blockingTime) { ReqId = reqId, BlockingTime = blockingTime - }); + },reqId); } public byte[] FetchBlock(ulong reqId, ulong messageId) @@ -89,7 +89,7 @@ public byte[] FetchBlock(ulong reqId, ulong messageId) { ReqId = reqId, MessageId = messageId - }); + },reqId); } public byte[] FetchRawBlock(ulong messageId) @@ -103,7 +103,7 @@ public byte[] FetchRawBlock(ulong reqId, ulong messageId) { ReqId = reqId, MessageId = messageId - }); + },reqId); } public WSTMQCommitResp Commit() @@ -116,7 +116,7 @@ public WSTMQCommitResp Commit(ulong reqId) return SendJsonBackJson(WSTMQAction.TMQCommit, new WSTMQCommitReq { ReqId = reqId, - }); + },reqId); } public WSTMQUnsubscribeResp Unsubscribe() @@ -130,7 +130,7 @@ public WSTMQUnsubscribeResp Unsubscribe(ulong reqId) new WSTMQUnsubscribeReq { ReqId = reqId - }); + },reqId); } public WSTMQGetTopicAssignmentResp Assignment(string topic) @@ -145,7 +145,7 @@ public WSTMQGetTopicAssignmentResp Assignment(ulong reqId, string topic) { ReqId = reqId, Topic = topic - }); + },reqId); } public WSTMQOffsetSeekResp Seek(string topic, int vgroupId, long offset) @@ -162,7 +162,7 @@ public WSTMQOffsetSeekResp Seek(ulong reqId, string topic, int vgroupId, long of Topic = topic, VGroupId = vgroupId, Offset = offset - }); + },reqId); } public WSTMQCommitOffsetResp CommitOffset(string topic, int vgroupId, long offset) @@ -179,7 +179,7 @@ public WSTMQCommitOffsetResp CommitOffset(ulong reqId, string topic, int vgroupI Topic = topic, VGroupId = vgroupId, Offset = offset - }); + },reqId); } public WSTMQCommittedResp Committed(List tvIds) @@ -194,7 +194,7 @@ public WSTMQCommittedResp Committed(ulong reqId, List tvIds) { ReqId = reqId, TopicVgroupIds = tvIds, - }); + },reqId); } public WSTMQPositionResp Position(List tvIds) @@ -209,7 +209,7 @@ public WSTMQPositionResp Position(ulong reqId, List tvIds) { ReqId = reqId, TopicVgroupIds = tvIds, - }); + },reqId); } public WSTMQListTopicsResp Subscription() @@ -223,7 +223,7 @@ public WSTMQListTopicsResp Subscription(ulong reqId) new WSTMQListTopicsReq { ReqId = reqId - }); + },reqId); } } diff --git a/src/TMQ/WebSocket/Consumer.cs b/src/TMQ/WebSocket/Consumer.cs index 97f1265..953db04 100644 --- a/src/TMQ/WebSocket/Consumer.cs +++ b/src/TMQ/WebSocket/Consumer.cs @@ -26,6 +26,8 @@ public class Consumer : IConsumer { typeof(Dictionary), DictionaryDeserializer.Dictionary }, }; + private readonly object _reconnectLock = new object(); + public Consumer(ConsumerBuilder builder) { _options = new TMQOptions(builder.Config); @@ -75,42 +77,51 @@ private void Reconnect() { if (!_reconnect) return; - TMQConnection connection = null; - for (int i = 0; i < _reconnectRetryCount; i++) + lock (_reconnectLock) { - try + if (_connection != null) { - System.Threading.Thread.Sleep(_reconnectRetryIntervalMs); - connection = new TMQConnection(_options); - if (_topics != null) - { - connection.Subscribe(_topics, _options); - } - - break; + // connection is available, no need to reconnect + if (_connection.IsAvailable()) return; } - catch (Exception) + + TMQConnection connection = null; + for (int i = 0; i < _reconnectRetryCount; i++) { - if (connection != null) + try { - connection.Close(); - connection = null; + System.Threading.Thread.Sleep(_reconnectRetryIntervalMs); + connection = new TMQConnection(_options); + if (_topics != null) + { + connection.Subscribe(_topics, _options); + } + + break; + } + catch (Exception) + { + if (connection != null) + { + connection.Close(); + connection = null; + } } } - } - if (connection == null) - { - throw new TDengineError((int)TDengineError.InternalErrorCode.WS_RECONNECT_FAILED, - "websocket connection reconnect failed"); - } + if (connection == null) + { + throw new TDengineError((int)TDengineError.InternalErrorCode.WS_RECONNECT_FAILED, + "websocket connection reconnect failed"); + } - if (_connection != null) - { - _connection.Close(); - } + if (_connection != null) + { + _connection.Close(); + } - _connection = connection; + _connection = connection; + } } public ConsumeResult Consume(int millisecondsTimeout) diff --git a/test/Driver.Test/Client/Query/Client.cs b/test/Driver.Test/Client/Query/Client.cs index 9c20565..678181b 100644 --- a/test/Driver.Test/Client/Query/Client.cs +++ b/test/Driver.Test/Client/Query/Client.cs @@ -669,5 +669,74 @@ private void AssertValue(IRows rows, object?[][] data) Assert.Equal(Encoding.UTF8.GetBytes("{\"a\":\"b\"}"), rows.GetValue(data[i].Length)); } } + + private void QueryConcurrencyTest(string connectString, string db) + { + var precision = TDenginePrecision.TSDB_TIME_PRECISION_MILLI; + var builder = new ConnectionStringBuilder(connectString); + var client = DbDriver.Open(builder); + var count = 30; + try + { + client.Exec($"drop database if exists {db}"); + client.Exec($"create database {db} precision '{PrecisionString(precision)}'"); + client.Exec($"use {db}"); + client.Exec("create table t1 (ts timestamp, a int, b float, c binary(10))"); + var ts = new long[count]; + var dateTime = DateTime.Now; + var tsv = new DateTime[count]; + for (int i = 0; i < count; i++) + { + ts[i] = (dateTime.Add(TimeSpan.FromSeconds(i)).ToUniversalTime().Ticks - + TDengineConstant.TimeZero.Ticks) / 10000; + tsv[i] = TDengineConstant.ConvertTimeToDatetime(ts[i], precision); + } + + var valuesStr = ""; + for (int i = 0; i < count; i++) + { + valuesStr += $"({ts[i]}, {i}, {i}, '中文')"; + } + + client.Exec($"insert into t1 values {valuesStr}"); + var tasks = new System.Collections.Generic.List(); + for (var i = 0; i < count; i++) + { + int localI = i; + string query = "select * from t1 where ts = " + ts[localI]; + tasks.Add(System.Threading.Tasks.Task.Run(() => + { + using (var rows = client.Query(query)) + { + Assert.Equal(1, rows.GetOrdinal("a")); + var fieldCount = rows.FieldCount; + Assert.Equal(4, fieldCount); + Assert.Equal("ts", rows.GetName(0)); + Assert.Equal("a", rows.GetName(1)); + Assert.Equal("b", rows.GetName(2)); + Assert.Equal("c", rows.GetName(3)); + var haveNext = rows.Read(); + Assert.True(haveNext); + Assert.Equal(tsv[localI], rows.GetValue(0)); + Assert.Equal(localI, rows.GetValue(1)); + Assert.Equal((float)localI, rows.GetValue(2)); + Assert.Equal(Encoding.UTF8.GetBytes("中文"), rows.GetValue(3)); + } + })); + } + + System.Threading.Tasks.Task.WaitAll(tasks.ToArray()); + } + catch (Exception e) + { + _output.WriteLine(e.ToString()); + throw; + } + finally + { + client.Exec($"drop database if exists {db}"); + client.Dispose(); + } + } } } \ No newline at end of file diff --git a/test/Driver.Test/Client/Query/Native.cs b/test/Driver.Test/Client/Query/Native.cs index 7f60529..285fadc 100644 --- a/test/Driver.Test/Client/Query/Native.cs +++ b/test/Driver.Test/Client/Query/Native.cs @@ -137,5 +137,12 @@ public void NativeSMLJsonTest() var db = "sml_json_test"; this.SMLJsonTest(this._nativeConnectString, db); } + + [Fact] + public void NativeQueryConcurrencyTest() + { + var db = "query_concurrency_test"; + this.QueryConcurrencyTest(this._nativeConnectString, db); + } } } \ No newline at end of file diff --git a/test/Driver.Test/Client/Query/WS.cs b/test/Driver.Test/Client/Query/WS.cs index a208693..214a3ce 100644 --- a/test/Driver.Test/Client/Query/WS.cs +++ b/test/Driver.Test/Client/Query/WS.cs @@ -1,4 +1,7 @@ -using TDengine.Driver; +using System; +using System.Text; +using TDengine.Driver; +using TDengine.Driver.Client; using Xunit; namespace Driver.Test.Client.Query @@ -137,5 +140,120 @@ public void WebSocketSMLJsonTest() var db = "ws_sml_json_test"; this.SMLJsonTest(this._wsConnectString, db); } + + [Fact] + public void WebSocketQueryConcurrencyTest() + { + var db = "ws_query_concurrency_test"; + this.QueryConcurrencyTest(this._wsConnectString, db); + } + + [Fact] + public void WebSocketQueryInvalidReqIdTest() + { + var db = "ws_invalid_reqid_test"; + + var precision = TDenginePrecision.TSDB_TIME_PRECISION_MILLI; + var builder = new ConnectionStringBuilder(_wsConnectString); + var client = DbDriver.Open(builder); + var count = 10; + try + { + client.Exec($"drop database if exists {db}"); + client.Exec($"create database {db} precision '{PrecisionString(precision)}'"); + client.Exec($"use {db}"); + client.Exec("create table t1 (ts timestamp, a int, b float, c binary(10))"); + var ts = new long[count]; + var dateTime = DateTime.Now; + var tsv = new DateTime[count]; + for (int i = 0; i < count; i++) + { + ts[i] = (dateTime.Add(TimeSpan.FromSeconds(i)).ToUniversalTime().Ticks - + TDengineConstant.TimeZero.Ticks) / 10000; + tsv[i] = TDengineConstant.ConvertTimeToDatetime(ts[i], precision); + } + + var valuesStr = ""; + for (int i = 0; i < count; i++) + { + valuesStr += $"({ts[i]}, {i}, {i}, '中文')"; + } + + client.Exec($"insert into t1 values {valuesStr}"); + var tasks = new System.Collections.Generic.List(); + long reqid = 0x123456; + bool haveException = false; + for (var i = 0; i < count; i++) + { + int localI = i; + string query = "select * from t1 where ts = " + ts[localI]; + tasks.Add(System.Threading.Tasks.Task.Run(() => + { + try + { + using (var rows = client.Query(query, reqid)) + { + Assert.Equal(1, rows.GetOrdinal("a")); + var fieldCount = rows.FieldCount; + Assert.Equal(4, fieldCount); + Assert.Equal("ts", rows.GetName(0)); + Assert.Equal("a", rows.GetName(1)); + Assert.Equal("b", rows.GetName(2)); + Assert.Equal("c", rows.GetName(3)); + var haveNext = rows.Read(); + Assert.True(haveNext); + Assert.Equal(tsv[localI], rows.GetValue(0)); + Assert.Equal(localI, rows.GetValue(1)); + Assert.Equal((float)localI, rows.GetValue(2)); + Assert.Equal(Encoding.UTF8.GetBytes("中文"), rows.GetValue(3)); + } + } + catch (InvalidOperationException e) + { + Assert.Equal($"Request with reqId '0x{reqid:x}' already exists.", e.Message); + haveException = true; + } + })); + } + + System.Threading.Tasks.Task.WaitAll(tasks.ToArray()); + Assert.True(haveException); + } + catch (Exception e) + { + _output.WriteLine(e.ToString()); + throw; + } + finally + { + client.Exec($"drop database if exists {db}"); + client.Dispose(); + } + } + + [Fact] + public void WebSocketTimeoutTest() + { + var builder = new ConnectionStringBuilder(_wsConnectString); + builder.ReadTimeout = TimeSpan.FromTicks(100); + var timeout = false; + try + { + var client = DbDriver.Open(builder); + } + catch (TimeoutException e) + { + timeout = true; + } + catch (Exception e) + { + _output.WriteLine(e.ToString()); + throw; + } + finally + { + Assert.True(timeout); + } + } } } \ No newline at end of file