Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 59 additions & 35 deletions provider/pihole/pihole.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,59 +120,84 @@ func (p *PiholeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, err

// ApplyChanges implements Provider, syncing desired state with the Pi-hole server Local DNS.
func (p *PiholeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
// Handle pure deletes first.
for _, ep := range changes.Delete {
if err := p.applyDeletes(ctx, changes.Delete); err != nil {
return err
}

updateNew := p.buildUpdateMap(changes.UpdateNew)

if err := p.applyUpdateOld(ctx, changes.UpdateOld, updateNew); err != nil {
return err
}

return p.applyCreates(ctx, changes.Create, updateNew)
}

// applyDeletes performs pure deletes from the plan.
func (p *PiholeProvider) applyDeletes(ctx context.Context, deletes []*endpoint.Endpoint) error {
for _, ep := range deletes {
if err := p.api.deleteRecord(ctx, ep); err != nil {
return err
}
}
return nil
}

// Handle updated state - there are no endpoints for updating in place.
updateNew := make(map[piholeEntryKey]*endpoint.Endpoint)
for _, ep := range changes.UpdateNew {
// buildUpdateMap collapses UpdateNew endpoints by (DNSName, RecordType). For
// the v6 API, endpoints that share a key have their Targets merged and
// deduplicated so a single createRecord call carries all desired targets.
func (p *PiholeProvider) buildUpdateMap(updateNew []*endpoint.Endpoint) map[piholeEntryKey]*endpoint.Endpoint {
m := make(map[piholeEntryKey]*endpoint.Endpoint, len(updateNew))
for _, ep := range updateNew {
key := piholeEntryKey{ep.DNSName, ep.RecordType}

// If the API version is 6, we need to handle multiple targets for the same DNS name.
if p.apiVersion == "6" {
if existing, ok := updateNew[key]; ok {
if existing, ok := m[key]; ok {
existing.Targets = append(existing.Targets, ep.Targets...)

// Deduplicate targets
slices.Sort(existing.Targets)
existing.Targets = slices.Compact(existing.Targets)

ep = existing
}
}
updateNew[key] = ep
m[key] = ep
}
return m
}

for _, ep := range changes.UpdateOld {
// Check if this existing entry has an exact match for an updated entry and skip it if so.
// applyUpdateOld walks the old side of in-place updates. For each old entry
// whose paired new entry is unchanged, the update is dropped from updateNew
// (nothing to do). Otherwise the old record is deleted so the new record can
// be created in the subsequent phase.
func (p *PiholeProvider) applyUpdateOld(ctx context.Context, updateOld []*endpoint.Endpoint, updateNew map[piholeEntryKey]*endpoint.Endpoint) error {
for _, ep := range updateOld {
key := piholeEntryKey{ep.DNSName, ep.RecordType}
if newRecord := updateNew[key]; newRecord != nil {
// If the API version is 6, we need to handle multiple targets for the same DNS name.
if p.apiVersion == "6" {
if cmp.Diff(ep.Targets, newRecord.Targets) == "" {
delete(updateNew, key)
continue
}
} else {
// For API version <= 5, we only check the first target.
if newRecord.Targets[0] == ep.Targets[0] {
delete(updateNew, key)
continue
}
}

if err := p.api.deleteRecord(ctx, ep); err != nil {
return err
}
newRecord, ok := updateNew[key]
if !ok {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’d prefer to keep checking for a nil newRecord.

continue
}
if p.updateIsNoOp(ep, newRecord) {
delete(updateNew, key)
continue
}
if err := p.api.deleteRecord(ctx, ep); err != nil {
return err
}
}
return nil
}

// updateIsNoOp reports whether an update is actually a change. For v6 all
// targets must match; for older APIs only the first target is compared, which
// matches the historical single-target semantics of the v5 local-DNS endpoint.
func (p *PiholeProvider) updateIsNoOp(oldEP, newEP *endpoint.Endpoint) bool {
if p.apiVersion == "6" {
return cmp.Diff(oldEP.Targets, newEP.Targets) == ""
}
return newEP.Targets[0] == oldEP.Targets[0]
}

// Handle pure creates before applying new updated state.
for _, ep := range changes.Create {
// applyCreates runs pure creates followed by the (possibly pruned) updates.
func (p *PiholeProvider) applyCreates(ctx context.Context, creates []*endpoint.Endpoint, updateNew map[piholeEntryKey]*endpoint.Endpoint) error {
for _, ep := range creates {
if err := p.api.createRecord(ctx, ep); err != nil {
return err
}
Expand All @@ -182,6 +207,5 @@ func (p *PiholeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes
return err
}
}

return nil
}
Loading