diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index daf95864fb..d57e081f18 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -72,6 +72,24 @@ jobs: CGO_ENABLED=1 GOARCH=arm64 ./mage-bin -v go:build mv ghostunnel ghostunnel-darwin-arm64 lipo -create -output ghostunnel-darwin-universal ghostunnel-darwin-amd64 ghostunnel-darwin-arm64 + - name: Codesign binaries + env: + CODESIGN_IDENTITY: ${{ secrets.CODESIGN_IDENTITY }} + CODESIGN_CERTIFICATE: ${{ secrets.CODESIGN_CERTIFICATE }} + CODESIGN_CERTIFICATE_PASSWORD: ${{ secrets.CODESIGN_CERTIFICATE_PASSWORD }} + run: | + ./mage-bin -v apple:codesign ghostunnel-darwin-amd64 + ./mage-bin -v apple:codesign ghostunnel-darwin-arm64 + ./mage-bin -v apple:codesign ghostunnel-darwin-universal + - name: Notarize binaries + env: + NOTARIZE_ISSUER_ID: ${{ secrets.NOTARIZE_ISSUER_ID }} + NOTARIZE_KEY_ID: ${{ secrets.NOTARIZE_KEY_ID }} + NOTARIZE_KEY: ${{ secrets.NOTARIZE_KEY }} + run: | + ./mage-bin -v apple:notarize ghostunnel-darwin-amd64 + ./mage-bin -v apple:notarize ghostunnel-darwin-arm64 + ./mage-bin -v apple:notarize ghostunnel-darwin-universal - name: Upload artifact uses: actions/upload-artifact@v6 with: diff --git a/.golangci.yml b/.golangci.yml index 82d257a2bb..75eea7dd83 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -10,10 +10,9 @@ linters: - std-error-handling - common-false-positives rules: - # Additional exclusions for os.Remove() in tests - cleanup that's best-effort + # Ignore errcheck in test files - test code often has best-effort operations - linters: - errcheck - text: "os.Remove" path: "_test\\.go" settings: errcheck: diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000000..6c569d0907 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,96 @@ +# AGENTS.md + +This file provides guidance to agents when working with code in this repository. + +## Project Overview + +Ghostunnel is a TLS proxy with mutual authentication support for securing non-TLS backend applications. It operates in two modes: +- **Server mode**: Accepts TLS connections and proxies to insecure backends (TCP or UNIX sockets) +- **Client mode**: Accepts insecure connections and proxies to TLS-secured services + +## Build Commands + +This project uses [mage](https://magefile.org) as the build system: + +```bash +# Build the binary +mage go:build + +# Run all tests (unit + integration) +mage test:all + +# Run only unit tests +mage test:unit + +# Run only integration tests (requires Python 3.5+) +mage test:integration + +# Run tests in Docker (includes PKCS#11 tests with SoftHSM) +mage test:docker + +# Generate test certificates for development +mage test:keys + +# View coverage +go tool cover -html coverage/all.profile + +# Build Docker images +mage docker:build + +# List all available targets +mage -l +``` + +### Running a Single Test + +For unit tests: +```bash +go test -v -run TestName ./... +go test -v ./auth/... # Run all tests in a package +``` + +For integration tests (Python): +```bash +cd tests && python3 test-name.py +``` + +### Linting + +```bash +golangci-lint run +``` + +The project uses golangci-lint with configuration in `.golangci.yml`. Standard linters are enabled with exclusions for common error handling patterns. + +## Architecture + +### Package Structure + +- **main** (`main.go`, `doc.go`): Entry point, CLI flag parsing (using kingpin), mode dispatch (server/client) +- **auth**: Authorization via X.509 certificate validation (CN, OU, DNS SAN, URI SAN, IP SAN checks) +- **certloader**: Certificate loading abstractions supporting PEM files, PKCS#12 keystores, PKCS#11 HSMs, SPIFFE Workload API, ACME, and macOS/Windows keychain +- **proxy**: Connection forwarding with configurable timeouts, connection limits, and PROXY protocol support +- **policy**: Open Policy Agent (OPA) integration for declarative access control policies +- **socket**: Network socket utilities including systemd/launchd socket activation +- **wildcard**: Pattern matching for URI-based access control +- **certstore**: Platform-specific keychain integration (macOS/Windows) + +### Key Design Patterns + +1. **TLSConfigSource interface**: Abstracts certificate sources (files, SPIFFE, ACME, keychain) behind a common interface for hot-reloading +2. **Conditional compilation**: Platform-specific features (PKCS#11, keychain, Landlock) use build tags (`pkcs11_enabled.go`/`pkcs11_disabled.go`) +3. **Signal handling**: SIGHUP triggers certificate reload; SIGTERM/SIGINT trigger graceful shutdown + +### Testing + +- Unit tests: Go standard testing in `*_test.go` files +- Integration tests: Python scripts in `tests/` directory using `tests/common.py` helper module +- Test certificates are generated in `test-keys/` via `mage test:keys` + +## Key Flags + +Server mode requires access control: `--allow-all`, `--allow-cn`, `--allow-ou`, `--allow-dns`, `--allow-uri`, `--allow-policy`, or `--disable-authentication` + +Certificate sources are mutually exclusive: `--keystore`, `--cert/--key`, `--keychain-identity`, `--use-workload-api`, `--auto-acme-cert` + +Safe addresses (localhost, 127.0.0.1, [::1], unix:, systemd:, launchd:) don't require `--unsafe-target` or `--unsafe-listen` diff --git a/certloader/acmetlsconfig_test.go b/certloader/acmetlsconfig_test.go new file mode 100644 index 0000000000..ee179122f9 --- /dev/null +++ b/certloader/acmetlsconfig_test.go @@ -0,0 +1,129 @@ +/*- + * Copyright 2025 Ghostunnel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "crypto/tls" + "testing" + + "github.com/caddyserver/certmagic" + "github.com/mholt/acmez" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Note: Full ACME testing requires external ACME server interaction. +// These tests cover the code paths that can be tested without external dependencies. + +func TestACMETLSConfigSourceGetClientConfigError(t *testing.T) { + // GetClientConfig should always fail for ACME sources + // (ACME is server-only feature) + source := &acmeTLSConfigSource{ + magicConfig: certmagic.NewDefault(), + gtACMEConfig: &ACMEConfig{}, + } + + _, err := source.GetClientConfig(nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not supported in client mode") +} + +func TestACMETLSConfigSourceCanServe(t *testing.T) { + source := &acmeTLSConfigSource{ + magicConfig: certmagic.NewDefault(), + gtACMEConfig: &ACMEConfig{}, + } + + // CanServe should always return true (certmagic manages validity) + assert.True(t, source.CanServe()) +} + +func TestACMETLSConfigSourceReload(t *testing.T) { + source := &acmeTLSConfigSource{ + magicConfig: certmagic.NewDefault(), + gtACMEConfig: &ACMEConfig{}, + } + + // Reload should be a no-op (certmagic auto-refreshes) + err := source.Reload() + assert.NoError(t, err) +} + +func TestACMETLSConfigSourceGetServerConfigNilBase(t *testing.T) { + source := &acmeTLSConfigSource{ + magicConfig: certmagic.NewDefault(), + gtACMEConfig: &ACMEConfig{}, + } + + // GetServerConfig should work with nil base config + config, err := source.GetServerConfig(nil) + require.NoError(t, err) + require.NotNil(t, config) + + tlsConfig := config.GetServerConfig() + require.NotNil(t, tlsConfig) + assert.NotNil(t, tlsConfig.GetCertificate, "GetCertificate should be set") + assert.Contains(t, tlsConfig.NextProtos, acmez.ACMETLS1Protocol, "ACME-TLS protocol should be in NextProtos") +} + +func TestACMETLSConfigSourceGetServerConfigWithBase(t *testing.T) { + source := &acmeTLSConfigSource{ + magicConfig: certmagic.NewDefault(), + gtACMEConfig: &ACMEConfig{}, + } + + // GetServerConfig should preserve base config settings + base := &tls.Config{ + MinVersion: tls.VersionTLS13, + NextProtos: []string{"h2", "http/1.1"}, + } + + config, err := source.GetServerConfig(base) + require.NoError(t, err) + require.NotNil(t, config) + + tlsConfig := config.GetServerConfig() + require.NotNil(t, tlsConfig) + assert.Equal(t, uint16(tls.VersionTLS13), tlsConfig.MinVersion, "MinVersion should be preserved from base") + assert.Contains(t, tlsConfig.NextProtos, "h2", "base NextProtos should be preserved") + assert.Contains(t, tlsConfig.NextProtos, "http/1.1", "base NextProtos should be preserved") + assert.Contains(t, tlsConfig.NextProtos, acmez.ACMETLS1Protocol, "ACME-TLS protocol should be added") +} + +func TestACMETLSConfigGetServerConfig(t *testing.T) { + magicConfig := certmagic.NewDefault() + base := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + acmeConfig := &acmeTLSConfig{ + magicConfig: magicConfig, + base: base, + } + + tlsConfig := acmeConfig.GetServerConfig() + require.NotNil(t, tlsConfig) + + // Verify it's a clone (not the same pointer) + assert.NotSame(t, base, tlsConfig, "GetServerConfig should return a clone") + + // Verify GetCertificate is set + assert.NotNil(t, tlsConfig.GetCertificate) + + // Verify ACME-TLS protocol is added + assert.Contains(t, tlsConfig.NextProtos, acmez.ACMETLS1Protocol) +} diff --git a/certloader/dialer_test.go b/certloader/dialer_test.go new file mode 100644 index 0000000000..095c29128f --- /dev/null +++ b/certloader/dialer_test.go @@ -0,0 +1,181 @@ +/*- + * Copyright 2025 Ghostunnel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "context" + "crypto/tls" + "net" + "testing" + "time" + + spiffetest "github.com/ghostunnel/ghostunnel/certloader/internal/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockTLSClientConfig implements TLSClientConfig for testing +type mockTLSClientConfig struct { + config *tls.Config +} + +func (m *mockTLSClientConfig) GetClientConfig() *tls.Config { + return m.config +} + +func TestDialerWithCertificate(t *testing.T) { + // Create test CA and certificates + rootPool, serverCert := spiffetest.CreateWebCredentials(t) + _, clientCert := spiffetest.CreateWebCredentials(t) + + // Start a TLS server + serverConfig := &tls.Config{ + Certificates: []tls.Certificate{*serverCert}, + ClientAuth: tls.NoClientCert, + } + listener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig) + require.NoError(t, err) + defer listener.Close() + + // Accept connections in goroutine + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + // Write something back to confirm connection works + conn.Write([]byte("OK")) + }() + + // Create dialer with client cert + clientConfig := &tls.Config{ + Certificates: []tls.Certificate{*clientCert}, + RootCAs: rootPool, + InsecureSkipVerify: true, // Skip verification for test simplicity + } + mockConfig := &mockTLSClientConfig{config: clientConfig} + dialer := DialerWithCertificate(mockConfig, 5*time.Second, &net.Dialer{}) + + // Test successful dial + conn, err := dialer.DialContext(context.Background(), "tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() + + // Verify it's a TLS connection + _, ok := conn.(*tls.Conn) + assert.True(t, ok, "returned connection should be TLS") + + // Read response to confirm connection works + buf := make([]byte, 2) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Equal(t, "OK", string(buf[:n])) +} + +func TestDialWithDialerRawConnFailure(t *testing.T) { + // Test when raw connection fails (e.g., connection refused) + config := &tls.Config{InsecureSkipVerify: true} + + _, err := dialWithDialer(&net.Dialer{}, context.Background(), + 100*time.Millisecond, "tcp", "127.0.0.1:1", config) // Port 1 should be closed + assert.Error(t, err) +} + +func TestDialWithDialerHandshakeFailure(t *testing.T) { + // Start a plain TCP server (not TLS) to cause handshake failure + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + // Accept and immediately close to trigger handshake failure + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + // Close immediately to cause handshake failure + conn.Close() + }() + + config := &tls.Config{InsecureSkipVerify: true} + + _, err = dialWithDialer(&net.Dialer{}, context.Background(), + 1*time.Second, "tcp", listener.Addr().String(), config) + assert.Error(t, err) +} + +func TestDialWithDialerContextCancellation(t *testing.T) { + // Start a TLS server that delays + _, serverCert := spiffetest.CreateWebCredentials(t) + serverConfig := &tls.Config{ + Certificates: []tls.Certificate{*serverCert}, + } + listener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig) + require.NoError(t, err) + defer listener.Close() + + // Accept but delay handshake + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + // Delay to allow context cancellation + time.Sleep(2 * time.Second) + conn.Close() + }() + + // Cancel context immediately + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + config := &tls.Config{InsecureSkipVerify: true} + + _, err = dialWithDialer(&net.Dialer{}, ctx, + 5*time.Second, "tcp", listener.Addr().String(), config) + assert.Error(t, err) +} + +func TestDialWithDialerTimeout(t *testing.T) { + // Start a TLS server that delays handshake + _, serverCert := spiffetest.CreateWebCredentials(t) + serverConfig := &tls.Config{ + Certificates: []tls.Certificate{*serverCert}, + } + listener, err := tls.Listen("tcp", "127.0.0.1:0", serverConfig) + require.NoError(t, err) + defer listener.Close() + + // Accept but don't complete handshake + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + // Hold connection open but delay to trigger timeout + time.Sleep(5 * time.Second) + conn.Close() + }() + + config := &tls.Config{InsecureSkipVerify: true} + + // Use very short timeout + _, err = dialWithDialer(&net.Dialer{}, context.Background(), + 50*time.Millisecond, "tcp", listener.Addr().String(), config) + assert.Error(t, err) +} diff --git a/certloader/listener_test.go b/certloader/listener_test.go new file mode 100644 index 0000000000..80a982f776 --- /dev/null +++ b/certloader/listener_test.go @@ -0,0 +1,228 @@ +/*- + * Copyright 2025 Ghostunnel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package certloader + +import ( + "crypto/tls" + "errors" + "net" + "testing" + "time" + + spiffetest "github.com/ghostunnel/ghostunnel/certloader/internal/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockTLSServerConfig implements TLSServerConfig for testing +type mockTLSServerConfig struct { + config *tls.Config +} + +func (m *mockTLSServerConfig) GetServerConfig() *tls.Config { + return m.config +} + +// failingListener is a mock listener that always fails on Accept +type failingListener struct{} + +func (f *failingListener) Accept() (net.Conn, error) { + return nil, errors.New("mock accept error") +} + +func (f *failingListener) Close() error { + return nil +} + +func (f *failingListener) Addr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func TestNewListener(t *testing.T) { + inner, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer inner.Close() + + mockConfig := &mockTLSServerConfig{config: &tls.Config{}} + listener := NewListener(inner, mockConfig) + + assert.Equal(t, inner.Addr(), listener.Addr()) + assert.Equal(t, inner, listener.Listener) +} + +func TestListenerAccept(t *testing.T) { + // Create test certificates + _, serverCert := spiffetest.CreateWebCredentials(t) + + // Create inner TCP listener + inner, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + // Create TLS config + serverConfig := &tls.Config{ + Certificates: []tls.Certificate{*serverCert}, + } + mockConfig := &mockTLSServerConfig{config: serverConfig} + listener := NewListener(inner, mockConfig) + defer listener.Close() + + // Accept in goroutine (server side) + acceptDone := make(chan net.Conn, 1) + acceptErr := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + acceptErr <- err + return + } + // Complete handshake on server side + tlsConn := conn.(*tls.Conn) + if err := tlsConn.Handshake(); err != nil { + conn.Close() + acceptErr <- err + return + } + acceptDone <- conn + }() + + // Connect from client side + clientConn, err := tls.Dial("tcp", listener.Addr().String(), + &tls.Config{InsecureSkipVerify: true}) + require.NoError(t, err) + defer clientConn.Close() + + // Wait for server to accept + select { + case conn := <-acceptDone: + defer conn.Close() + // Verify it's a TLS connection + _, ok := conn.(*tls.Conn) + assert.True(t, ok, "returned connection should be TLS") + case err := <-acceptErr: + t.Fatalf("accept failed: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("test timed out") + } +} + +func TestListenerAcceptError(t *testing.T) { + // Test error propagation when inner listener fails + mockListener := &failingListener{} + mockConfig := &mockTLSServerConfig{config: &tls.Config{}} + listener := NewListener(mockListener, mockConfig) + + _, err := listener.Accept() + assert.Error(t, err) + assert.Contains(t, err.Error(), "mock accept error") +} + +func TestListenerClose(t *testing.T) { + inner, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + mockConfig := &mockTLSServerConfig{config: &tls.Config{}} + listener := NewListener(inner, mockConfig) + + // Close should close the inner listener + err = listener.Close() + assert.NoError(t, err) + + // Trying to accept should fail now + _, err = inner.Accept() + assert.Error(t, err) +} + +func TestListenerConfigReload(t *testing.T) { + // Create test certificates + _, serverCert1 := spiffetest.CreateWebCredentials(t) + _, serverCert2 := spiffetest.CreateWebCredentials(t) + + // Create inner TCP listener + inner, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + // Start with first config + currentConfig := &tls.Config{ + Certificates: []tls.Certificate{*serverCert1}, + } + + // Create a mock config that can be updated + mockConfig := &reloadableMockConfig{config: currentConfig} + listener := NewListener(inner, mockConfig) + defer listener.Close() + + // Helper to do one connection cycle + doConnection := func() error { + acceptDone := make(chan net.Conn, 1) + acceptErr := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + acceptErr <- err + return + } + // Complete handshake + tlsConn := conn.(*tls.Conn) + if err := tlsConn.Handshake(); err != nil { + conn.Close() + acceptErr <- err + return + } + acceptDone <- conn + }() + + // Connect from client + clientConn, err := tls.Dial("tcp", listener.Addr().String(), + &tls.Config{InsecureSkipVerify: true}) + if err != nil { + return err + } + clientConn.Close() + + select { + case conn := <-acceptDone: + conn.Close() + return nil + case err := <-acceptErr: + return err + case <-time.After(5 * time.Second): + return errors.New("timeout") + } + } + + // First connection with first config + err = doConnection() + require.NoError(t, err) + + // Update config (simulating reload) + mockConfig.config = &tls.Config{ + Certificates: []tls.Certificate{*serverCert2}, + } + + // Second connection should use new config + err = doConnection() + require.NoError(t, err) +} + +// reloadableMockConfig allows config changes between accepts +type reloadableMockConfig struct { + config *tls.Config +} + +func (m *reloadableMockConfig) GetServerConfig() *tls.Config { + return m.config +} diff --git a/magefile.go b/magefile.go index 761e24d2c7..97a5722fc5 100644 --- a/magefile.go +++ b/magefile.go @@ -5,6 +5,8 @@ package main import ( "bytes" "context" + "crypto/rand" + "encoding/base64" "fmt" "os" "os/exec" @@ -18,12 +20,22 @@ import ( ) type Go mg.Namespace +type Apple mg.Namespace type Git mg.Namespace type Test mg.Namespace type Docker mg.Namespace var Default = Go.Build +// runSilent executes a command without echoing it to stdout, to avoid +// leaking sensitive arguments (passwords, secrets) in CI logs. +func runSilent(name string, args ...string) error { + cmd := exec.Command(name, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() +} + // printf prints the given format and args if verbose mode is enabled. func printf(format string, args ...interface{}) { if mg.Verbose() { @@ -77,6 +89,252 @@ func (Go) Man(ctx context.Context) error { return nil } +// Codesign signs a macOS binary using the codesign tool. The binary argument +// specifies which file to sign. Requires macOS. If CODESIGN_CERTIFICATE is +// set, a temporary keychain is created, the certificate is imported, and the +// keychain is cleaned up after signing. +// +// Environment variables: +// - CODESIGN_IDENTITY: Signing identity (required, e.g. "Developer ID Application: Name (TEAMID)") +// - CODESIGN_CERTIFICATE: Base64-encoded .p12 certificate to import into a temporary keychain (optional, for CI) +// - CODESIGN_CERTIFICATE_PASSWORD: Password for the .p12 certificate (required if CODESIGN_CERTIFICATE is set) +func (Apple) Codesign(ctx context.Context, binary string) error { + if runtime.GOOS != "darwin" { + return fmt.Errorf("codesigning is only supported on macOS") + } + + identity := os.Getenv("CODESIGN_IDENTITY") + if identity == "" { + return fmt.Errorf("CODESIGN_IDENTITY must be set") + } + + certData := os.Getenv("CODESIGN_CERTIFICATE") + if certData != "" { + cleanup, err := setupCodesignKeychain(certData) + if err != nil { + return err + } + defer cleanup() + } + + printf("Signing binary %s with identity %s\n", binary, identity) + + if err := sh.Run("codesign", "--force", "--options", "runtime", "--sign", identity, binary); err != nil { + return fmt.Errorf("codesign %s failed: %w", binary, err) + } + + if err := sh.Run("codesign", "--verify", "--verbose", binary); err != nil { + return fmt.Errorf("codesign verification of %s failed: %w", binary, err) + } + + printf("Binary %s signed and verified successfully\n", binary) + return nil +} + +// Notarize submits a signed macOS binary to Apple's notary service. Requires +// macOS. If NOTARIZE_KEY is set, the .p8 key is written to a temp file and +// cleaned up after notarization. +// +// The binary is zipped for submission and the zip is removed afterward. +// Note: stapling only works for .app, .pkg, and .dmg — for bare binaries the +// notarization is registered with Apple but cannot be stapled. The staple step +// is attempted but a failure is not treated as an error. +// +// Environment variables: +// - NOTARIZE_ISSUER_ID: App Store Connect API issuer ID (required) +// - NOTARIZE_KEY_ID: App Store Connect API key ID (required) +// - NOTARIZE_KEY: Base64-encoded .p8 private key (optional, for CI; if not set, key must already exist) +func (Apple) Notarize(ctx context.Context, binary string) error { + if runtime.GOOS != "darwin" { + return fmt.Errorf("notarization is only supported on macOS") + } + + issuerID := os.Getenv("NOTARIZE_ISSUER_ID") + keyID := os.Getenv("NOTARIZE_KEY_ID") + if issuerID == "" || keyID == "" { + return fmt.Errorf("NOTARIZE_ISSUER_ID and NOTARIZE_KEY_ID must be set") + } + + // If NOTARIZE_KEY is set, write the .p8 key to a temp file for notarytool + keyPath, err := setupNotarizeKey(keyID) + if err != nil { + return err + } + if keyPath != "" { + defer os.Remove(keyPath) + } + + // Create zip for submission + zipPath := binary + ".zip" + if err := sh.Run("ditto", "-c", "-k", "--sequesterRsrc", binary, zipPath); err != nil { + return fmt.Errorf("failed to create zip for %s: %w", binary, err) + } + + printf("Submitting %s for notarization...\n", binary) + + submitArgs := []string{"notarytool", "submit", zipPath, + "--issuer", issuerID, + "--key-id", keyID, + } + if keyPath != "" { + submitArgs = append(submitArgs, "--key", keyPath) + } + submitArgs = append(submitArgs, "--wait") + + err = sh.Run("xcrun", submitArgs...) + os.Remove(zipPath) + if err != nil { + return fmt.Errorf("notarization of %s failed: %w", binary, err) + } + + // Attempt to staple — this only works for .app/.pkg/.dmg, not bare binaries + if err := sh.Run("xcrun", "stapler", "staple", binary); err != nil { + printf("Stapling skipped for %s (not supported for bare binaries): %v\n", binary, err) + } + + printf("Notarization of %s completed successfully\n", binary) + return nil +} + +// setupCodesignKeychain creates a temporary keychain, imports the signing +// certificate, and configures the keychain search list. Returns a cleanup +// function that removes the temporary keychain and restores the original +// search list. +func setupCodesignKeychain(certBase64 string) (func(), error) { + password := os.Getenv("CODESIGN_CERTIFICATE_PASSWORD") + if password == "" { + return nil, fmt.Errorf("CODESIGN_CERTIFICATE_PASSWORD must be set when CODESIGN_CERTIFICATE is set") + } + + // Decode certificate + certBytes, err := base64.StdEncoding.DecodeString(certBase64) + if err != nil { + return nil, fmt.Errorf("failed to decode CODESIGN_CERTIFICATE: %w", err) + } + + // Write certificate to temp file + certFile, err := os.CreateTemp("", "codesign-*.p12") + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + if _, err := certFile.Write(certBytes); err != nil { + os.Remove(certFile.Name()) + return nil, fmt.Errorf("failed to write certificate: %w", err) + } + certFile.Close() + + // Generate random keychain password + keychainPassBytes := make([]byte, 32) + if _, err := rand.Read(keychainPassBytes); err != nil { + os.Remove(certFile.Name()) + return nil, fmt.Errorf("failed to generate keychain password: %w", err) + } + keychainPassword := base64.StdEncoding.EncodeToString(keychainPassBytes) + + keychainPath := "ghostunnel-signing.keychain-db" + + // Save original keychain search list + originalKeychains, err := sh.Output("security", "list-keychains", "-d", "user") + if err != nil { + os.Remove(certFile.Name()) + return nil, fmt.Errorf("failed to list keychains: %w", err) + } + + cleanup := func() { + // Restore original keychain search list + restoreArgs := []string{"list-keychains", "-d", "user", "-s"} + restoreArgs = append(restoreArgs, parseKeychainPaths(originalKeychains)...) + sh.Run("security", restoreArgs...) + sh.Run("security", "delete-keychain", keychainPath) + os.Remove(certFile.Name()) + } + + // Create temporary keychain (suppress command echo to avoid leaking keychain password) + if err := runSilent("security", "create-keychain", "-p", keychainPassword, keychainPath); err != nil { + cleanup() + return nil, fmt.Errorf("failed to create keychain: %w", err) + } + + // Set keychain settings (no auto-lock) + if err := sh.Run("security", "set-keychain-settings", keychainPath); err != nil { + cleanup() + return nil, fmt.Errorf("failed to set keychain settings: %w", err) + } + + // Unlock keychain (suppress command echo to avoid leaking keychain password) + if err := runSilent("security", "unlock-keychain", "-p", keychainPassword, keychainPath); err != nil { + cleanup() + return nil, fmt.Errorf("failed to unlock keychain: %w", err) + } + + // Import certificate into keychain (suppress command echo to avoid leaking certificate password) + if err := runSilent("security", "import", certFile.Name(), "-k", keychainPath, "-f", "pkcs12", "-P", password, "-T", "/usr/bin/codesign"); err != nil { + cleanup() + return nil, fmt.Errorf("failed to import certificate: %w", err) + } + + // Set key partition list to allow codesign access (suppress command echo to avoid leaking keychain password) + if err := runSilent("security", "set-key-partition-list", "-S", "apple-tool:,apple:,codesign:", "-s", "-k", keychainPassword, keychainPath); err != nil { + cleanup() + return nil, fmt.Errorf("failed to set key partition list: %w", err) + } + + // Add temporary keychain to search list (prepend to existing) + keychainArgs := []string{"list-keychains", "-d", "user", "-s", keychainPath} + keychainArgs = append(keychainArgs, parseKeychainPaths(originalKeychains)...) + if err := sh.Run("security", keychainArgs...); err != nil { + cleanup() + return nil, fmt.Errorf("failed to update keychain search list: %w", err) + } + + return cleanup, nil +} + +// setupNotarizeKey writes the NOTARIZE_KEY env var (base64-encoded .p8) to a +// temp file and returns its path. Returns an empty path if NOTARIZE_KEY is not +// set (assumes the key file is already available locally). +func setupNotarizeKey(keyID string) (string, error) { + keyData := os.Getenv("NOTARIZE_KEY") + if keyData == "" { + return "", nil + } + + keyBytes, err := base64.StdEncoding.DecodeString(keyData) + if err != nil { + return "", fmt.Errorf("failed to decode NOTARIZE_KEY: %w", err) + } + + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + + keyDir := filepath.Join(homeDir, "private_keys") + if err := os.MkdirAll(keyDir, 0700); err != nil { + return "", fmt.Errorf("failed to create private_keys directory: %w", err) + } + + keyPath := filepath.Join(keyDir, fmt.Sprintf("AuthKey_%s.p8", keyID)) + if err := os.WriteFile(keyPath, keyBytes, 0600); err != nil { + return "", fmt.Errorf("failed to write API key: %w", err) + } + + return keyPath, nil +} + +// parseKeychainPaths parses the output of `security list-keychains` into +// a list of unquoted keychain paths. +func parseKeychainPaths(output string) []string { + var paths []string + for _, line := range strings.Split(output, "\n") { + kc := strings.TrimSpace(strings.Trim(strings.TrimSpace(line), "\"")) + if kc != "" { + paths = append(paths, kc) + } + } + return paths +} + // Clean removes build artifacts. func (Git) Clean(ctx context.Context) error { return sh.Run("git", "clean", "-Xdf") @@ -365,19 +623,23 @@ func (Docker) Push(ctx context.Context) error { // buildDocker builds and tags all Docker containers, optionally pushing them to Docker Hub. func buildDocker(ctx context.Context, push bool) error { - // Determine base tag (latest for master, version tag otherwise) - baseTag, err := getDockerTag() + baseTags, err := getDockerTags() if err != nil { return err } - builds := map[string][]string{ - "Dockerfile-alpine": []string{ + builds := map[string][]string{} + for _, baseTag := range baseTags { + builds["Dockerfile-alpine"] = append(builds["Dockerfile-alpine"], fmt.Sprintf("ghostunnel/ghostunnel:%s", baseTag), fmt.Sprintf("ghostunnel/ghostunnel:%s-alpine", baseTag), - }, - "Dockerfile-debian": []string{fmt.Sprintf("ghostunnel/ghostunnel:%s-debian", baseTag)}, - "Dockerfile-distroless": []string{fmt.Sprintf("ghostunnel/ghostunnel:%s-distroless", baseTag)}, + ) + builds["Dockerfile-debian"] = append(builds["Dockerfile-debian"], + fmt.Sprintf("ghostunnel/ghostunnel:%s-debian", baseTag), + ) + builds["Dockerfile-distroless"] = append(builds["Dockerfile-distroless"], + fmt.Sprintf("ghostunnel/ghostunnel:%s-distroless", baseTag), + ) } for dockerfile, tags := range builds { @@ -425,37 +687,39 @@ func getVersion() string { return strings.TrimSpace(output) } -// getDockerTag determines the Docker tag to use based on git state. -// Returns "latest" if on master branch, otherwise returns the most recent tag. -func getDockerTag() (string, error) { +// getDockerTags determines the Docker tags to use based on git state. +// For release tags (refs/tags/v*), returns both the version tag and "latest". +// For master branch, returns "master". For local non-master branches, returns +// the most recent git tag. +func getDockerTags() ([]string, error) { // Check if we're on a tag (for GitHub Actions when triggered by tag push) // In GitHub Actions, GITHUB_REF will be set, but locally we check git githubRef := os.Getenv("GITHUB_REF") if githubRef != "" { // GitHub Actions: refs/heads/master or refs/tags/v1.2.3 if strings.HasPrefix(githubRef, "refs/heads/master") { - return "latest", nil + return []string{"master"}, nil } if strings.HasPrefix(githubRef, "refs/tags/") { tag := strings.TrimPrefix(githubRef, "refs/tags/") - return tag, nil + return []string{tag, "latest"}, nil } } // Check current branch branch, err := sh.Output("git", "rev-parse", "--abbrev-ref", "HEAD") if err != nil { - return "", fmt.Errorf("failed to determine git ref: %w", err) + return nil, fmt.Errorf("failed to determine git ref: %w", err) } if strings.TrimSpace(branch) == "master" { - return "latest", nil + return []string{"master"}, nil } // Not on master, get the most recent tag tag, err := sh.Output("git", "describe", "--tags", "--abbrev=0") if err != nil { - return "", fmt.Errorf("failed to get git tag: %w", err) + return nil, fmt.Errorf("failed to get git tag: %w", err) } - return strings.TrimSpace(tag), nil + return []string{strings.TrimSpace(tag)}, nil } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 33fb76c9db..d16d6e74e9 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -315,3 +315,257 @@ func TestCopyData(t *testing.T) { t.Fatalf("input and output were different after copy") } } + +// mockNetError implements net.Error for testing isTimeoutError +type mockNetError struct { + timeout bool + temporary bool + msg string +} + +func (e *mockNetError) Error() string { return e.msg } +func (e *mockNetError) Timeout() bool { return e.timeout } +func (e *mockNetError) Temporary() bool { return e.temporary } + +func TestIsTimeoutError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "regular error", + err: errors.New("test error"), + expected: false, + }, + { + name: "net.Error with timeout=true", + err: &mockNetError{timeout: true, msg: "timeout error"}, + expected: true, + }, + { + name: "net.Error with timeout=false", + err: &mockNetError{timeout: false, msg: "non-timeout error"}, + expected: false, + }, + { + name: "context.DeadlineExceeded", + err: context.DeadlineExceeded, + expected: true, + }, + { + name: "context.Canceled", + err: context.Canceled, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := isTimeoutError(tc.err) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestIsClosedConnectionError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "regular error", + err: errors.New("test error"), + expected: false, + }, + { + name: "closed pipe error", + err: errors.New("read/write on closed pipe"), + expected: true, + }, + { + name: "net.OpError read closed", + err: &net.OpError{ + Op: "read", + Err: errors.New("use of closed network connection"), + }, + expected: true, + }, + { + name: "net.OpError write closed", + err: &net.OpError{ + Op: "write", + Err: errors.New("use of closed network connection"), + }, + expected: true, + }, + { + name: "net.OpError readfrom closed", + err: &net.OpError{ + Op: "readfrom", + Err: errors.New("use of closed network connection"), + }, + expected: true, + }, + { + name: "net.OpError writeto closed", + err: &net.OpError{ + Op: "writeto", + Err: errors.New("use of closed network connection"), + }, + expected: true, + }, + { + name: "net.OpError other op", + err: &net.OpError{ + Op: "dial", + Err: errors.New("use of closed network connection"), + }, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := isClosedConnectionError(tc.err) + assert.Equal(t, tc.expected, result) + }) + } +} + +// mockConn is a minimal net.Conn implementation that is neither TCP nor Unix +type mockConn struct { + closed bool +} + +func (m *mockConn) Read(b []byte) (n int, err error) { return 0, nil } +func (m *mockConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (m *mockConn) Close() error { m.closed = true; return nil } +func (m *mockConn) LocalAddr() net.Addr { return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} } +func (m *mockConn) RemoteAddr() net.Addr { return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} } +func (m *mockConn) SetDeadline(t time.Time) error { return nil } +func (m *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil } + +func TestCloseReadNonTCPConnection(t *testing.T) { + conn := &mockConn{} + closeRead(conn) + assert.True(t, conn.closed, "non-TCP/Unix conn should be closed via Close()") +} + +func TestCloseWriteNonTCPConnection(t *testing.T) { + conn := &mockConn{} + closeWrite(conn) + assert.True(t, conn.closed, "non-TCP/Unix conn should be closed via Close()") +} + +func TestCloseReadTCPConnection(t *testing.T) { + // Create a TCP connection pair + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.Nil(t, err) + defer listener.Close() + + go func() { + conn, _ := listener.Accept() + if conn != nil { + time.Sleep(100 * time.Millisecond) + conn.Close() + } + }() + + conn, err := net.Dial("tcp", listener.Addr().String()) + assert.Nil(t, err) + defer conn.Close() + + // closeRead should not panic and should work on TCP + closeRead(conn) +} + +func TestCloseWriteTCPConnection(t *testing.T) { + // Create a TCP connection pair + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.Nil(t, err) + defer listener.Close() + + go func() { + conn, _ := listener.Accept() + if conn != nil { + time.Sleep(100 * time.Millisecond) + conn.Close() + } + }() + + conn, err := net.Dial("tcp", listener.Addr().String()) + assert.Nil(t, err) + defer conn.Close() + + // closeWrite should not panic and should work on TCP + closeWrite(conn) +} + +func TestForceHandshakeNonTLSConn(t *testing.T) { + // Create a regular TCP connection (non-TLS) + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.Nil(t, err) + defer listener.Close() + + go func() { + conn, _ := listener.Accept() + if conn != nil { + conn.Close() + } + }() + + conn, err := net.Dial("tcp", listener.Addr().String()) + assert.Nil(t, err) + defer conn.Close() + + // forceHandshake should be a no-op for non-TLS connections + ctx := context.Background() + err = forceHandshake(ctx, conn) + assert.Nil(t, err, "forceHandshake should succeed for non-TLS conn") +} + +func TestLogConnectionMessageDisabled(t *testing.T) { + // Test with LogConnections disabled + p := New(nil, 5*time.Second, 5*time.Second, 0, 0, nil, &testLogger{}, 0, false) + + // Create pipe connections + src, dst := net.Pipe() + defer src.Close() + defer dst.Close() + + // Should not panic even with logging disabled + p.logConnectionMessage("test", src, dst, 0, 0, time.Time{}) +} + +func TestLogConditional(t *testing.T) { + logged := false + logger := &callbackLogger{callback: func(format string, v ...interface{}) { + logged = true + }} + + // Test with flag enabled + p := New(nil, 5*time.Second, 5*time.Second, 0, 0, nil, logger, LogConnectionErrors, false) + p.logConditional(LogConnectionErrors, "test message") + assert.True(t, logged, "should log when flag is enabled") + + // Test with flag disabled + logged = false + p.logConditional(LogHandshakeErrors, "test message") + assert.False(t, logged, "should not log when flag is disabled") +} + +type callbackLogger struct { + callback func(format string, v ...interface{}) +} + +func (c *callbackLogger) Printf(format string, v ...interface{}) { + c.callback(format, v...) +}