@@ -3,15 +3,19 @@ package rabbitmqamqp
33import (
44 "context"
55 "crypto/tls"
6+ "encoding/base64"
67 "errors"
78 "fmt"
89 "math/rand"
10+ "net/http"
11+ "net/url"
912 "sync"
1013 "sync/atomic"
1114 "time"
1215
1316 "github.com/Azure/go-amqp"
1417 "github.com/google/uuid"
18+ "github.com/gorilla/websocket"
1519)
1620
1721var ErrConnectionClosed = errors .New ("connection is closed" )
@@ -453,10 +457,45 @@ func (a *AmqpConnection) open(ctx context.Context, address string, connOptions *
453457 TLSConfig : connOptions .TLSConfig ,
454458 WriteTimeout : connOptions .WriteTimeout ,
455459 }
456- azureConnection , err = amqp .Dial (ctx , address , amqpLiteConnOptions )
457- if err != nil && (connOptions .TLSConfig != nil || uri .Scheme == AMQPS ) {
458- Error ("Failed to open TLS connection" , fmt .Sprintf ("%s://%s" , uri .Scheme , uri .Host ), err , "ID" , connOptions .Id )
459- return fmt .Errorf ("failed to open TLS connection: %w" , err )
460+
461+ u , err := url .Parse (address )
462+ if err != nil {
463+ return err
464+ }
465+
466+ if u .Scheme == "ws" || u .Scheme == "wss" {
467+
468+ wsAddress , wsHeaders , err := sanitizeWebSocketURL (address )
469+ if err != nil {
470+ Error ("Failed to sanitize websocket URL" , "url" , ExtractWithoutPassword (address ), "error" , err , "ID" , connOptions .Id )
471+ return fmt .Errorf ("failed to sanitize websocket URL: %w" , err )
472+ }
473+
474+ // Create a WebSocket dialer
475+ dialer := websocket .DefaultDialer
476+ if u .Scheme == "wss" && connOptions .TLSConfig != nil {
477+ dialer .TLSClientConfig = connOptions .TLSConfig
478+ }
479+
480+ // Dial the WebSocket server
481+ wsConn , _ , err := dialer .Dial (wsAddress , wsHeaders )
482+ if err != nil {
483+ Error ("Failed to open a websocket connection" , "url" , ExtractWithoutPassword (wsAddress ), "error" , err , "ID" , connOptions .Id )
484+ return fmt .Errorf ("failed to open a websocket connection: %w" , err )
485+ }
486+
487+ // Wrap the WebSocket connection in a WebSocketConn
488+ neConn := NewWebSocketConn (wsConn )
489+ azureConnection , err = amqp .NewConn (ctx , neConn , amqpLiteConnOptions )
490+ if err != nil {
491+ Error ("Failed to open AMQP over WebSocket connection" , "url" , ExtractWithoutPassword (address ), "error" , err , "ID" , connOptions .Id )
492+ }
493+ } else {
494+ azureConnection , err = amqp .Dial (ctx , address , amqpLiteConnOptions )
495+ if err != nil && (connOptions .TLSConfig != nil || uri .Scheme == AMQPS ) {
496+ Error ("Failed to open TLS connection" , fmt .Sprintf ("%s://%s" , uri .Scheme , uri .Host ), err , "ID" , connOptions .Id )
497+ return fmt .Errorf ("failed to open TLS connection: %w" , err )
498+ }
460499 }
461500 if err != nil {
462501 Error ("Failed to open connection" , "url" , ExtractWithoutPassword (address ), "error" , err , "ID" , connOptions .Id )
@@ -701,3 +740,36 @@ func (a *AmqpConnection) RefreshToken(background context.Context, token string)
701740}
702741
703742//*** end management section ***
743+
744+ // sanitizeWebSocketURL ensures the URL is correctly formatted for the Gorilla websocket dialer.
745+ func sanitizeWebSocketURL (rawURL string ) (string , http.Header , error ) {
746+ u , err := url .Parse (rawURL )
747+ if err != nil {
748+ return "" , nil , err
749+ }
750+
751+ if u .Scheme != "ws" && u .Scheme != "wss" {
752+ return "" , nil , fmt .Errorf ("invalid websocket scheme: %s" , u .Scheme )
753+ }
754+
755+ // Prepare Headers for Auth
756+ headers := http.Header {}
757+ if u .User != nil {
758+ username := u .User .Username ()
759+ password , _ := u .User .Password ()
760+
761+ // Construct Basic Auth Header manually
762+ auth := base64 .StdEncoding .EncodeToString ([]byte (username + ":" + password ))
763+ headers .Add ("Authorization" , "Basic " + auth )
764+
765+ u .User = nil
766+ }
767+
768+ if u .Path == "" {
769+ u .Path = "/"
770+ } else if u .Path [0 ] != '/' {
771+ u .Path = "/" + u .Path
772+ }
773+
774+ return u .String (), headers , nil
775+ }
0 commit comments