-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathmain.go
148 lines (144 loc) · 3.46 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package main
import (
"context"
"encoding/base32"
"encoding/base64"
"flag"
"fmt"
"log"
"net"
"os"
"os/signal"
"syscall"
"time"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var (
deviceFlag = flag.String("device", "",
"name of Wireguard device to manage")
dnsServerFlag = flag.String("dns", "",
"ip:port of DNS server")
dnsZoneFlag = flag.String("zone", "", "dns zone name")
)
func main() {
flag.Parse()
if len(*deviceFlag) < 1 {
log.Fatal("missing device flag")
}
if len(*dnsZoneFlag) < 1 {
log.Fatal("missing zone flag")
}
if len(*dnsServerFlag) < 1 {
log.Fatal("missing dns flag")
}
_, _, err := net.SplitHostPort(*dnsServerFlag)
if err != nil {
log.Fatalf("invalid dns flag value: %v", err)
}
wgClient, err := wgctrl.New()
if err != nil {
log.Fatalf("error constructing Wireguard control client: %v",
err)
}
wgDevice, err := wgClient.Device(*deviceFlag)
if err != nil {
log.Fatalf(
"error retrieving Wireguard device '%s' info: %v",
*deviceFlag, err)
}
if len(wgDevice.Peers) < 1 {
log.Println("no peers found")
os.Exit(0)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
dnsClient := &dns.Client{
Timeout: time.Second * 5,
}
for _, peer := range wgDevice.Peers {
select {
case <-ctx.Done():
return
default:
}
srvCtx, srvCancel := context.WithCancel(ctx)
pubKeyBase32 := base32.StdEncoding.EncodeToString(peer.PublicKey[:])
pubKeyBase64 := base64.StdEncoding.EncodeToString(peer.PublicKey[:])
m := &dns.Msg{}
question := fmt.Sprintf("%s._wireguard._udp.%s",
pubKeyBase32, dns.Fqdn(*dnsZoneFlag))
m.SetQuestion(question, dns.TypeSRV)
r, _, err := dnsClient.ExchangeContext(srvCtx, m, *dnsServerFlag)
srvCancel()
if err != nil {
log.Printf(
"[%s] failed to lookup SRV: %v", pubKeyBase64, err)
continue
}
if len(r.Answer) < 1 {
log.Printf("[%s] no SRV records found", pubKeyBase64)
continue
}
srv, ok := r.Answer[0].(*dns.SRV)
if !ok {
log.Printf(
"[%s] non-SRV answer in response to SRV query: %s",
pubKeyBase64, r.Answer[0].String())
}
if len(r.Extra) < 1 {
log.Printf("[%s] SRV response missing extra A/AAAA",
pubKeyBase64)
}
var endpointIP net.IP
hostA, ok := r.Extra[0].(*dns.A)
if !ok {
hostAAAA, ok := r.Extra[0].(*dns.AAAA)
if !ok {
log.Printf(
"[%s] non-A/AAAA extra in SRV response: %s",
pubKeyBase64, r.Extra[0].String())
continue
}
endpointIP = hostAAAA.AAAA
} else {
endpointIP = hostA.A
}
peerConfig := wgtypes.PeerConfig{
PublicKey: peer.PublicKey,
UpdateOnly: true,
Endpoint: &net.UDPAddr{
IP: endpointIP,
Port: int(srv.Port),
},
}
deviceConfig := wgtypes.Config{
PrivateKey: &wgDevice.PrivateKey,
ReplacePeers: false,
Peers: []wgtypes.PeerConfig{peerConfig},
}
if wgDevice.FirewallMark > 0 {
deviceConfig.FirewallMark = &wgDevice.FirewallMark
}
err = wgClient.ConfigureDevice(*deviceFlag, deviceConfig)
if err != nil {
log.Printf(
"[%s] failed to configure peer on %s, error: %v",
pubKeyBase64, *deviceFlag, err)
}
}
}()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case sig := <-sigCh:
log.Printf("exiting due to signal %s", sig)
cancel()
<-done
case <-done:
}
}