diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 7344d2bf..bbe040a9 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -172,12 +172,12 @@ jobs: # Split into parts IFS='.' read -r -a VERSION_PARTS <<< "$CURRENT_VERSION" MAJOR=${VERSION_PARTS[0]} - MINOR=${VERSION_PARTS[1]} + MINOR=0 PATCH=${VERSION_PARTS[2]} # Increment patch version - NEW_MINOR=$((MINOR + 1)) - NEW_VERSION="v${MAJOR}.${NEW_MINOR}.${PATCH}" + NEW_MAJOR=$((MAJOR + 1)) + NEW_VERSION="v${NEW_MAJOR}.${MINOR}.${PATCH}" fi # Set the full tag name diff --git a/common/CHANGELOG.md b/common/CHANGELOG.md index 33655299..8c1f6038 100644 --- a/common/CHANGELOG.md +++ b/common/CHANGELOG.md @@ -1,5 +1,12 @@ ### Changelog +## 2.0.0 - 2026-02-11 + +### Changed (2) + +- Updated `WebSocketCommon` connect method to accept streams parameter for subscribing upon connection. +- Updated `retry` logic in `utils.go` to use `SleepContext` for better handling of context cancellation during retries. + ## 1.2.0 - 2026-01-23 ### Added (1) diff --git a/common/common/configuration.go b/common/common/configuration.go index 527017f8..37cc24aa 100644 --- a/common/common/configuration.go +++ b/common/common/configuration.go @@ -111,7 +111,7 @@ type ConfigurationWebsocketStreamsOption func(*ConfigurationWebsocketStreams) // @return A pointer to the newly created ConfigurationRestAPI. func NewConfigurationRestAPI(opts ...ConfigurationRestAPIOption) *ConfigurationRestAPI { basePath := "https://api.binance.com" - timeout := 5 * time.Millisecond + timeout := 5000 * time.Millisecond retries := 3 backoff := 1000 keepAlive := true diff --git a/common/common/utils.go b/common/common/utils.go index 22258e95..067e1e1d 100644 --- a/common/common/utils.go +++ b/common/common/utils.go @@ -408,6 +408,38 @@ func ParseRateLimitHeaders(header http.Header) ([]RateLimit, error) { return rateLimits, nil } +// SleepContext pauses the execution for the specified duration or until the context is done. +// +// @param ctx The context to observe for cancellation. +// @param duration The duration to sleep. +// @return An error if the context is done before the duration elapses. +func SleepContext(ctx context.Context, duration time.Duration) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if duration <= 0 { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } + } + + timer := time.NewTimer(duration) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + // SendRequest sends an HTTP request and handles retries, response decoding, and error handling. // // @param ctx The context for the request. @@ -439,7 +471,7 @@ func SendRequest[T any](ctx context.Context, path string, method string, queryPa backoff := cfg.Backoff if backoff <= 0 { - backoff = 1 + backoff = 1000 } var lastErr error @@ -449,7 +481,9 @@ func SendRequest[T any](ctx context.Context, path string, method string, queryPa if err != nil { lastErr = err if attempt < retries && ShouldRetryRequest(err, method, retries-attempt, resp) { - time.Sleep(time.Duration(backoff*attempt) * time.Second) + if err := SleepContext(ctx, time.Duration(backoff*(attempt+1))*time.Millisecond); err != nil { + return &RestApiResponse[T]{}, err + } continue } return &RestApiResponse[T]{}, NewNetworkError(fmt.Sprintf("Network error: %v", err)) @@ -490,7 +524,9 @@ func SendRequest[T any](ctx context.Context, path string, method string, queryPa if resp.StatusCode >= 500 && resp.StatusCode <= 504 { if attempt < retries { - time.Sleep(time.Duration(backoff*attempt) * time.Second) + if err := SleepContext(ctx, time.Duration(backoff*(attempt+1))*time.Millisecond); err != nil { + return &RestApiResponse[T]{}, err + } continue } return &RestApiResponse[T]{}, fmt.Errorf("request failed after %d retries: received status %d", retries, resp.StatusCode) @@ -703,11 +739,11 @@ func SetupProxy(cfg *ConfigurationRestAPI) *http.Client { transport := BuildTransport(cfg.HTTPSAgent, cfg) return &http.Client{ Transport: transport, - Timeout: time.Duration(cfg.Timeout) * time.Millisecond, + Timeout: cfg.Timeout, } } return &http.Client{ - Timeout: time.Duration(cfg.Timeout) * time.Millisecond, + Timeout: cfg.Timeout, } } diff --git a/common/common/websocket.go b/common/common/websocket.go index 756cebe9..b13de884 100644 --- a/common/common/websocket.go +++ b/common/common/websocket.go @@ -362,30 +362,44 @@ func (w *WebSocketCommon) initializePool() { // // @param config The WebSocketConfig containing configuration details. // @param userAgent The user agent string to be used for the connection. +// @param streams A slice of stream names to subscribe to upon connection. // @return An error if the connection fails, otherwise nil. -func (w *WebSocketCommon) Connect(config WebSocketConfig, userAgent string) error { +func (w *WebSocketCommon) Connect(config WebSocketConfig, userAgent string, streams []string) error { if err := w.setupProxyDialer(config); err != nil { return fmt.Errorf("proxy setup failed: %v", err) } - BasePath := w.prepareBasePath(config) headers := w.prepareHeaders(config, userAgent) dialer := w.CreateWebSocketDialer(config) if w.Mode == SINGLE { + BasePath := w.prepareBasePath(config, streams, true) return w.connectSingleMode(BasePath, headers, dialer, config, userAgent) } - return w.connectPoolMode(BasePath, headers, dialer, config, userAgent) + return w.connectPoolMode(headers, dialer, config, userAgent, streams) } // prepareBasePath constructs the base path for the WebSocket connection. // // @param config The WebSocketConfig containing configuration details. +// @param streams The list of streams to include in the base path. +// @param includeStreams A boolean indicating whether to include streams in the base path. // @return The constructed base path string. -func (w *WebSocketCommon) prepareBasePath(config WebSocketConfig) string { +func (w *WebSocketCommon) prepareBasePath(config WebSocketConfig, streams []string, includeStreams bool) string { BasePath := config.GetBasePath() + if includeStreams && streams != nil && len(streams) > 0 { + BasePath += "?streams=" + for _, stream := range streams { + BasePath += stream + "/" + } + BasePath = strings.TrimSuffix(BasePath, "/") + } if timeUnit := config.GetTimeUnit(); timeUnit != "" { - BasePath = BasePath + "?timeUnit=" + string(timeUnit) + if streams != nil && len(streams) > 0 { + BasePath += "&timeUnit=" + string(timeUnit) + } else { + BasePath += "?timeUnit=" + string(timeUnit) + } } return BasePath } @@ -485,13 +499,13 @@ func (w *WebSocketCommon) connectSingleMode(BasePath string, headers http.Header // connectPoolMode establishes WebSocket connections in pool mode. // -// @param BasePath The base URL for the WebSocket connection. // @param headers The HTTP headers to include in the connection request. // @param dialer The WebSocket dialer to use for the connection. // @param config The WebSocketConfig containing configuration details. // @param userAgent The user agent string to use for the connection. +// @param streams A slice of stream names to subscribe to upon connection. // @return An error if the connection fails, otherwise nil. -func (w *WebSocketCommon) connectPoolMode(BasePath string, headers http.Header, dialer websocket.Dialer, config WebSocketConfig, userAgent string) error { +func (w *WebSocketCommon) connectPoolMode(headers http.Header, dialer websocket.Dialer, config WebSocketConfig, userAgent string, streams []string) error { var wg sync.WaitGroup successChan := make(chan bool, len(w.Connections)) errChan := make(chan error, len(w.Connections)) @@ -502,6 +516,7 @@ func (w *WebSocketCommon) connectPoolMode(BasePath string, headers http.Header, go func(num int, conn *WebSocketConnection) { defer wg.Done() + BasePath := w.prepareBasePath(config, streams, num == 0) wsConn, _, err := dialer.Dial(BasePath, headers) if err != nil { log.Printf("WebSocket connection error: %v", err) @@ -601,7 +616,7 @@ func (w *WebSocketCommon) KeepAlive(connection *WebSocketConnection, config WebS // @param userAgent The user agent string to use for the connection. // @return An error if the reconnection fails, otherwise nil. func (w *WebSocketCommon) reconnect(conn *WebSocketConnection, config WebSocketConfig, userAgent string) error { - BasePath := w.prepareBasePath(config) + BasePath := w.prepareBasePath(config, conn.StreamConnectionMap, conn.SessionLogonRequest == nil) headers := w.prepareHeaders(config, userAgent) dialer := w.CreateWebSocketDialer(config) @@ -617,8 +632,6 @@ func (w *WebSocketCommon) reconnect(conn *WebSocketConnection, config WebSocketC w.restoreSessionIfNeeded(conn) time.Sleep(1 * time.Second) w.resubscribeUserDataStreams(conn) - } else if len(conn.StreamConnectionMap) > 0 { - w.resubscribeRegularStreams(conn) } return nil @@ -713,34 +726,6 @@ func (w *WebSocketCommon) resubscribeUserDataStreams(connection *WebSocketConnec } } -// resubscribeRegularStreams resubscribes to all regular streams after reconnection. -// -// @param connection The WebSocketConnection to resubscribe streams on. -func (w *WebSocketCommon) resubscribeRegularStreams(connection *WebSocketConnection) { - for _, stream := range connection.StreamConnectionMap { - subscribePayload := map[string]interface{}{ - "method": "SUBSCRIBE", - "params": []string{stream}, - "id": GenerateUUID(), - } - - message, err := json.Marshal(subscribePayload) - if err != nil { - log.Printf("Error during resubscription to stream %s: %v", stream, err) - continue - } - - connection.mu.Lock() - err = connection.Websocket.WriteMessage(websocket.TextMessage, message) - connection.mu.Unlock() - if err != nil { - log.Printf("Error sending resubscription message for stream %s: %v", stream, err) - continue - } - log.Printf("Resubscribed to stream %s on reconnection", stream) - } -} - // isConnectionReady checks if the WebSocket connection is open. // // @param connection The WebSocketConnection to check. @@ -867,7 +852,7 @@ func NewWebsocketAPI(cfg *ConfigurationWebsocketApi) (*WebsocketAPI, error) { // @param userAgent The user agent string to be used for the connection. // @return An error if the connection fails, otherwise nil. func (w *WebsocketAPI) Connect(userAgent string) error { - return w.WsCommon.Connect(w.Cfg, userAgent) + return w.WsCommon.Connect(w.Cfg, userAgent, []string{}) } // SendMessage sends a message over the WebSocket connection and returns channels for the response and error. @@ -1121,8 +1106,25 @@ func NewWebsocketStreams(cfg *ConfigurationWebsocketStreams) (*WebsocketStreams, }, nil } -func (w *WebsocketStreams) Connect(userAgent string) error { - return w.WsCommon.Connect(w.Cfg, userAgent) +// Connect establishes the WebSocket connection using the provided user agent and streams. +// +// @param userAgent The user agent string to be used for the connection. +// @param streams A slice of stream names to subscribe to upon connection. +// @return An error if the connection fails, otherwise nil. +func (w *WebsocketStreams) Connect(userAgent string, streams []string) error { + err := w.WsCommon.Connect(w.Cfg, userAgent, streams) + if err != nil { + fmt.Println("WebSocket connection error:", err) + return err + } + if streams != nil && len(streams) > 0 { + conn := w.WsCommon.Connections[0] + for _, stream := range streams { + w.GlobalStreamConnectionMap[stream] = append(w.GlobalStreamConnectionMap[stream], conn) + conn.StreamConnectionMap = append(conn.StreamConnectionMap, stream) + } + } + return nil } // Subscribe subscribes to the specified streams. diff --git a/common/tests/unit/utils_test.go b/common/tests/unit/utils_test.go index cfcf9c52..dcf3f734 100644 --- a/common/tests/unit/utils_test.go +++ b/common/tests/unit/utils_test.go @@ -513,6 +513,92 @@ func TestSendRequest_ContentEncodingGzip(t *testing.T) { } } +func TestSleepContext_CompletesNormally(t *testing.T) { + ctx := context.Background() + duration := 50 * time.Millisecond + + start := time.Now() + err := common.SleepContext(ctx, duration) + elapsed := time.Since(start) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if elapsed < duration { + t.Errorf("Expected sleep to last at least %v, but it lasted %v", duration, elapsed) + } +} + +func TestSleepContext_CanceledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := common.SleepContext(ctx, 1*time.Second) + + if err != context.Canceled { + t.Errorf("Expected context.Canceled error, got %v", err) + } +} + +func TestSleepContext_CanceledDuringSleep(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + start := time.Now() + err := common.SleepContext(ctx, 1*time.Second) + elapsed := time.Since(start) + + if err != context.Canceled { + t.Errorf("Expected context.Canceled error, got %v", err) + } + + if elapsed >= 1*time.Second { + t.Errorf("Expected sleep to be interrupted before 1 second, but it lasted %v", elapsed) + } +} + +func TestSleepContext_ZeroDuration(t *testing.T) { + ctx := context.Background() + + err := common.SleepContext(ctx, 0) + + if err != nil { + t.Errorf("Expected no error with zero duration, got %v", err) + } +} + +func TestSleepContext_NegativeDuration(t *testing.T) { + ctx := context.Background() + + err := common.SleepContext(ctx, -1*time.Second) + + if err != nil { + t.Errorf("Expected no error with negative duration, got %v", err) + } +} + +func TestSleepContext_TimeoutContext(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + start := time.Now() + err := common.SleepContext(ctx, 1*time.Second) + elapsed := time.Since(start) + + if err != context.DeadlineExceeded { + t.Errorf("Expected context.DeadlineExceeded error, got %v", err) + } + + if elapsed >= 1*time.Second { + t.Errorf("Expected sleep to be interrupted by timeout, but it lasted %v", elapsed) + } +} + func TestPrepareRequest_BasicHeadersAndQuery(t *testing.T) { cfg := &common.ConfigurationRestAPI{ ApiKey: "apikey123", diff --git a/common/tests/unit/websocket_test.go b/common/tests/unit/websocket_test.go index dc9b71fd..35823104 100644 --- a/common/tests/unit/websocket_test.go +++ b/common/tests/unit/websocket_test.go @@ -1440,7 +1440,7 @@ func TestWebSocketCommon_Connect(t *testing.T) { }, }) - err := wsc.Connect(config, "test-agent") + err := wsc.Connect(config, "test-agent", []string{"stream1", "stream2"}) if err != nil { t.Fatalf("Connect failed: %v", err) @@ -1462,8 +1462,14 @@ func TestWebSocketCommon_Connect(t *testing.T) { var connectionCount int var mu sync.Mutex + requestURLs := make([]string, 0) s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + requestURLs = append(requestURLs, r.URL.String()) + connectionCount++ + mu.Unlock() + conn, err := upgrader.Upgrade(w, r, nil) if err != nil { t.Logf("WebSocket upgrade failed: %v", err) @@ -1471,10 +1477,6 @@ func TestWebSocketCommon_Connect(t *testing.T) { } defer func() { _ = conn.Close() }() - mu.Lock() - connectionCount++ - mu.Unlock() - for { _, _, err := conn.ReadMessage() if err != nil { @@ -1504,12 +1506,15 @@ func TestWebSocketCommon_Connect(t *testing.T) { config := NewMockWebSocketConfig() config.basePath = u.String() - err := wsc.Connect(config, "test-agent") + streams := []string{"stream1", "stream2"} + err := wsc.Connect(config, "test-agent", streams) if err != nil { t.Fatalf("Connect failed: %v", err) } + time.Sleep(200 * time.Millisecond) + if len(wsc.Connections) != 3 { t.Errorf("Expected 3 connections, got %d", len(wsc.Connections)) } @@ -1527,6 +1532,99 @@ func TestWebSocketCommon_Connect(t *testing.T) { if connectionCount != 3 { t.Errorf("Expected 3 server connections, got %d", connectionCount) } + + hasStreamsCount := 0 + var firstURLWithStreams string + for _, url := range requestURLs { + if strings.Contains(url, "streams=") { + hasStreamsCount++ + if firstURLWithStreams == "" { + firstURLWithStreams = url + } + } + } + + if hasStreamsCount != 1 { + t.Errorf("Expected exactly 1 connection with streams parameter, got %d. URLs: %v", hasStreamsCount, requestURLs) + } + + if firstURLWithStreams != "" { + if !strings.Contains(firstURLWithStreams, "streams=stream1/stream2") { + t.Errorf("Expected streams parameter to be 'streams=stream1/stream2', got: %s", firstURLWithStreams) + } + } + mu.Unlock() + }) + + t.Run("pool mode connection without streams", func(t *testing.T) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + + var connectionCount int + var mu sync.Mutex + requestURLs := make([]string, 0) + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + requestURLs = append(requestURLs, r.URL.String()) + connectionCount++ + mu.Unlock() + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("WebSocket upgrade failed: %v", err) + return + } + defer func() { _ = conn.Close() }() + + for { + _, _, err := conn.ReadMessage() + if err != nil { + return + } + } + })) + defer s.Close() + + u, _ := url.Parse(s.URL) + u.Scheme = "ws" + + wsc, _ := common.NewWebSocketCommon(&common.ConfigurationWrapper{ + APIConfig: &common.ConfigurationWebsocketApi{ + PoolSize: 3, + Mode: common.POOL, + }, + }) + + if len(wsc.Connections) == 0 { + wsc.Connections = make([]*common.WebSocketConnection, 3) + for i := 0; i < 3; i++ { + wsc.Connections[i] = &common.WebSocketConnection{} + } + } + + config := NewMockWebSocketConfig() + config.basePath = u.String() + + err := wsc.Connect(config, "test-agent", []string{}) + + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + + time.Sleep(200 * time.Millisecond) + + mu.Lock() + if connectionCount != 3 { + t.Errorf("Expected 3 server connections, got %d", connectionCount) + } + + for _, url := range requestURLs { + if strings.Contains(url, "streams=") { + t.Errorf("Expected no connection with streams parameter when empty slice provided, but found: %s", url) + } + } mu.Unlock() }) } @@ -2172,12 +2270,110 @@ func TestNewWebsocketStreams(t *testing.T) { } func TestWebsocketStreams_Connect(t *testing.T) { - mockConn := NewMockWebSocketConn() - ws := createTestWebsocketStreams(mockConn) + t.Run("connect without streams", func(t *testing.T) { + mockConn := NewMockWebSocketConn() + ws := createTestWebsocketStreams(mockConn) - err := ws.Connect("test-user-agent") + err := ws.Connect("test-user-agent", []string{}) - assert.NoError(t, err) + assert.NoError(t, err) + assert.Empty(t, ws.GlobalStreamConnectionMap) + if len(ws.WsCommon.Connections) > 0 { + assert.Empty(t, ws.WsCommon.Connections[0].StreamConnectionMap) + } + }) + + t.Run("connect with nil streams", func(t *testing.T) { + mockConn := NewMockWebSocketConn() + ws := createTestWebsocketStreams(mockConn) + + err := ws.Connect("test-user-agent", nil) + + assert.NoError(t, err) + assert.Empty(t, ws.GlobalStreamConnectionMap) + if len(ws.WsCommon.Connections) > 0 { + assert.Empty(t, ws.WsCommon.Connections[0].StreamConnectionMap) + } + }) + + t.Run("connect with single stream", func(t *testing.T) { + mockConn := NewMockWebSocketConn() + ws := createTestWebsocketStreams(mockConn) + + streams := []string{"stream1"} + err := ws.Connect("test-user-agent", streams) + + assert.NoError(t, err) + assert.Len(t, ws.GlobalStreamConnectionMap, 1) + assert.Contains(t, ws.GlobalStreamConnectionMap, "stream1") + assert.Len(t, ws.GlobalStreamConnectionMap["stream1"], 1) + assert.Equal(t, ws.WsCommon.Connections[0], ws.GlobalStreamConnectionMap["stream1"][0]) + + assert.Len(t, ws.WsCommon.Connections[0].StreamConnectionMap, 1) + assert.Contains(t, ws.WsCommon.Connections[0].StreamConnectionMap, "stream1") + }) + + t.Run("connect with multiple streams", func(t *testing.T) { + mockConn := NewMockWebSocketConn() + ws := createTestWebsocketStreams(mockConn) + + streams := []string{"stream1", "stream2", "stream3"} + err := ws.Connect("test-user-agent", streams) + + assert.NoError(t, err) + assert.Len(t, ws.GlobalStreamConnectionMap, 3) + + for _, stream := range streams { + assert.Contains(t, ws.GlobalStreamConnectionMap, stream) + assert.Len(t, ws.GlobalStreamConnectionMap[stream], 1) + assert.Equal(t, ws.WsCommon.Connections[0], ws.GlobalStreamConnectionMap[stream][0]) + } + + assert.Len(t, ws.WsCommon.Connections[0].StreamConnectionMap, 3) + for _, stream := range streams { + assert.Contains(t, ws.WsCommon.Connections[0].StreamConnectionMap, stream) + } + }) + + + t.Run("verify first connection is used for stream mapping", func(t *testing.T) { + mockConn := NewMockWebSocketConn() + ws := createTestWebsocketStreams(mockConn) + + if len(ws.WsCommon.Connections) > 1 { + streams := []string{"stream1", "stream2"} + err := ws.Connect("test-user-agent", streams) + + assert.NoError(t, err) + + firstConn := ws.WsCommon.Connections[0] + assert.Len(t, firstConn.StreamConnectionMap, 2) + + for i := 1; i < len(ws.WsCommon.Connections); i++ { + if ws.WsCommon.Connections[i].StreamConnectionMap != nil { + assert.Empty(t, ws.WsCommon.Connections[i].StreamConnectionMap) + } + } + } + }) + + t.Run("connect with duplicate streams", func(t *testing.T) { + mockConn := NewMockWebSocketConn() + ws := createTestWebsocketStreams(mockConn) + + streams := []string{"stream1", "stream1", "stream2"} + err := ws.Connect("test-user-agent", streams) + + assert.NoError(t, err) + + assert.Contains(t, ws.GlobalStreamConnectionMap, "stream1") + assert.Contains(t, ws.GlobalStreamConnectionMap, "stream2") + + assert.Len(t, ws.GlobalStreamConnectionMap["stream1"], 2) + assert.Len(t, ws.GlobalStreamConnectionMap["stream2"], 1) + + assert.Len(t, ws.WsCommon.Connections[0].StreamConnectionMap, 3) + }) } func TestWebsocketStreams_Subscribe_Success(t *testing.T) {