Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions management/server/types/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func (a *Account) GetPeerNetworkMap(
if dnsManagementStatus {
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain,
Records: records,
Expand Down Expand Up @@ -1682,7 +1682,7 @@ func peerSupportsPortRanges(peerVer string) bool {
}

// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord {
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord {
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
peerIPs := make(map[string]struct{})

Expand All @@ -1693,6 +1693,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p
peerIPs[peerToConnect.IP.String()] = struct{}{}
}

for _, expiredPeer := range expiredPeers {
peerIPs[expiredPeer.IP.String()] = struct{}{}
}

for _, record := range customZone.Records {
if _, exists := peerIPs[record.RData]; exists {
filteredRecords = append(filteredRecords, record)
Expand Down
34 changes: 31 additions & 3 deletions management/server/types/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
peer *nbpeer.Peer
customZone nbdns.CustomZone
peersToConnect []*nbpeer.Peer
expiredPeers []*nbpeer.Peer
expectedRecords []nbdns.SimpleRecord
}{
{
Expand All @@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
},
},
peersToConnect: []*nbpeer.Peer{},
expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
Expand Down Expand Up @@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
}
return peers
}(),
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: func() []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
Expand Down Expand Up @@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expiredPeers: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
Expand All @@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
{
name: "expired peers are included in DNS entries",
customZone: nbdns.CustomZone{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
{Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
peersToConnect: []*nbpeer.Peer{
{ID: "peer1", IP: net.ParseIP("10.0.0.1")},
},
expiredPeers: []*nbpeer.Peer{
{ID: "expired-peer", IP: net.ParseIP("10.0.0.99")},
},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect)
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers)
assert.Equal(t, len(tt.expectedRecords), len(result))
assert.ElementsMatch(t, tt.expectedRecords, result)
})
Expand Down
Loading