Skip to content

Commit 16529c8

Browse files
authored
Merge pull request #2938 from headlamp-k8s/handle-token-multiplexer
backend: frontend: Pass authentication token to websocket multiplexer
2 parents 1b2e8da + e6cb0bc commit 16529c8

File tree

4 files changed

+41
-14
lines changed

4 files changed

+41
-14
lines changed

backend/cmd/multiplexer.go

+25-4
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ type Connection struct {
7272
writeMu sync.Mutex
7373
// closed is a flag to indicate if the connection is closed.
7474
closed bool
75+
// Authentication token.
76+
Token *string
7577
}
7678

7779
// Message represents a WebSocket message structure.
@@ -90,6 +92,8 @@ type Message struct {
9092
Binary bool `json:"binary,omitempty"`
9193
// Type is the type of the message.
9294
Type string `json:"type"`
95+
// Authentication token.
96+
Token *string `json:"token"`
9397
}
9498

9599
// Multiplexer manages multiple WebSocket connections.
@@ -246,14 +250,15 @@ func (m *Multiplexer) establishClusterConnection(
246250
path,
247251
query string,
248252
clientConn *WSConnLock,
253+
token *string,
249254
) (*Connection, error) {
250255
config, err := m.getClusterConfigWithFallback(clusterID, userID)
251256
if err != nil {
252257
logger.Log(logger.LevelError, map[string]string{"clusterID": clusterID}, err, "getting cluster config")
253258
return nil, err
254259
}
255260

256-
connection := m.createConnection(clusterID, userID, path, query, clientConn)
261+
connection := m.createConnection(clusterID, userID, path, query, clientConn, token)
257262

258263
wsURL := createWebSocketURL(config.Host, path, query)
259264

@@ -264,7 +269,7 @@ func (m *Multiplexer) establishClusterConnection(
264269
return nil, fmt.Errorf("failed to get TLS config: %v", err)
265270
}
266271

267-
conn, err := m.dialWebSocket(wsURL, tlsConfig, config.Host)
272+
conn, err := m.dialWebSocket(wsURL, tlsConfig, config.Host, token)
268273
if err != nil {
269274
connection.updateStatus(StateError, err)
270275

@@ -309,6 +314,7 @@ func (m *Multiplexer) createConnection(
309314
path,
310315
query string,
311316
clientConn *WSConnLock,
317+
token *string,
312318
) *Connection {
313319
return &Connection{
314320
ClusterID: clusterID,
@@ -321,16 +327,29 @@ func (m *Multiplexer) createConnection(
321327
State: StateConnecting,
322328
LastMsg: time.Now(),
323329
},
330+
Token: token,
324331
}
325332
}
326333

327334
// dialWebSocket establishes a WebSocket connection.
328-
func (m *Multiplexer) dialWebSocket(wsURL string, tlsConfig *tls.Config, host string) (*websocket.Conn, error) {
335+
func (m *Multiplexer) dialWebSocket(
336+
wsURL string,
337+
tlsConfig *tls.Config,
338+
host string,
339+
token *string,
340+
) (*websocket.Conn, error) {
329341
dialer := websocket.Dialer{
330342
TLSClientConfig: tlsConfig,
331343
HandshakeTimeout: HandshakeTimeout,
332344
}
333345

346+
if token != nil {
347+
dialer.Subprotocols = []string{
348+
"base64.binary.k8s.io",
349+
"base64url.bearer.authorization.k8s.io." + base64.RawStdEncoding.EncodeToString([]byte(*token)),
350+
}
351+
}
352+
334353
conn, resp, err := dialer.Dial(
335354
wsURL,
336355
http.Header{
@@ -339,6 +358,7 @@ func (m *Multiplexer) dialWebSocket(wsURL string, tlsConfig *tls.Config, host st
339358
)
340359
if err != nil {
341360
logger.Log(logger.LevelError, nil, err, "dialing WebSocket")
361+
logger.Log(logger.LevelError, nil, resp, "WebSocket response")
342362
// We only attempt to close the response body if there was an error and resp is not nil.
343363
// In the successful case (when err is nil), the resp will actually be nil for WebSocket connections,
344364
// so we don't need to close anything.
@@ -393,6 +413,7 @@ func (m *Multiplexer) reconnect(conn *Connection) (*Connection, error) {
393413
conn.Path,
394414
conn.Query,
395415
conn.Client,
416+
conn.Token,
396417
)
397418
if err != nil {
398419
logger.Log(logger.LevelError, map[string]string{"clusterID": conn.ClusterID}, err, "reconnecting to cluster")
@@ -482,7 +503,7 @@ func (m *Multiplexer) getOrCreateConnection(msg Message, clientConn *WSConnLock)
482503
if !exists {
483504
var err error
484505

485-
conn, err = m.establishClusterConnection(msg.ClusterID, msg.UserID, msg.Path, msg.Query, clientConn)
506+
conn, err = m.establishClusterConnection(msg.ClusterID, msg.UserID, msg.Path, msg.Query, clientConn, msg.Token)
486507
if err != nil {
487508
logger.Log(
488509
logger.LevelError,

backend/cmd/multiplexer_test.go

+10-10
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func TestCreateConnection(t *testing.T) {
112112
clientConn, _ := createTestWebSocketConnection()
113113

114114
// Add RequestID to the createConnection call
115-
conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn)
115+
conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn, nil)
116116
assert.NotNil(t, conn)
117117
assert.Equal(t, "test-cluster", conn.ClusterID)
118118
assert.Equal(t, "test-user", conn.UserID)
@@ -153,7 +153,7 @@ func TestDialWebSocket(t *testing.T) {
153153
defer server.Close()
154154

155155
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
156-
conn, err := m.dialWebSocket(wsURL, &tls.Config{InsecureSkipVerify: true}, server.URL) //nolint:gosec
156+
conn, err := m.dialWebSocket(wsURL, &tls.Config{InsecureSkipVerify: true}, server.URL, nil) //nolint:gosec
157157

158158
assert.NoError(t, err)
159159
assert.NotNil(t, conn)
@@ -170,12 +170,12 @@ func TestDialWebSocket_Errors(t *testing.T) {
170170
// Test invalid URL
171171
tlsConfig := &tls.Config{InsecureSkipVerify: true} //nolint:gosec
172172

173-
ws, err := m.dialWebSocket("invalid-url", tlsConfig, "")
173+
ws, err := m.dialWebSocket("invalid-url", tlsConfig, "", nil)
174174
assert.Error(t, err)
175175
assert.Nil(t, ws)
176176

177177
// Test unreachable URL
178-
ws, err = m.dialWebSocket("ws://localhost:12345", tlsConfig, "")
178+
ws, err = m.dialWebSocket("ws://localhost:12345", tlsConfig, "", nil)
179179
assert.Error(t, err)
180180
assert.Nil(t, ws)
181181
}
@@ -535,7 +535,7 @@ func TestEstablishClusterConnection(t *testing.T) {
535535
defer clientServer.Close()
536536

537537
// Test successful connection establishment
538-
conn, err := m.establishClusterConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn)
538+
conn, err := m.establishClusterConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn, nil)
539539
assert.NoError(t, err)
540540
assert.NotNil(t, conn)
541541
assert.Equal(t, "test-cluster", conn.ClusterID)
@@ -544,7 +544,7 @@ func TestEstablishClusterConnection(t *testing.T) {
544544
assert.Equal(t, "watch=true", conn.Query)
545545

546546
// Test with invalid cluster
547-
conn, err = m.establishClusterConnection("non-existent", "test-user", "/api/v1/pods", "watch=true", clientConn)
547+
conn, err = m.establishClusterConnection("non-existent", "test-user", "/api/v1/pods", "watch=true", clientConn, nil)
548548
assert.Error(t, err)
549549
assert.Nil(t, conn)
550550
}
@@ -572,7 +572,7 @@ func TestReconnect(t *testing.T) {
572572
defer clientServer.Close()
573573

574574
// Create initial connection
575-
conn := m.createConnection("test-cluster", "test-user", "/api/v1/services", "watch=true", clientConn)
575+
conn := m.createConnection("test-cluster", "test-user", "/api/v1/services", "watch=true", clientConn, nil)
576576
wsConn, wsServer := createTestWebSocketConnection()
577577

578578
defer wsServer.Close()
@@ -598,7 +598,7 @@ func TestReconnect(t *testing.T) {
598598
assert.Contains(t, err.Error(), "getting context: key not found")
599599

600600
// Test reconnection with closed connection
601-
conn = m.createConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn)
601+
conn = m.createConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn, nil)
602602
wsConn2, wsServer2 := createTestWebSocketConnection()
603603

604604
defer wsServer2.Close()
@@ -829,7 +829,7 @@ func TestMonitorConnection_ReconnectFailure(t *testing.T) {
829829
clientConn, clientServer := createTestWebSocketConnection()
830830
defer clientServer.Close()
831831

832-
conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn)
832+
conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn, nil)
833833
wsConn, wsServer := createTestWebSocketConn()
834834

835835
defer wsServer.Close()
@@ -1097,7 +1097,7 @@ func TestMonitorConnection_Reconnect(t *testing.T) {
10971097
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
10981098
tlsConfig := &tls.Config{InsecureSkipVerify: true} //nolint:gosec
10991099

1100-
ws, err := m.dialWebSocket(wsURL, tlsConfig, "")
1100+
ws, err := m.dialWebSocket(wsURL, tlsConfig, "", nil)
11011101
require.NoError(t, err)
11021102

11031103
conn.WSConn = ws

frontend/src/lib/k8s/api/v2/webSocket.test.ts

+2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ describe('WebSocket Tests', () => {
8888
path,
8989
query,
9090
userId,
91+
token: 'test-token',
9192
type: 'REQUEST',
9293
});
9394

@@ -130,6 +131,7 @@ describe('WebSocket Tests', () => {
130131
path: sub.path,
131132
query: sub.query,
132133
userId,
134+
token: 'test-token',
133135
type: 'REQUEST',
134136
});
135137

frontend/src/lib/k8s/api/v2/webSocket.ts

+4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ interface WebSocketMessage {
5050
* - COMPLETE: Server indicates the watch request has completed (e.g., due to timeout or error)
5151
*/
5252
type: 'REQUEST' | 'CLOSE' | 'COMPLETE';
53+
54+
/** Authentication token */
55+
token?: string;
5356
}
5457

5558
/**
@@ -210,6 +213,7 @@ export const WebSocketManager = {
210213
query,
211214
userId: userId || '',
212215
type: 'REQUEST',
216+
token: getToken(clusterId),
213217
};
214218
socket.send(JSON.stringify(requestMsg));
215219

0 commit comments

Comments
 (0)