Skip to content

feat(pihole): add support for IPv6 Dual format #5253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 43 additions & 13 deletions provider/pihole/clientV6.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"
"io"
"net/http"
"net/netip"
"net/url"
"strconv"
"strings"
Expand Down Expand Up @@ -63,6 +64,7 @@ func newPiholeClientV6(cfg PiholeConfig) (piholeAPI, error) {
},
},
}

cl := instrumented_http.NewClient(httpClient, &instrumented_http.Callbacks{})

p := &piholeClientV6{
Expand Down Expand Up @@ -114,6 +116,32 @@ func (p *piholeClientV6) getConfigValue(ctx context.Context, rtype string) ([]st
return results, nil
}

/**
* isValidIPv4 checks if the given IP address is a valid IPv4 address.
* It returns true if the IP address is valid, false otherwise.
* If the IP address is in IPv6 format, it will return false.
*/
func isValidIPv4(ip string) bool {
addr, err := netip.ParseAddr(ip)
if err != nil {
return false
}
return addr.Is4()
}

/**
* isValidIPv6 checks if the given IP address is a valid IPv6 address.
* It returns true if the IP address is valid, false otherwise.
* If the IP address is in IPv6 with dual format y:y:y:y:y:y:x.x.x.x. , it will return true.
*/
func isValidIPv6(ip string) bool {
addr, err := netip.ParseAddr(ip)
if err != nil {
return false
}
return addr.Is6()
}

func (p *piholeClientV6) listRecords(ctx context.Context, rtype string) ([]*endpoint.Endpoint, error) {
out := make([]*endpoint.Endpoint, 0)
results, err := p.getConfigValue(ctx, rtype)
Expand All @@ -126,21 +154,22 @@ func (p *piholeClientV6) listRecords(ctx context.Context, rtype string) ([]*endp
return r == ' ' || r == ','
})
if len(recs) < 2 {
log.Warnf("skipping record %s: invalid format", rec)
log.Warnf("skipping record %s: invalid format received from PiHole", rec)
continue
}
var DNSName, Target string
var Ttl endpoint.TTL = 0
var Ttl = endpoint.TTL(0)
// A/AAAA record format is target(IP) DNSName
DNSName, Target = recs[1], recs[0]

switch rtype {
case endpoint.RecordTypeA:
if strings.Contains(Target, ":") {
if !isValidIPv4(Target) {
log.Warnf("skipping A record %s: invalid format received from PiHole", rec)
continue
}
case endpoint.RecordTypeAAAA:
if strings.Contains(Target, ".") {
if !isValidIPv6(Target) {
log.Warnf("skipping AAAA record %s: invalid format received from PiHole", rec)
continue
}
case endpoint.RecordTypeCNAME:
Expand All @@ -151,17 +180,12 @@ func (p *piholeClientV6) listRecords(ctx context.Context, rtype string) ([]*endp
if ttlInt, err := strconv.ParseInt(recs[2], 10, 64); err == nil {
Ttl = endpoint.TTL(ttlInt)
} else {
log.Warnf("failed to parse TTL value '%s': %v; using a TTL of %d", recs[2], err, Ttl)
log.Warnf("failed to parse TTL value received from PiHole '%s': %v; using a TTL of %d", recs[2], err, Ttl)
}
}
}

out = append(out, &endpoint.Endpoint{
DNSName: DNSName,
Targets: []string{Target},
RecordTTL: Ttl,
RecordType: rtype,
})
out = append(out, endpoint.NewEndpointWithTTL(DNSName, rtype, Ttl, Target))
}
return out, nil
}
Expand Down Expand Up @@ -375,7 +399,13 @@ func (p *piholeClientV6) do(req *http.Request) ([]byte, error) {
if err := json.Unmarshal(jRes, &apiError); err != nil {
return nil, fmt.Errorf("failed to unmarshal error response: %w", err)
}
log.Debugf("Error on request %s", req.Body)
if log.IsLevelEnabled(log.DebugLevel) {
log.Debugf("Error on request %s", req.URL)
if req.Body != nil {
log.Debugf("Body of the request %s", req.Body)
}
}

if res.StatusCode == http.StatusUnauthorized && p.token != "" {
tryCount := 1
maxRetries := 3
Expand Down
116 changes: 91 additions & 25 deletions provider/pihole/clientV6_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,62 @@ import (
"sigs.k8s.io/external-dns/endpoint"
)

func TestIsValidIPv4(t *testing.T) {
tests := []struct {
ip string
expected bool
}{
{"192.168.1.1", true},
{"255.255.255.255", true},
{"0.0.0.0", true},
{"", false},
{"256.256.256.256", false},
{"192.168.0.1/22", false},
{"192.168.1", false},
{"abc.def.ghi.jkl", false},
{"::ffff:192.168.20.3", false},
}

for _, test := range tests {
t.Run(test.ip, func(t *testing.T) {
if got := isValidIPv4(test.ip); got != test.expected {
t.Errorf("isValidIPv4(%s) = %v; want %v", test.ip, got, test.expected)
}
})
}
}

func TestIsValidIPv6(t *testing.T) {
tests := []struct {
ip string
expected bool
}{
{"2001:0db8:85a3:0000:0000:8a2e:0370:7334", true},
{"2001:db8:85a3::8a2e:370:7334", true},
//IPV6 dual, the format is y:y:y:y:y:y:x.x.x.x.
{"::ffff:192.168.20.3", true},
{"::1", true},
{"::", true},
{"2001:db8::", true},
{"", false},
{":", false},
{"::ffff:", false},
{"192.168.20.3", false},
{"2001:db8:85a3:0:0:8a2e:370:7334:1234", false},
{"2001:db8:85a3::8a2e:370g:7334", false},
{"2001:db8:85a3::8a2e:370:7334::", false},
{"2001:db8:85a3::8a2e:370:7334::1", false},
}

for _, test := range tests {
t.Run(test.ip, func(t *testing.T) {
if got := isValidIPv6(test.ip); got != test.expected {
t.Errorf("isValidIPv6(%s) = %v; want %v", test.ip, got, test.expected)
}
})
}
}

func newTestServerV6(t *testing.T, hdlr http.HandlerFunc) *httptest.Server {
t.Helper()
svr := httptest.NewServer(hdlr)
Expand Down Expand Up @@ -137,7 +193,9 @@ func TestListRecordsV6(t *testing.T) {
"192.168.178.34 service3.example.com",
"fc00::1:192:168:1:1 service4.example.com",
"fc00::1:192:168:1:2 service5.example.com",
"fc00::1:192:168:1:3 service6.example.com"
"fc00::1:192:168:1:3 service6.example.com",
"::ffff:192.168.20.3 service7.example.com",
"192.168.20.3 service7.example.com"
]
}
},
Expand Down Expand Up @@ -177,20 +235,22 @@ func TestListRecordsV6(t *testing.T) {
t.Fatal(err)
}

// Ensure A records were parsed correctly
expected := [][]string{
{"service1.example.com", "192.168.178.33"},
{"service2.example.com", "192.168.178.34"},
{"service3.example.com", "192.168.178.34"},
{"service7.example.com", "192.168.20.3"},
}
// Test retrieve A records unfiltered
arecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeA)
if err != nil {
t.Fatal(err)
}
if len(arecs) != 3 {
t.Fatal("Expected 3 A records returned, got:", len(arecs))
}
// Ensure records were parsed correctly
expected := [][]string{
{"service1.example.com", "192.168.178.33"},
{"service2.example.com", "192.168.178.34"},
{"service3.example.com", "192.168.178.34"},
if len(arecs) != len(expected) {
t.Fatalf("Expected %d A records returned, got: %d", len(expected), len(arecs))
}

for idx, rec := range arecs {
if rec.DNSName != expected[idx][0] {
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
Expand All @@ -200,20 +260,23 @@ func TestListRecordsV6(t *testing.T) {
}
}

// Ensure AAAA records were parsed correctly
expected = [][]string{
{"service4.example.com", "fc00::1:192:168:1:1"},
{"service5.example.com", "fc00::1:192:168:1:2"},
{"service6.example.com", "fc00::1:192:168:1:3"},
{"service7.example.com", "::ffff:192.168.20.3"},
}
// Test retrieve AAAA records unfiltered
arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeAAAA)
if err != nil {
t.Fatal(err)
}
if len(arecs) != 3 {
t.Fatal("Expected 3 AAAA records returned, got:", len(arecs))
}
// Ensure records were parsed correctly
expected = [][]string{
{"service4.example.com", "fc00::1:192:168:1:1"},
{"service5.example.com", "fc00::1:192:168:1:2"},
{"service6.example.com", "fc00::1:192:168:1:3"},

if len(arecs) != len(expected) {
t.Fatalf("Expected %d AAAA records returned, got: %d", len(expected), len(arecs))
}

for idx, rec := range arecs {
if rec.DNSName != expected[idx][0] {
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
Expand All @@ -223,20 +286,22 @@ func TestListRecordsV6(t *testing.T) {
}
}

// Ensure CNAME records were parsed correctly
expected = [][]string{
{"source1.example.com", "target1.domain.com", "1000"},
{"source2.example.com", "target2.domain.com", "50"},
{"source3.example.com", "target3.domain.com"},
}

// Test retrieve CNAME records unfiltered
cnamerecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeCNAME)
if err != nil {
t.Fatal(err)
}
if len(cnamerecs) != 3 {
t.Fatal("Expected 3 CAME records returned, got:", len(cnamerecs))
}
// Ensure records were parsed correctly
expected = [][]string{
{"source1.example.com", "target1.domain.com", "1000"},
{"source2.example.com", "target2.domain.com", "50"},
{"source3.example.com", "target3.domain.com"},
if len(cnamerecs) != len(expected) {
t.Fatalf("Expected %d CAME records returned, got: %d", len(expected), len(cnamerecs))
}

for idx, rec := range cnamerecs {
if rec.DNSName != expected[idx][0] {
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
Expand All @@ -261,6 +326,7 @@ func TestListRecordsV6(t *testing.T) {
t.Fatal("Expected error for using unsupported record type")
}
}

func TestErrorsV6(t *testing.T) {
//Error test cases

Expand Down
Loading