Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 209 additions & 17 deletions internal/discovery/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,102 @@

import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"

"github.com/miekg/dns"
"github.com/projectdiscovery/dnsx/libs/dnsx"
"github.com/projectdiscovery/retryabledns"
dnsv1alpha1 "go.miloapis.com/dns-operator/api/v1alpha1"
)

const (
discoveryLookupConcurrency = 8
discoveryLookupTimeout = 5 * time.Second
)

var queryRecordsFunc = queryRecordsForName

// commonDiscoverySubdomains captures a curated list of host/service labels that we want
// to probe in addition to the zone apex. The goal is to cover the most frequent records
// seen across customer zones (web/mobile properties, mail/exchange, voice/video, etc.).
// The list is intentionally opinionated and can grow as we observe additional patterns.
var commonDiscoverySubdomains = []string{
"", // zone apex

// Web/mobile properties
"www",
"m",
"api",
"app",
"beta",
"dev",
"stage",
"staging",
"test",
"preview",
"admin",
"portal",
"dashboard",
"login",
"auth",
"sso",
"cdn",
"static",
"assets",
"media",
"img",
"files",
"support",
"help",
"status",

// Remote access / infra
"vpn",
"remote",
"intranet",
"edge",

// Email / collaboration
"mail",
"smtp",
"imap",
"pop",
"pop3",
"autodiscover",
"_autodiscover._tcp",
"_dmarc",
"_mta-sts",

// SIP/voice and chat services
"_sip._tcp",
"_sipfederationtls._tcp",
"_sipinternaltls._tcp",
"_xmpp-client._tcp",
"_xmpp-server._tcp",

// Secure mail submission / IMAP
"_imap._tcp",
"_imaps._tcp",
"_submission._tcp",

// Legacy transfer
"ftp",
"sftp",
}

// DiscoverZoneRecords performs best-effort discovery of common RR types for the given domain
// and returns RecordSets grouped by RecordType, with typed fields populated where available.
func DiscoverZoneRecords(ctx context.Context, domain string) ([]dnsv1alpha1.DiscoveredRecordSet, error) {
fmt.Printf("starting discovery for domain=%q\n", domain)
candidateNames := candidateDiscoveryNames(domain)
if len(candidateNames) == 0 {
return nil, fmt.Errorf("no discovery candidates produced for domain %q", domain)
}
fmt.Printf("querying %d candidate names for domain=%q\n", len(candidateNames), domain)
// Exclude NS and SOA per requirements.
options := dnsx.DefaultOptions
qtypes := []uint16{
Expand All @@ -36,26 +121,66 @@
if err != nil {
return nil, err
}
fmt.Println("querying multiple record types")
resp, err := client.QueryMultiple(domain)
if err != nil {
return nil, err

typeToRRs := make(map[uint16][]dns.RR)
var mu sync.Mutex
mergeResponses := func(rrs map[uint16][]dns.RR) {
if len(rrs) == 0 {
return
}
mu.Lock()
defer mu.Unlock()
for rt, entries := range rrs {
typeToRRs[rt] = append(typeToRRs[rt], entries...)
}
}
if resp == nil {
return nil, fmt.Errorf("dnsx returned nil response")

apexRRs, err := queryRecordsFunc(ctx, candidateNames[0], client, qtypes)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
fmt.Printf("timed out querying apex fqdn=%q after %s; treating as empty\n", candidateNames[0], discoveryLookupTimeout)
} else {
return nil, err
}
} else {
mergeResponses(apexRRs)
}
// Print quick summary for debugging
fmt.Printf("resolver status=%s answers=%d types=%v\n", resp.StatusCode, len(resp.AllRecords), qtypes)

// Build RRs from textual records since QueryMultiple does not populate RawResp
typeToRRs := make(map[uint16][]dns.RR)
for _, rec := range resp.AllRecords {
rr, perr := dns.NewRR(rec)
if perr != nil || rr == nil {
continue
if len(candidateNames) > 1 {
var wg sync.WaitGroup
sem := make(chan struct{}, discoveryLookupConcurrency)

for _, name := range candidateNames[1:] {
name := name

Check failure on line 154 in internal/discovery/resolver.go

View workflow job for this annotation

GitHub Actions / Run on Ubuntu

The copy of the 'for' variable "name" can be deleted (Go 1.22+) (copyloopvar)

Check failure on line 154 in internal/discovery/resolver.go

View workflow job for this annotation

GitHub Actions / Run on Ubuntu

The copy of the 'for' variable "name" can be deleted (Go 1.22+) (copyloopvar)
wg.Add(1)
go func() {
defer wg.Done()

select {
case sem <- struct{}{}:
case <-ctx.Done():
fmt.Printf("context canceled before querying fqdn=%q\n", name)
return
}
defer func() { <-sem }()

rrs, qerr := queryRecordsFunc(ctx, name, client, qtypes)
if qerr != nil {
switch {
case errors.Is(qerr, context.DeadlineExceeded):
fmt.Printf("timed out querying fqdn=%q after %s; treating as empty\n", name, discoveryLookupTimeout)
case errors.Is(qerr, context.Canceled):
fmt.Printf("context canceled before completing lookup for fqdn=%q\n", name)
default:
fmt.Printf("skipping fqdn=%q due to query error: %v\n", name, qerr)
}
return
}
mergeResponses(rrs)
}()
}
rt := rr.Header().Rrtype
typeToRRs[rt] = append(typeToRRs[rt], rr)

wg.Wait()
}

typeToEntries := make(map[dnsv1alpha1.RRType][]dnsv1alpha1.RecordEntry)
Expand All @@ -66,7 +191,6 @@
}
entries := mapAnswersToEntries(domain, answers)
if len(entries) == 0 {

continue
}
if rt, ok := mapQtypeToRRType(qt); ok {
Expand Down Expand Up @@ -114,3 +238,71 @@
return "", false
}
}

func queryRecordsForName(ctx context.Context, name string, client *dnsx.DNSX, qtypes []uint16) (map[uint16][]dns.RR, error) {
fmt.Printf("querying multiple record types for fqdn=%q\n", name)

type queryResult struct {
resp *retryabledns.DNSData
err error
}

resCh := make(chan queryResult, 1)
go func() {
resp, err := client.QueryMultiple(name)
resCh <- queryResult{resp: resp, err: err}
}()

timeoutCtx, cancel := context.WithTimeout(ctx, discoveryLookupTimeout)
defer cancel()

select {
case <-timeoutCtx.Done():
return nil, timeoutCtx.Err()
case res := <-resCh:
if res.err != nil {
return nil, res.err
}
if res.resp == nil {
return nil, fmt.Errorf("dnsx returned nil response")
}
fmt.Printf("resolver status=%s answers=%d fqdn=%s types=%v\n", res.resp.StatusCode, len(res.resp.AllRecords), name, qtypes)

typeToRRs := make(map[uint16][]dns.RR)
for _, rec := range res.resp.AllRecords {
rr, perr := dns.NewRR(rec)
if perr != nil || rr == nil {
continue
}
rt := rr.Header().Rrtype
typeToRRs[rt] = append(typeToRRs[rt], rr)
}
return typeToRRs, nil
}
}

func candidateDiscoveryNames(domain string) []string {
base := strings.TrimSpace(domain)
base = strings.TrimSuffix(base, ".")
if base == "" {
return nil
}

seen := make(map[string]struct{}, len(commonDiscoverySubdomains))
names := make([]string, 0, len(commonDiscoverySubdomains))
for _, label := range commonDiscoverySubdomains {
var fqdn string
switch label {
case "", "@":
fqdn = base
default:
fqdn = fmt.Sprintf("%s.%s", label, base)
}
if _, exists := seen[fqdn]; exists {
continue
}
seen[fqdn] = struct{}{}
names = append(names, fqdn)
}
return names
}
131 changes: 131 additions & 0 deletions internal/discovery/resolver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package discovery

import (
"context"
"errors"
"sync"
"testing"

"github.com/miekg/dns"
"github.com/projectdiscovery/dnsx/libs/dnsx"
"github.com/stretchr/testify/require"

dnsv1alpha1 "go.miloapis.com/dns-operator/api/v1alpha1"
)

func TestDiscoverZoneRecordsContinuesAfterTimeout(t *testing.T) {
t.Cleanup(func() {
queryRecordsFunc = queryRecordsForName
})

stub := newStubRecordLookup()
stub.responses["example.com"] = map[uint16][]dns.RR{
dns.TypeA: {mustRR(t, "example.com. 60 IN A 1.2.3.4")},
}
stub.errors["www.example.com"] = context.DeadlineExceeded

queryRecordsFunc = stub.query

results, err := DiscoverZoneRecords(context.Background(), "example.com")
require.NoError(t, err)

rs := findRecordSet(results, dnsv1alpha1.RRTypeA)
require.NotNil(t, rs, "expected A recordset to be present")
require.Len(t, rs.Records, 1)
entry := rs.Records[0]
require.Equal(t, "@", entry.Name)
require.NotNil(t, entry.A)
require.Equal(t, "1.2.3.4", entry.A.Content)
}

func TestDiscoverZoneRecordsFailsWhenApexLookupErrors(t *testing.T) {
t.Cleanup(func() {
queryRecordsFunc = queryRecordsForName
})

stub := newStubRecordLookup()
stub.errors["example.com"] = errors.New("dns failure")
queryRecordsFunc = stub.query

_, err := DiscoverZoneRecords(context.Background(), "example.com")
require.Error(t, err)
require.ErrorContains(t, err, "dns failure")
}

func TestDiscoverZoneRecordsAllowsApexTimeoutWithOtherRecords(t *testing.T) {
t.Cleanup(func() {
queryRecordsFunc = queryRecordsForName
})

stub := newStubRecordLookup()
stub.errors["example.com"] = context.DeadlineExceeded
stub.responses["www.example.com"] = map[uint16][]dns.RR{
dns.TypeA: {mustRR(t, "www.example.com. 30 IN A 5.6.7.8")},
}
queryRecordsFunc = stub.query

results, err := DiscoverZoneRecords(context.Background(), "example.com")
require.NoError(t, err)

rs := findRecordSet(results, dnsv1alpha1.RRTypeA)
require.NotNil(t, rs)
entry := findRecordByName(rs.Records, "www")
require.NotNil(t, entry, "expected www entry when apex times out")
require.NotNil(t, entry.A)
require.Equal(t, "5.6.7.8", entry.A.Content)
}

type stubRecordLookup struct {
mu sync.Mutex
responses map[string]map[uint16][]dns.RR
errors map[string]error
calls []string
}

func newStubRecordLookup() *stubRecordLookup {
return &stubRecordLookup{
responses: make(map[string]map[uint16][]dns.RR),
errors: make(map[string]error),
}
}

func (s *stubRecordLookup) query(_ context.Context, name string, _ *dnsx.DNSX, _ []uint16) (map[uint16][]dns.RR, error) {
s.mu.Lock()
s.calls = append(s.calls, name)
resp, respOK := s.responses[name]
err, errOK := s.errors[name]
s.mu.Unlock()

if errOK {
return nil, err
}
if respOK {
return resp, nil
}
return nil, nil
}

func mustRR(t *testing.T, rr string) dns.RR {
t.Helper()
record, err := dns.NewRR(rr)
require.NoError(t, err)
return record
}

func findRecordSet(sets []dnsv1alpha1.DiscoveredRecordSet, rt dnsv1alpha1.RRType) *dnsv1alpha1.DiscoveredRecordSet {
for i := range sets {
if sets[i].RecordType == rt {
return &sets[i]
}
}
return nil
}

func findRecordByName(entries []dnsv1alpha1.RecordEntry, name string) *dnsv1alpha1.RecordEntry {
for i := range entries {
if entries[i].Name == name {
return &entries[i]
}
}
return nil
}
Loading