diff --git a/.gitignore b/.gitignore index 776cd950c..9c28d156d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ tags test.out a.out +.*.swp diff --git a/labels.go b/labels.go index f9faacfeb..664fe6f6e 100644 --- a/labels.go +++ b/labels.go @@ -61,20 +61,18 @@ func CompareDomainName(s1, s2 string) (n int) { i2 := len(l2) - 2 // the second check can be done here: last/only label // before we fall through into the for-loop below - if equal(s1[l1[j1]:], s2[l2[j2]:]) { - n++ - } else { + if !equal(s1[l1[j1]:], s2[l2[j2]:]) { return } + n++ for { if i1 < 0 || i2 < 0 { break } - if equal(s1[l1[i1]:l1[j1]], s2[l2[i2]:l2[j2]]) { - n++ - } else { + if !equal(s1[l1[i1]:l1[j1]], s2[l2[i2]:l2[j2]]) { break } + n++ j1-- i1-- j2-- @@ -186,16 +184,55 @@ func PrevLabel(s string, n int) (i int, start bool) { return 0, n > 1 } -// equal compares a and b while ignoring case. It returns true when equal otherwise false. -func equal(a, b string) bool { - // might be lifted into API function. +// Compare compares domains according to the canonical ordering specified in RFC4034 +// returns an integer value similar to strcmp +// (0 for equal values, -1 if s1 < s2, 1 if s1 > s2) +func Compare(s1, s2 string) int { + s1b := doDDD([]byte(s1)) + s2b := doDDD([]byte(s2)) + + s1 = string(s1b) + s2 = string(s2b) + + s1lend := len(s1) + s2lend := len(s2) + + for i := 0; ; i++ { + s1lstart, end1 := PrevLabel(s1, i) + s2lstart, end2 := PrevLabel(s2, i) + + if end1 && end2 { + return 0 + } + + s1l := string(s1b[s1lstart:s1lend]) + s2l := string(s2b[s2lstart:s2lend]) + + if cmp := labelCompare(s1l, s2l); cmp != 0 { + return cmp + } + + s1lend = s1lstart - 1 + s2lend = s2lstart - 1 + if s1lend == -1 { + s1lend = 0 + } + if s2lend == -1 { + s2lend = 0 + } + } +} + +// essentially strcasecmp +// (0 for equal values, -1 if s1 < s2, 1 if s1 > s2) +func labelCompare(a, b string) int { la := len(a) lb := len(b) - if la != lb { - return false + minLen := la + if lb < la { + minLen = lb } - - for i := la - 1; i >= 0; i-- { + for i := 0; i < minLen; i++ { ai := a[i] bi := b[i] if ai >= 'A' && ai <= 'Z' { @@ -205,8 +242,41 @@ func equal(a, b string) bool { bi |= 'a' - 'A' } if ai != bi { - return false + if ai > bi { + return 1 + } + return -1 + } + } + + if la > lb { + return 1 + } else if la < lb { + return -1 + } + return 0 +} + +// equal compares a and b while ignoring case. It returns true when equal otherwise false. +func equal(a, b string) bool { + // might be lifted into API function. + if len(a) != len(b) { + return false + } + + return labelCompare(a, b) == 0 +} + +func doDDD(b []byte) []byte { + lb := len(b) + for i := 0; i < lb; i++ { + if i+3 < lb && b[i] == '\\' && isDigit(b[i+1]) && isDigit(b[i+2]) && isDigit(b[i+3]) { + b[i] = dddToByte(b[i+1 : i+4]) + for j := i + 1; j < lb-3; j++ { + b[j] = b[j+3] + } + lb -= 3 } } - return true + return b[:lb] } diff --git a/labels_test.go b/labels_test.go index 3e672fec8..9982e3650 100644 --- a/labels_test.go +++ b/labels_test.go @@ -334,3 +334,51 @@ func BenchmarkPrevLabelMixed(b *testing.B) { PrevLabel(`www\\\.example.com`, 10) } } + +func BenchmarkCompare(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + Compare("\\097.", "A.") + } +} + +func TestCompare(t *testing.T) { + domains := []string{ // based on an exanple from RFC 4034 + "example.", + "a.example.", + "yljkjljk.a.example.", + "Z.a.example.", + "zABC.a.EXAMPLE.", + "a-.example.", + "z.example.", + "\001.z.example.", + "*.z.example.", + "\200.z.example.", + } + + len_domains := len(domains) + + for i, domain := range domains { + if i != 0 { + prev_domain := domains[i-1] + if !(Compare(prev_domain, domain) == -1 && Compare(domain, prev_domain) == 1) { + t.Fatalf("prev comparison failure between %s and %s", prev_domain, domain) + } + } + + if Compare(domain, domain) != 0 { + t.Fatalf("self comparison failure for %s", domain) + } + + if i != len_domains-1 { + next_domain := domains[i+1] + if !(Compare(domain, next_domain) == -1 && Compare(next_domain, domain) == 1) { + t.Fatalf("next comparison failure between %s and %s, %d and %d", domain, next_domain, Compare(domain, next_domain), Compare(next_domain, domain)) + } + } + } + + if Compare("\\097.", "A.") != 0 { + t.Fatal("failure to normalize DDD escape sequence") + } +} diff --git a/nsecx.go b/nsecx.go index f8826817b..b53ef35e0 100644 --- a/nsecx.go +++ b/nsecx.go @@ -93,3 +93,8 @@ func (rr *NSEC3) Match(name string) bool { } return false } + +// Match returns true if the given name is covered by the NSEC record +func (rr *NSEC) Cover(name string) bool { + return Compare(rr.Hdr.Name, name) <= 0 && Compare(name, rr.NextDomain) == -1 +} diff --git a/nsecx_test.go b/nsecx_test.go index ee9265334..ca42b37b7 100644 --- a/nsecx_test.go +++ b/nsecx_test.go @@ -168,3 +168,23 @@ func BenchmarkHashName(b *testing.B) { }) } } + +func TestNsecCover(t *testing.T) { + nsec := testRR("aaa.ee. 3600 IN NSEC aac.ee. NS RRSIG NSEC").(*NSEC) + + if !nsec.Cover("aaaa.ee.") { + t.Fatal("nsec cover not covering in-range name") + } + + if !nsec.Cover("aaa.ee.") { + t.Fatal("nsec cover not covering start of range") + } + + if nsec.Cover("aac.ee.") { + t.Fatal("nsec cover range end failure") + } + + if nsec.Cover("aad.ee.") { + t.Fatal("nsec cover covering out-of-range name") + } +}