Skip to content

Commit c3ceb0b

Browse files
authored
Merge pull request #11 from cyverse-de/websocket-fix
Websocket/Reverse Proxy Update
2 parents b7446d9 + 129d703 commit c3ceb0b

File tree

4 files changed

+74
-48
lines changed

4 files changed

+74
-48
lines changed

Dockerfile

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,30 @@
1-
# First stage
1+
# First stage: Build the binary
22
FROM golang:1.24 AS build-root
33

44
WORKDIR /build
55

6-
COPY go.mod .
7-
COPY go.sum .
6+
# Copy dependency files first for better layer caching
7+
COPY go.mod go.sum ./
88

99
RUN go mod download
1010

11+
# Copy source code
1112
COPY . .
1213

14+
# Build static binary with optimizations
1315
ENV CGO_ENABLED=0
1416
ENV GOOS=linux
1517
ENV GOARCH=amd64
1618

17-
RUN go build ./...
19+
RUN go build -o vice-proxy -ldflags="-w -s" .
1820

19-
## Second stage
20-
FROM golang:1.24
21+
## Second stage: Minimal runtime image
22+
FROM alpine:3.20
2123

24+
# Copy CA certificates from build stage for HTTPS connections to Keycloak
25+
COPY --from=build-root /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
26+
27+
# Copy the binary from build stage
2228
COPY --from=build-root /build/vice-proxy /bin/vice-proxy
2329

2430
ENTRYPOINT ["vice-proxy"]

go.mod

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ require (
1111
github.com/rs/cors v1.11.1
1212
github.com/sirupsen/logrus v1.9.3
1313
github.com/stretchr/testify v1.9.0
14-
github.com/yhat/wsutil v0.0.0-20170731153501-1d66fa95c997
1514
)
1615

1716
require (

go.sum

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
4242
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
4343
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
4444
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
45-
github.com/yhat/wsutil v0.0.0-20170731153501-1d66fa95c997 h1:1+FQ4Ns+UZtUiQ4lP0sTCyKSQ0EXoiwAdHZB0Pd5t9Q=
46-
github.com/yhat/wsutil v0.0.0-20170731153501-1d66fa95c997/go.mod h1:DIGbh/f5XMAessMV/uaIik81gkDVjUeQ9ApdaU7wRKE=
4745
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
4846
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
49-
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
50-
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
5147
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
5248
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
5349
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

main.go

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"github.com/pkg/errors"
2222
"github.com/rs/cors"
2323
"github.com/sirupsen/logrus"
24-
"github.com/yhat/wsutil"
2524
)
2625

2726
var log = logrus.WithFields(logrus.Fields{
@@ -487,23 +486,27 @@ func (c *VICEProxy) Session(r *http.Request, m *mux.RouteMatch) bool {
487486
}
488487

489488
// ReverseProxy returns a proxy that forwards requests to the configured
490-
// backend URL. It can act as a http.Handler.
489+
// backend URL. It can act as a http.Handler and properly handles WebSocket upgrades.
491490
func (c *VICEProxy) ReverseProxy() (*httputil.ReverseProxy, error) {
492491
backend, err := url.Parse(c.backendURL)
493492
if err != nil {
494493
return nil, errors.Wrapf(err, "failed to parse %s", c.backendURL)
495494
}
496-
return httputil.NewSingleHostReverseProxy(backend), nil
497-
}
498495

499-
// WSReverseProxy returns a proxy that forwards websocket request to the
500-
// configured backend URL. It can act as a http.Handler.
501-
func (c *VICEProxy) WSReverseProxy() (*wsutil.ReverseProxy, error) {
502-
w, err := url.Parse(c.wsbackendURL)
503-
if err != nil {
504-
return nil, errors.Wrapf(err, "failed to parse the websocket backend URL %s", c.wsbackendURL)
496+
proxy := httputil.NewSingleHostReverseProxy(backend)
497+
498+
// Customize the director to handle WebSocket upgrade properly
499+
originalDirector := proxy.Director
500+
proxy.Director = func(req *http.Request) {
501+
originalDirector(req)
502+
// For WebSocket requests, ensure proper scheme in target URL
503+
if c.isWebsocket(req) {
504+
// The backend URL stays http:// but the proxy will handle upgrade
505+
log.Infof("WebSocket upgrade request detected for %s", req.URL.Path)
506+
}
505507
}
506-
return wsutil.NewSingleHostReverseProxy(w), nil
508+
509+
return proxy, nil
507510
}
508511

509512
// isWebsocket returns true if the connection is a websocket request. Adapted
@@ -540,17 +543,21 @@ func (c *VICEProxy) backendIsReady(backendURL string) (bool, error) {
540543
// {"ready":boolean}, telling whether or not the underlying application is ready
541544
// for business yet.
542545
func (c *VICEProxy) URLIsReady(w http.ResponseWriter, r *http.Request) {
546+
log.Infof("checking backend readiness at %s", c.backendURL)
543547
ready, err := c.backendIsReady(c.backendURL)
544548
if err != nil {
545-
log.Error(err)
549+
log.Errorf("backend readiness check failed: %v", err)
546550
}
547551

552+
log.Infof("backend ready status: %v", ready)
553+
548554
data := map[string]bool{
549555
"ready": ready,
550556
}
551557

552558
body, err := json.Marshal(data)
553559
if err != nil {
560+
log.Errorf("failed to marshal readiness response: %v", err)
554561
http.Error(w, err.Error(), http.StatusInternalServerError)
555562
return
556563
}
@@ -574,11 +581,6 @@ func (c *VICEProxy) GetFrontendHost() (string, error) {
574581

575582
// Proxy returns a handler that can support both websockets and http requests.
576583
func (c *VICEProxy) Proxy() (http.Handler, error) {
577-
ws, err := c.WSReverseProxy()
578-
if err != nil {
579-
return nil, err
580-
}
581-
582584
rp, err := c.ReverseProxy()
583585
if err != nil {
584586
return nil, err
@@ -590,22 +592,25 @@ func (c *VICEProxy) Proxy() (http.Handler, error) {
590592
}
591593

592594
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
593-
//log.Debugf("handling request for %s from remote address %s", r.URL.String(), r.RemoteAddr)
595+
log.Infof("handling request for %s from remote address %s", r.URL.String(), r.RemoteAddr)
594596

595597
//Get the username from the cookie
596598
session, err := c.sessionStore.Get(r, sessionName)
597599
if err != nil {
598600
err = errors.Wrap(err, "failed to get session")
601+
log.Errorf("session error: %v", err)
599602
http.Error(w, err.Error(), http.StatusInternalServerError)
600603
return
601604
}
602605

603606
username := session.Values[sessionKey].(string)
604607
if username == "" {
605608
err = errors.Wrap(err, "username was empty")
609+
log.Errorf("authentication error: %v", err)
606610
http.Error(w, err.Error(), http.StatusForbidden)
607611
return
608612
}
613+
log.Infof("authenticated user: %s", username)
609614

610615
// Check to make sure the user can access the resource.
611616
allowed, err := c.IsAllowed(username, c.resourceName)
@@ -615,23 +620,27 @@ func (c *VICEProxy) Proxy() (http.Handler, error) {
615620
} else {
616621
err = errors.New("access denied")
617622
}
623+
log.Errorf("authorization error for user %s: %v", username, err)
618624
http.Error(w, err.Error(), http.StatusForbidden)
619625
return
620626
}
627+
log.Infof("user %s authorized for resource %s", username, c.resourceName)
621628

622629
// Override the X-Forwarded-Host header.
623630
r.Header.Set("X-Forwarded-Host", frontendHost)
624631

625-
if err = c.ResetSessionExpiration(w, r); err != nil {
626-
err = errors.Wrap(err, "error resetting session expiration")
627-
http.Error(w, err.Error(), http.StatusInternalServerError)
628-
return
632+
// CRITICAL: Don't reset session for WebSocket upgrades (would corrupt the upgrade handshake)
633+
if !c.isWebsocket(r) {
634+
if err = c.ResetSessionExpiration(w, r); err != nil {
635+
err = errors.Wrap(err, "error resetting session expiration")
636+
log.Errorf("session expiration error: %v", err)
637+
http.Error(w, err.Error(), http.StatusInternalServerError)
638+
return
639+
}
629640
}
630641

631-
if c.isWebsocket(r) {
632-
ws.ServeHTTP(w, r)
633-
return
634-
}
642+
// The reverse proxy handles both HTTP and WebSocket upgrade requests transparently
643+
log.Infof("proxying request to %s%s", c.backendURL, r.URL.Path)
635644
rp.ServeHTTP(w, r)
636645
}), nil
637646
}
@@ -669,6 +678,9 @@ func main() {
669678
checkResourceAccessBase = flag.String("check-resource-access-base", "http://check-resource-access", "The base URL for the check-resource-access service.")
670679
externalID = flag.String("external-id", "", "The external ID to pass to the apps service when looking up the analysis ID.")
671680
encodedSSOTimeout = flag.String("sso-timeout", "5s", "The timeout period for back-channel requests to the identity provider.")
681+
encodedReadTimeout = flag.String("read-timeout", "48h", "The maximum duration for reading the entire request, including the body.")
682+
encodedWriteTimeout = flag.String("write-timeout", "48h", "The maximum duration before timing out writes of the response.")
683+
encodedIdleTimeout = flag.String("idle-timeout", "5000s", "The maximum amount of time to wait for the next request when keep-alives are enabled.")
672684
)
673685

674686
flag.Var(&corsOrigins, "allowed-origins", "List of allowed origins, separated by commas.")
@@ -694,15 +706,6 @@ func main() {
694706
corsOrigins = originFlags{"*.cyverse.run", "*.cyverse.org", "*.cyverse.run:4343", "cyverse.run", "cyverse.run:4343"}
695707
}
696708

697-
if *wsbackendURL == "" {
698-
w, err := url.Parse(*backendURL)
699-
if err != nil {
700-
log.Fatal(err)
701-
}
702-
w.Scheme = "ws"
703-
*wsbackendURL = w.String()
704-
}
705-
706709
if *externalID == "" {
707710
log.Fatal("--external-id must be set.")
708711
}
@@ -715,6 +718,9 @@ func main() {
715718
log.Infof("Keycloak realm is %s", *keycloakRealm)
716719
log.Infof("Keycloak client ID is %s", *keycloakClientID)
717720
log.Infof("Keycloak client secret is %s", *keycloakClientSecret)
721+
log.Infof("read timeout is %s", *encodedReadTimeout)
722+
log.Infof("write timeout is %s", *encodedWriteTimeout)
723+
log.Infof("idle timeout is %s", *encodedIdleTimeout)
718724

719725
for _, c := range corsOrigins {
720726
log.Infof("Origin: %s\n", c)
@@ -739,6 +745,22 @@ func main() {
739745
log.Fatalf("invalid timeout duration for back-channel requests to the IdP: %s", err.Error())
740746
}
741747

748+
// Decode the timeout durations for the HTTP server.
749+
readTimeout, err := time.ParseDuration(*encodedReadTimeout)
750+
if err != nil {
751+
log.Fatalf("invalid read timeout duration: %s", err.Error())
752+
}
753+
754+
writeTimeout, err := time.ParseDuration(*encodedWriteTimeout)
755+
if err != nil {
756+
log.Fatalf("invalid write timeout duration: %s", err.Error())
757+
}
758+
759+
idleTimeout, err := time.ParseDuration(*encodedIdleTimeout)
760+
if err != nil {
761+
log.Fatalf("invalid idle timeout duration: %s", err.Error())
762+
}
763+
742764
// Create an HTTP client to use for back-channel requests to the identity provider.
743765
client := &http.Client{
744766
Timeout: ssoTimeout,
@@ -784,8 +806,11 @@ func main() {
784806
})
785807

786808
server := &http.Server{
787-
Handler: c.Handler(r),
788-
Addr: *listenAddr,
809+
Handler: c.Handler(r),
810+
Addr: *listenAddr,
811+
ReadTimeout: readTimeout,
812+
WriteTimeout: writeTimeout,
813+
IdleTimeout: idleTimeout,
789814
}
790815
if useSSL {
791816
err = server.ListenAndServeTLS(*sslCert, *sslKey)

0 commit comments

Comments
 (0)