From 807cd59f013a1671644b7e4236eb81c189682eb4 Mon Sep 17 00:00:00 2001 From: Cedric Staub Date: Mon, 20 Apr 2026 20:47:39 -0700 Subject: [PATCH 1/2] Fix two bugs: systemd socket check and proxy backend leak - Fix wrong variable in systemd socket length check (len(listeners) instead of len(listener)), which rejected valid single-socket names when multiple socket names existed - Fix backend connection leak when PROXY protocol header write fails - Extract systemdSocketFromMap helper for testability and add unit tests confirming both fixes --- proxy/proxy.go | 1 + proxy/proxy_test.go | 46 +++++++++++++++++++++++++++++++++++ socket/systemd_enabled.go | 10 +++++--- socket/systemd_test.go | 50 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 socket/systemd_test.go diff --git a/proxy/proxy.go b/proxy/proxy.go index fc104fe070..f920664b96 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -359,6 +359,7 @@ func (p *Proxy) Accept() { _, err = h.WriteTo(backend) if err != nil { p.logConditional(LogConnectionErrors, "error writing proxy header: %s", err) + backend.Close() return } } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 5c560b3de0..d5874cdcaf 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -1008,3 +1008,49 @@ func TestProxyProtoHeaderConnMode(t *testing.T) { assert.Nil(t, err) assert.Empty(t, tlvs, "conn mode should have no TLVs even with TLS state") } + +// failWriteConn is a mock connection that tracks Close() calls and fails on Write(). +// Used to simulate a PROXY protocol header write failure. +type failWriteConn struct { + closed bool + mockConn +} + +func (f *failWriteConn) Write(b []byte) (int, error) { return 0, errors.New("write error") } +func (f *failWriteConn) Close() error { f.closed = true; return nil } + +func TestProxyProtocolWriteFailureClosesBackend(t *testing.T) { + // Incoming listener (plain TCP — forceHandshake is a no-op for non-TLS) + incoming, err := net.Listen("tcp", "127.0.0.1:0") + assert.Nil(t, err) + defer incoming.Close() + + // Backend mock that fails on Write (simulating PROXY header write failure) + backend := &failWriteConn{} + dialCalled := make(chan struct{}) + + dialer := func(_ context.Context) (net.Conn, error) { + close(dialCalled) + return backend, nil + } + + // Create proxy with PROXY protocol enabled + p := proxyForTestWithProxyProtocol(incoming, dialer) + go p.Accept() + + // Connect a client to trigger the handler + client, err := net.Dial("tcp", incoming.Addr().String()) + assert.Nil(t, err) + + // Wait for the handler to reach Dial, ensuring it's past the accept stage + <-dialCalled + client.Close() + + // Shut down and wait for all handlers to complete + p.Shutdown() + p.Wait() + + // BUG: backend connection should be closed when PROXY header write fails, + // but current code returns without closing it. + assert.True(t, backend.closed, "backend connection must be closed when PROXY protocol header write fails") +} diff --git a/socket/systemd_enabled.go b/socket/systemd_enabled.go index 597abde605..7bb9654bbe 100644 --- a/socket/systemd_enabled.go +++ b/socket/systemd_enabled.go @@ -31,11 +31,15 @@ func systemdSocket(name string) (net.Listener, error) { return nil, err } + return systemdSocketFromMap(name, listeners) +} + +func systemdSocketFromMap(name string, listeners map[string][]net.Listener) (net.Listener, error) { if listener, ok := listeners[name]; ok { - if len(listeners) != 1 { - return nil, fmt.Errorf("expected exactly 1 listening socket configured in systemd for name %s, found %d", name, len(listeners)) + if len(listener) != 1 { + return nil, fmt.Errorf("expected exactly 1 listening socket configured in systemd for name %s, found %d", name, len(listener)) } - return listener[0], err + return listener[0], nil } return nil, fmt.Errorf("expected listener with name %s, but found none", name) diff --git a/socket/systemd_test.go b/socket/systemd_test.go new file mode 100644 index 0000000000..131c7595c6 --- /dev/null +++ b/socket/systemd_test.go @@ -0,0 +1,50 @@ +//go:build linux + +package socket + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSystemdSocketMultipleNames(t *testing.T) { + // Two systemd socket names, each with exactly one listener. + // Requesting "web" should succeed since "web" has exactly 1 socket. + webListener, err := net.Listen("tcp", "127.0.0.1:0") + assert.Nil(t, err) + defer webListener.Close() + + apiListener, err := net.Listen("tcp", "127.0.0.1:0") + assert.Nil(t, err) + defer apiListener.Close() + + listeners := map[string][]net.Listener{ + "web": {webListener}, + "api": {apiListener}, + } + + result, err := systemdSocketFromMap("web", listeners) + assert.Nil(t, err, "requesting 'web' with 1 socket should succeed even when other names exist") + assert.Equal(t, webListener, result) +} + +func TestSystemdSocketMultipleSocketsSameName(t *testing.T) { + // One systemd socket name with two listeners. + // This should fail since we expect exactly 1 socket per name. + l1, err := net.Listen("tcp", "127.0.0.1:0") + assert.Nil(t, err) + defer l1.Close() + + l2, err := net.Listen("tcp", "127.0.0.1:0") + assert.Nil(t, err) + defer l2.Close() + + listeners := map[string][]net.Listener{ + "web": {l1, l2}, + } + + _, err = systemdSocketFromMap("web", listeners) + assert.NotNil(t, err, "should fail when a name has multiple sockets") +} From 46057e37cc014bf36544b3b6e17e9d1361013e17 Mon Sep 17 00:00:00 2001 From: Cedric Staub Date: Mon, 20 Apr 2026 21:04:56 -0700 Subject: [PATCH 2/2] Fix a few more smaller bugs: OPA client validation, status race - Add validateClientOPA() to reject partial --verify-policy/--verify-query flags - Move s.lastReload read under mutex in status() - Close fd after net.FileListener in launchd socket activation - Close listener on error paths in serverListen and clientListen --- main.go | 17 +++++++++++++++++ main_test.go | 25 +++++++++++++++++++++++++ proxy/proxy_test.go | 7 ++++--- socket/launchd_enabled.go | 1 + socket/systemd_test.go | 16 ++++++++++++---- status.go | 4 ++-- 6 files changed, 61 insertions(+), 9 deletions(-) diff --git a/main.go b/main.go index 9a582d5e46..b35f9a4e4e 100644 --- a/main.go +++ b/main.go @@ -433,6 +433,17 @@ func validateClientListen() error { return nil } +func validateClientOPA() error { + hasOPAFlags := len(*clientAllowPolicy) > 0 || len(*clientAllowQuery) > 0 + if !hasOPAFlags { + return nil + } + if *clientAllowPolicy == "" || *clientAllowQuery == "" { + return errors.New("--verify-policy and --verify-query have to be used together") + } + return nil +} + // Validate flags for client mode func clientValidateFlags() error { if err := validateClientCredentials(); err != nil { @@ -441,6 +452,9 @@ func clientValidateFlags() error { if err := validateClientListen(); err != nil { return err } + if err := validateClientOPA(); err != nil { + return err + } return validateCipherSuites() } @@ -696,6 +710,7 @@ func serverListen(env *Environment) error { serverConfig, err := getServerConfig(env.tlsConfigSource, config) if err != nil { + listener.Close() logger.Printf("error: unable to get server TLS config: %s", err) return err } @@ -715,6 +730,7 @@ func serverListen(env *Environment) error { if *statusAddress != "" { err := env.serveStatus() if err != nil { + listener.Close() logger.Printf("error serving /_status: %s", err) return err } @@ -760,6 +776,7 @@ func clientListen(env *Environment) error { if *statusAddress != "" { err := env.serveStatus() if err != nil { + listener.Close() logger.Printf("error serving /_status: %s", err) return err } diff --git a/main_test.go b/main_test.go index 7953215431..d998508b75 100644 --- a/main_test.go +++ b/main_test.go @@ -424,6 +424,31 @@ func TestClientFlagValidation(t *testing.T) { err = clientValidateFlags() assert.NotNil(t, err, "--key without --cert should be rejected") *keyPath = "" + + // Test: OPA flags must be used together + *keystorePath = "file" + *clientListenAddress = "127.0.0.1:8080" + *enabledCipherSuites = "AES,CHACHA" + + *clientAllowPolicy = "policy" + *clientAllowQuery = "" + err = clientValidateFlags() + assert.NotNil(t, err, "--verify-policy needs --verify-query") + + *clientAllowPolicy = "" + *clientAllowQuery = "query" + err = clientValidateFlags() + assert.NotNil(t, err, "--verify-query needs --verify-policy") + + *clientAllowPolicy = "policy" + *clientAllowQuery = "query" + err = clientValidateFlags() + assert.Nil(t, err, "--verify-policy and --verify-query together should be valid") + + *clientAllowPolicy = "" + *clientAllowQuery = "" + err = clientValidateFlags() + assert.Nil(t, err, "neither OPA flag set should be valid") } func TestAllowsLocalhost(t *testing.T) { diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index d5874cdcaf..daf600fb8f 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -1022,7 +1022,9 @@ func (f *failWriteConn) Close() error { f.closed = true; return n func TestProxyProtocolWriteFailureClosesBackend(t *testing.T) { // Incoming listener (plain TCP — forceHandshake is a no-op for non-TLS) incoming, err := net.Listen("tcp", "127.0.0.1:0") - assert.Nil(t, err) + if err != nil { + t.Fatal(err) + } defer incoming.Close() // Backend mock that fails on Write (simulating PROXY header write failure) @@ -1050,7 +1052,6 @@ func TestProxyProtocolWriteFailureClosesBackend(t *testing.T) { p.Shutdown() p.Wait() - // BUG: backend connection should be closed when PROXY header write fails, - // but current code returns without closing it. + // Regression: verify backend connection is closed when PROXY header write fails. assert.True(t, backend.closed, "backend connection must be closed when PROXY protocol header write fails") } diff --git a/socket/launchd_enabled.go b/socket/launchd_enabled.go index 2b806857c6..5a708bc784 100644 --- a/socket/launchd_enabled.go +++ b/socket/launchd_enabled.go @@ -51,6 +51,7 @@ func launchdSocket(address string) (net.Listener, error) { fds := (*[1]C.int)(ptr) file := os.NewFile(uintptr(fds[0]), "") + defer file.Close() return net.FileListener(file) } diff --git a/socket/systemd_test.go b/socket/systemd_test.go index 131c7595c6..786be3b289 100644 --- a/socket/systemd_test.go +++ b/socket/systemd_test.go @@ -13,11 +13,15 @@ func TestSystemdSocketMultipleNames(t *testing.T) { // Two systemd socket names, each with exactly one listener. // Requesting "web" should succeed since "web" has exactly 1 socket. webListener, err := net.Listen("tcp", "127.0.0.1:0") - assert.Nil(t, err) + if err != nil { + t.Fatal(err) + } defer webListener.Close() apiListener, err := net.Listen("tcp", "127.0.0.1:0") - assert.Nil(t, err) + if err != nil { + t.Fatal(err) + } defer apiListener.Close() listeners := map[string][]net.Listener{ @@ -34,11 +38,15 @@ func TestSystemdSocketMultipleSocketsSameName(t *testing.T) { // One systemd socket name with two listeners. // This should fail since we expect exactly 1 socket per name. l1, err := net.Listen("tcp", "127.0.0.1:0") - assert.Nil(t, err) + if err != nil { + t.Fatal(err) + } defer l1.Close() l2, err := net.Listen("tcp", "127.0.0.1:0") - assert.Nil(t, err) + if err != nil { + t.Fatal(err) + } defer l2.Close() listeners := map[string][]net.Listener{ diff --git a/status.go b/status.go index 9bf6d3f9e3..965ab3e01e 100644 --- a/status.go +++ b/status.go @@ -136,8 +136,7 @@ func (s *statusHandler) HandleWatchdog() { func (s *statusHandler) status(ctx context.Context) statusResponse { resp := statusResponse{ - Time: time.Now(), - LastReload: s.lastReload, + Time: time.Now(), } resp.Revision = version @@ -156,6 +155,7 @@ func (s *statusHandler) status(ctx context.Context) statusResponse { } s.mu.Lock() + resp.LastReload = s.lastReload resp.Ok = s.listening && resp.BackendOk if s.stopping { resp.Message = "stopping"