diff --git a/main.go b/main.go index 9a582d5e46..b35f9a4e4e 100644 --- a/main.go +++ b/main.go @@ -433,6 +433,17 @@ func validateClientListen() error { return nil } +func validateClientOPA() error { + hasOPAFlags := len(*clientAllowPolicy) > 0 || len(*clientAllowQuery) > 0 + if !hasOPAFlags { + return nil + } + if *clientAllowPolicy == "" || *clientAllowQuery == "" { + return errors.New("--verify-policy and --verify-query have to be used together") + } + return nil +} + // Validate flags for client mode func clientValidateFlags() error { if err := validateClientCredentials(); err != nil { @@ -441,6 +452,9 @@ func clientValidateFlags() error { if err := validateClientListen(); err != nil { return err } + if err := validateClientOPA(); err != nil { + return err + } return validateCipherSuites() } @@ -696,6 +710,7 @@ func serverListen(env *Environment) error { serverConfig, err := getServerConfig(env.tlsConfigSource, config) if err != nil { + listener.Close() logger.Printf("error: unable to get server TLS config: %s", err) return err } @@ -715,6 +730,7 @@ func serverListen(env *Environment) error { if *statusAddress != "" { err := env.serveStatus() if err != nil { + listener.Close() logger.Printf("error serving /_status: %s", err) return err } @@ -760,6 +776,7 @@ func clientListen(env *Environment) error { if *statusAddress != "" { err := env.serveStatus() if err != nil { + listener.Close() logger.Printf("error serving /_status: %s", err) return err } diff --git a/main_test.go b/main_test.go index 7953215431..d998508b75 100644 --- a/main_test.go +++ b/main_test.go @@ -424,6 +424,31 @@ func TestClientFlagValidation(t *testing.T) { err = clientValidateFlags() assert.NotNil(t, err, "--key without --cert should be rejected") *keyPath = "" + + // Test: OPA flags must be used together + *keystorePath = "file" + *clientListenAddress = "127.0.0.1:8080" + *enabledCipherSuites = "AES,CHACHA" + + *clientAllowPolicy = "policy" + *clientAllowQuery = "" + err = clientValidateFlags() + assert.NotNil(t, err, "--verify-policy needs --verify-query") + + *clientAllowPolicy = "" + *clientAllowQuery = "query" + err = clientValidateFlags() + assert.NotNil(t, err, "--verify-query needs --verify-policy") + + *clientAllowPolicy = "policy" + *clientAllowQuery = "query" + err = clientValidateFlags() + assert.Nil(t, err, "--verify-policy and --verify-query together should be valid") + + *clientAllowPolicy = "" + *clientAllowQuery = "" + err = clientValidateFlags() + assert.Nil(t, err, "neither OPA flag set should be valid") } func TestAllowsLocalhost(t *testing.T) { diff --git a/proxy/proxy.go b/proxy/proxy.go index fc104fe070..f920664b96 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -359,6 +359,7 @@ func (p *Proxy) Accept() { _, err = h.WriteTo(backend) if err != nil { p.logConditional(LogConnectionErrors, "error writing proxy header: %s", err) + backend.Close() return } } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 5c560b3de0..daf600fb8f 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -1008,3 +1008,50 @@ func TestProxyProtoHeaderConnMode(t *testing.T) { assert.Nil(t, err) assert.Empty(t, tlvs, "conn mode should have no TLVs even with TLS state") } + +// failWriteConn is a mock connection that tracks Close() calls and fails on Write(). +// Used to simulate a PROXY protocol header write failure. +type failWriteConn struct { + closed bool + mockConn +} + +func (f *failWriteConn) Write(b []byte) (int, error) { return 0, errors.New("write error") } +func (f *failWriteConn) Close() error { f.closed = true; return nil } + +func TestProxyProtocolWriteFailureClosesBackend(t *testing.T) { + // Incoming listener (plain TCP — forceHandshake is a no-op for non-TLS) + incoming, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer incoming.Close() + + // Backend mock that fails on Write (simulating PROXY header write failure) + backend := &failWriteConn{} + dialCalled := make(chan struct{}) + + dialer := func(_ context.Context) (net.Conn, error) { + close(dialCalled) + return backend, nil + } + + // Create proxy with PROXY protocol enabled + p := proxyForTestWithProxyProtocol(incoming, dialer) + go p.Accept() + + // Connect a client to trigger the handler + client, err := net.Dial("tcp", incoming.Addr().String()) + assert.Nil(t, err) + + // Wait for the handler to reach Dial, ensuring it's past the accept stage + <-dialCalled + client.Close() + + // Shut down and wait for all handlers to complete + p.Shutdown() + p.Wait() + + // Regression: verify backend connection is closed when PROXY header write fails. + assert.True(t, backend.closed, "backend connection must be closed when PROXY protocol header write fails") +} diff --git a/socket/launchd_enabled.go b/socket/launchd_enabled.go index 2b806857c6..5a708bc784 100644 --- a/socket/launchd_enabled.go +++ b/socket/launchd_enabled.go @@ -51,6 +51,7 @@ func launchdSocket(address string) (net.Listener, error) { fds := (*[1]C.int)(ptr) file := os.NewFile(uintptr(fds[0]), "") + defer file.Close() return net.FileListener(file) } diff --git a/socket/systemd_enabled.go b/socket/systemd_enabled.go index 597abde605..7bb9654bbe 100644 --- a/socket/systemd_enabled.go +++ b/socket/systemd_enabled.go @@ -31,11 +31,15 @@ func systemdSocket(name string) (net.Listener, error) { return nil, err } + return systemdSocketFromMap(name, listeners) +} + +func systemdSocketFromMap(name string, listeners map[string][]net.Listener) (net.Listener, error) { if listener, ok := listeners[name]; ok { - if len(listeners) != 1 { - return nil, fmt.Errorf("expected exactly 1 listening socket configured in systemd for name %s, found %d", name, len(listeners)) + if len(listener) != 1 { + return nil, fmt.Errorf("expected exactly 1 listening socket configured in systemd for name %s, found %d", name, len(listener)) } - return listener[0], err + return listener[0], nil } return nil, fmt.Errorf("expected listener with name %s, but found none", name) diff --git a/socket/systemd_test.go b/socket/systemd_test.go new file mode 100644 index 0000000000..786be3b289 --- /dev/null +++ b/socket/systemd_test.go @@ -0,0 +1,58 @@ +//go:build linux + +package socket + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSystemdSocketMultipleNames(t *testing.T) { + // Two systemd socket names, each with exactly one listener. + // Requesting "web" should succeed since "web" has exactly 1 socket. + webListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer webListener.Close() + + apiListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer apiListener.Close() + + listeners := map[string][]net.Listener{ + "web": {webListener}, + "api": {apiListener}, + } + + result, err := systemdSocketFromMap("web", listeners) + assert.Nil(t, err, "requesting 'web' with 1 socket should succeed even when other names exist") + assert.Equal(t, webListener, result) +} + +func TestSystemdSocketMultipleSocketsSameName(t *testing.T) { + // One systemd socket name with two listeners. + // This should fail since we expect exactly 1 socket per name. + l1, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l1.Close() + + l2, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l2.Close() + + listeners := map[string][]net.Listener{ + "web": {l1, l2}, + } + + _, err = systemdSocketFromMap("web", listeners) + assert.NotNil(t, err, "should fail when a name has multiple sockets") +} diff --git a/status.go b/status.go index 9bf6d3f9e3..965ab3e01e 100644 --- a/status.go +++ b/status.go @@ -136,8 +136,7 @@ func (s *statusHandler) HandleWatchdog() { func (s *statusHandler) status(ctx context.Context) statusResponse { resp := statusResponse{ - Time: time.Now(), - LastReload: s.lastReload, + Time: time.Now(), } resp.Revision = version @@ -156,6 +155,7 @@ func (s *statusHandler) status(ctx context.Context) statusResponse { } s.mu.Lock() + resp.LastReload = s.lastReload resp.Ok = s.listening && resp.BackendOk if s.stopping { resp.Message = "stopping"