@@ -2,6 +2,7 @@ package main
22
33import (
44 "context"
5+ "errors"
56 "io"
67 "net"
78 "net/http"
@@ -240,3 +241,92 @@ func TestOpenPortForwardPreservesEOFFramedExchange(t *testing.T) {
240241 t .Fatalf ("upstream exchange failed: %v" , err )
241242 }
242243}
244+
245+ func TestOpenPortForwardClosesUpstreamWhenWebSocketCloses (t * testing.T ) {
246+ scheme := newTestSpritzScheme (t )
247+ spritz := & spritzv1.Spritz {
248+ ObjectMeta : metav1.ObjectMeta {
249+ Name : "tidal-falcon" ,
250+ Namespace : "spritz-test" ,
251+ },
252+ Spec : spritzv1.SpritzSpec {
253+ Owner : spritzv1.SpritzOwner {ID : "user-1" },
254+ },
255+ }
256+
257+ upstreamListener , err := net .Listen ("tcp" , "127.0.0.1:0" )
258+ if err != nil {
259+ t .Fatalf ("listen upstream: %v" , err )
260+ }
261+ defer upstreamListener .Close ()
262+
263+ upstreamDone := make (chan error , 1 )
264+ go func () {
265+ conn , err := upstreamListener .Accept ()
266+ if err != nil {
267+ upstreamDone <- err
268+ return
269+ }
270+ defer conn .Close ()
271+ _ = conn .SetReadDeadline (time .Now ().Add (2 * time .Second ))
272+ buffer := make ([]byte , 1 )
273+ _ , err = conn .Read (buffer )
274+ if err == nil || ! errors .Is (err , io .EOF ) {
275+ upstreamDone <- err
276+ return
277+ }
278+ upstreamDone <- nil
279+ }()
280+
281+ s := & server {
282+ client : ctrlclientfake .NewClientBuilder ().
283+ WithScheme (scheme ).
284+ WithObjects (spritz ).
285+ Build (),
286+ scheme : scheme ,
287+ namespace : "spritz-test" ,
288+ auth : authConfig {
289+ mode : authModeHeader ,
290+ headerID : "X-Spritz-User-Id" ,
291+ headerDefaultType : principalTypeHuman ,
292+ },
293+ internalAuth : internalAuthConfig {enabled : false },
294+ portForward : portForwardConfig {enabled : true , containerName : "spritz" },
295+ findRunningPodFunc : func (ctx context.Context , namespace , name , container string ) (* corev1.Pod , error ) {
296+ return & corev1.Pod {
297+ ObjectMeta : metav1.ObjectMeta {
298+ Name : "tidal-falcon-pod" ,
299+ Namespace : namespace ,
300+ },
301+ }, nil
302+ },
303+ openPodPortForwardFunc : func (ctx context.Context , pod * corev1.Pod , remotePort uint32 ) (net.Conn , io.Closer , error ) {
304+ conn , err := (& net.Dialer {}).DialContext (ctx , "tcp" , upstreamListener .Addr ().String ())
305+ if err != nil {
306+ return nil , nil , err
307+ }
308+ return conn , closeFunc (func () error { return nil }), nil
309+ },
310+ }
311+
312+ e := echo .New ()
313+ e .GET ("/api/spritzes/:name/port-forward" , s .openPortForward )
314+ srv := httptest .NewServer (e )
315+ defer srv .Close ()
316+
317+ wsURL := "ws" + strings .TrimPrefix (srv .URL , "http" ) + "/api/spritzes/tidal-falcon/port-forward?port=3000"
318+ headers := http.Header {}
319+ headers .Set ("X-Spritz-User-Id" , "user-1" )
320+ headers .Set ("Origin" , srv .URL )
321+ conn , _ , err := websocket .DefaultDialer .Dial (wsURL , headers )
322+ if err != nil {
323+ t .Fatalf ("dial websocket: %v" , err )
324+ }
325+
326+ if err := conn .Close (); err != nil {
327+ t .Fatalf ("close websocket: %v" , err )
328+ }
329+ if err := <- upstreamDone ; err != nil {
330+ t .Fatalf ("expected upstream to close after websocket exit: %v" , err )
331+ }
332+ }
0 commit comments