forked from slackhq/nebula
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdns_server.go
More file actions
449 lines (404 loc) · 11.3 KB
/
dns_server.go
File metadata and controls
449 lines (404 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
package nebula
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"sync/atomic"
"github.com/miekg/dns"
"github.com/slackhq/nebula/config"
)
type dnsServer struct {
sync.RWMutex
l *slog.Logger
ctx context.Context
dnsMap4 map[string]netip.Addr
dnsMap6 map[string]netip.Addr
hostMap *HostMap
pki *PKI
// selfHost is the cached FQDN we last seeded for ourselves
selfHost string
mux *dns.ServeMux
// enabled mirrors `lighthouse.serve_dns && lighthouse.am_lighthouse`.
// Start, Add, and reload consult it so callers don't need to know the
// gating rules. When it toggles off via reload, accumulated records are
// cleared so a later re-enable starts with a fresh map populated from
// new handshakes.
enabled atomic.Bool
serverMu sync.Mutex
server *dns.Server
// started is closed once `server` has finished binding (or after
// ListenAndServe returns on a bind failure). Stop waits on it before
// calling Shutdown to avoid the miekg/dns "server not started" race
// where a Shutdown that arrives before bind completes is silently
// ignored, leaving the listener running forever.
started chan struct{}
addr string
}
// newDnsServerFromConfig builds a dnsServer, applies the initial config, and
// registers a reload callback. The reload callback is registered before the
// initial config is applied, so a SIGHUP can later enable, fix, or disable
// DNS even if the initial application failed.
//
// The dnsServer internally gates on `lighthouse.serve_dns &&
// lighthouse.am_lighthouse`. Start and Add are safe to call unconditionally,
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
// watcher that tears the listener down on nebula shutdown. The returned
// pointer is always non-nil, even on error.
func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, pki *PKI, hostMap *HostMap, c *config.C) (*dnsServer, error) {
ds := &dnsServer{
l: l,
ctx: ctx,
dnsMap4: make(map[string]netip.Addr),
dnsMap6: make(map[string]netip.Addr),
hostMap: hostMap,
pki: pki,
}
ds.mux = dns.NewServeMux()
ds.mux.HandleFunc(".", ds.handleDnsRequest)
c.RegisterReloadCallback(func(c *config.C) {
if err := ds.reload(c, false); err != nil {
ds.l.Error("Failed to reload DNS responder from config", "error", err)
}
})
if err := ds.reload(c, true); err != nil {
return ds, err
}
ds.seedSelf()
return ds, nil
}
// reload applies the latest config and reconciles the running state with it:
// - enabled toggled on -> spawn a runner
// - enabled toggled off -> stop the runner
// - listen address changed (while running) -> restart on the new address
// - everything else -> no-op
//
// On the initial call it only records configuration; Control.Start is what
// launches the first runner via dnsStart.
func (d *dnsServer) reload(c *config.C, initial bool) error {
wantsDns := c.GetBool("lighthouse.serve_dns", false)
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
enabled := wantsDns && amLighthouse
newAddr := getDnsServerAddr(c)
d.serverMu.Lock()
running := d.server
runningStarted := d.started
sameAddr := d.addr == newAddr
d.addr = newAddr
d.enabled.Store(enabled)
d.serverMu.Unlock()
if initial {
if wantsDns && !amLighthouse {
d.l.Warn("DNS server refusing to run because this host is not a lighthouse.")
}
return nil
}
if !enabled {
if running != nil {
d.Stop()
}
// Drop any records that accumulated while enabled; a later re-enable
// will repopulate from fresh handshakes and a fresh seedSelf.
d.clearRecords()
return nil
}
if running == nil {
// Was disabled (or never started); bring it up now.
go d.Start()
} else if !sameAddr {
d.shutdownServer(running, runningStarted, "reload")
// Old Start goroutine has now exited; bring up a fresh listener on the new address.
go d.Start()
}
// Refresh the self entry every enabled reload so cert renewals that change our name or VPN addresses are picked up.
d.seedSelf()
return nil
}
// shutdownServer waits for the server to finish binding (so Shutdown actually
// stops it rather than no-oping) and then shuts it down.
func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reason string) {
if srv == nil {
return
}
if started != nil {
<-started
}
if err := srv.Shutdown(); err != nil {
d.l.Warn("Failed to shut down the DNS responder", "reason", reason, "error", err)
}
}
// Start binds and serves the DNS responder. Blocks until Stop is called or
// the listener errors. Safe to call when DNS is disabled (returns
// immediately). This is what Control.dnsStart points at.
//
// Must be invoked after the tun device is active so that lighthouse.dns.host
// may bind to a nebula IP.
func (d *dnsServer) Start() {
if !d.enabled.Load() {
return
}
started := make(chan struct{})
d.serverMu.Lock()
if d.ctx.Err() != nil {
d.serverMu.Unlock()
return
}
addr := d.addr
server := &dns.Server{
Addr: addr,
Net: "udp",
Handler: d.mux,
NotifyStartedFunc: func() { close(started) },
}
d.server = server
d.started = started
d.serverMu.Unlock()
// Per-invocation ctx watcher. Exits when Start does, so we don't leak a
// watcher per reload-driven restart.
done := make(chan struct{})
go func() {
select {
case <-d.ctx.Done():
d.shutdownServer(server, started, "shutdown")
case <-done:
}
}()
d.l.Info("Starting DNS responder", "dnsListener", addr)
err := server.ListenAndServe()
close(done)
// If the listener never bound (bind error) NotifyStartedFunc never fires,
// so close started here to release any Stop caller waiting on it.
select {
case <-started:
default:
close(started)
}
if err != nil {
d.l.Warn("Failed to run the DNS responder", "error", err)
}
}
// Stop shuts down the active server, if any. Idempotent.
func (d *dnsServer) Stop() {
d.serverMu.Lock()
srv := d.server
started := d.started
d.server = nil
d.started = nil
d.serverMu.Unlock()
d.shutdownServer(srv, started, "stop")
}
// Query returns the address for the given name and query type. The second
// return value reports whether the name is known at all (in either A or AAAA),
// which lets callers distinguish NODATA from NXDOMAIN.
func (d *dnsServer) Query(q uint16, data string) (netip.Addr, bool) {
data = strings.ToLower(data)
d.RLock()
defer d.RUnlock()
addr4, haveV4 := d.dnsMap4[data]
addr6, haveV6 := d.dnsMap6[data]
nameExists := haveV4 || haveV6
switch q {
case dns.TypeA:
if haveV4 {
return addr4, nameExists
}
case dns.TypeAAAA:
if haveV6 {
return addr6, nameExists
}
}
return netip.Addr{}, nameExists
}
func (d *dnsServer) QueryCert(data string) string {
if len(data) < 2 {
return ""
}
ip, err := netip.ParseAddr(data[:len(data)-1])
if err != nil {
return ""
}
// The hostmap only ever contains peers we have handshaked with, so it never carries an entry for ourselves.
// Answer self lookups straight from the local cert state.
if cs := d.certState(); cs != nil && cs.myVpnAddrsTable != nil && cs.myVpnAddrsTable.Contains(ip) {
c := cs.GetDefaultCertificate()
if c == nil {
return ""
}
b, err := c.MarshalJSON()
if err != nil {
return ""
}
return string(b)
}
hostinfo := d.hostMap.QueryVpnAddr(ip)
if hostinfo == nil {
return ""
}
q := hostinfo.GetCert()
if q == nil {
return ""
}
b, err := q.Certificate.MarshalJSON()
if err != nil {
return ""
}
return string(b)
}
// clearRecords drops all DNS records, including the self entry.
func (d *dnsServer) clearRecords() {
d.Lock()
defer d.Unlock()
clear(d.dnsMap4)
clear(d.dnsMap6)
d.selfHost = ""
}
// seedSelf inserts (or refreshes) a record for our own cert name pointing at our VPN addresses,
// so a single-lighthouse network can resolve the lighthouse's own hostname without the two-process workaround.
func (d *dnsServer) seedSelf() {
if !d.enabled.Load() {
return
}
cs := d.certState()
if cs == nil {
return
}
c := cs.GetDefaultCertificate()
if c == nil {
return
}
newHost := strings.ToLower(c.Name()) + "."
d.Lock()
defer d.Unlock()
if d.selfHost != "" && d.selfHost != newHost {
delete(d.dnsMap4, d.selfHost)
delete(d.dnsMap6, d.selfHost)
}
d.selfHost = newHost
delete(d.dnsMap4, newHost)
delete(d.dnsMap6, newHost)
haveV4, haveV6 := false, false
for _, addr := range cs.myVpnAddrs {
if addr.Is4() && !haveV4 {
d.dnsMap4[newHost] = addr
haveV4 = true
} else if addr.Is6() && !haveV6 {
d.dnsMap6[newHost] = addr
haveV6 = true
}
if haveV4 && haveV6 {
break
}
}
}
func (d *dnsServer) certState() *CertState {
if d.pki == nil {
return nil
}
return d.pki.getCertState()
}
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
func (d *dnsServer) Add(host string, addresses []netip.Addr) {
if !d.enabled.Load() {
return
}
host = strings.ToLower(host)
d.Lock()
defer d.Unlock()
haveV4 := false
haveV6 := false
for _, addr := range addresses {
if addr.Is4() && !haveV4 {
d.dnsMap4[host] = addr
haveV4 = true
} else if addr.Is6() && !haveV6 {
d.dnsMap6[host] = addr
haveV6 = true
}
if haveV4 && haveV6 {
break
}
}
}
func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
a, _, _ := net.SplitHostPort(addr)
b, err := netip.ParseAddr(a)
if err != nil {
return false
}
if b.IsLoopback() {
return true
}
cs := d.certState()
if cs == nil || cs.myVpnAddrsTable == nil {
return false
}
//if we found it in this table, it's good
return cs.myVpnAddrsTable.Contains(b)
}
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
debugEnabled := d.l.Enabled(context.Background(), slog.LevelDebug)
// Per RFC 2308 §2.2, a name that exists but has no record of the requested
// type must be answered with NOERROR and an empty answer section (NODATA),
// not NXDOMAIN (RFC 2308 §2.1), which is reserved for names that do not
// exist at all.
anyNameExists := false
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeA, dns.TypeAAAA:
qType := dns.TypeToString[q.Qtype]
if debugEnabled {
d.l.Debug("DNS query", "type", qType, "name", q.Name)
}
ip, nameExists := d.Query(q.Qtype, q.Name)
if nameExists {
anyNameExists = true
}
if ip.IsValid() {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeTXT:
// We only answer these queries from nebula nodes or localhost
if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
return
}
if debugEnabled {
d.l.Debug("DNS query", "type", "TXT", "name", q.Name)
}
ip := d.QueryCert(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
if len(m.Answer) == 0 && !anyNameExists {
m.Rcode = dns.RcodeNameError
}
}
func (d *dnsServer) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false
switch r.Opcode {
case dns.OpcodeQuery:
d.parseQuery(m, w)
}
w.WriteMsg(m)
}
func getDnsServerAddr(c *config.C) string {
dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
// Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
if dnsHost == "[::]" {
dnsHost = "::"
}
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
}