Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,17 @@ func validateClientListen() error {
return nil
}

func validateClientOPA() error {
hasOPAFlags := len(*clientAllowPolicy) > 0 || len(*clientAllowQuery) > 0
if !hasOPAFlags {
return nil
}
if *clientAllowPolicy == "" || *clientAllowQuery == "" {
return errors.New("--verify-policy and --verify-query have to be used together")
}
return nil
}

// Validate flags for client mode
func clientValidateFlags() error {
if err := validateClientCredentials(); err != nil {
Expand All @@ -441,6 +452,9 @@ func clientValidateFlags() error {
if err := validateClientListen(); err != nil {
return err
}
if err := validateClientOPA(); err != nil {
return err
}
return validateCipherSuites()
}

Expand Down Expand Up @@ -696,6 +710,7 @@ func serverListen(env *Environment) error {

serverConfig, err := getServerConfig(env.tlsConfigSource, config)
if err != nil {
listener.Close()
logger.Printf("error: unable to get server TLS config: %s", err)
return err
}
Expand All @@ -715,6 +730,7 @@ func serverListen(env *Environment) error {
if *statusAddress != "" {
err := env.serveStatus()
if err != nil {
listener.Close()
logger.Printf("error serving /_status: %s", err)
return err
}
Expand Down Expand Up @@ -760,6 +776,7 @@ func clientListen(env *Environment) error {
if *statusAddress != "" {
err := env.serveStatus()
if err != nil {
listener.Close()
logger.Printf("error serving /_status: %s", err)
return err
}
Expand Down
25 changes: 25 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,31 @@ func TestClientFlagValidation(t *testing.T) {
err = clientValidateFlags()
assert.NotNil(t, err, "--key without --cert should be rejected")
*keyPath = ""

// Test: OPA flags must be used together
*keystorePath = "file"
*clientListenAddress = "127.0.0.1:8080"
*enabledCipherSuites = "AES,CHACHA"

*clientAllowPolicy = "policy"
*clientAllowQuery = ""
err = clientValidateFlags()
assert.NotNil(t, err, "--verify-policy needs --verify-query")

*clientAllowPolicy = ""
*clientAllowQuery = "query"
err = clientValidateFlags()
assert.NotNil(t, err, "--verify-query needs --verify-policy")

*clientAllowPolicy = "policy"
*clientAllowQuery = "query"
err = clientValidateFlags()
assert.Nil(t, err, "--verify-policy and --verify-query together should be valid")

*clientAllowPolicy = ""
*clientAllowQuery = ""
err = clientValidateFlags()
assert.Nil(t, err, "neither OPA flag set should be valid")
}

func TestAllowsLocalhost(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ func (p *Proxy) Accept() {
_, err = h.WriteTo(backend)
if err != nil {
p.logConditional(LogConnectionErrors, "error writing proxy header: %s", err)
backend.Close()
return
}
}
Expand Down
47 changes: 47 additions & 0 deletions proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1008,3 +1008,50 @@ func TestProxyProtoHeaderConnMode(t *testing.T) {
assert.Nil(t, err)
assert.Empty(t, tlvs, "conn mode should have no TLVs even with TLS state")
}

// failWriteConn is a mock connection that tracks Close() calls and fails on Write().
// Used to simulate a PROXY protocol header write failure.
type failWriteConn struct {
closed bool
mockConn
}

func (f *failWriteConn) Write(b []byte) (int, error) { return 0, errors.New("write error") }
func (f *failWriteConn) Close() error { f.closed = true; return nil }

func TestProxyProtocolWriteFailureClosesBackend(t *testing.T) {
// Incoming listener (plain TCP — forceHandshake is a no-op for non-TLS)
incoming, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer incoming.Close()

// Backend mock that fails on Write (simulating PROXY header write failure)
backend := &failWriteConn{}
dialCalled := make(chan struct{})

dialer := func(_ context.Context) (net.Conn, error) {
close(dialCalled)
return backend, nil
}

// Create proxy with PROXY protocol enabled
p := proxyForTestWithProxyProtocol(incoming, dialer)
go p.Accept()

// Connect a client to trigger the handler
client, err := net.Dial("tcp", incoming.Addr().String())
assert.Nil(t, err)

// Wait for the handler to reach Dial, ensuring it's past the accept stage
<-dialCalled
client.Close()

// Shut down and wait for all handlers to complete
p.Shutdown()
p.Wait()

// Regression: verify backend connection is closed when PROXY header write fails.
assert.True(t, backend.closed, "backend connection must be closed when PROXY protocol header write fails")
}
1 change: 1 addition & 0 deletions socket/launchd_enabled.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func launchdSocket(address string) (net.Listener, error) {

fds := (*[1]C.int)(ptr)
file := os.NewFile(uintptr(fds[0]), "")
defer file.Close()

return net.FileListener(file)
}
10 changes: 7 additions & 3 deletions socket/systemd_enabled.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ func systemdSocket(name string) (net.Listener, error) {
return nil, err
}

return systemdSocketFromMap(name, listeners)
}

func systemdSocketFromMap(name string, listeners map[string][]net.Listener) (net.Listener, error) {
if listener, ok := listeners[name]; ok {
if len(listeners) != 1 {
return nil, fmt.Errorf("expected exactly 1 listening socket configured in systemd for name %s, found %d", name, len(listeners))
if len(listener) != 1 {
return nil, fmt.Errorf("expected exactly 1 listening socket configured in systemd for name %s, found %d", name, len(listener))
}
return listener[0], err
return listener[0], nil
}

return nil, fmt.Errorf("expected listener with name %s, but found none", name)
Expand Down
58 changes: 58 additions & 0 deletions socket/systemd_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
//go:build linux

package socket

import (
"net"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSystemdSocketMultipleNames(t *testing.T) {
// Two systemd socket names, each with exactly one listener.
// Requesting "web" should succeed since "web" has exactly 1 socket.
webListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer webListener.Close()

apiListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer apiListener.Close()

listeners := map[string][]net.Listener{
"web": {webListener},
"api": {apiListener},
}

result, err := systemdSocketFromMap("web", listeners)
assert.Nil(t, err, "requesting 'web' with 1 socket should succeed even when other names exist")
assert.Equal(t, webListener, result)
}

func TestSystemdSocketMultipleSocketsSameName(t *testing.T) {
// One systemd socket name with two listeners.
// This should fail since we expect exactly 1 socket per name.
l1, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer l1.Close()

l2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer l2.Close()

listeners := map[string][]net.Listener{
"web": {l1, l2},
}

_, err = systemdSocketFromMap("web", listeners)
assert.NotNil(t, err, "should fail when a name has multiple sockets")
}
4 changes: 2 additions & 2 deletions status.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ func (s *statusHandler) HandleWatchdog() {

func (s *statusHandler) status(ctx context.Context) statusResponse {
resp := statusResponse{
Time: time.Now(),
LastReload: s.lastReload,
Time: time.Now(),
}

resp.Revision = version
Expand All @@ -156,6 +155,7 @@ func (s *statusHandler) status(ctx context.Context) statusResponse {
}

s.mu.Lock()
resp.LastReload = s.lastReload
resp.Ok = s.listening && resp.BackendOk
if s.stopping {
resp.Message = "stopping"
Expand Down
Loading