diff --git a/pkg/gateway/tunnel/wireguard/device.go b/pkg/gateway/tunnel/wireguard/device.go index f40d54c1b0..ef7a83882e 100644 --- a/pkg/gateway/tunnel/wireguard/device.go +++ b/pkg/gateway/tunnel/wireguard/device.go @@ -42,6 +42,11 @@ func configureDevice(wgcl *wgctrl.Client, options *Options, peerPubKey wgtypes.K switch options.GwOptions.Mode { case gateway.ModeServer: confdev.ListenPort = &options.ListenPort + + endpoint := getExistingEndpoint(wgcl, peerPubKey) + if endpoint != nil { + confdev.Peers[0].Endpoint = endpoint + } case gateway.ModeClient: confdev.Peers[0].Endpoint = &net.UDPAddr{ IP: options.EndpointIP, @@ -56,3 +61,48 @@ func configureDevice(wgcl *wgctrl.Client, options *Options, peerPubKey wgtypes.K } return nil } + +func getExistingEndpoint(wgcl *wgctrl.Client, peerPubKey wgtypes.Key) *net.UDPAddr { + peer := getExistingPeer(wgcl, peerPubKey) + + if peer == nil { + return nil + } + + if peer.Endpoint != nil { + klog.Infof("Discovered endpoint %s for peer %s", peer.Endpoint, peerPubKey.String()) + return peer.Endpoint + } + + return nil +} + +func getExistingPeer(wgcl *wgctrl.Client, peerPubKey wgtypes.Key) *wgtypes.Peer { + dev := getExistingDevice(wgcl) + + if dev == nil { + return nil + } + + for i := range dev.Peers { + if dev.Peers[i].PublicKey == peerPubKey { + klog.Infof("Found existing peer for key %s", peerPubKey.String()) + return &dev.Peers[i] + } + } + + klog.Infof("No existing peer %s found", peerPubKey.String()) + return nil +} + +func getExistingDevice(wgcl *wgctrl.Client) *wgtypes.Device { + dev, err := wgcl.Device(tunnel.TunnelInterfaceName) + + if err == nil { + klog.Infof("Found existing device %s", tunnel.TunnelInterfaceName) + return dev + } + + klog.Infof("No existing device %s found", tunnel.TunnelInterfaceName) + return nil +}