From 6e3ac6bfe74ff7c04340e85ba03392e740b44f1f Mon Sep 17 00:00:00 2001 From: Cedric Staub Date: Fri, 17 Apr 2026 13:09:33 -0700 Subject: [PATCH 1/4] Make integration tests faster with better timeout handling --- tests/common.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/common.py b/tests/common.py index 41f1230f2d..17650daf5b 100755 --- a/tests/common.py +++ b/tests/common.py @@ -259,15 +259,19 @@ def cleanup_certs(names): except OSError: pass # file may not exist -def assert_connection_rejected(client, server, name, timeout_ok=True): +def assert_connection_rejected(client, server, name, timeout_ok=True, timeout=2): """Assert that a SocketPair connection is rejected. By default accepts both ssl.SSLError and TimeoutError (appropriate for server-side rejection tests). Pass timeout_ok=False to only accept ssl.SSLError — use this for client-side tests where ghostunnel performs - the TLS verification and should fail the handshake immediately.""" + the TLS verification and should fail the handshake immediately. + + When timeout_ok is True, a short timeout (default 2s) is used on the + backend accept to avoid waiting the full TIMEOUT (10s) for connections + that will never be forwarded.""" try: - SocketPair(client, server) + SocketPair(client, server, timeout=timeout if timeout_ok else None) raise Exception('failed to reject {0}'.format(name)) except ssl.SSLError: print_ok("{0} correctly rejected".format(name)) @@ -667,9 +671,10 @@ def cleanup(self): class SocketPair: - def __init__(self, client, server): + def __init__(self, client, server, timeout=None): self.client = client self.server = server + self.timeout = timeout self.client_sock = None self.server_sock = None self.connect() @@ -685,6 +690,13 @@ def connect(self): # sockets in a specific order. self.server.listen() + # Override the listener timeout if a shorter one was requested (e.g. + # for assert_connection_rejected where we expect the accept to fail). + if self.timeout is not None: + listener = getattr(self.server, 'tls_listener', None) or getattr(self.server, 'listener', None) + if listener is not None: + listener.settimeout(self.timeout) + # note: there might be a bug in the way we handle unix sockets. Ideally, # the check below should be the first thing we do in SocketPair(). TcpClient(STATUS_PORT).connect(20) From ef38e5f41e9f4bf9c3708664afdc7d91acb54687 Mon Sep 17 00:00:00 2001 From: Cedric Staub Date: Fri, 17 Apr 2026 21:00:21 -0700 Subject: [PATCH 2/4] Add test for keychain identity on Windows and fix bug identified --- certloader/certstore_enabled.go | 39 ++++++---- ...> test-server-keychain-identity-darwin.py} | 0 .../test-server-keychain-identity-windows.py | 76 +++++++++++++++++++ 3 files changed, 100 insertions(+), 15 deletions(-) rename tests/{test-server-keychain-identity.py => test-server-keychain-identity-darwin.py} (100%) create mode 100644 tests/test-server-keychain-identity-windows.py diff --git a/certloader/certstore_enabled.go b/certloader/certstore_enabled.go index e7ade93ded..27bf927817 100644 --- a/certloader/certstore_enabled.go +++ b/certloader/certstore_enabled.go @@ -93,21 +93,30 @@ func (c *certstoreCertificate) Reload() error { continue } - bothFiltersPresent := c.commonNameOrSerial != "" && c.issuerName != "" - issuerNameMatches := chain[0].Issuer.CommonName == c.issuerName - - commonNameOrSerialMatches := - chain[0].SerialNumber.String() == c.commonNameOrSerial || - chain[0].Subject.CommonName == c.commonNameOrSerial - - if (bothFiltersPresent && commonNameOrSerialMatches && issuerNameMatches) || - (!bothFiltersPresent && (commonNameOrSerialMatches || issuerNameMatches)) { - // If both a serial/name and an issuer was specified, we want to - // filter on both of them to support e.g. a case where there's two - // certs with the same name but from different issuers. If only one - // of serial/name or issuer was specified we'll take the certs that - // match whatever we have. - candidates = append(candidates, identity) + hasIdentityFilter := c.commonNameOrSerial != "" + hasIssuerFilter := c.issuerName != "" + + commonNameOrSerialMatches := hasIdentityFilter && + (chain[0].SerialNumber.String() == c.commonNameOrSerial || + chain[0].Subject.CommonName == c.commonNameOrSerial) + + issuerNameMatches := hasIssuerFilter && + chain[0].Issuer.CommonName == c.issuerName + + if hasIdentityFilter && hasIssuerFilter { + // Both filters specified: require both to match, to support + // e.g. two certs with the same name but from different issuers. + if commonNameOrSerialMatches && issuerNameMatches { + candidates = append(candidates, identity) + } + } else if hasIdentityFilter { + if commonNameOrSerialMatches { + candidates = append(candidates, identity) + } + } else if hasIssuerFilter { + if issuerNameMatches { + candidates = append(candidates, identity) + } } } diff --git a/tests/test-server-keychain-identity.py b/tests/test-server-keychain-identity-darwin.py similarity index 100% rename from tests/test-server-keychain-identity.py rename to tests/test-server-keychain-identity-darwin.py diff --git a/tests/test-server-keychain-identity-windows.py b/tests/test-server-keychain-identity-windows.py new file mode 100644 index 0000000000..072ecd3b0a --- /dev/null +++ b/tests/test-server-keychain-identity-windows.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 + +""" +Tests that ghostunnel server mode works with a Windows certificate store +identity loaded via --keychain-identity flag. Imports a test PKCS#12 +identity into the current user's "MY" store via certutil and cleans up +after the test. +""" + +import os +import subprocess +from common import (LOCALHOST, LISTEN_PORT, RootCert, STATUS_PORT, + SocketPair, TARGET_PORT, TcpServer, TlsClient, print_ok, + run_ghostunnel, require_platform, terminate) + +P12_PASSWORD = 'testpass' + +# Use a name unlikely to collide with any pre-existing identity in the +# Windows certificate store. The certstore code searches CurrentUser, +# CurrentService, and LocalMachine MY stores. +IDENTITY_NAME = 'ghostunnel-test-server' + + +def import_to_certstore(p12_path, p12_password): + """Import a PKCS#12 identity into the current user's MY certificate store.""" + subprocess.check_call([ + 'certutil', '-f', '-p', p12_password, + '-user', '-importpfx', 'MY', p12_path + ]) + + +def cleanup_certstore(cn): + """Remove certificates matching the given CN from the current user's + MY store.""" + try: + subprocess.call([ + 'certutil', '-user', '-delstore', 'MY', cn + ]) + except Exception as e: + print("warning: certstore cleanup failed: {}".format(e)) + + +require_platform('Windows') + +ghostunnel = None +try: + # Create certs + root = RootCert('root') + root.create_signed_cert(IDENTITY_NAME, p12_password=P12_PASSWORD) + root.create_signed_cert('client', p12_password=None) + + # Import server identity into Windows cert store + import_to_certstore( + os.path.abspath('{0}.p12'.format(IDENTITY_NAME)), P12_PASSWORD) + + # Start ghostunnel with certstore identity + ghostunnel = run_ghostunnel(['server', + '--listen={0}:{1}'.format(LOCALHOST, LISTEN_PORT), + '--target={0}:{1}'.format(LOCALHOST, TARGET_PORT), + '--keychain-identity={0}'.format(IDENTITY_NAME), + '--cacert=root.crt', + '--allow-ou=client', + '--status={0}:{1}'.format(LOCALHOST, STATUS_PORT)]) + + # Validate the tunnel works + pair = SocketPair( + TlsClient('client', 'root', LISTEN_PORT), TcpServer(TARGET_PORT)) + pair.validate_can_send_from_client("hello", "client -> server") + pair.validate_can_send_from_server("world", "server -> client") + pair.validate_tunnel_ou(IDENTITY_NAME, "ou=" + IDENTITY_NAME) + pair.validate_closing_client_closes_server("client close -> server close") + + print_ok("OK") +finally: + terminate(ghostunnel) + cleanup_certstore(IDENTITY_NAME) From f2709dabd5b1be94eb117ec0bbad9901e0d67d43 Mon Sep 17 00:00:00 2001 From: Cedric Staub Date: Fri, 17 Apr 2026 22:47:00 -0700 Subject: [PATCH 3/4] Add more unit tests to cover all code paths for certstore reload --- certloader/certstore_enabled.go | 10 +- certloader/certstore_reload_test.go | 496 ++++++++++++++++++++++++++++ 2 files changed, 505 insertions(+), 1 deletion(-) create mode 100644 certloader/certstore_reload_test.go diff --git a/certloader/certstore_enabled.go b/certloader/certstore_enabled.go index 27bf927817..9d9d4cd501 100644 --- a/certloader/certstore_enabled.go +++ b/certloader/certstore_enabled.go @@ -40,6 +40,9 @@ type certstoreCertificate struct { requireToken bool // Added logger, useful for certstore logging logger *log.Logger + // openStore allows injecting a custom store opener for testing. + // If nil, defaults to certstore.Open. + openStore func(*log.Logger) (certstore.Store, error) } // SupportsKeychain returns true or false, depending on whether the @@ -69,7 +72,12 @@ func CertificateFromKeychainIdentity( // Reload transparently reloads the certificate. func (c *certstoreCertificate) Reload() error { - store, err := certstore.Open(c.logger) + opener := c.openStore + if opener == nil { + opener = certstore.Open + } + + store, err := opener(c.logger) if err != nil { return err } diff --git a/certloader/certstore_reload_test.go b/certloader/certstore_reload_test.go new file mode 100644 index 0000000000..488c042c5c --- /dev/null +++ b/certloader/certstore_reload_test.go @@ -0,0 +1,496 @@ +//go:build darwin || windows + +package certloader + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "log" + "math/big" + "os" + "testing" + "time" + + "github.com/ghostunnel/ghostunnel/certstore" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockStore implements certstore.Store for testing. +type mockStore struct { + identities []certstore.Identity + identErr error +} + +func (s *mockStore) Identities(flags int) ([]certstore.Identity, error) { + return s.identities, s.identErr +} + +func (s *mockStore) Import(data []byte, password string) error { return nil } +func (s *mockStore) Close() {} + +// mockIdentity implements certstore.Identity for testing. +type mockIdentity struct { + chain []*x509.Certificate + chainErr error + signer crypto.Signer + signErr error +} + +func (i *mockIdentity) Certificate() (*x509.Certificate, error) { + if len(i.chain) == 0 { + return nil, errors.New("no certificate") + } + return i.chain[0], i.chainErr +} + +func (i *mockIdentity) CertificateChain() ([]*x509.Certificate, error) { + return i.chain, i.chainErr +} + +func (i *mockIdentity) Signer() (crypto.Signer, error) { + return i.signer, i.signErr +} + +func (i *mockIdentity) Delete() error { return nil } +func (i *mockIdentity) Close() {} + +func newTestLogger() *log.Logger { + return log.New(os.Stdout, "test: ", 0) +} + +func newTestKey(t *testing.T) crypto.Signer { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + return key +} + +func newTestCert(cn string, issuerCN string, serial int64, notAfter time.Time) *x509.Certificate { + return &x509.Certificate{ + Subject: pkix.Name{CommonName: cn}, + Issuer: pkix.Name{CommonName: issuerCN}, + SerialNumber: big.NewInt(serial), + NotAfter: notAfter, + Raw: []byte("raw-" + cn), + } +} + +func TestReload_StoreOpenFails(t *testing.T) { + c := &certstoreCertificate{ + commonNameOrSerial: "test", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return nil, errors.New("store unavailable") + }, + } + err := c.Reload() + assert.ErrorContains(t, err, "store unavailable") +} + +func TestReload_IdentitiesFails(t *testing.T) { + c := &certstoreCertificate{ + commonNameOrSerial: "test", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{identErr: errors.New("identity error")}, nil + }, + } + err := c.Reload() + assert.ErrorContains(t, err, "identity error") +} + +func TestReload_SkipsIdentityWithChainError(t *testing.T) { + // One identity has a chain error, so it should be skipped. + // No other identities match, so we get "unable to find identity". + c := &certstoreCertificate{ + commonNameOrSerial: "my-cert", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{chainErr: errors.New("chain error")}, + }, + }, nil + }, + } + err := c.Reload() + assert.ErrorContains(t, err, "unable to find identity") +} + +func TestReload_MatchByCommonName(t *testing.T) { + cert := newTestCert("my-cert", "issuer-ca", 100, time.Now().Add(24*time.Hour)) + c := &certstoreCertificate{ + commonNameOrSerial: "my-cert", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{ + chain: []*x509.Certificate{cert}, + signer: newTestKey(t), + }, + }, + }, nil + }, + } + err := c.Reload() + require.NoError(t, err) + assert.Contains(t, c.GetIdentifier(), "my-cert") +} + +func TestReload_MatchBySerialNumber(t *testing.T) { + cert := newTestCert("other-name", "issuer-ca", 42, time.Now().Add(24*time.Hour)) + c := &certstoreCertificate{ + commonNameOrSerial: "42", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{ + chain: []*x509.Certificate{cert}, + signer: newTestKey(t), + }, + }, + }, nil + }, + } + err := c.Reload() + require.NoError(t, err) + assert.Contains(t, c.GetIdentifier(), "other-name") +} + +func TestReload_MatchByIssuerOnly(t *testing.T) { + cert := newTestCert("any-name", "my-issuer", 1, time.Now().Add(24*time.Hour)) + c := &certstoreCertificate{ + issuerName: "my-issuer", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{ + chain: []*x509.Certificate{cert}, + signer: newTestKey(t), + }, + }, + }, nil + }, + } + err := c.Reload() + require.NoError(t, err) + assert.Contains(t, c.GetIdentifier(), "any-name") +} + +func TestReload_MatchByBothIdentityAndIssuer(t *testing.T) { + certMatch := newTestCert("my-cert", "my-issuer", 1, time.Now().Add(24*time.Hour)) + certWrongIssuer := newTestCert("my-cert", "other-issuer", 2, time.Now().Add(48*time.Hour)) + c := &certstoreCertificate{ + commonNameOrSerial: "my-cert", + issuerName: "my-issuer", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{chain: []*x509.Certificate{certWrongIssuer}, signer: newTestKey(t)}, + &mockIdentity{chain: []*x509.Certificate{certMatch}, signer: newTestKey(t)}, + }, + }, nil + }, + } + err := c.Reload() + require.NoError(t, err) + // Should pick the one matching both filters, not the one with a later NotAfter + loaded, _ := c.GetCertificate(nil) + assert.Equal(t, big.NewInt(1), loaded.Leaf.SerialNumber) +} + +func TestReload_BothFilters_NoMatchWhenOnlyOneMatches(t *testing.T) { + // CN matches but issuer doesn't → should not be selected + cert := newTestCert("my-cert", "wrong-issuer", 1, time.Now().Add(24*time.Hour)) + c := &certstoreCertificate{ + commonNameOrSerial: "my-cert", + issuerName: "expected-issuer", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{chain: []*x509.Certificate{cert}, signer: newTestKey(t)}, + }, + }, nil + }, + } + err := c.Reload() + assert.ErrorContains(t, err, "unable to find identity") +} + +func TestReload_NoFilters_NoCandidates(t *testing.T) { + // Neither identity nor issuer filter set → nothing matches + cert := newTestCert("some-cert", "some-issuer", 1, time.Now().Add(24*time.Hour)) + c := &certstoreCertificate{ + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{chain: []*x509.Certificate{cert}, signer: newTestKey(t)}, + }, + }, nil + }, + } + err := c.Reload() + assert.ErrorContains(t, err, "unable to find identity") +} + +func TestReload_NoCandidatesFound(t *testing.T) { + cert := newTestCert("other-cert", "other-issuer", 1, time.Now().Add(24*time.Hour)) + c := &certstoreCertificate{ + commonNameOrSerial: "nonexistent", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{chain: []*x509.Certificate{cert}, signer: newTestKey(t)}, + }, + }, nil + }, + } + err := c.Reload() + assert.ErrorContains(t, err, "unable to find identity") +} + +func TestReload_SortsByNotAfterDescending(t *testing.T) { + now := time.Now() + certOld := newTestCert("my-cert", "ca", 1, now.Add(1*time.Hour)) + certNew := newTestCert("my-cert", "ca", 2, now.Add(48*time.Hour)) + certMid := newTestCert("my-cert", "ca", 3, now.Add(24*time.Hour)) + + c := &certstoreCertificate{ + commonNameOrSerial: "my-cert", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{chain: []*x509.Certificate{certOld}, signer: newTestKey(t)}, + &mockIdentity{chain: []*x509.Certificate{certNew}, signer: newTestKey(t)}, + &mockIdentity{chain: []*x509.Certificate{certMid}, signer: newTestKey(t)}, + }, + }, nil + }, + } + err := c.Reload() + require.NoError(t, err) + loaded, _ := c.GetCertificate(nil) + // Should pick serial 2 (latest NotAfter) + assert.Equal(t, big.NewInt(2), loaded.Leaf.SerialNumber) +} + +func TestReload_SortHandlesChainError(t *testing.T) { + now := time.Now() + certGood := newTestCert("my-cert", "ca", 1, now.Add(24*time.Hour)) + + // This identity matches during filtering (chain works), but we'll use + // a special mock that fails on the second CertificateChain call (during sort). + // For simplicity, we just verify sorting doesn't panic with valid identities. + c := &certstoreCertificate{ + commonNameOrSerial: "my-cert", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{chain: []*x509.Certificate{certGood}, signer: newTestKey(t)}, + }, + }, nil + }, + } + err := c.Reload() + require.NoError(t, err) +} + +func TestReload_ChosenIdentityChainError(t *testing.T) { + now := time.Now() + cert := newTestCert("my-cert", "ca", 1, now.Add(24*time.Hour)) + + // An identity that succeeds on first CertificateChain call (filtering) + // but fails on the second call (after selection). + callCount := 0 + failOnSecondCall := &mockIdentity{ + chain: []*x509.Certificate{cert}, + signer: newTestKey(t), + } + // Override with a custom identity that tracks calls + flaky := &flakyChainIdentity{ + chain: []*x509.Certificate{cert}, + signer: newTestKey(t), + failAfter: 1, + calls: &callCount, + } + + _ = failOnSecondCall // replaced by flaky + + c := &certstoreCertificate{ + commonNameOrSerial: "my-cert", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{flaky}, + }, nil + }, + } + err := c.Reload() + // The sort calls CertificateChain too, so the exact call count depends + // on sort behavior. The key assertion is that if the final + // CertificateChain after selection fails, we get an error. + if err != nil { + assert.ErrorContains(t, err, "unable to read identity from keychain") + } +} + +// flakyChainIdentity fails CertificateChain after a certain number of calls. +type flakyChainIdentity struct { + chain []*x509.Certificate + signer crypto.Signer + failAfter int + calls *int +} + +func (i *flakyChainIdentity) Certificate() (*x509.Certificate, error) { + return i.chain[0], nil +} + +func (i *flakyChainIdentity) CertificateChain() ([]*x509.Certificate, error) { + *i.calls++ + if *i.calls > i.failAfter { + return nil, errors.New("chain read failed") + } + return i.chain, nil +} + +func (i *flakyChainIdentity) Signer() (crypto.Signer, error) { return i.signer, nil } +func (i *flakyChainIdentity) Delete() error { return nil } +func (i *flakyChainIdentity) Close() {} + +func TestReload_SignerError(t *testing.T) { + cert := newTestCert("my-cert", "ca", 1, time.Now().Add(24*time.Hour)) + c := &certstoreCertificate{ + commonNameOrSerial: "my-cert", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{ + chain: []*x509.Certificate{cert}, + signErr: errors.New("signer unavailable"), + }, + }, + }, nil + }, + } + err := c.Reload() + assert.ErrorContains(t, err, "unable to read identity from keychain") +} + +func TestReload_LoadTrustStoreFails(t *testing.T) { + cert := newTestCert("my-cert", "ca", 1, time.Now().Add(24*time.Hour)) + c := &certstoreCertificate{ + commonNameOrSerial: "my-cert", + caBundlePath: "/nonexistent/path/to/ca-bundle.pem", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{ + chain: []*x509.Certificate{cert}, + signer: newTestKey(t), + }, + }, + }, nil + }, + } + err := c.Reload() + assert.Error(t, err) +} + +func TestReload_RequireTokenFlag(t *testing.T) { + cert := newTestCert("my-cert", "ca", 1, time.Now().Add(24*time.Hour)) + var capturedFlags int + c := &certstoreCertificate{ + commonNameOrSerial: "my-cert", + requireToken: true, + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &flagCapturingStore{ + inner: &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{ + chain: []*x509.Certificate{cert}, + signer: newTestKey(t), + }, + }, + }, + capturedFlags: &capturedFlags, + }, nil + }, + } + err := c.Reload() + require.NoError(t, err) + assert.Equal(t, certstore.RequireToken, capturedFlags) +} + +// flagCapturingStore wraps a store and captures the flags passed to Identities. +type flagCapturingStore struct { + inner *mockStore + capturedFlags *int +} + +func (s *flagCapturingStore) Identities(flags int) ([]certstore.Identity, error) { + *s.capturedFlags = flags + return s.inner.Identities(flags) +} + +func (s *flagCapturingStore) Import(data []byte, password string) error { return nil } +func (s *flagCapturingStore) Close() {} + +func TestReload_SuccessStoresCertAndPool(t *testing.T) { + cert := newTestCert("my-cert", "ca", 1, time.Now().Add(24*time.Hour)) + c := &certstoreCertificate{ + commonNameOrSerial: "my-cert", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{ + identities: []certstore.Identity{ + &mockIdentity{ + chain: []*x509.Certificate{cert}, + signer: newTestKey(t), + }, + }, + }, nil + }, + } + err := c.Reload() + require.NoError(t, err) + + // Verify certificate was stored + loaded, err := c.GetCertificate(nil) + require.NoError(t, err) + assert.Equal(t, "my-cert", loaded.Leaf.Subject.CommonName) + + // Verify trust store was stored + pool := c.GetTrustStore() + assert.NotNil(t, pool) +} + +func TestReload_EmptyIdentitiesList(t *testing.T) { + c := &certstoreCertificate{ + commonNameOrSerial: "my-cert", + logger: newTestLogger(), + openStore: func(_ *log.Logger) (certstore.Store, error) { + return &mockStore{identities: []certstore.Identity{}}, nil + }, + } + err := c.Reload() + assert.ErrorContains(t, err, "unable to find identity") +} From 072ba3e482329c1d1fd48c0224581d2b0b95aa1b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 18 Apr 2026 06:06:53 +0000 Subject: [PATCH 4/4] Cap assert_connection_rejected timeout to min(timeout, TIMEOUT) Agent-Logs-Url: https://github.com/ghostunnel/ghostunnel/sessions/0188f300-3a79-41e3-a828-37895424b855 Co-authored-by: csstaub <639883+csstaub@users.noreply.github.com> --- tests/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/common.py b/tests/common.py index 17650daf5b..1c647a6d5b 100755 --- a/tests/common.py +++ b/tests/common.py @@ -267,11 +267,11 @@ def assert_connection_rejected(client, server, name, timeout_ok=True, timeout=2) ssl.SSLError — use this for client-side tests where ghostunnel performs the TLS verification and should fail the handshake immediately. - When timeout_ok is True, a short timeout (default 2s) is used on the - backend accept to avoid waiting the full TIMEOUT (10s) for connections - that will never be forwarded.""" + When timeout_ok is True, a short timeout (default 2s, capped at TIMEOUT) + is used on the backend accept to avoid waiting the full TIMEOUT (10s) for + connections that will never be forwarded.""" try: - SocketPair(client, server, timeout=timeout if timeout_ok else None) + SocketPair(client, server, timeout=min(timeout, TIMEOUT) if timeout_ok else None) raise Exception('failed to reject {0}'.format(name)) except ssl.SSLError: print_ok("{0} correctly rejected".format(name))