Skip to content

Commit 5ac1f33

Browse files
authored
feat: Add local websocket proxy (#860)
1 parent 294568d commit 5ac1f33

File tree

4 files changed

+135
-33
lines changed

4 files changed

+135
-33
lines changed

pkg/cloud/cloud.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ func New(projectName string, opts LocalCloudOptions) (*LocalCloud, error) {
321321
return nil, err
322322
}
323323

324-
localWebsites := websites.NewLocalWebsitesService(localGateway.GetApiAddress, opts.LocalCloudMode == StartMode)
324+
localWebsites := websites.NewLocalWebsitesService(localGateway.GetApiAddress, localGateway.GetWebsocketAddress, opts.LocalCloudMode == StartMode)
325325

326326
return &LocalCloud{
327327
servers: make(map[string]*server.NitricServer),

pkg/cloud/gateway/gateway.go

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ import (
4646
"github.com/nitrictech/cli/pkg/netx"
4747
"github.com/nitrictech/cli/pkg/project/localconfig"
4848
"github.com/nitrictech/cli/pkg/system"
49-
"github.com/nitrictech/cli/pkg/view/tui"
5049

5150
base_http "github.com/nitrictech/nitric/cloud/common/runtime/gateway"
5251

@@ -157,6 +156,19 @@ func (s *LocalGatewayService) GetApiAddress(apiName string) string {
157156
return ""
158157
}
159158

159+
func (s *LocalGatewayService) GetWebsocketAddress(socketName string) string {
160+
s.lock.RLock()
161+
defer s.lock.RUnlock()
162+
163+
addresses := s.GetWebsocketAddresses()
164+
165+
if address, ok := addresses[socketName]; ok {
166+
return address
167+
}
168+
169+
return ""
170+
}
171+
160172
func (s *LocalGatewayService) GetHttpWorkerAddresses() map[string]string {
161173
s.lock.RLock()
162174
defer s.lock.RUnlock()
@@ -349,14 +361,14 @@ func (s *LocalGatewayService) handleWebsocketRequest(socketName string) func(ctx
349361
SocketName: socketName,
350362
})
351363
if err != nil {
352-
tui.Error.Println(err.Error())
364+
system.Logf("Websocket error: %s", err.Error())
353365
return
354366
}
355367
}()
356368

357369
err = s.websocketPlugin.RegisterConnection(socketName, connectionId, ws)
358370
if err != nil {
359-
tui.Error.Println(err.Error())
371+
system.Logf("Websocket error: %s", err.Error())
360372
return
361373
}
362374

@@ -372,7 +384,7 @@ func (s *LocalGatewayService) handleWebsocketRequest(socketName string) func(ctx
372384
if err != nil && websocket.IsCloseError(err, 1001, 1005) {
373385
break
374386
} else if err != nil {
375-
log.Println("read:", err)
387+
system.Logf("websocket read error: %v", err)
376388
break
377389
}
378390

@@ -390,7 +402,7 @@ func (s *LocalGatewayService) handleWebsocketRequest(socketName string) func(ctx
390402
},
391403
})
392404
if err != nil {
393-
tui.Error.Println(err.Error())
405+
system.Logf("Websocket error: %s", err.Error())
394406
return
395407
}
396408
}
@@ -407,13 +419,13 @@ func (s *LocalGatewayService) handleWebsocketRequest(socketName string) func(ctx
407419
},
408420
})
409421
if err != nil {
410-
tui.Error.Println(err.Error())
422+
system.Logf("Websocket error: %s", err.Error())
411423
return
412424
}
413425
})
414426
if err != nil {
415427
if _, ok := err.(websocket.HandshakeError); ok {
416-
tui.Error.Println(err.Error())
428+
system.Logf("Websocket error: %s", err.Error())
417429
}
418430

419431
return

pkg/cloud/websites/websites.go

Lines changed: 98 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ import (
3131
"sync"
3232

3333
"github.com/asaskevich/EventBus"
34+
"github.com/gorilla/websocket"
3435

3536
"github.com/nitrictech/cli/pkg/netx"
37+
"github.com/nitrictech/cli/pkg/system"
3638
deploymentspb "github.com/nitrictech/nitric/core/pkg/proto/deployments/v1"
3739
)
3840

@@ -55,10 +57,11 @@ type (
5557
)
5658

5759
type LocalWebsiteService struct {
58-
websiteRegLock sync.RWMutex
59-
state State
60-
getApiAddress GetApiAddress
61-
isStartCmd bool
60+
websiteRegLock sync.RWMutex
61+
state State
62+
getApiAddress GetApiAddress
63+
getWebsocketAddress GetApiAddress
64+
isStartCmd bool
6265

6366
bus EventBus.Bus
6467
}
@@ -74,6 +77,22 @@ func (l *LocalWebsiteService) SubscribeToState(fn func(State)) {
7477
_ = l.bus.Subscribe(localWebsitesTopic, fn)
7578
}
7679

80+
func proxyWebSocketMessages(src, dst *websocket.Conn, errChan chan error) {
81+
for {
82+
messageType, message, err := src.ReadMessage()
83+
if err != nil {
84+
errChan <- err
85+
return
86+
}
87+
88+
err = dst.WriteMessage(messageType, message)
89+
if err != nil {
90+
errChan <- err
91+
return
92+
}
93+
}
94+
}
95+
7796
// register - Register a new website
7897
func (l *LocalWebsiteService) register(website Website, port int) {
7998
l.websiteRegLock.Lock()
@@ -182,25 +201,72 @@ func (h staticSiteHandler) ServeHTTP(res http.ResponseWriter, req *http.Request)
182201
h.serveStatic(res, req)
183202
}
184203

185-
// createAPIPathHandler creates a handler for API proxy requests
186-
func (l *LocalWebsiteService) createAPIPathHandler() http.HandlerFunc {
187-
return func(res http.ResponseWriter, req *http.Request) {
188-
apiName := req.PathValue("name")
204+
// websocketPathHandler creates a handler for WebSocket proxy requests
205+
func (l *LocalWebsiteService) websocketPathHandler(w http.ResponseWriter, r *http.Request) {
206+
// Get the WebSocket API name from the request path
207+
apiName := r.PathValue("name")
189208

190-
apiAddress := l.getApiAddress(apiName)
191-
if apiAddress == "" {
192-
http.Error(res, fmt.Sprintf("api %s not found", apiName), http.StatusNotFound)
193-
return
194-
}
209+
// Get the address of the WebSocket API
210+
apiAddress := l.getWebsocketAddress(apiName)
211+
if apiAddress == "" {
212+
http.Error(w, fmt.Sprintf("WebSocket API %s not found", apiName), http.StatusNotFound)
213+
return
214+
}
195215

196-
targetPath := strings.TrimPrefix(req.URL.Path, fmt.Sprintf("/api/%s", apiName))
197-
targetUrl, _ := url.Parse(apiAddress)
216+
// Dial the backend WebSocket server
217+
targetURL := fmt.Sprintf("ws://%s%s", apiAddress, r.URL.Path)
218+
if r.URL.RawQuery != "" {
219+
targetURL = fmt.Sprintf("%s?%s", targetURL, r.URL.RawQuery)
220+
}
221+
222+
targetConn, _, err := websocket.DefaultDialer.Dial(targetURL, nil)
223+
if err != nil {
224+
http.Error(w, fmt.Sprintf("Failed to connect to backend WebSocket server: %v", err), http.StatusInternalServerError)
225+
return
226+
}
227+
defer targetConn.Close()
198228

199-
proxy := httputil.NewSingleHostReverseProxy(targetUrl)
200-
req.URL.Path = targetPath
229+
// Upgrade the HTTP connection to a WebSocket connection
230+
upgrader := websocket.Upgrader{}
201231

202-
proxy.ServeHTTP(res, req)
232+
clientConn, err := upgrader.Upgrade(w, r, nil)
233+
if err != nil {
234+
http.Error(w, fmt.Sprintf("Failed to upgrade to WebSocket: %v", err), http.StatusInternalServerError)
235+
return
236+
}
237+
238+
defer clientConn.Close()
239+
240+
// Proxy messages between the client and the backend WebSocket server
241+
errChan := make(chan error, 2)
242+
go proxyWebSocketMessages(clientConn, targetConn, errChan)
243+
go proxyWebSocketMessages(targetConn, clientConn, errChan)
244+
245+
// Wait for an error to occur
246+
err = <-errChan
247+
if err != nil && !errors.Is(err, websocket.ErrCloseSent) {
248+
// Because the error is already proxied through by the connection we can just log the error here
249+
system.Logf("received error on websocket %s: %v", apiName, err)
250+
}
251+
}
252+
253+
// apiPathHandler creates a handler for API proxy requests
254+
func (l *LocalWebsiteService) apiPathHandler(res http.ResponseWriter, req *http.Request) {
255+
apiName := req.PathValue("name")
256+
257+
apiAddress := l.getApiAddress(apiName)
258+
if apiAddress == "" {
259+
http.Error(res, fmt.Sprintf("api %s not found", apiName), http.StatusNotFound)
260+
return
203261
}
262+
263+
targetPath := strings.TrimPrefix(req.URL.Path, fmt.Sprintf("/api/%s", apiName))
264+
targetUrl, _ := url.Parse(apiAddress)
265+
266+
proxy := httputil.NewSingleHostReverseProxy(targetUrl)
267+
req.URL.Path = targetPath
268+
269+
proxy.ServeHTTP(res, req)
204270
}
205271

206272
// createServer creates and configures an HTTP server with the given mux
@@ -250,7 +316,10 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
250316
mux := http.NewServeMux()
251317

252318
// Register the API proxy handler for this website
253-
mux.HandleFunc("/api/{name}/", l.createAPIPathHandler())
319+
mux.HandleFunc("/api/{name}/", l.apiPathHandler)
320+
321+
// Register the WebSocket proxy handler for this website
322+
mux.HandleFunc("/ws/{name}", l.websocketPathHandler)
254323

255324
// Create the SPA handler for this website
256325
spa := staticSiteHandler{
@@ -287,7 +356,10 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
287356
mux := http.NewServeMux()
288357

289358
// Register the API proxy handler
290-
mux.HandleFunc("/api/{name}/", l.createAPIPathHandler())
359+
mux.HandleFunc("/api/{name}/", l.apiPathHandler)
360+
361+
// Register the WebSocket proxy handler for this website
362+
mux.HandleFunc("/ws/{name}", l.websocketPathHandler)
291363

292364
// Register the SPA handler for each website
293365
for i := range websites {
@@ -325,11 +397,12 @@ func (l *LocalWebsiteService) Start(websites []Website) error {
325397
return nil
326398
}
327399

328-
func NewLocalWebsitesService(getApiAddress GetApiAddress, isStartCmd bool) *LocalWebsiteService {
400+
func NewLocalWebsitesService(getApiAddress GetApiAddress, getWebsocketAddress GetApiAddress, isStartCmd bool) *LocalWebsiteService {
329401
return &LocalWebsiteService{
330-
state: State{},
331-
bus: EventBus.New(),
332-
getApiAddress: getApiAddress,
333-
isStartCmd: isStartCmd,
402+
state: State{},
403+
bus: EventBus.New(),
404+
getApiAddress: getApiAddress,
405+
getWebsocketAddress: getWebsocketAddress,
406+
isStartCmd: isStartCmd,
334407
}
335408
}

pkg/dashboard/frontend/src/lib/utils/generate-architecture-data.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,23 @@ export function generateArchitectureData(data: WebSocketResponse): {
609609
label: `Rewrites to /api/${api.name}`,
610610
})
611611
})
612+
613+
data.websockets.forEach((websocket) => {
614+
edges.push({
615+
id: `e-${websocket.name}-websites`,
616+
source: websitesNode.id,
617+
target: `websocket-${websocket.name}`,
618+
animated: true,
619+
markerEnd: {
620+
type: MarkerType.ArrowClosed,
621+
},
622+
markerStart: {
623+
type: MarkerType.ArrowClosed,
624+
orient: 'auto-start-reverse',
625+
},
626+
label: `Rewrites to /ws/${websocket.name}`,
627+
})
628+
})
612629
}
613630

614631
data.services.forEach((service) => {

0 commit comments

Comments
 (0)