Skip to content

Commit d9229d8

Browse files
committed
Refactor Google provider to process Endpoints as a sequence of changes.
1 parent 238fcf3 commit d9229d8

File tree

2 files changed

+105
-214
lines changed

2 files changed

+105
-214
lines changed

provider/google/google.go

+87-210
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package google
1919
import (
2020
"context"
2121
"fmt"
22-
"sort"
22+
"slices"
2323
"strings"
2424
"time"
2525

@@ -156,12 +156,12 @@ func NewGoogleProvider(ctx context.Context, project string, domainFilter endpoin
156156
zoneTypeFilter := provider.NewZoneTypeFilter(zoneVisibility)
157157

158158
provider := &GoogleProvider{
159-
dryRun: dryRun,
160-
batchChangeSize: batchChangeSize,
159+
dryRun: dryRun,
160+
batchChangeSize: batchChangeSize,
161161
batchChangeInterval: batchChangeInterval,
162-
domainFilter: domainFilter,
163-
zoneTypeFilter: zoneTypeFilter,
164-
zoneIDFilter: zoneIDFilter,
162+
domainFilter: domainFilter,
163+
zoneTypeFilter: zoneTypeFilter,
164+
zoneIDFilter: zoneIDFilter,
165165
managedZonesClient: managedZonesService{
166166
project: project,
167167
service: dnsClient.ManagedZones,
@@ -244,232 +244,109 @@ func (p *GoogleProvider) Records(ctx context.Context) (endpoints []*endpoint.End
244244
return endpoints, nil
245245
}
246246

247-
// ApplyChanges applies a given set of changes in a given zone.
247+
// ApplyChanges applies a given set of changes.
248248
func (p *GoogleProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
249-
change := &dns.Change{}
250-
251-
change.Additions = append(change.Additions, p.newFilteredRecords(changes.Create)...)
252-
253-
change.Additions = append(change.Additions, p.newFilteredRecords(changes.UpdateNew)...)
254-
change.Deletions = append(change.Deletions, p.newFilteredRecords(changes.UpdateOld)...)
255-
256-
change.Deletions = append(change.Deletions, p.newFilteredRecords(changes.Delete)...)
257-
258-
return p.submitChange(ctx, change)
259-
}
260-
261-
// SupportedRecordType returns true if the record type is supported by the provider
262-
func (p *GoogleProvider) SupportedRecordType(recordType string) bool {
263-
switch recordType {
264-
case "MX":
265-
return true
266-
default:
267-
return provider.SupportedRecordType(recordType)
268-
}
269-
}
270-
271-
// newFilteredRecords returns a collection of RecordSets based on the given endpoints and domainFilter.
272-
func (p *GoogleProvider) newFilteredRecords(endpoints []*endpoint.Endpoint) []*dns.ResourceRecordSet {
273-
records := []*dns.ResourceRecordSet{}
274-
275-
for _, endpoint := range endpoints {
276-
if p.domainFilter.Match(endpoint.DNSName) {
277-
records = append(records, newRecord(endpoint))
278-
}
279-
}
280-
281-
return records
282-
}
283-
284-
// submitChange takes a zone and a Change and sends it to Google.
285-
func (p *GoogleProvider) submitChange(ctx context.Context, change *dns.Change) error {
286-
if len(change.Additions) == 0 && len(change.Deletions) == 0 {
287-
log.Info("All records are already up to date")
288-
return nil
289-
}
290-
291249
zones, err := p.Zones(ctx)
292250
if err != nil {
293251
return err
294252
}
295-
296-
// separate into per-zone change sets to be passed to the API.
297-
changes := separateChange(zones, change)
298-
299-
for zone, change := range changes {
300-
for batch, c := range batchChange(change, p.batchChangeSize) {
301-
log.Infof("Change zone: %v batch #%d", zone, batch)
302-
for _, del := range c.Deletions {
303-
log.Infof("Del records: %s %s %s %d", del.Name, del.Type, del.Rrdatas, del.Ttl)
253+
zoneMap := provider.ZoneIDName{}
254+
for _, z := range zones {
255+
zoneMap.Add(z.Name, z.DnsName)
256+
}
257+
zoneBatches := map[string][]*dns.Change{}
258+
for rrSetChange := range changes.All() {
259+
if zone, _ := zoneMap.FindZone(string(rrSetChange.Name)); zone != "" {
260+
change := p.newChange(rrSetChange)
261+
changeSize := len(change.Additions) + len(change.Deletions)
262+
if changeSize == 0 {
263+
continue
264+
}
265+
if _, ok := zoneBatches[zone]; !ok {
266+
zoneBatches[zone] = []*dns.Change{{}}
304267
}
305-
for _, add := range c.Additions {
306-
log.Infof("Add records: %s %s %s %d", add.Name, add.Type, add.Rrdatas, add.Ttl)
268+
batch := zoneBatches[zone][len(zoneBatches[zone])-1]
269+
if p.batchChangeSize > 0 && len(batch.Additions)+len(batch.Deletions)+changeSize > p.batchChangeSize {
270+
batch = &dns.Change{}
271+
zoneBatches[zone] = append(zoneBatches[zone], batch)
307272
}
273+
batch.Additions = append(batch.Additions, change.Additions...)
274+
batch.Deletions = append(batch.Deletions, change.Deletions...)
275+
}
276+
}
308277

278+
for zone, batches := range zoneBatches {
279+
for index, batch := range batches {
280+
log.Infof("Change zone: %v batch #%d", zone, index)
281+
for _, record := range batch.Deletions {
282+
log.Infof("Del records: %s %s %s %d", record.Name, record.Type, record.Rrdatas, record.Ttl)
283+
}
284+
for _, record := range batch.Additions {
285+
log.Infof("Add records: %s %s %s %d", record.Name, record.Type, record.Rrdatas, record.Ttl)
286+
}
309287
if p.dryRun {
310288
continue
311289
}
312-
313-
if _, err := p.changesClient.Create(zone, c).Do(); err != nil {
290+
if index > 0 {
291+
time.Sleep(p.batchChangeInterval)
292+
}
293+
if _, err := p.changesClient.Create(zone, batch).Do(); err != nil {
314294
return provider.NewSoftError(fmt.Errorf("failed to create changes: %w", err))
315295
}
316-
317-
time.Sleep(p.batchChangeInterval)
318296
}
319297
}
320298

321299
return nil
322300
}
323301

324-
// batchChange separates a zone in multiple transaction.
325-
func batchChange(change *dns.Change, batchSize int) []*dns.Change {
326-
changes := []*dns.Change{}
327-
328-
if batchSize == 0 {
329-
return append(changes, change)
330-
}
331-
332-
type dnsChange struct {
333-
additions []*dns.ResourceRecordSet
334-
deletions []*dns.ResourceRecordSet
335-
}
336-
337-
changesByName := map[string]*dnsChange{}
338-
339-
for _, a := range change.Additions {
340-
change, ok := changesByName[a.Name]
341-
if !ok {
342-
change = &dnsChange{}
343-
changesByName[a.Name] = change
344-
}
345-
346-
change.additions = append(change.additions, a)
347-
}
348-
349-
for _, a := range change.Deletions {
350-
change, ok := changesByName[a.Name]
351-
if !ok {
352-
change = &dnsChange{}
353-
changesByName[a.Name] = change
354-
}
355-
356-
change.deletions = append(change.deletions, a)
357-
}
358-
359-
names := make([]string, 0)
360-
for v := range changesByName {
361-
names = append(names, v)
362-
}
363-
sort.Strings(names)
364-
365-
currentChange := &dns.Change{}
366-
var totalChanges int
367-
for _, name := range names {
368-
c := changesByName[name]
369-
370-
totalChangesByName := len(c.additions) + len(c.deletions)
371-
372-
if totalChangesByName > batchSize {
373-
log.Warnf("Total changes for %s exceeds max batch size of %d, total changes: %d", name,
374-
batchSize, totalChangesByName)
375-
continue
376-
}
377-
378-
if totalChanges+totalChangesByName > batchSize {
379-
totalChanges = 0
380-
changes = append(changes, currentChange)
381-
currentChange = &dns.Change{}
382-
}
383-
384-
currentChange.Additions = append(currentChange.Additions, c.additions...)
385-
currentChange.Deletions = append(currentChange.Deletions, c.deletions...)
386-
387-
totalChanges += totalChangesByName
388-
}
389-
390-
if totalChanges > 0 {
391-
changes = append(changes, currentChange)
392-
}
393-
394-
return changes
395-
}
396-
397-
// separateChange separates a multi-zone change into a single change per zone.
398-
func separateChange(zones map[string]*dns.ManagedZone, change *dns.Change) map[string]*dns.Change {
399-
changes := make(map[string]*dns.Change)
400-
zoneNameIDMapper := provider.ZoneIDName{}
401-
for _, z := range zones {
402-
zoneNameIDMapper[z.Name] = z.DnsName
403-
changes[z.Name] = &dns.Change{
404-
Additions: []*dns.ResourceRecordSet{},
405-
Deletions: []*dns.ResourceRecordSet{},
406-
}
407-
}
408-
for _, a := range change.Additions {
409-
if zoneName, _ := zoneNameIDMapper.FindZone(provider.EnsureTrailingDot(a.Name)); zoneName != "" {
410-
changes[zoneName].Additions = append(changes[zoneName].Additions, a)
411-
} else {
412-
log.Warnf("No matching zone for record addition: %s %s %s %d", a.Name, a.Type, a.Rrdatas, a.Ttl)
413-
}
414-
}
415-
416-
for _, d := range change.Deletions {
417-
if zoneName, _ := zoneNameIDMapper.FindZone(provider.EnsureTrailingDot(d.Name)); zoneName != "" {
418-
changes[zoneName].Deletions = append(changes[zoneName].Deletions, d)
419-
} else {
420-
log.Warnf("No matching zone for record deletion: %s %s %s %d", d.Name, d.Type, d.Rrdatas, d.Ttl)
421-
}
422-
}
423-
424-
// separating a change could lead to empty sub changes, remove them here.
425-
for zone, change := range changes {
426-
if len(change.Additions) == 0 && len(change.Deletions) == 0 {
427-
delete(changes, zone)
428-
}
302+
// SupportedRecordType returns true if the record type is supported by the provider
303+
func (p *GoogleProvider) SupportedRecordType(recordType string) bool {
304+
switch recordType {
305+
case "MX":
306+
return true
307+
default:
308+
return provider.SupportedRecordType(recordType)
429309
}
430-
431-
return changes
432310
}
433311

434-
// newRecord returns a RecordSet based on the given endpoint.
435-
func newRecord(ep *endpoint.Endpoint) *dns.ResourceRecordSet {
436-
// TODO(linki): works around appending a trailing dot to TXT records. I think
437-
// we should go back to storing DNS names with a trailing dot internally. This
438-
// way we can use it has is here and trim it off if it exists when necessary.
439-
targets := make([]string, len(ep.Targets))
440-
copy(targets, []string(ep.Targets))
441-
if ep.RecordType == endpoint.RecordTypeCNAME {
442-
targets[0] = provider.EnsureTrailingDot(targets[0])
443-
}
444-
445-
if ep.RecordType == endpoint.RecordTypeMX {
446-
for i, mxRecord := range ep.Targets {
447-
targets[i] = provider.EnsureTrailingDot(mxRecord)
448-
}
449-
}
450-
451-
if ep.RecordType == endpoint.RecordTypeSRV {
452-
for i, srvRecord := range ep.Targets {
453-
targets[i] = provider.EnsureTrailingDot(srvRecord)
454-
}
455-
}
456-
457-
if ep.RecordType == endpoint.RecordTypeNS {
458-
for i, nsRecord := range ep.Targets {
459-
targets[i] = provider.EnsureTrailingDot(nsRecord)
312+
// newChange returns a DNS change based upon the given resource record set change.
313+
func (p *GoogleProvider) newChange(rrSetChange *plan.RRSetChange) *dns.Change {
314+
change := dns.Change{}
315+
for index, endpoints := range [][]*endpoint.Endpoint{rrSetChange.Delete, rrSetChange.Create} {
316+
for _, ep := range endpoints {
317+
record := dns.ResourceRecordSet{
318+
Name: provider.EnsureTrailingDot(ep.DNSName),
319+
Ttl: googleRecordTTL,
320+
Type: ep.RecordType,
321+
}
322+
if ep.RecordTTL.IsConfigured() {
323+
record.Ttl = int64(ep.RecordTTL)
324+
}
325+
// TODO(linki): works around appending a trailing dot to TXT records. I think
326+
// we should go back to storing DNS names with a trailing dot internally. This
327+
// way we can use it has is here and trim it off if it exists when necessary.
328+
switch record.Type {
329+
case endpoint.RecordTypeCNAME:
330+
record.Rrdatas = []string{provider.EnsureTrailingDot(ep.Targets[0])}
331+
case endpoint.RecordTypeMX:
332+
fallthrough
333+
case endpoint.RecordTypeNS:
334+
fallthrough
335+
case endpoint.RecordTypeSRV:
336+
record.Rrdatas = make([]string, len(ep.Targets))
337+
for i, target := range ep.Targets {
338+
record.Rrdatas[i] = provider.EnsureTrailingDot(target)
339+
}
340+
default:
341+
record.Rrdatas = slices.Clone(ep.Targets)
342+
}
343+
switch index {
344+
case 0:
345+
change.Deletions = append(change.Deletions, &record)
346+
case 1:
347+
change.Additions = append(change.Additions, &record)
348+
}
460349
}
461350
}
462-
463-
// no annotation results in a Ttl of 0, default to 300 for backwards-compatibility
464-
var ttl int64 = googleRecordTTL
465-
if ep.RecordTTL.IsConfigured() {
466-
ttl = int64(ep.RecordTTL)
467-
}
468-
469-
return &dns.ResourceRecordSet{
470-
Name: provider.EnsureTrailingDot(ep.DNSName),
471-
Rrdatas: targets,
472-
Ttl: ttl,
473-
Type: ep.RecordType,
474-
}
351+
return &change
475352
}

provider/google/google_test.go

+18-4
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ func TestGoogleApplyChanges(t *testing.T) {
321321
endpoint.NewEndpoint("create-test.zone-1.local", endpoint.RecordTypeA, "8.8.8.8"),
322322
endpoint.NewEndpointWithTTL("create-test-ttl.zone-2.local", endpoint.RecordTypeA, endpoint.TTL(15), "8.8.4.4"),
323323
endpoint.NewEndpoint("create-test-cname.zone-1.local", endpoint.RecordTypeCNAME, "create-test-cname"),
324+
endpoint.NewEndpoint("create-test-ns.zone-1.local", endpoint.RecordTypeNS, "create-test-ns"),
324325
endpoint.NewEndpoint("filter-create-test.zone-3.local", endpoint.RecordTypeA, "4.2.2.2"),
325326
endpoint.NewEndpoint("nomatch-create-test.zone-0.local", endpoint.RecordTypeA, "4.2.2.1"),
326327
},
@@ -374,6 +375,12 @@ func TestGoogleApplyChanges(t *testing.T) {
374375
Ttl: googleRecordTTL,
375376
Rrdatas: []string{"updated-test-cname."},
376377
},
378+
&dns.ResourceRecordSet{
379+
Name: "create-test-ns.zone-1.local.",
380+
Type: "NS",
381+
Ttl: googleRecordTTL,
382+
Rrdatas: []string{"create-test-ns."},
383+
},
377384
},
378385
&dns.ManagedZone{
379386
Name: "zone-2",
@@ -477,10 +484,11 @@ func TestGoogleApplyChangesEmpty(t *testing.T) {
477484
assert.NoError(t, p.ApplyChanges(context.Background(), &plan.Changes{}))
478485
}
479486

480-
func TestNewFilteredRecords(t *testing.T) {
487+
func TestNewChange(t *testing.T) {
481488
p := newGoogleProvider().WithMockClients(newMockClients(t))
482489

483-
records := p.newFilteredRecords([]*endpoint.Endpoint{
490+
records := []*dns.ResourceRecordSet{}
491+
for _, ep := range []*endpoint.Endpoint{
484492
endpoint.NewEndpointWithTTL("update-test.zone-2.local", endpoint.RecordTypeA, 1, "8.8.4.4"),
485493
endpoint.NewEndpointWithTTL("delete-test.zone-2.local", endpoint.RecordTypeA, 120, "8.8.4.4"),
486494
endpoint.NewEndpointWithTTL("update-test-cname.zone-1.local", endpoint.RecordTypeCNAME, 4000, "update-test-cname"),
@@ -489,7 +497,13 @@ func TestNewFilteredRecords(t *testing.T) {
489497
endpoint.NewEndpointWithTTL("update-test-mx.zone-1.local", endpoint.RecordTypeMX, 6000, "10 mail"),
490498
endpoint.NewEndpoint("delete-test.zone-1.local", endpoint.RecordTypeA, "8.8.8.8"),
491499
endpoint.NewEndpoint("delete-test-cname.zone-1.local", endpoint.RecordTypeCNAME, "delete-test-cname"),
492-
})
500+
} {
501+
records = append(records, p.newChange(&plan.RRSetChange{
502+
Name: plan.RRName(provider.EnsureTrailingDot(ep.DNSName)),
503+
Type: plan.RRType(ep.RecordType),
504+
Create: []*endpoint.Endpoint{ep},
505+
}).Additions...)
506+
}
493507

494508
validateChangeRecords(t, records, []*dns.ResourceRecordSet{
495509
{Name: "update-test.zone-2.local.", Rrdatas: []string{"8.8.4.4"}, Type: "A", Ttl: 1},
@@ -746,7 +760,7 @@ func isValidRecordSet(recordSet *dns.ResourceRecordSet) bool {
746760
}
747761

748762
switch recordSet.Type {
749-
case endpoint.RecordTypeCNAME:
763+
case endpoint.RecordTypeCNAME, endpoint.RecordTypeMX, endpoint.RecordTypeNS, endpoint.RecordTypeSRV:
750764
for _, rrd := range recordSet.Rrdatas {
751765
if !hasTrailingDot(rrd) {
752766
return false

0 commit comments

Comments
 (0)