Skip to content

Commit 45a6911

Browse files
committed
fix: error on multiple servers (#92)
1 parent 25606bc commit 45a6911

2 files changed

Lines changed: 86 additions & 7 deletions

File tree

main.go

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -440,42 +440,59 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
440440

441441
go func() {
442442
var entries []*output.Entry
443+
multiServer := len(opts.Server) > 1
443444
for _, serverStr := range opts.Server {
444445
// Parse server address and transport type
445446
server, transportType, err := parseServer(serverStr)
446447
if err != nil {
448+
if multiServer {
449+
log.Warnf("Skipping server %s: %v", serverStr, err)
450+
continue
451+
}
447452
errChan <- fmt.Errorf("parsing server %s: %s", serverStr, err)
453+
return
448454
}
449455
log.Debugf("Using server %s with transport %s", server, transportType)
450456

451457
// Recursive zone transfer
452458
if opts.RecAXFR {
453459
if opts.Name == "" {
454460
errChan <- fmt.Errorf("no name specified for AXFR")
461+
return
455462
}
456463
_ = RecAXFR(opts.Name, server, out)
457464
errChan <- nil // exit immediately
465+
return
458466
}
459467

460468
// Create transport
461469
txp, err := newTransport(server, transportType, tlsConfig)
462470
if err != nil {
471+
if multiServer {
472+
log.Warnf("Skipping server %s (transport error): %v", server, err)
473+
continue
474+
}
463475
errChan <- fmt.Errorf("creating transport: %s", err)
476+
return
464477
}
465478

466479
startTime := time.Now()
467480
var replies []*dns.Msg
481+
var serverFailed error
468482
for _, msg := range msgs {
469483
if txp == nil {
470-
errChan <- fmt.Errorf("transport is nil")
484+
serverFailed = fmt.Errorf("transport is nil")
485+
break
471486
}
472487
reply, err := (*txp).Exchange(&msg)
473488
if err != nil {
474-
errChan <- fmt.Errorf("exchange: %s", err)
489+
serverFailed = fmt.Errorf("exchange: %s", err)
490+
break
475491
}
476492

477493
if reply == nil {
478-
errChan <- fmt.Errorf("no reply from server")
494+
serverFailed = fmt.Errorf("no reply from server")
495+
break
479496
}
480497

481498
if opts.ShowOpt {
@@ -487,11 +504,23 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
487504
}
488505

489506
if transportType != transport.TypeQUIC && opts.IDCheck && reply.Id != msg.Id {
490-
errChan <- fmt.Errorf("ID mismatch: expected %d, got %d", msg.Id, reply.Id)
507+
serverFailed = fmt.Errorf("ID mismatch: expected %d, got %d", msg.Id, reply.Id)
508+
break
491509
}
492510
replies = append(replies, reply)
493511
}
494512

513+
// If this server failed at any point, either skip (multi) or exit (single)
514+
if serverFailed != nil {
515+
_ = (*txp).Close()
516+
if multiServer {
517+
log.Warnf("Server %s failed: %v", server, serverFailed)
518+
continue
519+
}
520+
errChan <- serverFailed
521+
return
522+
}
523+
495524
// Process TXT parsing
496525
if opts.TXTConcat {
497526
for _, reply := range replies {
@@ -522,10 +551,21 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
522551
entries = append(entries, e)
523552

524553
if err := (*txp).Close(); err != nil {
525-
errChan <- fmt.Errorf("closing transport: %s", err)
554+
if multiServer {
555+
log.Warnf("Server %s close error: %v", server, err)
556+
} else {
557+
errChan <- fmt.Errorf("closing transport: %s", err)
558+
return
559+
}
526560
}
527561
}
528562

563+
// If none of the servers succeeded, return an error in multi-server mode
564+
if len(entries) == 0 {
565+
errChan <- fmt.Errorf("all servers failed")
566+
return
567+
}
568+
529569
printer := output.Printer{
530570
Out: out,
531571
Opts: &opts,
@@ -538,6 +578,7 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
538578
// Skip printing if NSIDOnly is set
539579
if opts.NSIDOnly {
540580
errChan <- nil
581+
return
541582
}
542583

543584
switch opts.Format {
@@ -551,14 +592,24 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
551592
printer.PrintStructured(entries)
552593
default:
553594
errChan <- fmt.Errorf("invalid output format %s", opts.Format)
595+
return
554596
}
555597

556598
errChan <- nil
557599
}()
558600

601+
// When multiple servers are configured, queries are attempted sequentially.
602+
// Give the worker goroutine enough time to iterate through all servers by
603+
// scaling the overall timeout proportionally. This prevents the controller
604+
// from timing out before a later server succeeds.
605+
totalTimeout := opts.Timeout
606+
if len(opts.Server) > 1 {
607+
totalTimeout = opts.Timeout * time.Duration(len(opts.Server))
608+
}
609+
559610
select {
560-
case <-time.After(opts.Timeout):
561-
return fmt.Errorf("timeout after %s", opts.Timeout)
611+
case <-time.After(totalTimeout):
612+
return fmt.Errorf("timeout after %s", totalTimeout)
562613
case err := <-errChan:
563614
return err
564615
}

main_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,3 +666,31 @@ func TestMainEDE_Pretty(t *testing.T) {
666666
assert.Contains(t, s, "6 (DNSSEC Bogus)")
667667
assert.Contains(t, s, "This EDE was intentionally inserted by dnsdist")
668668
}
669+
670+
func TestMainMultipleServersSkipFailures(t *testing.T) {
671+
out, err := run(
672+
"--all",
673+
"-q", "example.com",
674+
"-t", "A",
675+
"-s", "127.127.127.127:1", // expected to fail
676+
"-s", "8.8.8.8", // expected to succeed
677+
)
678+
assert.Nil(t, err)
679+
s := out.String()
680+
assert.Regexp(t, regexp.MustCompile(`example\.com\. .* A .*`), s)
681+
// When multiple servers are used, the server suffix is appended to answers
682+
assert.Truef(t, strings.Contains(s, "(8.8.8.8:53)") || strings.Contains(s, "from 8.8.8.8:53"), "expected output to include successful server 8.8.8.8:53, got: %s", s)
683+
}
684+
685+
func TestMainMultipleServersAllFail(t *testing.T) {
686+
_, err := run(
687+
"--all",
688+
"-q", "example.com",
689+
"-t", "A",
690+
"-s", "127.127.127.127:1",
691+
"-s", "127.127.127.127:2",
692+
"--timeout", "1s",
693+
)
694+
assert.NotNil(t, err)
695+
assert.Contains(t, err.Error(), "timeout after 2s")
696+
}

0 commit comments

Comments
 (0)