Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved IPv6 link local resolver error and add unit testing for parsing function #526

Merged
merged 4 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/cli/worker_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,9 @@ func populateIPTransportMode(gc *CLIConf, config *zdns.ResolverConfig) (*zdns.Re
// check OS' default resolver(s) to determine if we support IPv4 or IPv6
ipv4NSStrings, ipv6NSStrings, err = zdns.GetDNSServers(config.DNSConfigFilePath)
if err != nil {
log.Fatal("unable to parse resolvers file, please use '--name-servers': ", err)
log.Fatalf("ZDNS is unable to parse resolvers file. ZDNS only supports IPv4 and IPv6 addresses with an optional port, "+
" either 111.222.333.444:9953 or [1111:2222::3333]:9953. "+
"Please either modify your %s file or use '--name-servers'. Error: %v", config.DNSConfigFilePath, err)
}
if len(ipv4NSStrings) == 0 && len(ipv6NSStrings) == 0 {
return nil, errors.New("no nameservers found with OS defaults. Please specify desired nameservers with --name-servers")
Expand Down
3 changes: 2 additions & 1 deletion src/internal/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const (
DefaultDNSPort = "53"
DefaultHTTPSPort = "443"
DefaultTLSPort = "853"
InvalidPortErrorMsg = "invalid port"
)

func SplitHostPort(inaddr string) (net.IP, int, error) {
Expand All @@ -44,7 +45,7 @@ func SplitHostPort(inaddr string) (net.IP, int, error) {

portInt, err := strconv.Atoi(port)
if err != nil {
return nil, 0, errors.Wrap(err, "invalid port")
return nil, 0, errors.Wrap(err, InvalidPortErrorMsg)
}

return ip, portInt, nil
Expand Down
65 changes: 47 additions & 18 deletions src/zdns/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"io"
"math/rand"
"net"
"os"
"regexp"
"strings"

Expand All @@ -35,33 +36,61 @@ import (

var ErrorContextExpired = errors.New("context expired")

// GetDNSServers returns a list of IPv4, IPv6 DNS servers from a file, or an error if one occurs
func GetDNSServers(path string) (ipv4, ipv6 []string, err error) {
c, err := dns.ClientConfigFromFile(path)
file, err := os.Open(path)
if err != nil {
return []string{}, []string{}, fmt.Errorf("error reading DNS config file (%s): %w", path, err)
return nil, nil, fmt.Errorf("error opening DNS config file (%s): %w", path, err)
}
servers := make([]string, 0, len(c.Servers))
for _, s := range c.Servers {
if s[0:1] != "[" && strings.Contains(s, ":") {
s = "[" + s + "]"
defer func(file *os.File) {
err := file.Close()
if err != nil {
log.Errorf("error closing DNS config file (%s): %s", path, err)
}
full := strings.Join([]string{s, c.Port}, ":")
servers = append(servers, full)
}(file)
return getDNSServersFromReader(file)
}

// getDNSServersFromReader returns a list of IPv4, IPv6 DNS servers from an io.Reader, or an error if one occurs
func getDNSServersFromReader(resolvReader io.Reader) (ipv4, ipv6 []string, err error) {
c, err := dns.ClientConfigFromReader(resolvReader)
if err != nil {
return []string{}, []string{}, fmt.Errorf("error parsing DNS config file: %v", err)
}
servers := make([]string, 0, len(c.Servers))
ipv4 = make([]string, 0, len(servers))
ipv6 = make([]string, 0, len(servers))
for _, s := range servers {
ip, _, err := util.SplitHostPort(s)
if err != nil {
return []string{}, []string{}, fmt.Errorf("could not parse IP address (%s) from file: %w", s, err)
for _, s := range c.Servers {
// We don't support specifying link-local IPv6 addresses with %interface or domain names with #domain
// See https://man7.org/linux/man-pages/man1/resolvectl.1.html
if strings.Contains(s, "%") {
return []string{}, []string{}, fmt.Errorf("could not parse IP address (%s) from config. We do not support specifying link-local IPv6 addresses or per-interface nameservers", s)
}
if strings.Contains(s, "#") {
return []string{}, []string{}, fmt.Errorf("could not parse IP address (%s) from config. We do not support specifying domain names for nameservers", s)
}
// We need to check if there is a port specified, and add the default if not
ipStr, _, err := net.SplitHostPort(s)
if err == nil {
// port specified, determine IP type
ip := net.ParseIP(ipStr)
if ip != nil && ip.To4() != nil {
ipv4 = append(ipv4, s)
} else if ip != nil {
ipv6 = append(ipv6, s)
}
continue
}
if ip.To4() != nil {
ipv4 = append(ipv4, s)
} else if util.IsIPv6(&ip) {
ipv6 = append(ipv6, s)
// no port specified, check if s is an IP
ip := net.ParseIP(s)
if ip == nil {
return []string{}, []string{}, fmt.Errorf("could not parse IP address (%s) from config", s)
} else if ip.To4() != nil {
// IPv4, use default port
ipv4 = append(ipv4, strings.Join([]string{s, c.Port}, ":"))
} else {
return []string{}, []string{}, fmt.Errorf("could not parse IP address (%s) from file: %s", s, path)
// IPv6, use default port
s = "[" + s + "]"
ipv6 = append(ipv6, strings.Join([]string{s, c.Port}, ":"))
}
}
return ipv4, ipv6, nil
Expand Down
104 changes: 104 additions & 0 deletions src/zdns/lookup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"reflect"
"regexp"
"sort"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -1972,6 +1973,109 @@ func TestInvalidInputsLookup(t *testing.T) {
})
}

func TestGetDNSServersFromReader(t *testing.T) {
tests := []struct {
name string
input string
wantIPv4 []string
wantIPv6 []string
wantErr bool
}{
{
name: "Valid IPv4 address with default port",
input: "nameserver 1.2.3.4",
wantIPv4: []string{"1.2.3.4:53"},
wantIPv6: nil,
wantErr: false,
},
{
name: "Valid IPv6 address with default port",
input: "nameserver 2001:db8::1",
wantIPv4: nil,
wantIPv6: []string{"[2001:db8::1]:53"},
wantErr: false,
},
{
name: "Valid IPv6 compressed address with default port",
input: "nameserver ::1",
wantIPv4: nil,
wantIPv6: []string{"[::1]:53"},
wantErr: false,
},
{
name: "Valid IPv6 partially-compressed address with default port",
input: "nameserver 2001:db8:0:0:0::1",
wantIPv4: nil,
wantIPv6: []string{"[2001:db8:0:0:0::1]:53"},
wantErr: false,
},
{
name: "Valid IPv4 with custom port",
input: "nameserver 1.2.3.4:35",
wantIPv4: []string{"1.2.3.4:35"},
wantIPv6: nil,
wantErr: false,
},
{
name: "Valid IPv6 with custom port",
input: "nameserver [2001:db8::1]:35",
wantIPv4: nil,
wantIPv6: []string{"[2001:db8::1]:35"},
wantErr: false,
},
{
name: "Invalid IPv4 address",
input: "nameserver 1.2.3",
wantIPv4: nil,
wantIPv6: nil,
wantErr: true,
},
{
name: "IPv6 link-local address (should error)",
input: "nameserver fe80::1%eth0",
wantIPv4: nil,
wantIPv6: nil,
wantErr: true,
},
{
name: "Invalid format - interface specified on IPv4",
input: "nameserver 111.222.333.444:9953%ifname",
wantIPv4: nil,
wantIPv6: nil,
wantErr: true,
},
{
name: "Invalid format - interface specified on IPv6",
input: "nameserver [2001:db8::1]]:9953%ifname",
wantIPv4: nil,
wantIPv6: nil,
wantErr: true,
},
{
name: "Invalid format - domain specified",
input: "nameserver 111.222.333.444:9953#example.com",
wantIPv4: nil,
wantIPv6: nil,
wantErr: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ipv4, ipv6, err := getDNSServersFromReader(strings.NewReader(test.input))
if (err != nil) != test.wantErr {
t.Errorf("getDNSServersFromReader() received error = %v, wantErr %v", err, test.wantErr)
}
if fmt.Sprintf("%v", ipv4) != fmt.Sprintf("%v", test.wantIPv4) {
t.Errorf("getDNSServersFromReader() received ipv4 = %v, want %v", ipv4, test.wantIPv4)
}
if fmt.Sprintf("%v", ipv6) != fmt.Sprintf("%v", test.wantIPv6) {
t.Errorf("getDNSServersFromReader() received ipv6 = %v, want %v", ipv6, test.wantIPv6)
}
})
}
}

func verifyNsResult(t *testing.T, servers []NSRecord, expectedServersMap map[string]IPResult) {
serversLength := len(servers)
expectedServersLength := len(expectedServersMap)
Expand Down