Skip to content

Commit 22334d9

Browse files
authored
Merge pull request #42 from gotd/feature/proto-gzip-mitigate-oom
feat(proto): mitigate possible DOS in gzip decoding
2 parents 12233b7 + dfdb786 commit 22334d9

File tree

8 files changed

+91
-16
lines changed

8 files changed

+91
-16
lines changed

bin/buffer.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@ type Buffer struct {
99
Buf []byte
1010
}
1111

12+
// Encode wrapper.
13+
func (b *Buffer) Encode(e Encoder) error {
14+
return e.Encode(b)
15+
}
16+
17+
// Decode wrapper.
18+
func (b *Buffer) Decode(d Decoder) error {
19+
return d.Decode(b)
20+
}
21+
1222
// ResetN resets buffer and expands it to fit n bytes.
1323
func (b *Buffer) ResetN(n int) {
1424
b.Buf = append(b.Buf[:0], make([]byte, n)...)

internal/proto/codec/codec.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ import (
77
"github.com/gotd/td/bin"
88
)
99

10-
func tryReadLength(r io.Reader, b *bin.Buffer) (int, error) {
10+
// readLen reads 32-bit integer and validates it as message length.
11+
func readLen(r io.Reader, b *bin.Buffer) (int, error) {
1112
b.ResetN(bin.Word)
1213
if _, err := io.ReadFull(r, b.Buf[:bin.Word]); err != nil {
1314
return 0, fmt.Errorf("failed to read length: %w", err)

internal/proto/codec/full.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ var errSeqNoMismatch = errors.New("seq_no mismatch")
8383
var errCRCMismatch = errors.New("crc mismatch")
8484

8585
func readFull(r io.Reader, seqNo int, b *bin.Buffer) error {
86-
n, err := tryReadLength(r, b)
86+
n, err := readLen(r, b)
8787
if err != nil {
88-
return err
88+
return xerrors.Errorf("len: %w", err)
8989
}
9090

9191
// Put length, because it need to count CRC.

internal/proto/codec/intermediate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func writeIntermediate(w io.Writer, b *bin.Buffer) error {
8585

8686
// readIntermediate reads payload from r to b.
8787
func readIntermediate(r io.Reader, b *bin.Buffer) error {
88-
n, err := tryReadLength(r, b)
88+
n, err := readLen(r, b)
8989
if err != nil {
9090
return err
9191
}

internal/proto/gzip.go

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package proto
33
import (
44
"bytes"
55
"compress/gzip"
6+
"io"
67
"io/ioutil"
78

89
"golang.org/x/xerrors"
@@ -21,6 +22,26 @@ type GZIP struct {
2122
// GZIPTypeID is TL type id of GZIP.
2223
const GZIPTypeID = 0x3072cfa1
2324

25+
// Encode implements bin.Encoder.
26+
func (g GZIP) Encode(b *bin.Buffer) error {
27+
b.PutID(GZIPTypeID)
28+
29+
// Writing compressed data to buf.
30+
buf := new(bytes.Buffer)
31+
w := gzip.NewWriter(buf)
32+
if _, err := io.Copy(w, bytes.NewReader(g.Data)); err != nil {
33+
return xerrors.Errorf("compress: %w", err)
34+
}
35+
if err := w.Close(); err != nil {
36+
return xerrors.Errorf("close: %w", err)
37+
}
38+
39+
// Writing compressed data as bytes.
40+
b.PutBytes(buf.Bytes())
41+
42+
return nil
43+
}
44+
2445
// Decode implements bin.Decoder.
2546
func (g *GZIP) Decode(b *bin.Buffer) error {
2647
if err := b.ConsumeID(GZIPTypeID); err != nil {
@@ -37,13 +58,17 @@ func (g *GZIP) Decode(b *bin.Buffer) error {
3758
}
3859
defer func() { _ = r.Close() }()
3960

40-
if g.Data, err = ioutil.ReadAll(r); err != nil {
41-
return err
61+
// Apply mitigation for reading too much data which can result in OOM.
62+
const maxUncompressedSize = 1024 * 1024 * 10 // 10 mb
63+
// TODO(ernado): fail explicitly if limit is reached
64+
// Currently we just return nil, but it is better than failing with OOM.
65+
if g.Data, err = ioutil.ReadAll(io.LimitReader(r, maxUncompressedSize)); err != nil {
66+
return xerrors.Errorf("decompress: %w", err)
4267
}
4368

4469
if err := r.Close(); err != nil {
45-
// This will verify checksum.
46-
return xerrors.Errorf("gzip error: %w", err)
70+
// This will verify checksum only if limit is not reached.
71+
return xerrors.Errorf("checksum: %w", err)
4772
}
4873

4974
return nil

internal/proto/gzip_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package proto
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/gotd/td/bin"
10+
)
11+
12+
func TestGZIP_Encode(t *testing.T) {
13+
data := bytes.Repeat([]byte{1, 2, 3}, 100)
14+
g := &GZIP{
15+
Data: data,
16+
}
17+
18+
var b bin.Buffer
19+
require.NoError(t, b.Encode(g))
20+
21+
var decoded GZIP
22+
require.NoError(t, b.Decode(&decoded))
23+
require.Equal(t, data, decoded.Data)
24+
}
25+
26+
func TestGZIP_Decode(t *testing.T) {
27+
g := &GZIP{
28+
Data: make([]byte, 1024*1024*15),
29+
}
30+
var b bin.Buffer
31+
require.NoError(t, b.Encode(g))
32+
33+
var decoded GZIP
34+
// TODO(ernado): fail explicitly if limit is reached
35+
require.NoError(t, b.Decode(&decoded))
36+
require.Less(t, len(decoded.Data), len(g.Data))
37+
}

telegram/client_e2e_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ func (h handler) OnMessage(k tgtest.Session, msgID int64, in *bin.Buffer) error
8383

8484
func testTransport(trp *transport.Transport) func(t *testing.T) {
8585
return func(t *testing.T) {
86-
srv := tgtest.NewUnstartedServer(t, trp.Codec())
86+
t.Helper()
87+
88+
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
89+
defer cancel()
90+
91+
srv := tgtest.NewUnstartedServer(ctx, t, trp.Codec())
8792
h := handler{
8893
server: srv,
8994
t: t,
@@ -93,9 +98,6 @@ func testTransport(trp *transport.Transport) func(t *testing.T) {
9398
srv.Start()
9499
defer srv.Close()
95100

96-
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Minute))
97-
defer cancel()
98-
99101
dispatcher := tg.NewUpdateDispatcher()
100102
log, _ := zap.NewDevelopment(zap.IncreaseLevel(zapcore.DebugLevel))
101103
client := NewClient(1, "hash", Options{

telegram/internal/tgtest/server.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,20 @@ func (s *Server) Close() {
5050
_ = s.server.Close()
5151
}
5252

53-
func NewServer(tb TB, codec transport.Codec, h Handler) *Server {
54-
s := NewUnstartedServer(tb, codec)
53+
func NewServer(ctx context.Context, tb TB, codec transport.Codec, h Handler) *Server {
54+
s := NewUnstartedServer(ctx, tb, codec)
5555
s.SetHandler(h)
5656
s.Start()
5757
return s
5858
}
5959

60-
func NewUnstartedServer(tb TB, codec transport.Codec) *Server {
60+
func NewUnstartedServer(ctx context.Context, tb TB, codec transport.Codec) *Server {
6161
k, err := rsa.GenerateKey(rand.Reader, 2048)
6262
if err != nil {
6363
panic(err)
6464
}
6565

66-
ctx, cancel := context.WithCancel(context.Background())
66+
ctx, cancel := context.WithCancel(ctx)
6767
s := &Server{
6868
server: transport.NewCustomServer(codec, newLocalListener()),
6969
tb: tb,

0 commit comments

Comments
 (0)