17
17
package zstd
18
18
19
19
import (
20
- "bytes"
21
20
"errors"
22
21
"io"
23
22
"runtime"
@@ -48,21 +47,22 @@ type decoderWrapper struct {
48
47
* zstd.Decoder
49
48
}
50
49
50
+ type encoderWrapper struct {
51
+ * zstd.Encoder
52
+ pool * sync.Pool
53
+ }
54
+
51
55
type compressor struct {
52
- encoder * zstd. Encoder
53
- decoderPool sync.Pool // To hold *zstd.Decoder's.
56
+ encoderPool sync. Pool
57
+ decoderPool sync.Pool
54
58
}
55
59
56
60
func PretendInit (clobbering bool ) {
57
61
if ! clobbering && encoding .GetCompressor (Name ) != nil {
58
62
return
59
63
}
60
64
61
- enc , _ := zstd .NewWriter (nil , encoderOptions ... )
62
- c := & compressor {
63
- encoder : enc ,
64
- }
65
- encoding .RegisterCompressor (c )
65
+ encoding .RegisterCompressor (& compressor {})
66
66
}
67
67
68
68
var ErrNotInUse = errors .New ("SetLevel ineffective because another zstd compressor has been registered" )
@@ -71,40 +71,42 @@ var ErrNotInUse = errors.New("SetLevel ineffective because another zstd compress
71
71
// level. NOTE: this function must only be called from an init function, and
72
72
// is not threadsafe.
73
73
func SetLevel (level zstd.EncoderLevel ) error {
74
- c , ok := encoding .GetCompressor (Name ).(* compressor )
74
+ _ , ok := encoding .GetCompressor (Name ).(* compressor )
75
75
if ! ok {
76
76
return ErrNotInUse
77
77
}
78
78
79
- enc , err := zstd .NewWriter (nil , zstd .WithEncoderLevel (level ))
80
- if err != nil {
81
- return err
82
- }
83
-
84
- c .encoder = enc
79
+ encoderOptions = append (encoderOptions , zstd .WithEncoderLevel (level ))
85
80
return nil
86
81
}
87
82
88
83
func (c * compressor ) Compress (w io.Writer ) (io.WriteCloser , error ) {
89
- return & zstdWriteCloser {
90
- enc : c .encoder ,
91
- writer : w ,
92
- }, nil
93
- }
84
+ var err error
85
+ var found bool
86
+ var encoder * zstd.Encoder
94
87
95
- type zstdWriteCloser struct {
96
- enc * zstd.Encoder
97
- writer io.Writer // Compressed data will be written here.
98
- buf bytes.Buffer // Buffer uncompressed data here, compress on Close.
99
- }
88
+ encoder , found = c .encoderPool .Get ().(* zstd.Encoder )
89
+ if ! found {
90
+ encoder , err = zstd .NewWriter (w , encoderOptions ... )
91
+ if err != nil {
92
+ return nil , err
93
+ }
94
+ } else {
95
+ encoder .Reset (w )
96
+ }
97
+
98
+ wrapper := & encoderWrapper {Encoder : encoder , pool : & c .encoderPool }
99
+ runtime .SetFinalizer (wrapper , func (ew * encoderWrapper ) {
100
+ ew .Reset (nil )
101
+ c .encoderPool .Put (ew .Encoder )
102
+ })
100
103
101
- func (z * zstdWriteCloser ) Write (p []byte ) (int , error ) {
102
- return z .buf .Write (p )
104
+ return wrapper , nil
103
105
}
104
106
105
- func (z * zstdWriteCloser ) Close () error {
106
- compressed := z . enc . EncodeAll ( z . buf . Bytes (), nil )
107
- _ , err := io . Copy ( z . writer , bytes . NewReader ( compressed ) )
107
+ func (w * encoderWrapper ) Close () error {
108
+ err := w . Encoder . Close ( )
109
+ w . pool . Put ( w . Encoder )
108
110
return err
109
111
}
110
112
0 commit comments