Skip to content

Commit 5915749

Browse files
committed
fix(port-forward): close dropped websocket tunnels
1 parent bc0ceb1 commit 5915749

4 files changed

Lines changed: 182 additions & 15 deletions

File tree

api/port_forward.go

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ type portForwardControlMessage struct {
3333
Type string `json:"type"`
3434
}
3535

36+
var errPortForwardHalfClose = errors.New("port-forward half-close")
37+
3638
func newPortForwardConfig() portForwardConfig {
3739
return portForwardConfig{
3840
enabled: parseBoolEnv("SPRITZ_PORT_FORWARD_ENABLED", true),
@@ -164,22 +166,35 @@ func proxyWebSocketNetConn(ws *websocket.Conn, upstream net.Conn) error {
164166
errCh <- copyNetConnToWebSocket(upstream, ws)
165167
}()
166168

167-
var firstErr error
169+
halfClosed := 0
168170
for completed := 0; completed < 2; completed++ {
169171
err := <-errCh
170-
if err == nil || errors.Is(err, io.EOF) || websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
171-
continue
172-
}
173-
if ne, ok := err.(net.Error); ok && ne.Timeout() {
174-
continue
175-
}
176-
if firstErr == nil {
177-
firstErr = err
172+
switch {
173+
case err == nil:
174+
closeAll()
175+
return nil
176+
case errors.Is(err, errPortForwardHalfClose):
177+
halfClosed++
178+
if halfClosed == 2 {
179+
closeAll()
180+
return nil
181+
}
182+
case errors.Is(err, io.EOF), errors.Is(err, context.Canceled), websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway):
183+
closeAll()
184+
return nil
185+
case func() bool {
186+
ne, ok := err.(net.Error)
187+
return ok && ne.Timeout()
188+
}():
189+
closeAll()
190+
return nil
191+
default:
178192
closeAll()
193+
return err
179194
}
180195
}
181196
closeAll()
182-
return firstErr
197+
return nil
183198
}
184199

185200
func (s *server) findPortForwardPod(ctx context.Context, namespace, name, container string) (*corev1.Pod, error) {
@@ -204,7 +219,7 @@ func copyWebSocketToNetConn(ws *websocket.Conn, upstream net.Conn) error {
204219
if err := closeConnWrite(upstream); err != nil {
205220
return err
206221
}
207-
return nil
222+
return errPortForwardHalfClose
208223
}
209224
continue
210225
}
@@ -234,6 +249,7 @@ func copyNetConnToWebSocket(upstream net.Conn, ws *websocket.Conn) error {
234249
if writeErr := ws.WriteMessage(websocket.TextMessage, mustMarshalPortForwardControl(portForwardControlMessage{Type: "eof"})); writeErr != nil {
235250
return writeErr
236251
}
252+
return errPortForwardHalfClose
237253
}
238254
return err
239255
}

api/port_forward_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"errors"
56
"io"
67
"net"
78
"net/http"
@@ -240,3 +241,92 @@ func TestOpenPortForwardPreservesEOFFramedExchange(t *testing.T) {
240241
t.Fatalf("upstream exchange failed: %v", err)
241242
}
242243
}
244+
245+
func TestOpenPortForwardClosesUpstreamWhenWebSocketCloses(t *testing.T) {
246+
scheme := newTestSpritzScheme(t)
247+
spritz := &spritzv1.Spritz{
248+
ObjectMeta: metav1.ObjectMeta{
249+
Name: "tidal-falcon",
250+
Namespace: "spritz-test",
251+
},
252+
Spec: spritzv1.SpritzSpec{
253+
Owner: spritzv1.SpritzOwner{ID: "user-1"},
254+
},
255+
}
256+
257+
upstreamListener, err := net.Listen("tcp", "127.0.0.1:0")
258+
if err != nil {
259+
t.Fatalf("listen upstream: %v", err)
260+
}
261+
defer upstreamListener.Close()
262+
263+
upstreamDone := make(chan error, 1)
264+
go func() {
265+
conn, err := upstreamListener.Accept()
266+
if err != nil {
267+
upstreamDone <- err
268+
return
269+
}
270+
defer conn.Close()
271+
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
272+
buffer := make([]byte, 1)
273+
_, err = conn.Read(buffer)
274+
if err == nil || !errors.Is(err, io.EOF) {
275+
upstreamDone <- err
276+
return
277+
}
278+
upstreamDone <- nil
279+
}()
280+
281+
s := &server{
282+
client: ctrlclientfake.NewClientBuilder().
283+
WithScheme(scheme).
284+
WithObjects(spritz).
285+
Build(),
286+
scheme: scheme,
287+
namespace: "spritz-test",
288+
auth: authConfig{
289+
mode: authModeHeader,
290+
headerID: "X-Spritz-User-Id",
291+
headerDefaultType: principalTypeHuman,
292+
},
293+
internalAuth: internalAuthConfig{enabled: false},
294+
portForward: portForwardConfig{enabled: true, containerName: "spritz"},
295+
findRunningPodFunc: func(ctx context.Context, namespace, name, container string) (*corev1.Pod, error) {
296+
return &corev1.Pod{
297+
ObjectMeta: metav1.ObjectMeta{
298+
Name: "tidal-falcon-pod",
299+
Namespace: namespace,
300+
},
301+
}, nil
302+
},
303+
openPodPortForwardFunc: func(ctx context.Context, pod *corev1.Pod, remotePort uint32) (net.Conn, io.Closer, error) {
304+
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", upstreamListener.Addr().String())
305+
if err != nil {
306+
return nil, nil, err
307+
}
308+
return conn, closeFunc(func() error { return nil }), nil
309+
},
310+
}
311+
312+
e := echo.New()
313+
e.GET("/api/spritzes/:name/port-forward", s.openPortForward)
314+
srv := httptest.NewServer(e)
315+
defer srv.Close()
316+
317+
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/api/spritzes/tidal-falcon/port-forward?port=3000"
318+
headers := http.Header{}
319+
headers.Set("X-Spritz-User-Id", "user-1")
320+
headers.Set("Origin", srv.URL)
321+
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
322+
if err != nil {
323+
t.Fatalf("dial websocket: %v", err)
324+
}
325+
326+
if err := conn.Close(); err != nil {
327+
t.Fatalf("close websocket: %v", err)
328+
}
329+
if err := <-upstreamDone; err != nil {
330+
t.Fatalf("expected upstream to close after websocket exit: %v", err)
331+
}
332+
}

cli/src/index.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,10 +1507,7 @@ async function bridgePortForwardSocket(socket: net.Socket, url: string, headers:
15071507
}
15081508
writePortForwardOutput(socket, data);
15091509
};
1510-
const onWsClose = () => {
1511-
wsEnded = true;
1512-
maybeFinish();
1513-
};
1510+
const onWsClose = () => finish();
15141511
const onWsError = (err: Error) => {
15151512
if (!opened) {
15161513
finish(err);

cli/test/port-forward.test.ts

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,70 @@ test('port-forward preserves EOF-framed exchanges over websocket', async (t) =>
410410
assert.equal(exitCode, 0, `spz port-forward should exit cleanly: ${stderr}`);
411411
});
412412

413+
test('port-forward closes the local client when the websocket tunnel drops after startup', async (t) => {
414+
const localPort = await getFreePort();
415+
const server = http.createServer();
416+
const wss = new WebSocketServer({ noServer: true });
417+
await listen(server);
418+
t.after(() => {
419+
wss.close();
420+
server.close();
421+
});
422+
const address = server.address();
423+
assert.ok(address && typeof address === 'object');
424+
425+
let connections = 0;
426+
server.on('upgrade', (req, socket, head) => {
427+
wss.handleUpgrade(req, socket, head, (ws) => {
428+
wss.emit('connection', ws, req);
429+
});
430+
});
431+
wss.on('connection', (ws) => {
432+
connections += 1;
433+
if (connections === 1) {
434+
return;
435+
}
436+
ws.close(1011, 'boom');
437+
});
438+
439+
const child = spawnCli(
440+
['port-forward', 'devbox1', '--transport', 'ws', '--namespace', 'spritz', '--local', String(localPort), '--remote', '4000'],
441+
buildTestEnv(`http://127.0.0.1:${address.port}/api`),
442+
);
443+
let stderr = '';
444+
const stderrBuffer = { value: '' };
445+
child.stderr.on('data', (chunk) => {
446+
const text = chunk.toString();
447+
stderr += text;
448+
stderrBuffer.value += text;
449+
});
450+
t.after(() => {
451+
child.kill('SIGTERM');
452+
});
453+
454+
await waitForPattern(stderrBuffer, new RegExp(`forwarding 127\\.0\\.0\\.1:${localPort}`));
455+
456+
const client = net.connect(localPort, '127.0.0.1');
457+
const clientClosed = new Promise<void>((resolve, reject) => {
458+
const timer = setTimeout(() => reject(new Error('timed out waiting for local socket to close')), 2000);
459+
client.on('error', () => {
460+
clearTimeout(timer);
461+
resolve();
462+
});
463+
client.on('close', () => {
464+
clearTimeout(timer);
465+
resolve();
466+
});
467+
});
468+
469+
await clientClosed;
470+
client.destroy();
471+
472+
child.kill('SIGTERM');
473+
const exitCode = await new Promise<number | null>((resolve) => child.on('exit', resolve));
474+
assert.equal(exitCode, 0, `spz port-forward should exit cleanly: ${stderr}`);
475+
});
476+
413477
test('port-forward falls back to SSH when websocket startup validation is rejected by default', async (t) => {
414478
const tempDir = mkdtempSync(path.join(os.tmpdir(), 'spz-port-forward-'));
415479
const fakeKeygen = path.join(tempDir, 'ssh-keygen');

0 commit comments

Comments
 (0)