@@ -31,8 +31,10 @@ import (
3131 "sync"
3232
3333 "github.com/asaskevich/EventBus"
34+ "github.com/gorilla/websocket"
3435
3536 "github.com/nitrictech/cli/pkg/netx"
37+ "github.com/nitrictech/cli/pkg/system"
3638 deploymentspb "github.com/nitrictech/nitric/core/pkg/proto/deployments/v1"
3739)
3840
@@ -55,10 +57,11 @@ type (
5557)
5658
5759type LocalWebsiteService struct {
58- websiteRegLock sync.RWMutex
59- state State
60- getApiAddress GetApiAddress
61- isStartCmd bool
60+ websiteRegLock sync.RWMutex
61+ state State
62+ getApiAddress GetApiAddress
63+ getWebsocketAddress GetApiAddress
64+ isStartCmd bool
6265
6366 bus EventBus.Bus
6467}
@@ -74,6 +77,22 @@ func (l *LocalWebsiteService) SubscribeToState(fn func(State)) {
7477 _ = l .bus .Subscribe (localWebsitesTopic , fn )
7578}
7679
80+ func proxyWebSocketMessages (src , dst * websocket.Conn , errChan chan error ) {
81+ for {
82+ messageType , message , err := src .ReadMessage ()
83+ if err != nil {
84+ errChan <- err
85+ return
86+ }
87+
88+ err = dst .WriteMessage (messageType , message )
89+ if err != nil {
90+ errChan <- err
91+ return
92+ }
93+ }
94+ }
95+
7796// register - Register a new website
7897func (l * LocalWebsiteService ) register (website Website , port int ) {
7998 l .websiteRegLock .Lock ()
@@ -182,25 +201,72 @@ func (h staticSiteHandler) ServeHTTP(res http.ResponseWriter, req *http.Request)
182201 h .serveStatic (res , req )
183202}
184203
185- // createAPIPathHandler creates a handler for API proxy requests
186- func (l * LocalWebsiteService ) createAPIPathHandler () http.HandlerFunc {
187- return func ( res http. ResponseWriter , req * http. Request ) {
188- apiName := req .PathValue ("name" )
204+ // websocketPathHandler creates a handler for WebSocket proxy requests
205+ func (l * LocalWebsiteService ) websocketPathHandler ( w http.ResponseWriter , r * http. Request ) {
206+ // Get the WebSocket API name from the request path
207+ apiName := r .PathValue ("name" )
189208
190- apiAddress := l .getApiAddress (apiName )
191- if apiAddress == "" {
192- http .Error (res , fmt .Sprintf ("api %s not found" , apiName ), http .StatusNotFound )
193- return
194- }
209+ // Get the address of the WebSocket API
210+ apiAddress := l .getWebsocketAddress (apiName )
211+ if apiAddress == "" {
212+ http .Error (w , fmt .Sprintf ("WebSocket API %s not found" , apiName ), http .StatusNotFound )
213+ return
214+ }
195215
196- targetPath := strings .TrimPrefix (req .URL .Path , fmt .Sprintf ("/api/%s" , apiName ))
197- targetUrl , _ := url .Parse (apiAddress )
216+ // Dial the backend WebSocket server
217+ targetURL := fmt .Sprintf ("ws://%s%s" , apiAddress , r .URL .Path )
218+ if r .URL .RawQuery != "" {
219+ targetURL = fmt .Sprintf ("%s?%s" , targetURL , r .URL .RawQuery )
220+ }
221+
222+ targetConn , _ , err := websocket .DefaultDialer .Dial (targetURL , nil )
223+ if err != nil {
224+ http .Error (w , fmt .Sprintf ("Failed to connect to backend WebSocket server: %v" , err ), http .StatusInternalServerError )
225+ return
226+ }
227+ defer targetConn .Close ()
198228
199- proxy := httputil . NewSingleHostReverseProxy ( targetUrl )
200- req . URL . Path = targetPath
229+ // Upgrade the HTTP connection to a WebSocket connection
230+ upgrader := websocket. Upgrader {}
201231
202- proxy .ServeHTTP (res , req )
232+ clientConn , err := upgrader .Upgrade (w , r , nil )
233+ if err != nil {
234+ http .Error (w , fmt .Sprintf ("Failed to upgrade to WebSocket: %v" , err ), http .StatusInternalServerError )
235+ return
236+ }
237+
238+ defer clientConn .Close ()
239+
240+ // Proxy messages between the client and the backend WebSocket server
241+ errChan := make (chan error , 2 )
242+ go proxyWebSocketMessages (clientConn , targetConn , errChan )
243+ go proxyWebSocketMessages (targetConn , clientConn , errChan )
244+
245+ // Wait for an error to occur
246+ err = <- errChan
247+ if err != nil && ! errors .Is (err , websocket .ErrCloseSent ) {
248+ // Because the error is already proxied through by the connection we can just log the error here
249+ system .Logf ("received error on websocket %s: %v" , apiName , err )
250+ }
251+ }
252+
253+ // apiPathHandler creates a handler for API proxy requests
254+ func (l * LocalWebsiteService ) apiPathHandler (res http.ResponseWriter , req * http.Request ) {
255+ apiName := req .PathValue ("name" )
256+
257+ apiAddress := l .getApiAddress (apiName )
258+ if apiAddress == "" {
259+ http .Error (res , fmt .Sprintf ("api %s not found" , apiName ), http .StatusNotFound )
260+ return
203261 }
262+
263+ targetPath := strings .TrimPrefix (req .URL .Path , fmt .Sprintf ("/api/%s" , apiName ))
264+ targetUrl , _ := url .Parse (apiAddress )
265+
266+ proxy := httputil .NewSingleHostReverseProxy (targetUrl )
267+ req .URL .Path = targetPath
268+
269+ proxy .ServeHTTP (res , req )
204270}
205271
206272// createServer creates and configures an HTTP server with the given mux
@@ -250,7 +316,10 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
250316 mux := http .NewServeMux ()
251317
252318 // Register the API proxy handler for this website
253- mux .HandleFunc ("/api/{name}/" , l .createAPIPathHandler ())
319+ mux .HandleFunc ("/api/{name}/" , l .apiPathHandler )
320+
321+ // Register the WebSocket proxy handler for this website
322+ mux .HandleFunc ("/ws/{name}" , l .websocketPathHandler )
254323
255324 // Create the SPA handler for this website
256325 spa := staticSiteHandler {
@@ -287,7 +356,10 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
287356 mux := http .NewServeMux ()
288357
289358 // Register the API proxy handler
290- mux .HandleFunc ("/api/{name}/" , l .createAPIPathHandler ())
359+ mux .HandleFunc ("/api/{name}/" , l .apiPathHandler )
360+
361+ // Register the WebSocket proxy handler for this website
362+ mux .HandleFunc ("/ws/{name}" , l .websocketPathHandler )
291363
292364 // Register the SPA handler for each website
293365 for i := range websites {
@@ -325,11 +397,12 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
325397 return nil
326398}
327399
328- func NewLocalWebsitesService (getApiAddress GetApiAddress , isStartCmd bool ) * LocalWebsiteService {
400+ func NewLocalWebsitesService (getApiAddress GetApiAddress , getWebsocketAddress GetApiAddress , isStartCmd bool ) * LocalWebsiteService {
329401 return & LocalWebsiteService {
330- state : State {},
331- bus : EventBus .New (),
332- getApiAddress : getApiAddress ,
333- isStartCmd : isStartCmd ,
402+ state : State {},
403+ bus : EventBus .New (),
404+ getApiAddress : getApiAddress ,
405+ getWebsocketAddress : getWebsocketAddress ,
406+ isStartCmd : isStartCmd ,
334407 }
335408}
0 commit comments