Skip to content

Commit c5e3820

Browse files
committed
Fix read DNS message
1 parent 9ac31d0 commit c5e3820

File tree

2 files changed

+47
-40
lines changed

2 files changed

+47
-40
lines changed

common/sniff/dns.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func StreamDomainNameQuery(readCtx context.Context, reader io.Reader) (*adapter.
2222
if err != nil {
2323
return nil, err
2424
}
25-
if length > 512 {
25+
if length == 0 {
2626
return nil, os.ErrInvalid
2727
}
2828
_buffer := buf.StackNewSize(int(length))

outbound/dns.go

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ package outbound
33
import (
44
"context"
55
"encoding/binary"
6-
"io"
76
"net"
87
"os"
98

109
"github.com/sagernet/sing-box/adapter"
1110
"github.com/sagernet/sing-box/common/canceler"
1211
C "github.com/sagernet/sing-box/constant"
12+
"github.com/sagernet/sing-dns"
1313
"github.com/sagernet/sing/common"
1414
"github.com/sagernet/sing/common/buf"
1515
M "github.com/sagernet/sing/common/metadata"
@@ -47,53 +47,60 @@ func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.Pa
4747
func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
4848
defer conn.Close()
4949
ctx = adapter.WithContext(ctx, &metadata)
50-
_buffer := buf.StackNewSize(1024)
51-
defer common.KeepAlive(_buffer)
52-
buffer := common.Dup(_buffer)
53-
defer buffer.Release()
5450
for {
55-
var queryLength uint16
56-
err := binary.Read(conn, binary.BigEndian, &queryLength)
51+
err := d.handleConnection(ctx, conn, metadata)
5752
if err != nil {
5853
return err
5954
}
60-
if queryLength > 1024 {
61-
return io.ErrShortBuffer
62-
}
63-
buffer.FullReset()
64-
_, err = buffer.ReadFullFrom(conn, int(queryLength))
55+
}
56+
}
57+
58+
func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
59+
var queryLength uint16
60+
err := binary.Read(conn, binary.BigEndian, &queryLength)
61+
if err != nil {
62+
return err
63+
}
64+
if queryLength == 0 {
65+
return dns.RCodeFormatError
66+
}
67+
_buffer := buf.StackNewSize(int(queryLength))
68+
defer common.KeepAlive(_buffer)
69+
buffer := common.Dup(_buffer)
70+
defer buffer.Release()
71+
_, err = buffer.ReadFullFrom(conn, int(queryLength))
72+
if err != nil {
73+
return err
74+
}
75+
var message dnsmessage.Message
76+
err = message.Unpack(buffer.Bytes())
77+
if err != nil {
78+
return err
79+
}
80+
if len(message.Questions) > 0 {
81+
question := message.Questions[0]
82+
metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
83+
}
84+
go func() error {
85+
response, err := d.router.Exchange(ctx, &message)
6586
if err != nil {
6687
return err
6788
}
68-
var message dnsmessage.Message
69-
err = message.Unpack(buffer.Bytes())
89+
_responseBuffer := buf.StackNewPacket()
90+
defer common.KeepAlive(_responseBuffer)
91+
responseBuffer := common.Dup(_responseBuffer)
92+
defer responseBuffer.Release()
93+
responseBuffer.Resize(2, 0)
94+
n, err := response.AppendPack(responseBuffer.Index(0))
7095
if err != nil {
7196
return err
7297
}
73-
if len(message.Questions) > 0 {
74-
question := message.Questions[0]
75-
metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
76-
}
77-
go func() error {
78-
response, err := d.router.Exchange(ctx, &message)
79-
if err != nil {
80-
return err
81-
}
82-
_responseBuffer := buf.StackNewPacket()
83-
defer common.KeepAlive(_responseBuffer)
84-
responseBuffer := common.Dup(_responseBuffer)
85-
defer responseBuffer.Release()
86-
responseBuffer.Resize(2, 0)
87-
n, err := response.AppendPack(responseBuffer.Index(0))
88-
if err != nil {
89-
return err
90-
}
91-
responseBuffer.Truncate(len(n))
92-
binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
93-
_, err = conn.Write(responseBuffer.Bytes())
94-
return err
95-
}()
96-
}
98+
responseBuffer.Truncate(len(n))
99+
binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
100+
_, err = conn.Write(responseBuffer.Bytes())
101+
return err
102+
}()
103+
return nil
97104
}
98105

99106
func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
@@ -103,7 +110,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
103110
var group task.Group
104111
group.Append0(func(ctx context.Context) error {
105112
defer cancel()
106-
_buffer := buf.StackNewSize(1024)
113+
_buffer := buf.StackNewSize(dns.FixedPacketSize)
107114
defer common.KeepAlive(_buffer)
108115
buffer := common.Dup(_buffer)
109116
defer buffer.Release()

0 commit comments

Comments
 (0)