Skip to content

Commit e7069e7

Browse files
committed
Make flush work when we have multithreading enabled
1 parent 99ee413 commit e7069e7

File tree

3 files changed

+53
-23
lines changed

3 files changed

+53
-23
lines changed

c_src/zstd_nif.c

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ static ERL_NIF_TERM zstd_nif_init_compression_stream(ErlNifEnv* env, int argc, c
113113
int level = ZSTD_CLEVEL_DEFAULT;
114114
int window_log = 0;
115115
int enable_long_distance_matching = 0;
116+
int num_workers = 0;
116117
size_t ret;
117118
ZSTD_CStream **pzcs;
118119

@@ -129,7 +130,11 @@ static ERL_NIF_TERM zstd_nif_init_compression_stream(ErlNifEnv* env, int argc, c
129130
return enif_make_badarg(env);
130131

131132
/* extract the enable long distance matching if any */
132-
if ((argc == 4) && !(enif_get_int(env, argv[3], &enable_long_distance_matching)))
133+
if ((argc >= 4) && !(enif_get_int(env, argv[3], &enable_long_distance_matching)))
134+
return enif_make_badarg(env);
135+
136+
/* extract the number of workers */
137+
if ((argc == 5) && !(enif_get_int(env, argv[4], &num_workers)))
133138
return enif_make_badarg(env);
134139

135140
/* initialize the stream */
@@ -141,6 +146,8 @@ static ERL_NIF_TERM zstd_nif_init_compression_stream(ErlNifEnv* env, int argc, c
141146
return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(ret), ERL_NIF_LATIN1));
142147
if (ZSTD_isError(ret = ZSTD_CCtx_setParameter(*pzcs, ZSTD_c_checksumFlag, 1)))
143148
return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(ret), ERL_NIF_LATIN1));
149+
if (ZSTD_isError(ret = ZSTD_CCtx_setParameter(*pzcs, ZSTD_c_nbWorkers, num_workers)))
150+
return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(ret), ERL_NIF_LATIN1));
144151

145152
/* stream initialization successful */
146153
return zstd_atom_ok;
@@ -200,7 +207,6 @@ static ERL_NIF_TERM zstd_nif_reset_decompression_stream(ErlNifEnv* env, int argc
200207
}
201208

202209
static ERL_NIF_TERM zstd_nif_flush_compression_stream(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
203-
size_t ret;
204210
ErlNifBinary bin;
205211
ZSTD_CStream **pzcs;
206212

@@ -212,27 +218,36 @@ static ERL_NIF_TERM zstd_nif_flush_compression_stream(ErlNifEnv* env, int argc,
212218
if (!(enif_alloc_binary(ZSTD_CStreamOutSize(), &bin)))
213219
return enif_make_tuple2(env, zstd_atom_error, zstd_atom_enomem);
214220

215-
/* output buffer */
216-
ZSTD_outBuffer outbuf = {
217-
.pos = 0,
218-
.dst = bin.data,
219-
.size = bin.size,
220-
};
221-
222-
/* reset the stream */
223-
if (ZSTD_isError(ret = ZSTD_endStream(*pzcs, &outbuf)))
224-
{
225-
enif_release_binary(&bin);
226-
return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(ret), ERL_NIF_LATIN1));
227-
}
221+
int finished;
222+
size_t offset = 0;
223+
ZSTD_outBuffer outbuf;
224+
do {
225+
/* output buffer */
226+
outbuf.pos = 0;
227+
outbuf.dst = bin.data + offset;
228+
outbuf.size = bin.size - offset;
229+
230+
/* ends the stream */
231+
size_t const remaining = ZSTD_endStream(*pzcs, &outbuf);
232+
if (ZSTD_isError(remaining))
233+
{
234+
enif_release_binary(&bin);
235+
return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(remaining), ERL_NIF_LATIN1));
236+
}
237+
finished = remaining == 0;
238+
if(!finished) {
239+
offset += ZSTD_CStreamOutSize();
240+
enif_realloc_binary(&bin, bin.size + ZSTD_CStreamOutSize());
241+
}
242+
} while (!finished);
228243

229244
/* transfer to binary object */
230245
ERL_NIF_TERM binary = enif_make_binary(env, &bin);
231246
ERL_NIF_TERM result = binary;
232247

233248
/* remove unused spaces */
234249
if (outbuf.pos < outbuf.size)
235-
result = enif_make_sub_binary(env, binary, 0, outbuf.pos);
250+
result = enif_make_sub_binary(env, binary, 0, bin.size - (outbuf.size - outbuf.pos));
236251

237252
/* construct the result tuple */
238253
return enif_make_tuple2(env, zstd_atom_ok, result);
@@ -436,6 +451,7 @@ static ErlNifFunc nif_funcs[] = {
436451
{ "compression_stream_init" , 2, zstd_nif_init_compression_stream , ERL_DIRTY_JOB_CPU_BOUND },
437452
{ "compression_stream_init" , 3, zstd_nif_init_compression_stream , ERL_DIRTY_JOB_CPU_BOUND },
438453
{ "compression_stream_init" , 4, zstd_nif_init_compression_stream , ERL_DIRTY_JOB_CPU_BOUND },
454+
{ "compression_stream_init" , 5, zstd_nif_init_compression_stream , ERL_DIRTY_JOB_CPU_BOUND },
439455
{ "decompression_stream_init" , 1, zstd_nif_init_decompression_stream , ERL_DIRTY_JOB_CPU_BOUND },
440456

441457
{ "compression_stream_reset" , 2, zstd_nif_reset_compression_stream },

src/zstd.erl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
-export([compress/1, compress/2]).
44
-export([decompress/1]).
55
-export([new_compression_stream/0, compression_stream_init/1, compression_stream_init/2,
6-
compression_stream_init/3, compression_stream_init/4, compression_stream_reset/1,
7-
compression_stream_reset/2, stream_compress/2, stream_flush/1, new_decompression_stream/0,
8-
decompression_stream_init/1, decompression_stream_reset/1, stream_decompress/2]).
6+
compression_stream_init/3, compression_stream_init/4, compression_stream_init/5,
7+
compression_stream_reset/1, compression_stream_reset/2, stream_compress/2, stream_flush/1,
8+
new_decompression_stream/0, decompression_stream_init/1, decompression_stream_reset/1,
9+
stream_decompress/2]).
910

1011
-on_load init/0.
1112

@@ -51,6 +52,15 @@ compression_stream_init(_Ref, _Level, _WindowLog) ->
5152
compression_stream_init(_Ref, _Level, _WindowLog, _EnableLongDistanceMatching) ->
5253
erlang:nif_error(?LINE).
5354

55+
-spec compression_stream_init(reference(), 0..22, integer(), integer(), integer()) ->
56+
ok | {error, invalid | string()}.
57+
compression_stream_init(_Ref,
58+
_Level,
59+
_WindowLog,
60+
_EnableLongDistanceMatching,
61+
_NumWorkers) ->
62+
erlang:nif_error(?LINE).
63+
5464
-spec decompression_stream_init(reference()) -> ok | {error, invalid | string()}.
5565
decompression_stream_init(_Ref) ->
5666
erlang:nif_error(?LINE).

test/zstd_tests.erl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
-define(COMPRESSION_LEVEL, 5).
44
-define(WINDOW_LOG, 23).
55
-define(ENABLE_LONG_DISTANCE_MATCHING, 1).
6+
-define(NUM_WORKERS, 8).
67

78
-include_lib("eunit/include/eunit.hrl").
89

@@ -19,12 +20,15 @@ zstd_stream_test() ->
1920
zstd:compression_stream_init(CStream,
2021
?COMPRESSION_LEVEL,
2122
?WINDOW_LOG,
22-
?ENABLE_LONG_DISTANCE_MATCHING),
23+
?ENABLE_LONG_DISTANCE_MATCHING,
24+
?NUM_WORKERS),
2325
{ok, CompressionBin} = zstd:stream_compress(CStream, Bin),
24-
{ok, FlushBin} = zstd:stream_flush(CStream),
26+
{ok, LastBin} = zstd:stream_flush(CStream),
2527

2628
DStream = zstd:new_decompression_stream(),
2729
ok = zstd:decompression_stream_init(DStream),
2830
{ok, DBin1} = zstd:stream_decompress(DStream, CompressionBin),
29-
{ok, DBin2} = zstd:stream_decompress(DStream, FlushBin),
30-
?assertEqual(Bin, <<DBin1/binary, DBin2/binary>>).
31+
{ok, DBin2} = zstd:stream_decompress(DStream, LastBin),
32+
DecompressBin = <<DBin1/binary, DBin2/binary>>,
33+
?assertEqual(size(Bin), size(DecompressBin)),
34+
?assertEqual(Bin, DecompressBin).

0 commit comments

Comments
 (0)