From 3db441518aa71e70b4fa99259fd65973a5ad3059 Mon Sep 17 00:00:00 2001 From: Cedric Staub Date: Mon, 12 Jan 2026 09:09:14 -0800 Subject: [PATCH 1/2] More unit tests to improve test coverage --- .golangci.yml | 3 +- AGENTS.md | 96 ++++++++++++ certloader/acmetlsconfig_test.go | 129 ++++++++++++++++ certloader/dialer_test.go | 181 ++++++++++++++++++++++ certloader/listener_test.go | 228 +++++++++++++++++++++++++++ proxy/proxy_test.go | 254 +++++++++++++++++++++++++++++++ 6 files changed, 889 insertions(+), 2 deletions(-) create mode 100644 AGENTS.md create mode 100644 certloader/acmetlsconfig_test.go create mode 100644 certloader/dialer_test.go create mode 100644 certloader/listener_test.go diff --git a/.golangci.yml b/.golangci.yml index 82d257a2bb..75eea7dd83 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -10,10 +10,9 @@ linters: - std-error-handling - common-false-positives rules: - # Additional exclusions for os.Remove() in tests - cleanup that's best-effort + # Ignore errcheck in test files - test code often has best-effort operations - linters: - errcheck - text: "os.Remove" path: "_test\\.go" settings: errcheck: diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000000..6c569d0907 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,96 @@ +# AGENTS.md + +This file provides guidance to agents when working with code in this repository. + +## Project Overview + +Ghostunnel is a TLS proxy with mutual authentication support for securing non-TLS backend applications. It operates in two modes: +- **Server mode**: Accepts TLS connections and proxies to insecure backends (TCP or UNIX sockets) +- **Client mode**: Accepts insecure connections and proxies to TLS-secured services + +## Build Commands + +This project uses [mage](https://magefile.org) as the build system: + +```bash +# Build the binary +mage go:build + +# Run all tests (unit + integration) +mage test:all + +# Run only unit tests +mage test:unit + +# Run only integration tests (requires Python 3.5+) +mage test:integration + +# Run tests in Docker (includes PKCS#11 tests with SoftHSM) +mage test:docker + +# Generate test certificates for development +mage test:keys + +# View coverage +go tool cover -html coverage/all.profile + +# Build Docker images +mage docker:build + +# List all available targets +mage -l +``` + +### Running a Single Test + +For unit tests: +```bash +go test -v -run TestName ./... +go test -v ./auth/... # Run all tests in a package +``` + +For integration tests (Python): +```bash +cd tests && python3 test-name.py +``` + +### Linting + +```bash +golangci-lint run +``` + +The project uses golangci-lint with configuration in `.golangci.yml`. Standard linters are enabled with exclusions for common error handling patterns. + +## Architecture + +### Package Structure + +- **main** (`main.go`, `doc.go`): Entry point, CLI flag parsing (using kingpin), mode dispatch (server/client) +- **auth**: Authorization via X.509 certificate validation (CN, OU, DNS SAN, URI SAN, IP SAN checks) +- **certloader**: Certificate loading abstractions supporting PEM files, PKCS#12 keystores, PKCS#11 HSMs, SPIFFE Workload API, ACME, and macOS/Windows keychain +- **proxy**: Connection forwarding with configurable timeouts, connection limits, and PROXY protocol support +- **policy**: Open Policy Agent (OPA) integration for declarative access control policies +- **socket**: Network socket utilities including systemd/launchd socket activation +- **wildcard**: Pattern matching for URI-based access control +- **certstore**: Platform-specific keychain integration (macOS/Windows) + +### Key Design Patterns + +1. **TLSConfigSource interface**: Abstracts certificate sources (files, SPIFFE, ACME, keychain) behind a common interface for hot-reloading +2. **Conditional compilation**: Platform-specific features (PKCS#11, keychain, Landlock) use build tags (`pkcs11_enabled.go`/`pkcs11_disabled.go`) +3. **Signal handling**: SIGHUP triggers certificate reload; SIGTERM/SIGINT trigger graceful shutdown + +### Testing + +- Unit tests: Go standard testing in `*_test.go` files +- Integration tests: Python scripts in `tests/` directory using `tests/common.py` helper module +- Test certificates are generated in `test-keys/` via `mage test:keys` + +## Key Flags + +Server mode requires access control: `--allow-all`, `--allow-cn`, `--allow-ou`, `--allow-dns`, `--allow-uri`, `--allow-policy`, or `--disable-authentication` + +Certificate sources are mutually exclusive: `--keystore`, `--cert/--key`, `--keychain-identity`, `--use-workload-api`, `--auto-acme-cert` + +Safe addresses (localhost, 127.0.0.1, [::1], unix:, systemd:, launchd:) don't require `--unsafe-target` or `--unsafe-listen` diff --git a/certloader/acmetlsconfig_test.go b/certloader/acmetlsconfig_test.go new file mode 100644 index 0000000000..ee179122f9 --- /dev/null +++ b/certloader/acmetlsconfig_test.go @@ -0,0 +1,129 @@ +/*- + * Copyright 2025 Ghostunnel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "crypto/tls" + "testing" + + "github.com/caddyserver/certmagic" + "github.com/mholt/acmez" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Note: Full ACME testing requires external ACME server interaction. +// These tests cover the code paths that can be tested without external dependencies. + +func TestACMETLSConfigSourceGetClientConfigError(t *testing.T) { + // GetClientConfig should always fail for ACME sources + // (ACME is server-only feature) + source := &acmeTLSConfigSource{ + magicConfig: certmagic.NewDefault(), + gtACMEConfig: &ACMEConfig{}, + } + + _, err := source.GetClientConfig(nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not supported in client mode") +} + +func TestACMETLSConfigSourceCanServe(t *testing.T) { + source := &acmeTLSConfigSource{ + magicConfig: certmagic.NewDefault(), + gtACMEConfig: &ACMEConfig{}, + } + + // CanServe should always return true (certmagic manages validity) + assert.True(t, source.CanServe()) +} + +func TestACMETLSConfigSourceReload(t *testing.T) { + source := &acmeTLSConfigSource{ + magicConfig: certmagic.NewDefault(), + gtACMEConfig: &ACMEConfig{}, + } + + // Reload should be a no-op (certmagic auto-refreshes) + err := source.Reload() + assert.NoError(t, err) +} + +func TestACMETLSConfigSourceGetServerConfigNilBase(t *testing.T) { + source := &acmeTLSConfigSource{ + magicConfig: certmagic.NewDefault(), + gtACMEConfig: &ACMEConfig{}, + } + + // GetServerConfig should work with nil base config + config, err := source.GetServerConfig(nil) + require.NoError(t, err) + require.NotNil(t, config) + + tlsConfig := config.GetServerConfig() + require.NotNil(t, tlsConfig) + assert.NotNil(t, tlsConfig.GetCertificate, "GetCertificate should be set") + assert.Contains(t, tlsConfig.NextProtos, acmez.ACMETLS1Protocol, "ACME-TLS protocol should be in NextProtos") +} + +func TestACMETLSConfigSourceGetServerConfigWithBase(t *testing.T) { + source := &acmeTLSConfigSource{ + magicConfig: certmagic.NewDefault(), + gtACMEConfig: &ACMEConfig{}, + } + + // GetServerConfig should preserve base config settings + base := &tls.Config{ + MinVersion: tls.VersionTLS13, + NextProtos: []string{"h2", "http/1.1"}, + } + + config, err := source.GetServerConfig(base) + require.NoError(t, err) + require.NotNil(t, config) + + tlsConfig := config.GetServerConfig() + require.NotNil(t, tlsConfig) + assert.Equal(t, uint16(tls.VersionTLS13), tlsConfig.MinVersion, "MinVersion should be preserved from base") + assert.Contains(t, tlsConfig.NextProtos, "h2", "base NextProtos should be preserved") + assert.Contains(t, tlsConfig.NextProtos, "http/1.1", "base NextProtos should be preserved") + assert.Contains(t, tlsConfig.NextProtos, acmez.ACMETLS1Protocol, "ACME-TLS protocol should be added") +} + +func TestACMETLSConfigGetServerConfig(t *testing.T) { + magicConfig := certmagic.NewDefault() + base := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + acmeConfig := &acmeTLSConfig{ + magicConfig: magicConfig, + base: base, + } + + tlsConfig := acmeConfig.GetServerConfig() + require.NotNil(t, tlsConfig) + + // Verify it's a clone (not the same pointer) + assert.NotSame(t, base, tlsConfig, "GetServerConfig should return a clone") + + // Verify GetCertificate is set + assert.NotNil(t, tlsConfig.GetCertificate) + + // Verify ACME-TLS protocol is added + assert.Contains(t, tlsConfig.NextProtos, acmez.ACMETLS1Protocol) +} diff --git a/certloader/dialer_test.go b/certloader/dialer_test.go new file mode 100644 index 0000000000..095c29128f --- /dev/null +++ b/certloader/dialer_test.go @@ -0,0 +1,181 @@ +/*- + * Copyright 2025 Ghostunnel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "context" + "crypto/tls" + "net" + "testing" + "time" + + spiffetest "github.com/ghostunnel/ghostunnel/certloader/internal/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockTLSClientConfig implements TLSClientConfig for testing +type mockTLSClientConfig struct { + config *tls.Config +} + +func (m *mockTLSClientConfig) GetClientConfig() *tls.Config { + return m.config +} + +func TestDialerWithCertificate(t *testing.T) { + // Create test CA and certificates + rootPool, serverCert := spiffetest.CreateWebCredentials(t) + _, clientCert := spiffetest.CreateWebCredentials(t) + + // Start a TLS server + serverConfig := &tls.Config{ + Certificates: []tls.Certificate{*serverCert}, + ClientAuth: tls.NoClientCert, + } + listener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig) + require.NoError(t, err) + defer listener.Close() + + // Accept connections in goroutine + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + // Write something back to confirm connection works + conn.Write([]byte("OK")) + }() + + // Create dialer with client cert + clientConfig := &tls.Config{ + Certificates: []tls.Certificate{*clientCert}, + RootCAs: rootPool, + InsecureSkipVerify: true, // Skip verification for test simplicity + } + mockConfig := &mockTLSClientConfig{config: clientConfig} + dialer := DialerWithCertificate(mockConfig, 5*time.Second, &net.Dialer{}) + + // Test successful dial + conn, err := dialer.DialContext(context.Background(), "tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + + // Verify it's a TLS connection + _, ok := conn.(*tls.Conn) + assert.True(t, ok, "returned connection should be TLS") + + // Read response to confirm connection works + buf := make([]byte, 2) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "OK", string(buf[:n])) +} + +func TestDialWithDialerRawConnFailure(t *testing.T) { + // Test when raw connection fails (e.g., connection refused) + config := &tls.Config{InsecureSkipVerify: true} + + _, err := dialWithDialer(&net.Dialer{}, context.Background(), + 100*time.Millisecond, "tcp", "127.0.0.1:1", config) // Port 1 should be closed + assert.Error(t, err) +} + +func TestDialWithDialerHandshakeFailure(t *testing.T) { + // Start a plain TCP server (not TLS) to cause handshake failure + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + // Accept and immediately close to trigger handshake failure + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + // Close immediately to cause handshake failure + conn.Close() + }() + + config := &tls.Config{InsecureSkipVerify: true} + + _, err = dialWithDialer(&net.Dialer{}, context.Background(), + 1*time.Second, "tcp", listener.Addr().String(), config) + assert.Error(t, err) +} + +func TestDialWithDialerContextCancellation(t *testing.T) { + // Start a TLS server that delays + _, serverCert := spiffetest.CreateWebCredentials(t) + serverConfig := &tls.Config{ + Certificates: []tls.Certificate{*serverCert}, + } + listener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig) + require.NoError(t, err) + defer listener.Close() + + // Accept but delay handshake + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + // Delay to allow context cancellation + time.Sleep(2 * time.Second) + conn.Close() + }() + + // Cancel context immediately + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + config := &tls.Config{InsecureSkipVerify: true} + + _, err = dialWithDialer(&net.Dialer{}, ctx, + 5*time.Second, "tcp", listener.Addr().String(), config) + assert.Error(t, err) +} + +func TestDialWithDialerTimeout(t *testing.T) { + // Start a TLS server that delays handshake + _, serverCert := spiffetest.CreateWebCredentials(t) + serverConfig := &tls.Config{ + Certificates: []tls.Certificate{*serverCert}, + } + listener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig) + require.NoError(t, err) + defer listener.Close() + + // Accept but don't complete handshake + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + // Hold connection open but delay to trigger timeout + time.Sleep(5 * time.Second) + conn.Close() + }() + + config := &tls.Config{InsecureSkipVerify: true} + + // Use very short timeout + _, err = dialWithDialer(&net.Dialer{}, context.Background(), + 50*time.Millisecond, "tcp", listener.Addr().String(), config) + assert.Error(t, err) +} diff --git a/certloader/listener_test.go b/certloader/listener_test.go new file mode 100644 index 0000000000..80a982f776 --- /dev/null +++ b/certloader/listener_test.go @@ -0,0 +1,228 @@ +/*- + * Copyright 2025 Ghostunnel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "crypto/tls" + "errors" + "net" + "testing" + "time" + + spiffetest "github.com/ghostunnel/ghostunnel/certloader/internal/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockTLSServerConfig implements TLSServerConfig for testing +type mockTLSServerConfig struct { + config *tls.Config +} + +func (m *mockTLSServerConfig) GetServerConfig() *tls.Config { + return m.config +} + +// failingListener is a mock listener that always fails on Accept +type failingListener struct{} + +func (f *failingListener) Accept() (net.Conn, error) { + return nil, errors.New("mock accept error") +} + +func (f *failingListener) Close() error { + return nil +} + +func (f *failingListener) Addr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func TestNewListener(t *testing.T) { + inner, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer inner.Close() + + mockConfig := &mockTLSServerConfig{config: &tls.Config{}} + listener := NewListener(inner, mockConfig) + + assert.Equal(t, inner.Addr(), listener.Addr()) + assert.Equal(t, inner, listener.Listener) +} + +func TestListenerAccept(t *testing.T) { + // Create test certificates + _, serverCert := spiffetest.CreateWebCredentials(t) + + // Create inner TCP listener + inner, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + // Create TLS config + serverConfig := &tls.Config{ + Certificates: []tls.Certificate{*serverCert}, + } + mockConfig := &mockTLSServerConfig{config: serverConfig} + listener := NewListener(inner, mockConfig) + defer listener.Close() + + // Accept in goroutine (server side) + acceptDone := make(chan net.Conn, 1) + acceptErr := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + acceptErr <- err + return + } + // Complete handshake on server side + tlsConn := conn.(*tls.Conn) + if err := tlsConn.Handshake(); err != nil { + conn.Close() + acceptErr <- err + return + } + acceptDone <- conn + }() + + // Connect from client side + clientConn, err := tls.Dial("tcp", listener.Addr().String(), + &tls.Config{InsecureSkipVerify: true}) + require.NoError(t, err) + defer clientConn.Close() + + // Wait for server to accept + select { + case conn := <-acceptDone: + defer conn.Close() + // Verify it's a TLS connection + _, ok := conn.(*tls.Conn) + assert.True(t, ok, "returned connection should be TLS") + case err := <-acceptErr: + t.Fatalf("accept failed: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("test timed out") + } +} + +func TestListenerAcceptError(t *testing.T) { + // Test error propagation when inner listener fails + mockListener := &failingListener{} + mockConfig := &mockTLSServerConfig{config: &tls.Config{}} + listener := NewListener(mockListener, mockConfig) + + _, err := listener.Accept() + assert.Error(t, err) + assert.Contains(t, err.Error(), "mock accept error") +} + +func TestListenerClose(t *testing.T) { + inner, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + mockConfig := &mockTLSServerConfig{config: &tls.Config{}} + listener := NewListener(inner, mockConfig) + + // Close should close the inner listener + err = listener.Close() + assert.NoError(t, err) + + // Trying to accept should fail now + _, err = inner.Accept() + assert.Error(t, err) +} + +func TestListenerConfigReload(t *testing.T) { + // Create test certificates + _, serverCert1 := spiffetest.CreateWebCredentials(t) + _, serverCert2 := spiffetest.CreateWebCredentials(t) + + // Create inner TCP listener + inner, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + // Start with first config + currentConfig := &tls.Config{ + Certificates: []tls.Certificate{*serverCert1}, + } + + // Create a mock config that can be updated + mockConfig := &reloadableMockConfig{config: currentConfig} + listener := NewListener(inner, mockConfig) + defer listener.Close() + + // Helper to do one connection cycle + doConnection := func() error { + acceptDone := make(chan net.Conn, 1) + acceptErr := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + acceptErr <- err + return + } + // Complete handshake + tlsConn := conn.(*tls.Conn) + if err := tlsConn.Handshake(); err != nil { + conn.Close() + acceptErr <- err + return + } + acceptDone <- conn + }() + + // Connect from client + clientConn, err := tls.Dial("tcp", listener.Addr().String(), + &tls.Config{InsecureSkipVerify: true}) + if err != nil { + return err + } + clientConn.Close() + + select { + case conn := <-acceptDone: + conn.Close() + return nil + case err := <-acceptErr: + return err + case <-time.After(5 * time.Second): + return errors.New("timeout") + } + } + + // First connection with first config + err = doConnection() + require.NoError(t, err) + + // Update config (simulating reload) + mockConfig.config = &tls.Config{ + Certificates: []tls.Certificate{*serverCert2}, + } + + // Second connection should use new config + err = doConnection() + require.NoError(t, err) +} + +// reloadableMockConfig allows config changes between accepts +type reloadableMockConfig struct { + config *tls.Config +} + +func (m *reloadableMockConfig) GetServerConfig() *tls.Config { + return m.config +} diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 33fb76c9db..d16d6e74e9 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -315,3 +315,257 @@ func TestCopyData(t *testing.T) { t.Fatalf("input and output were different after copy") } } + +// mockNetError implements net.Error for testing isTimeoutError +type mockNetError struct { + timeout bool + temporary bool + msg string +} + +func (e *mockNetError) Error() string { return e.msg } +func (e *mockNetError) Timeout() bool { return e.timeout } +func (e *mockNetError) Temporary() bool { return e.temporary } + +func TestIsTimeoutError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "regular error", + err: errors.New("test error"), + expected: false, + }, + { + name: "net.Error with timeout=true", + err: &mockNetError{timeout: true, msg: "timeout error"}, + expected: true, + }, + { + name: "net.Error with timeout=false", + err: &mockNetError{timeout: false, msg: "non-timeout error"}, + expected: false, + }, + { + name: "context.DeadlineExceeded", + err: context.DeadlineExceeded, + expected: true, + }, + { + name: "context.Canceled", + err: context.Canceled, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := isTimeoutError(tc.err) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestIsClosedConnectionError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "regular error", + err: errors.New("test error"), + expected: false, + }, + { + name: "closed pipe error", + err: errors.New("read/write on closed pipe"), + expected: true, + }, + { + name: "net.OpError read closed", + err: &net.OpError{ + Op: "read", + Err: errors.New("use of closed network connection"), + }, + expected: true, + }, + { + name: "net.OpError write closed", + err: &net.OpError{ + Op: "write", + Err: errors.New("use of closed network connection"), + }, + expected: true, + }, + { + name: "net.OpError readfrom closed", + err: &net.OpError{ + Op: "readfrom", + Err: errors.New("use of closed network connection"), + }, + expected: true, + }, + { + name: "net.OpError writeto closed", + err: &net.OpError{ + Op: "writeto", + Err: errors.New("use of closed network connection"), + }, + expected: true, + }, + { + name: "net.OpError other op", + err: &net.OpError{ + Op: "dial", + Err: errors.New("use of closed network connection"), + }, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := isClosedConnectionError(tc.err) + assert.Equal(t, tc.expected, result) + }) + } +} + +// mockConn is a minimal net.Conn implementation that is neither TCP nor Unix +type mockConn struct { + closed bool +} + +func (m *mockConn) Read(b []byte) (n int, err error) { return 0, nil } +func (m *mockConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (m *mockConn) Close() error { m.closed = true; return nil } +func (m *mockConn) LocalAddr() net.Addr { return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} } +func (m *mockConn) RemoteAddr() net.Addr { return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} } +func (m *mockConn) SetDeadline(t time.Time) error { return nil } +func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +func TestCloseReadNonTCPConnection(t *testing.T) { + conn := &mockConn{} + closeRead(conn) + assert.True(t, conn.closed, "non-TCP/Unix conn should be closed via Close()") +} + +func TestCloseWriteNonTCPConnection(t *testing.T) { + conn := &mockConn{} + closeWrite(conn) + assert.True(t, conn.closed, "non-TCP/Unix conn should be closed via Close()") +} + +func TestCloseReadTCPConnection(t *testing.T) { + // Create a TCP connection pair + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.Nil(t, err) + defer listener.Close() + + go func() { + conn, _ := listener.Accept() + if conn != nil { + time.Sleep(100 * time.Millisecond) + conn.Close() + } + }() + + conn, err := net.Dial("tcp", listener.Addr().String()) + assert.Nil(t, err) + defer conn.Close() + + // closeRead should not panic and should work on TCP + closeRead(conn) +} + +func TestCloseWriteTCPConnection(t *testing.T) { + // Create a TCP connection pair + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.Nil(t, err) + defer listener.Close() + + go func() { + conn, _ := listener.Accept() + if conn != nil { + time.Sleep(100 * time.Millisecond) + conn.Close() + } + }() + + conn, err := net.Dial("tcp", listener.Addr().String()) + assert.Nil(t, err) + defer conn.Close() + + // closeWrite should not panic and should work on TCP + closeWrite(conn) +} + +func TestForceHandshakeNonTLSConn(t *testing.T) { + // Create a regular TCP connection (non-TLS) + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.Nil(t, err) + defer listener.Close() + + go func() { + conn, _ := listener.Accept() + if conn != nil { + conn.Close() + } + }() + + conn, err := net.Dial("tcp", listener.Addr().String()) + assert.Nil(t, err) + defer conn.Close() + + // forceHandshake should be a no-op for non-TLS connections + ctx := context.Background() + err = forceHandshake(ctx, conn) + assert.Nil(t, err, "forceHandshake should succeed for non-TLS conn") +} + +func TestLogConnectionMessageDisabled(t *testing.T) { + // Test with LogConnections disabled + p := New(nil, 5*time.Second, 5*time.Second, 0, 0, nil, &testLogger{}, 0, false) + + // Create pipe connections + src, dst := net.Pipe() + defer src.Close() + defer dst.Close() + + // Should not panic even with logging disabled + p.logConnectionMessage("test", src, dst, 0, 0, time.Time{}) +} + +func TestLogConditional(t *testing.T) { + logged := false + logger := &callbackLogger{callback: func(format string, v ...interface{}) { + logged = true + }} + + // Test with flag enabled + p := New(nil, 5*time.Second, 5*time.Second, 0, 0, nil, logger, LogConnectionErrors, false) + p.logConditional(LogConnectionErrors, "test message") + assert.True(t, logged, "should log when flag is enabled") + + // Test with flag disabled + logged = false + p.logConditional(LogHandshakeErrors, "test message") + assert.False(t, logged, "should not log when flag is disabled") +} + +type callbackLogger struct { + callback func(format string, v ...interface{}) +} + +func (c *callbackLogger) Printf(format string, v ...interface{}) { + c.callback(format, v...) +} From 2773cfaf100d2249b9df14ee7d2aa7289ced1a46 Mon Sep 17 00:00:00 2001 From: Cedric Staub Date: Wed, 28 Jan 2026 16:15:13 -0800 Subject: [PATCH 2/2] Add macOS codesigning, notarization, and Docker tag improvements Add apple:codesign and apple:notarize mage targets for signing and notarizing macOS binaries. Codesign uses the macOS codesign tool with hardened runtime enabled. Notarize submits to Apple's notary service via xcrun notarytool with App Store Connect API key auth. For CI, CODESIGN_CERTIFICATE (base64 .p12) triggers temporary keychain creation with automatic cleanup. NOTARIZE_KEY (base64 .p8) is written to a temp file for notarytool. Sensitive security commands use runSilent to suppress argument echoing in CI logs. Add codesign and notarize steps to the Darwin release workflow. Change Docker tagging so "latest" is only applied on release tag pushes (e.g. v1.9.0), not on master branch pushes. Master pushes now produce "master" tagged images instead. Co-Authored-By: Claude Opus 4.5 --- .github/workflows/release.yml | 18 +++ magefile.go | 296 ++++++++++++++++++++++++++++++++-- 2 files changed, 298 insertions(+), 16 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index daf95864fb..d57e081f18 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -72,6 +72,24 @@ jobs: CGO_ENABLED=1 GOARCH=arm64 ./mage-bin -v go:build mv ghostunnel ghostunnel-darwin-arm64 lipo -create -output ghostunnel-darwin-universal ghostunnel-darwin-amd64 ghostunnel-darwin-arm64 + - name: Codesign binaries + env: + CODESIGN_IDENTITY: ${{ secrets.CODESIGN_IDENTITY }} + CODESIGN_CERTIFICATE: ${{ secrets.CODESIGN_CERTIFICATE }} + CODESIGN_CERTIFICATE_PASSWORD: ${{ secrets.CODESIGN_CERTIFICATE_PASSWORD }} + run: | + ./mage-bin -v apple:codesign ghostunnel-darwin-amd64 + ./mage-bin -v apple:codesign ghostunnel-darwin-arm64 + ./mage-bin -v apple:codesign ghostunnel-darwin-universal + - name: Notarize binaries + env: + NOTARIZE_ISSUER_ID: ${{ secrets.NOTARIZE_ISSUER_ID }} + NOTARIZE_KEY_ID: ${{ secrets.NOTARIZE_KEY_ID }} + NOTARIZE_KEY: ${{ secrets.NOTARIZE_KEY }} + run: | + ./mage-bin -v apple:notarize ghostunnel-darwin-amd64 + ./mage-bin -v apple:notarize ghostunnel-darwin-arm64 + ./mage-bin -v apple:notarize ghostunnel-darwin-universal - name: Upload artifact uses: actions/upload-artifact@v6 with: diff --git a/magefile.go b/magefile.go index 761e24d2c7..97a5722fc5 100644 --- a/magefile.go +++ b/magefile.go @@ -5,6 +5,8 @@ package main import ( "bytes" "context" + "crypto/rand" + "encoding/base64" "fmt" "os" "os/exec" @@ -18,12 +20,22 @@ import ( ) type Go mg.Namespace +type Apple mg.Namespace type Git mg.Namespace type Test mg.Namespace type Docker mg.Namespace var Default = Go.Build +// runSilent executes a command without echoing it to stdout, to avoid +// leaking sensitive arguments (passwords, secrets) in CI logs. +func runSilent(name string, args ...string) error { + cmd := exec.Command(name, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} + // printf prints the given format and args if verbose mode is enabled. func printf(format string, args ...interface{}) { if mg.Verbose() { @@ -77,6 +89,252 @@ func (Go) Man(ctx context.Context) error { return nil } +// Codesign signs a macOS binary using the codesign tool. The binary argument +// specifies which file to sign. Requires macOS. If CODESIGN_CERTIFICATE is +// set, a temporary keychain is created, the certificate is imported, and the +// keychain is cleaned up after signing. +// +// Environment variables: +// - CODESIGN_IDENTITY: Signing identity (required, e.g. "Developer ID Application: Name (TEAMID)") +// - CODESIGN_CERTIFICATE: Base64-encoded .p12 certificate to import into a temporary keychain (optional, for CI) +// - CODESIGN_CERTIFICATE_PASSWORD: Password for the .p12 certificate (required if CODESIGN_CERTIFICATE is set) +func (Apple) Codesign(ctx context.Context, binary string) error { + if runtime.GOOS != "darwin" { + return fmt.Errorf("codesigning is only supported on macOS") + } + + identity := os.Getenv("CODESIGN_IDENTITY") + if identity == "" { + return fmt.Errorf("CODESIGN_IDENTITY must be set") + } + + certData := os.Getenv("CODESIGN_CERTIFICATE") + if certData != "" { + cleanup, err := setupCodesignKeychain(certData) + if err != nil { + return err + } + defer cleanup() + } + + printf("Signing binary %s with identity %s\n", binary, identity) + + if err := sh.Run("codesign", "--force", "--options", "runtime", "--sign", identity, binary); err != nil { + return fmt.Errorf("codesign %s failed: %w", binary, err) + } + + if err := sh.Run("codesign", "--verify", "--verbose", binary); err != nil { + return fmt.Errorf("codesign verification of %s failed: %w", binary, err) + } + + printf("Binary %s signed and verified successfully\n", binary) + return nil +} + +// Notarize submits a signed macOS binary to Apple's notary service. Requires +// macOS. If NOTARIZE_KEY is set, the .p8 key is written to a temp file and +// cleaned up after notarization. +// +// The binary is zipped for submission and the zip is removed afterward. +// Note: stapling only works for .app, .pkg, and .dmg — for bare binaries the +// notarization is registered with Apple but cannot be stapled. The staple step +// is attempted but a failure is not treated as an error. +// +// Environment variables: +// - NOTARIZE_ISSUER_ID: App Store Connect API issuer ID (required) +// - NOTARIZE_KEY_ID: App Store Connect API key ID (required) +// - NOTARIZE_KEY: Base64-encoded .p8 private key (optional, for CI; if not set, key must already exist) +func (Apple) Notarize(ctx context.Context, binary string) error { + if runtime.GOOS != "darwin" { + return fmt.Errorf("notarization is only supported on macOS") + } + + issuerID := os.Getenv("NOTARIZE_ISSUER_ID") + keyID := os.Getenv("NOTARIZE_KEY_ID") + if issuerID == "" || keyID == "" { + return fmt.Errorf("NOTARIZE_ISSUER_ID and NOTARIZE_KEY_ID must be set") + } + + // If NOTARIZE_KEY is set, write the .p8 key to a temp file for notarytool + keyPath, err := setupNotarizeKey(keyID) + if err != nil { + return err + } + if keyPath != "" { + defer os.Remove(keyPath) + } + + // Create zip for submission + zipPath := binary + ".zip" + if err := sh.Run("ditto", "-c", "-k", "--sequesterRsrc", binary, zipPath); err != nil { + return fmt.Errorf("failed to create zip for %s: %w", binary, err) + } + + printf("Submitting %s for notarization...\n", binary) + + submitArgs := []string{"notarytool", "submit", zipPath, + "--issuer", issuerID, + "--key-id", keyID, + } + if keyPath != "" { + submitArgs = append(submitArgs, "--key", keyPath) + } + submitArgs = append(submitArgs, "--wait") + + err = sh.Run("xcrun", submitArgs...) + os.Remove(zipPath) + if err != nil { + return fmt.Errorf("notarization of %s failed: %w", binary, err) + } + + // Attempt to staple — this only works for .app/.pkg/.dmg, not bare binaries + if err := sh.Run("xcrun", "stapler", "staple", binary); err != nil { + printf("Stapling skipped for %s (not supported for bare binaries): %v\n", binary, err) + } + + printf("Notarization of %s completed successfully\n", binary) + return nil +} + +// setupCodesignKeychain creates a temporary keychain, imports the signing +// certificate, and configures the keychain search list. Returns a cleanup +// function that removes the temporary keychain and restores the original +// search list. +func setupCodesignKeychain(certBase64 string) (func(), error) { + password := os.Getenv("CODESIGN_CERTIFICATE_PASSWORD") + if password == "" { + return nil, fmt.Errorf("CODESIGN_CERTIFICATE_PASSWORD must be set when CODESIGN_CERTIFICATE is set") + } + + // Decode certificate + certBytes, err := base64.StdEncoding.DecodeString(certBase64) + if err != nil { + return nil, fmt.Errorf("failed to decode CODESIGN_CERTIFICATE: %w", err) + } + + // Write certificate to temp file + certFile, err := os.CreateTemp("", "codesign-*.p12") + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + if _, err := certFile.Write(certBytes); err != nil { + os.Remove(certFile.Name()) + return nil, fmt.Errorf("failed to write certificate: %w", err) + } + certFile.Close() + + // Generate random keychain password + keychainPassBytes := make([]byte, 32) + if _, err := rand.Read(keychainPassBytes); err != nil { + os.Remove(certFile.Name()) + return nil, fmt.Errorf("failed to generate keychain password: %w", err) + } + keychainPassword := base64.StdEncoding.EncodeToString(keychainPassBytes) + + keychainPath := "ghostunnel-signing.keychain-db" + + // Save original keychain search list + originalKeychains, err := sh.Output("security", "list-keychains", "-d", "user") + if err != nil { + os.Remove(certFile.Name()) + return nil, fmt.Errorf("failed to list keychains: %w", err) + } + + cleanup := func() { + // Restore original keychain search list + restoreArgs := []string{"list-keychains", "-d", "user", "-s"} + restoreArgs = append(restoreArgs, parseKeychainPaths(originalKeychains)...) + sh.Run("security", restoreArgs...) + sh.Run("security", "delete-keychain", keychainPath) + os.Remove(certFile.Name()) + } + + // Create temporary keychain (suppress command echo to avoid leaking keychain password) + if err := runSilent("security", "create-keychain", "-p", keychainPassword, keychainPath); err != nil { + cleanup() + return nil, fmt.Errorf("failed to create keychain: %w", err) + } + + // Set keychain settings (no auto-lock) + if err := sh.Run("security", "set-keychain-settings", keychainPath); err != nil { + cleanup() + return nil, fmt.Errorf("failed to set keychain settings: %w", err) + } + + // Unlock keychain (suppress command echo to avoid leaking keychain password) + if err := runSilent("security", "unlock-keychain", "-p", keychainPassword, keychainPath); err != nil { + cleanup() + return nil, fmt.Errorf("failed to unlock keychain: %w", err) + } + + // Import certificate into keychain (suppress command echo to avoid leaking certificate password) + if err := runSilent("security", "import", certFile.Name(), "-k", keychainPath, "-f", "pkcs12", "-P", password, "-T", "/usr/bin/codesign"); err != nil { + cleanup() + return nil, fmt.Errorf("failed to import certificate: %w", err) + } + + // Set key partition list to allow codesign access (suppress command echo to avoid leaking keychain password) + if err := runSilent("security", "set-key-partition-list", "-S", "apple-tool:,apple:,codesign:", "-s", "-k", keychainPassword, keychainPath); err != nil { + cleanup() + return nil, fmt.Errorf("failed to set key partition list: %w", err) + } + + // Add temporary keychain to search list (prepend to existing) + keychainArgs := []string{"list-keychains", "-d", "user", "-s", keychainPath} + keychainArgs = append(keychainArgs, parseKeychainPaths(originalKeychains)...) + if err := sh.Run("security", keychainArgs...); err != nil { + cleanup() + return nil, fmt.Errorf("failed to update keychain search list: %w", err) + } + + return cleanup, nil +} + +// setupNotarizeKey writes the NOTARIZE_KEY env var (base64-encoded .p8) to a +// temp file and returns its path. Returns an empty path if NOTARIZE_KEY is not +// set (assumes the key file is already available locally). +func setupNotarizeKey(keyID string) (string, error) { + keyData := os.Getenv("NOTARIZE_KEY") + if keyData == "" { + return "", nil + } + + keyBytes, err := base64.StdEncoding.DecodeString(keyData) + if err != nil { + return "", fmt.Errorf("failed to decode NOTARIZE_KEY: %w", err) + } + + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + + keyDir := filepath.Join(homeDir, "private_keys") + if err := os.MkdirAll(keyDir, 0700); err != nil { + return "", fmt.Errorf("failed to create private_keys directory: %w", err) + } + + keyPath := filepath.Join(keyDir, fmt.Sprintf("AuthKey_%s.p8", keyID)) + if err := os.WriteFile(keyPath, keyBytes, 0600); err != nil { + return "", fmt.Errorf("failed to write API key: %w", err) + } + + return keyPath, nil +} + +// parseKeychainPaths parses the output of `security list-keychains` into +// a list of unquoted keychain paths. +func parseKeychainPaths(output string) []string { + var paths []string + for _, line := range strings.Split(output, "\n") { + kc := strings.TrimSpace(strings.Trim(strings.TrimSpace(line), "\"")) + if kc != "" { + paths = append(paths, kc) + } + } + return paths +} + // Clean removes build artifacts. func (Git) Clean(ctx context.Context) error { return sh.Run("git", "clean", "-Xdf") @@ -365,19 +623,23 @@ func (Docker) Push(ctx context.Context) error { // buildDocker builds and tags all Docker containers, optionally pushing them to Docker Hub. func buildDocker(ctx context.Context, push bool) error { - // Determine base tag (latest for master, version tag otherwise) - baseTag, err := getDockerTag() + baseTags, err := getDockerTags() if err != nil { return err } - builds := map[string][]string{ - "Dockerfile-alpine": []string{ + builds := map[string][]string{} + for _, baseTag := range baseTags { + builds["Dockerfile-alpine"] = append(builds["Dockerfile-alpine"], fmt.Sprintf("ghostunnel/ghostunnel:%s", baseTag), fmt.Sprintf("ghostunnel/ghostunnel:%s-alpine", baseTag), - }, - "Dockerfile-debian": []string{fmt.Sprintf("ghostunnel/ghostunnel:%s-debian", baseTag)}, - "Dockerfile-distroless": []string{fmt.Sprintf("ghostunnel/ghostunnel:%s-distroless", baseTag)}, + ) + builds["Dockerfile-debian"] = append(builds["Dockerfile-debian"], + fmt.Sprintf("ghostunnel/ghostunnel:%s-debian", baseTag), + ) + builds["Dockerfile-distroless"] = append(builds["Dockerfile-distroless"], + fmt.Sprintf("ghostunnel/ghostunnel:%s-distroless", baseTag), + ) } for dockerfile, tags := range builds { @@ -425,37 +687,39 @@ func getVersion() string { return strings.TrimSpace(output) } -// getDockerTag determines the Docker tag to use based on git state. -// Returns "latest" if on master branch, otherwise returns the most recent tag. -func getDockerTag() (string, error) { +// getDockerTags determines the Docker tags to use based on git state. +// For release tags (refs/tags/v*), returns both the version tag and "latest". +// For master branch, returns "master". For local non-master branches, returns +// the most recent git tag. +func getDockerTags() ([]string, error) { // Check if we're on a tag (for GitHub Actions when triggered by tag push) // In GitHub Actions, GITHUB_REF will be set, but locally we check git githubRef := os.Getenv("GITHUB_REF") if githubRef != "" { // GitHub Actions: refs/heads/master or refs/tags/v1.2.3 if strings.HasPrefix(githubRef, "refs/heads/master") { - return "latest", nil + return []string{"master"}, nil } if strings.HasPrefix(githubRef, "refs/tags/") { tag := strings.TrimPrefix(githubRef, "refs/tags/") - return tag, nil + return []string{tag, "latest"}, nil } } // Check current branch branch, err := sh.Output("git", "rev-parse", "--abbrev-ref", "HEAD") if err != nil { - return "", fmt.Errorf("failed to determine git ref: %w", err) + return nil, fmt.Errorf("failed to determine git ref: %w", err) } if strings.TrimSpace(branch) == "master" { - return "latest", nil + return []string{"master"}, nil } // Not on master, get the most recent tag tag, err := sh.Output("git", "describe", "--tags", "--abbrev=0") if err != nil { - return "", fmt.Errorf("failed to get git tag: %w", err) + return nil, fmt.Errorf("failed to get git tag: %w", err) } - return strings.TrimSpace(tag), nil + return []string{strings.TrimSpace(tag)}, nil }