Skip to content

Commit 3416ca5

Browse files
committed
Added UDPSize properties to client and server
1 parent b5bcf75 commit 3416ca5

File tree

8 files changed

+167
-15
lines changed

8 files changed

+167
-15
lines changed

client.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ import (
1616
type Client struct {
1717
Net string // protocol (can be "udp" or "tcp", by default - "udp")
1818
Timeout time.Duration // read/write timeout
19+
20+
// UDPSize is the maximum size of a DNS response (or query) this client can
21+
// sent or receive. If not set, we use dns.MinMsgSize by default.
22+
UDPSize int
1923
}
2024

2125
// ResolverInfo contains DNSCrypt resolver information necessary for decryption/encryption
@@ -158,7 +162,11 @@ func (c *Client) readResponse(conn net.Conn) ([]byte, error) {
158162
}
159163

160164
if proto == "udp" {
161-
response := make([]byte, maxQueryLen)
165+
bufSize := c.UDPSize
166+
if bufSize == 0 {
167+
bufSize = dns.MinMsgSize
168+
}
169+
response := make([]byte, bufSize)
162170
n, err := conn.Read(response)
163171
if err != nil {
164172
return nil, err
@@ -182,7 +190,12 @@ func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error)
182190
if err != nil {
183191
return nil, err
184192
}
185-
return q.Encrypt(query, resolverInfo.SharedKey)
193+
b, err := q.Encrypt(query, resolverInfo.SharedKey)
194+
if len(b) > c.maxQuerySize() {
195+
return nil, ErrQueryTooLarge
196+
}
197+
198+
return b, err
186199
}
187200

188201
// decrypts decrypts a DNS message using a shared key from the resolver info
@@ -212,7 +225,8 @@ func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {
212225

213226
query := new(dns.Msg)
214227
query.SetQuestion(providerName, dns.TypeTXT)
215-
client := dns.Client{Net: c.Net, UDPSize: uint16(maxQueryLen), Timeout: c.Timeout}
228+
// use 1252 as a UDPSize for this client to make sure the buffer is not too small
229+
client := dns.Client{Net: c.Net, UDPSize: uint16(1252), Timeout: c.Timeout}
216230
r, _, err := client.Exchange(query, stamp.ServerAddrStr)
217231
if err != nil {
218232
return nil, err
@@ -284,3 +298,15 @@ func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {
284298

285299
return nil, certErr
286300
}
301+
302+
func (c *Client) maxQuerySize() int {
303+
if c.Net == "tcp" {
304+
return dns.MaxMsgSize
305+
}
306+
307+
if c.UDPSize > 0 {
308+
return c.UDPSize
309+
}
310+
311+
return dns.MinMsgSize
312+
}

constants.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ const (
6161
// Some servers do not work if padded length is less than 256. Example: Quad9
6262
minUDPQuestionSize = 256
6363

64-
// <max-query-len> is the maximum allowed query length
65-
maxQueryLen = 1252
66-
6764
// Minimum possible DNS packet size
6865
minDNSPacketSize = 12 + 5
6966

encrypted_query.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,6 @@ func (q *EncryptedQuery) Encrypt(packet []byte, sharedKey [sharedKeySize]byte) (
7272
return nil, ErrEsVersion
7373
}
7474

75-
if len(query) > maxQueryLen {
76-
return nil, ErrQueryTooLarge
77-
}
78-
7975
return query, nil
8076
}
8177

server.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ type Server struct {
4848
// ResolverCert contains resolver certificate.
4949
ResolverCert *Cert
5050

51+
// UDPSize is the default buffer size to use to read incoming UDP messages.
52+
// If not set it defaults to dns.MinMsgSize (512 B).
53+
UDPSize int
54+
5155
// Handler to invoke. If nil, uses DefaultHandler.
5256
Handler Handler
5357

@@ -148,6 +152,10 @@ func (s *Server) init() {
148152
s.tcpConns = map[net.Conn]struct{}{}
149153
s.udpListeners = map[*net.UDPConn]struct{}{}
150154
s.tcpListeners = map[net.Listener]struct{}{}
155+
156+
if s.UDPSize == 0 {
157+
s.UDPSize = dns.MinMsgSize
158+
}
151159
}
152160

153161
// isStarted returns true if the server is processing queries right now

server_tcp.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func (w *TCPResponseWriter) RemoteAddr() net.Addr {
3535

3636
// WriteMsg writes DNS message to the client
3737
func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error {
38-
m.Truncate(dnsSize("tcp", w.req))
38+
normalize("tcp", w.req, m)
3939

4040
res, err := w.encrypt(m, w.query)
4141
if err != nil {

server_test.go

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,85 @@ func TestServer_ReadTimeout(t *testing.T) {
5656
testThisServerRespondMessages(t, "tcp", srv)
5757
}
5858

59+
func TestServer_UDPTruncateMessage(t *testing.T) {
60+
// Create a test server that returns large response which should be
61+
// truncated if sent over UDP
62+
srv := newTestServer(t, &testLargeMsgHandler{})
63+
t.Cleanup(func() {
64+
require.NoError(t, srv.Close())
65+
})
66+
67+
// Create client and connect
68+
client := &Client{
69+
Timeout: 1 * time.Second,
70+
Net: "udp",
71+
}
72+
serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port)
73+
stamp := dnsstamps.ServerStamp{
74+
ServerAddrStr: serverAddr,
75+
ServerPk: srv.resolverPk,
76+
ProviderName: srv.server.ProviderName,
77+
Proto: dnsstamps.StampProtoTypeDNSCrypt,
78+
}
79+
ri, err := client.DialStamp(stamp)
80+
require.NoError(t, err)
81+
require.NotNil(t, ri)
82+
83+
// Send a test message and check that the response was truncated
84+
m := createTestMessage()
85+
res, err := client.Exchange(m, ri)
86+
require.NoError(t, err)
87+
require.NotNil(t, res)
88+
require.Equal(t, dns.RcodeSuccess, res.Rcode)
89+
require.Len(t, res.Answer, 0)
90+
require.True(t, res.Truncated)
91+
}
92+
93+
func TestServer_UDPEDNS0_NoTruncate(t *testing.T) {
94+
// Create a test server that returns large response which should be
95+
// truncated if sent over UDP
96+
// However, when EDNS0 is set with the buffer large enough, there should
97+
// be no truncation
98+
srv := newTestServer(t, &testLargeMsgHandler{})
99+
t.Cleanup(func() {
100+
require.NoError(t, srv.Close())
101+
})
102+
103+
// Create client and connect
104+
client := &Client{
105+
Timeout: 1 * time.Second,
106+
Net: "udp",
107+
UDPSize: 7000, // make sure the client will be able to read the response
108+
}
109+
serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port)
110+
stamp := dnsstamps.ServerStamp{
111+
ServerAddrStr: serverAddr,
112+
ServerPk: srv.resolverPk,
113+
ProviderName: srv.server.ProviderName,
114+
Proto: dnsstamps.StampProtoTypeDNSCrypt,
115+
}
116+
ri, err := client.DialStamp(stamp)
117+
require.NoError(t, err)
118+
require.NotNil(t, ri)
119+
120+
// Send a test message with UDP buffer size large enough
121+
// and check that the response was NOT truncated
122+
m := createTestMessage()
123+
m.Extra = append(m.Extra, &dns.OPT{
124+
Hdr: dns.RR_Header{
125+
Name: ".",
126+
Rrtype: dns.TypeOPT,
127+
Class: 2000, // Set large enough UDPSize here
128+
},
129+
})
130+
res, err := client.Exchange(m, ri)
131+
require.NoError(t, err)
132+
require.NotNil(t, res)
133+
require.Equal(t, dns.RcodeSuccess, res.Rcode)
134+
require.Len(t, res.Answer, 64)
135+
require.False(t, res.Truncated)
136+
}
137+
59138
func testServerServeCert(t *testing.T, network string) {
60139
srv := newTestServer(t, &testHandler{})
61140
t.Cleanup(func() {
@@ -193,17 +272,44 @@ type testHandler struct{}
193272

194273
// ServeDNS - implements Handler interface
195274
func (h *testHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error {
196-
// Google DNS
197275
res := new(dns.Msg)
198276
res.SetReply(r)
277+
199278
answer := new(dns.A)
200279
answer.Hdr = dns.RR_Header{
201280
Name: r.Question[0].Name,
202281
Rrtype: dns.TypeA,
203282
Ttl: 300,
204283
Class: dns.ClassINET,
205284
}
285+
// First record is from Google DNS
206286
answer.A = net.IPv4(8, 8, 8, 8)
207287
res.Answer = append(res.Answer, answer)
288+
289+
return rw.WriteMsg(res)
290+
}
291+
292+
// testLargeMsgHandler is a handler that returns a huge response
293+
// used for testing messages truncation
294+
type testLargeMsgHandler struct{}
295+
296+
// ServeDNS - implements Handler interface
297+
func (h *testLargeMsgHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error {
298+
res := new(dns.Msg)
299+
res.SetReply(r)
300+
301+
for i := 0; i < 64; i++ {
302+
answer := new(dns.A)
303+
answer.Hdr = dns.RR_Header{
304+
Name: r.Question[0].Name,
305+
Rrtype: dns.TypeA,
306+
Ttl: 300,
307+
Class: dns.ClassINET,
308+
}
309+
answer.A = net.IPv4(127, 0, 0, byte(i))
310+
res.Answer = append(res.Answer, answer)
311+
}
312+
313+
res.Compress = true
208314
return rw.WriteMsg(res)
209315
}

server_udp.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ type encryptionFunc func(m *dns.Msg, q EncryptedQuery) ([]byte, error)
2020
type UDPResponseWriter struct {
2121
udpConn *net.UDPConn // UDP connection
2222
sess *dns.SessionUDP // SessionUDP (necessary to use dns.WriteToSessionUDP)
23-
encrypt encryptionFunc // DNSCRypt encryption function
23+
encrypt encryptionFunc // DNSCrypt encryption function
2424
req *dns.Msg // DNS query that was processed
2525
query EncryptedQuery // DNSCrypt query properties
2626
}
@@ -40,7 +40,7 @@ func (w *UDPResponseWriter) RemoteAddr() net.Addr {
4040

4141
// WriteMsg writes DNS message to the client
4242
func (w *UDPResponseWriter) WriteMsg(m *dns.Msg) error {
43-
m.Truncate(dnsSize("udp", w.req))
43+
normalize("udp", w.req, m)
4444

4545
res, err := w.encrypt(m, w.query)
4646
if err != nil {
@@ -157,7 +157,7 @@ func (s *Server) cleanUpUDP(udpWg *sync.WaitGroup, l *net.UDPConn) {
157157
// readUDPMsg reads incoming UDP message
158158
func (s *Server) readUDPMsg(l *net.UDPConn) ([]byte, *dns.SessionUDP, error) {
159159
_ = l.SetReadDeadline(time.Now().Add(defaultReadTimeout))
160-
b := make([]byte, dns.MinMsgSize)
160+
b := make([]byte, s.UDPSize)
161161
n, sess, err := dns.ReadFromSessionUDP(l, b)
162162
if err != nil {
163163
return nil, nil, err

util.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,25 @@ func unpackTxtString(s string) ([]byte, error) {
184184
return msg, nil
185185
}
186186

187+
// normalize truncates the DNS response if needed depending on the protocol
188+
func normalize(proto string, req *dns.Msg, res *dns.Msg) {
189+
size := dnsSize(proto, req)
190+
// DNSCrypt encryption adds a header to each message, we should
191+
// consider this when truncating a message.
192+
// 64 should cover all cases
193+
size = size - 64
194+
195+
// Truncate response message
196+
res.Truncate(size)
197+
198+
// In case of UDP it is safer to simply remove all response records
199+
// dns.Msg.Truncate method will not consider that we need a response
200+
// shorter than dns.MinMsgSize
201+
if res.Truncated && proto == "udp" {
202+
res.Answer = nil
203+
}
204+
}
205+
187206
// dnsSize returns if buffer size *advertised* in the requests OPT record.
188207
// Or when the request was over TCP, we return the maximum allowed size of 64K.
189208
func dnsSize(proto string, r *dns.Msg) int {

0 commit comments

Comments
 (0)