diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 121621109fc..1be1530fc01 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -53,6 +53,7 @@ var ( certLockMethod string wgPort int proxyProtocol bool + preSharedKey string ) var rootCmd = &cobra.Command{ @@ -84,6 +85,7 @@ func init() { rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease") rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments") rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies") + rootCmd.Flags().StringVar(&preSharedKey, "pre-shared-key", envStringOrDefault("NB_PROXY_PRE_SHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers") } // Execute runs the root command. @@ -156,6 +158,7 @@ func runServer(cmd *cobra.Command, args []string) error { CertLockMethod: nbacme.CertLockMethod(certLockMethod), WireguardPort: wgPort, ProxyProtocol: proxyProtocol, + PreSharedKey: preSharedKey, } ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index d7fd2746f49..481b42d2b01 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -86,6 +86,13 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo } } +// ClientConfig holds configuration for the embedded NetBird client. +type ClientConfig struct { + MgmtAddr string + WGPort int + PreSharedKey string +} + type statusNotifier interface { NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error } @@ -98,10 +105,9 @@ type managementClient interface { // backed by underlying NetBird connections. // Clients are keyed by AccountID, allowing multiple domains to share the same connection. type NetBird struct { - mgmtAddr string proxyID string proxyAddr string - wgPort int + clientCfg ClientConfig logger *log.Logger mgmtClient managementClient transportCfg transportConfig @@ -229,11 +235,12 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account // The peer has already been created via CreateProxyPeer RPC with the public key. client, err := embed.New(embed.Options{ DeviceName: deviceNamePrefix + n.proxyID, - ManagementURL: n.mgmtAddr, + ManagementURL: n.clientCfg.MgmtAddr, PrivateKey: privateKey.String(), LogLevel: log.WarnLevel.String(), BlockInbound: true, - WireguardPort: &n.wgPort, + WireguardPort: &n.clientCfg.WGPort, + PreSharedKey: n.clientCfg.PreSharedKey, }) if err != nil { return nil, fmt.Errorf("create netbird client: %w", err) @@ -536,18 +543,17 @@ func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client { return result } -// NewNetBird creates a new NetBird transport. Set wgPort to 0 for a random +// NewNetBird creates a new NetBird transport. Set clientCfg.WGPort to 0 for a random // OS-assigned port. A fixed port only works with single-account deployments; // multiple accounts will fail to bind the same port. -func NewNetBird(mgmtAddr, proxyID, proxyAddr string, wgPort int, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird { +func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird { if logger == nil { logger = log.StandardLogger() } return &NetBird{ - mgmtAddr: mgmtAddr, proxyID: proxyID, proxyAddr: proxyAddr, - wgPort: wgPort, + clientCfg: clientCfg, logger: logger, clients: make(map[types.AccountID]*clientEntry), statusNotifier: notifier, diff --git a/proxy/internal/roundtrip/netbird_test.go b/proxy/internal/roundtrip/netbird_test.go index 3e76af9da5f..0a742c2fa61 100644 --- a/proxy/internal/roundtrip/netbird_test.go +++ b/proxy/internal/roundtrip/netbird_test.go @@ -49,7 +49,11 @@ func (m *mockStatusNotifier) calls() []statusCall { // mockNetBird creates a NetBird instance for testing without actually connecting. // It uses an invalid management URL to prevent real connections. func mockNetBird() *NetBird { - return NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, nil, &mockMgmtClient{}) + return NewNetBird("test-proxy", "invalid.test", ClientConfig{ + MgmtAddr: "http://invalid.test:9999", + WGPort: 0, + PreSharedKey: "", + }, nil, nil, &mockMgmtClient{}) } func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) { @@ -282,7 +286,11 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) { func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { notifier := &mockStatusNotifier{} - nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{}) + nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{ + MgmtAddr: "http://invalid.test:9999", + WGPort: 0, + PreSharedKey: "", + }, nil, notifier, &mockMgmtClient{}) accountID := types.AccountID("account-1") // Add first domain — creates a new client entry. @@ -308,7 +316,11 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) { func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) { notifier := &mockStatusNotifier{} - nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{}) + nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{ + MgmtAddr: "http://invalid.test:9999", + WGPort: 0, + PreSharedKey: "", + }, nil, notifier, &mockMgmtClient{}) accountID := types.AccountID("account-1") err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1") diff --git a/proxy/server.go b/proxy/server.go index 60811e53b83..48a876899a4 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -114,6 +114,8 @@ type Server struct { // When enabled, the real client IP is extracted from the PROXY header // sent by upstream L4 proxies that support PROXY protocol. ProxyProtocol bool + // PreSharedKey used for tunnel between proxy and peers (set globally not per account) + PreSharedKey string } // NotifyStatus sends a status update to management about tunnel connectivity @@ -163,7 +165,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { // Initialize the netbird client, this is required to build peer connections // to proxy over. - s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient) + s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{ + MgmtAddr: s.ManagementAddress, + WGPort: s.WireguardPort, + PreSharedKey: s.PreSharedKey, + }, s.Logger, s, s.mgmtClient) tlsConfig, err := s.configureTLS(ctx) if err != nil {