Skip to content

Commit ce6ebe9

Browse files
committed
Avoid using zstd.Encoder.EncodeAll, to reduce memory usage
Followup to #25.
1 parent 9adaaed commit ce6ebe9

File tree

1 file changed

+32
-30
lines changed

1 file changed

+32
-30
lines changed

internal/zstd/zstd.go

+32-30
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package zstd
1818

1919
import (
20-
"bytes"
2120
"errors"
2221
"io"
2322
"runtime"
@@ -48,21 +47,22 @@ type decoderWrapper struct {
4847
*zstd.Decoder
4948
}
5049

50+
type encoderWrapper struct {
51+
*zstd.Encoder
52+
pool *sync.Pool
53+
}
54+
5155
type compressor struct {
52-
encoder *zstd.Encoder
53-
decoderPool sync.Pool // To hold *zstd.Decoder's.
56+
encoderPool sync.Pool
57+
decoderPool sync.Pool
5458
}
5559

5660
func PretendInit(clobbering bool) {
5761
if !clobbering && encoding.GetCompressor(Name) != nil {
5862
return
5963
}
6064

61-
enc, _ := zstd.NewWriter(nil, encoderOptions...)
62-
c := &compressor{
63-
encoder: enc,
64-
}
65-
encoding.RegisterCompressor(c)
65+
encoding.RegisterCompressor(&compressor{})
6666
}
6767

6868
var ErrNotInUse = errors.New("SetLevel ineffective because another zstd compressor has been registered")
@@ -71,40 +71,42 @@ var ErrNotInUse = errors.New("SetLevel ineffective because another zstd compress
7171
// level. NOTE: this function must only be called from an init function, and
7272
// is not threadsafe.
7373
func SetLevel(level zstd.EncoderLevel) error {
74-
c, ok := encoding.GetCompressor(Name).(*compressor)
74+
_, ok := encoding.GetCompressor(Name).(*compressor)
7575
if !ok {
7676
return ErrNotInUse
7777
}
7878

79-
enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
80-
if err != nil {
81-
return err
82-
}
83-
84-
c.encoder = enc
79+
encoderOptions = append(encoderOptions, zstd.WithEncoderLevel(level))
8580
return nil
8681
}
8782

8883
func (c *compressor) Compress(w io.Writer) (io.WriteCloser, error) {
89-
return &zstdWriteCloser{
90-
enc: c.encoder,
91-
writer: w,
92-
}, nil
93-
}
84+
var err error
85+
var found bool
86+
var encoder *zstd.Encoder
9487

95-
type zstdWriteCloser struct {
96-
enc *zstd.Encoder
97-
writer io.Writer // Compressed data will be written here.
98-
buf bytes.Buffer // Buffer uncompressed data here, compress on Close.
99-
}
88+
encoder, found = c.encoderPool.Get().(*zstd.Encoder)
89+
if !found {
90+
encoder, err = zstd.NewWriter(w, encoderOptions...)
91+
if err != nil {
92+
return nil, err
93+
}
94+
} else {
95+
encoder.Reset(w)
96+
}
97+
98+
wrapper := &encoderWrapper{Encoder: encoder, pool: &c.encoderPool}
99+
runtime.SetFinalizer(wrapper, func(ew *encoderWrapper) {
100+
ew.Reset(nil)
101+
c.encoderPool.Put(ew.Encoder)
102+
})
100103

101-
func (z *zstdWriteCloser) Write(p []byte) (int, error) {
102-
return z.buf.Write(p)
104+
return wrapper, nil
103105
}
104106

105-
func (z *zstdWriteCloser) Close() error {
106-
compressed := z.enc.EncodeAll(z.buf.Bytes(), nil)
107-
_, err := io.Copy(z.writer, bytes.NewReader(compressed))
107+
func (w *encoderWrapper) Close() error {
108+
err := w.Encoder.Close()
109+
w.pool.Put(w.Encoder)
108110
return err
109111
}
110112

0 commit comments

Comments
 (0)