Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,12 @@ jobs:
# Split into parts
IFS='.' read -r -a VERSION_PARTS <<< "$CURRENT_VERSION"
MAJOR=${VERSION_PARTS[0]}
MINOR=${VERSION_PARTS[1]}
MINOR=0
PATCH=${VERSION_PARTS[2]}
# Increment patch version
NEW_MINOR=$((MINOR + 1))
NEW_VERSION="v${MAJOR}.${NEW_MINOR}.${PATCH}"
NEW_MAJOR=$((MAJOR + 1))
NEW_VERSION="v${NEW_MAJOR}.${MINOR}.${PATCH}"
fi
# Set the full tag name
Expand Down
7 changes: 7 additions & 0 deletions common/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
### Changelog

## 2.0.0 - 2026-02-11

### Changed (2)

- Updated `WebSocketCommon` connect method to accept streams parameter for subscribing upon connection.
- Updated `retry` logic in `utils.go` to use `SleepContext` for better handling of context cancellation during retries.

## 1.2.0 - 2026-01-23

### Added (1)
Expand Down
2 changes: 1 addition & 1 deletion common/common/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ type ConfigurationWebsocketStreamsOption func(*ConfigurationWebsocketStreams)
// @return A pointer to the newly created ConfigurationRestAPI.
func NewConfigurationRestAPI(opts ...ConfigurationRestAPIOption) *ConfigurationRestAPI {
basePath := "https://api.binance.com"
timeout := 5 * time.Millisecond
timeout := 5000 * time.Millisecond
retries := 3
backoff := 1000
keepAlive := true
Expand Down
46 changes: 41 additions & 5 deletions common/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,38 @@ func ParseRateLimitHeaders(header http.Header) ([]RateLimit, error) {
return rateLimits, nil
}

// SleepContext pauses the execution for the specified duration or until the context is done.
//
// @param ctx The context to observe for cancellation.
// @param duration The duration to sleep.
// @return An error if the context is done before the duration elapses.
func SleepContext(ctx context.Context, duration time.Duration) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}

if duration <= 0 {
select {
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
}

timer := time.NewTimer(duration)
defer timer.Stop()

select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}

// SendRequest sends an HTTP request and handles retries, response decoding, and error handling.
//
// @param ctx The context for the request.
Expand Down Expand Up @@ -439,7 +471,7 @@ func SendRequest[T any](ctx context.Context, path string, method string, queryPa

backoff := cfg.Backoff
if backoff <= 0 {
backoff = 1
backoff = 1000
}
var lastErr error

Expand All @@ -449,7 +481,9 @@ func SendRequest[T any](ctx context.Context, path string, method string, queryPa
if err != nil {
lastErr = err
if attempt < retries && ShouldRetryRequest(err, method, retries-attempt, resp) {
time.Sleep(time.Duration(backoff*attempt) * time.Second)
if err := SleepContext(ctx, time.Duration(backoff*(attempt+1))*time.Millisecond); err != nil {
return &RestApiResponse[T]{}, err
}
continue
}
return &RestApiResponse[T]{}, NewNetworkError(fmt.Sprintf("Network error: %v", err))
Expand Down Expand Up @@ -490,7 +524,9 @@ func SendRequest[T any](ctx context.Context, path string, method string, queryPa

if resp.StatusCode >= 500 && resp.StatusCode <= 504 {
if attempt < retries {
time.Sleep(time.Duration(backoff*attempt) * time.Second)
if err := SleepContext(ctx, time.Duration(backoff*(attempt+1))*time.Millisecond); err != nil {
return &RestApiResponse[T]{}, err
}
continue
}
return &RestApiResponse[T]{}, fmt.Errorf("request failed after %d retries: received status %d", retries, resp.StatusCode)
Expand Down Expand Up @@ -703,11 +739,11 @@ func SetupProxy(cfg *ConfigurationRestAPI) *http.Client {
transport := BuildTransport(cfg.HTTPSAgent, cfg)
return &http.Client{
Transport: transport,
Timeout: time.Duration(cfg.Timeout) * time.Millisecond,
Timeout: cfg.Timeout,
}
}
return &http.Client{
Timeout: time.Duration(cfg.Timeout) * time.Millisecond,
Timeout: cfg.Timeout,
}
}

Expand Down
84 changes: 43 additions & 41 deletions common/common/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,30 +362,44 @@ func (w *WebSocketCommon) initializePool() {
//
// @param config The WebSocketConfig containing configuration details.
// @param userAgent The user agent string to be used for the connection.
// @param streams A slice of stream names to subscribe to upon connection.
// @return An error if the connection fails, otherwise nil.
func (w *WebSocketCommon) Connect(config WebSocketConfig, userAgent string) error {
func (w *WebSocketCommon) Connect(config WebSocketConfig, userAgent string, streams []string) error {
if err := w.setupProxyDialer(config); err != nil {
return fmt.Errorf("proxy setup failed: %v", err)
}

BasePath := w.prepareBasePath(config)
headers := w.prepareHeaders(config, userAgent)
dialer := w.CreateWebSocketDialer(config)

if w.Mode == SINGLE {
BasePath := w.prepareBasePath(config, streams, true)
return w.connectSingleMode(BasePath, headers, dialer, config, userAgent)
}
return w.connectPoolMode(BasePath, headers, dialer, config, userAgent)
return w.connectPoolMode(headers, dialer, config, userAgent, streams)
}

// prepareBasePath constructs the base path for the WebSocket connection.
//
// @param config The WebSocketConfig containing configuration details.
// @param streams The list of streams to include in the base path.
// @param includeStreams A boolean indicating whether to include streams in the base path.
// @return The constructed base path string.
func (w *WebSocketCommon) prepareBasePath(config WebSocketConfig) string {
func (w *WebSocketCommon) prepareBasePath(config WebSocketConfig, streams []string, includeStreams bool) string {
BasePath := config.GetBasePath()
if includeStreams && streams != nil && len(streams) > 0 {
BasePath += "?streams="
for _, stream := range streams {
BasePath += stream + "/"
}
BasePath = strings.TrimSuffix(BasePath, "/")
}
if timeUnit := config.GetTimeUnit(); timeUnit != "" {
BasePath = BasePath + "?timeUnit=" + string(timeUnit)
if streams != nil && len(streams) > 0 {
BasePath += "&timeUnit=" + string(timeUnit)
} else {
BasePath += "?timeUnit=" + string(timeUnit)
}
}
return BasePath
}
Expand Down Expand Up @@ -485,13 +499,13 @@ func (w *WebSocketCommon) connectSingleMode(BasePath string, headers http.Header

// connectPoolMode establishes WebSocket connections in pool mode.
//
// @param BasePath The base URL for the WebSocket connection.
// @param headers The HTTP headers to include in the connection request.
// @param dialer The WebSocket dialer to use for the connection.
// @param config The WebSocketConfig containing configuration details.
// @param userAgent The user agent string to use for the connection.
// @param streams A slice of stream names to subscribe to upon connection.
// @return An error if the connection fails, otherwise nil.
func (w *WebSocketCommon) connectPoolMode(BasePath string, headers http.Header, dialer websocket.Dialer, config WebSocketConfig, userAgent string) error {
func (w *WebSocketCommon) connectPoolMode(headers http.Header, dialer websocket.Dialer, config WebSocketConfig, userAgent string, streams []string) error {
var wg sync.WaitGroup
successChan := make(chan bool, len(w.Connections))
errChan := make(chan error, len(w.Connections))
Expand All @@ -502,6 +516,7 @@ func (w *WebSocketCommon) connectPoolMode(BasePath string, headers http.Header,
go func(num int, conn *WebSocketConnection) {
defer wg.Done()

BasePath := w.prepareBasePath(config, streams, num == 0)
wsConn, _, err := dialer.Dial(BasePath, headers)
if err != nil {
log.Printf("WebSocket connection error: %v", err)
Expand Down Expand Up @@ -601,7 +616,7 @@ func (w *WebSocketCommon) KeepAlive(connection *WebSocketConnection, config WebS
// @param userAgent The user agent string to use for the connection.
// @return An error if the reconnection fails, otherwise nil.
func (w *WebSocketCommon) reconnect(conn *WebSocketConnection, config WebSocketConfig, userAgent string) error {
BasePath := w.prepareBasePath(config)
BasePath := w.prepareBasePath(config, conn.StreamConnectionMap, conn.SessionLogonRequest == nil)
headers := w.prepareHeaders(config, userAgent)
dialer := w.CreateWebSocketDialer(config)

Expand All @@ -617,8 +632,6 @@ func (w *WebSocketCommon) reconnect(conn *WebSocketConnection, config WebSocketC
w.restoreSessionIfNeeded(conn)
time.Sleep(1 * time.Second)
w.resubscribeUserDataStreams(conn)
} else if len(conn.StreamConnectionMap) > 0 {
w.resubscribeRegularStreams(conn)
}

return nil
Expand Down Expand Up @@ -713,34 +726,6 @@ func (w *WebSocketCommon) resubscribeUserDataStreams(connection *WebSocketConnec
}
}

// resubscribeRegularStreams resubscribes to all regular streams after reconnection.
//
// @param connection The WebSocketConnection to resubscribe streams on.
func (w *WebSocketCommon) resubscribeRegularStreams(connection *WebSocketConnection) {
for _, stream := range connection.StreamConnectionMap {
subscribePayload := map[string]interface{}{
"method": "SUBSCRIBE",
"params": []string{stream},
"id": GenerateUUID(),
}

message, err := json.Marshal(subscribePayload)
if err != nil {
log.Printf("Error during resubscription to stream %s: %v", stream, err)
continue
}

connection.mu.Lock()
err = connection.Websocket.WriteMessage(websocket.TextMessage, message)
connection.mu.Unlock()
if err != nil {
log.Printf("Error sending resubscription message for stream %s: %v", stream, err)
continue
}
log.Printf("Resubscribed to stream %s on reconnection", stream)
}
}

// isConnectionReady checks if the WebSocket connection is open.
//
// @param connection The WebSocketConnection to check.
Expand Down Expand Up @@ -867,7 +852,7 @@ func NewWebsocketAPI(cfg *ConfigurationWebsocketApi) (*WebsocketAPI, error) {
// @param userAgent The user agent string to be used for the connection.
// @return An error if the connection fails, otherwise nil.
func (w *WebsocketAPI) Connect(userAgent string) error {
return w.WsCommon.Connect(w.Cfg, userAgent)
return w.WsCommon.Connect(w.Cfg, userAgent, []string{})
}

// SendMessage sends a message over the WebSocket connection and returns channels for the response and error.
Expand Down Expand Up @@ -1121,8 +1106,25 @@ func NewWebsocketStreams(cfg *ConfigurationWebsocketStreams) (*WebsocketStreams,
}, nil
}

func (w *WebsocketStreams) Connect(userAgent string) error {
return w.WsCommon.Connect(w.Cfg, userAgent)
// Connect establishes the WebSocket connection using the provided user agent and streams.
//
// @param userAgent The user agent string to be used for the connection.
// @param streams A slice of stream names to subscribe to upon connection.
// @return An error if the connection fails, otherwise nil.
func (w *WebsocketStreams) Connect(userAgent string, streams []string) error {
err := w.WsCommon.Connect(w.Cfg, userAgent, streams)
if err != nil {
fmt.Println("WebSocket connection error:", err)
return err
}
if streams != nil && len(streams) > 0 {
conn := w.WsCommon.Connections[0]
for _, stream := range streams {
w.GlobalStreamConnectionMap[stream] = append(w.GlobalStreamConnectionMap[stream], conn)
conn.StreamConnectionMap = append(conn.StreamConnectionMap, stream)
}
}
return nil
}

// Subscribe subscribes to the specified streams.
Expand Down
86 changes: 86 additions & 0 deletions common/tests/unit/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,92 @@ func TestSendRequest_ContentEncodingGzip(t *testing.T) {
}
}

func TestSleepContext_CompletesNormally(t *testing.T) {
ctx := context.Background()
duration := 50 * time.Millisecond

start := time.Now()
err := common.SleepContext(ctx, duration)
elapsed := time.Since(start)

if err != nil {
t.Errorf("Expected no error, got %v", err)
}

if elapsed < duration {
t.Errorf("Expected sleep to last at least %v, but it lasted %v", duration, elapsed)
}
}

func TestSleepContext_CanceledContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()

err := common.SleepContext(ctx, 1*time.Second)

if err != context.Canceled {
t.Errorf("Expected context.Canceled error, got %v", err)
}
}

func TestSleepContext_CanceledDuringSleep(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())

go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()

start := time.Now()
err := common.SleepContext(ctx, 1*time.Second)
elapsed := time.Since(start)

if err != context.Canceled {
t.Errorf("Expected context.Canceled error, got %v", err)
}

if elapsed >= 1*time.Second {
t.Errorf("Expected sleep to be interrupted before 1 second, but it lasted %v", elapsed)
}
}

func TestSleepContext_ZeroDuration(t *testing.T) {
ctx := context.Background()

err := common.SleepContext(ctx, 0)

if err != nil {
t.Errorf("Expected no error with zero duration, got %v", err)
}
}

func TestSleepContext_NegativeDuration(t *testing.T) {
ctx := context.Background()

err := common.SleepContext(ctx, -1*time.Second)

if err != nil {
t.Errorf("Expected no error with negative duration, got %v", err)
}
}

func TestSleepContext_TimeoutContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

start := time.Now()
err := common.SleepContext(ctx, 1*time.Second)
elapsed := time.Since(start)

if err != context.DeadlineExceeded {
t.Errorf("Expected context.DeadlineExceeded error, got %v", err)
}

if elapsed >= 1*time.Second {
t.Errorf("Expected sleep to be interrupted by timeout, but it lasted %v", elapsed)
}
}

func TestPrepareRequest_BasicHeadersAndQuery(t *testing.T) {
cfg := &common.ConfigurationRestAPI{
ApiKey: "apikey123",
Expand Down
Loading