diff --git a/c_src/zstd_nif.c b/c_src/zstd_nif.c index 894395e..c6ce698 100644 --- a/c_src/zstd_nif.c +++ b/c_src/zstd_nif.c @@ -113,11 +113,12 @@ static ERL_NIF_TERM zstd_nif_init_compression_stream(ErlNifEnv* env, int argc, c int level = ZSTD_CLEVEL_DEFAULT; int window_log = 0; int enable_long_distance_matching = 0; + int num_workers = 0; size_t ret; ZSTD_CStream **pzcs; /* extract the stream */ - if (!(enif_get_resource(env, argv[0], zstd_compression_stream_type, (void **)&pzcs))) + if (!(enif_get_resource(env, argv[0], zstd_compression_stream_type, (void **)&pzcs)) && *pzcs != NULL) return enif_make_tuple2(env, zstd_atom_error, zstd_atom_invalid); /* extract the compression level if any */ @@ -129,7 +130,11 @@ static ERL_NIF_TERM zstd_nif_init_compression_stream(ErlNifEnv* env, int argc, c return enif_make_badarg(env); /* extract the enable long distance matching if any */ - if ((argc == 4) && !(enif_get_int(env, argv[3], &enable_long_distance_matching))) + if ((argc >= 4) && !(enif_get_int(env, argv[3], &enable_long_distance_matching))) + return enif_make_badarg(env); + + /* extract the number of workers */ + if ((argc == 5) && !(enif_get_int(env, argv[4], &num_workers))) return enif_make_badarg(env); /* initialize the stream */ @@ -141,6 +146,8 @@ static ERL_NIF_TERM zstd_nif_init_compression_stream(ErlNifEnv* env, int argc, c return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(ret), ERL_NIF_LATIN1)); if (ZSTD_isError(ret = ZSTD_CCtx_setParameter(*pzcs, ZSTD_c_checksumFlag, 1))) return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(ret), ERL_NIF_LATIN1)); + if (ZSTD_isError(ret = ZSTD_CCtx_setParameter(*pzcs, ZSTD_c_nbWorkers, num_workers))) + return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(ret), ERL_NIF_LATIN1)); /* stream initialization successful */ return zstd_atom_ok; @@ -200,31 +207,42 @@ static ERL_NIF_TERM zstd_nif_reset_decompression_stream(ErlNifEnv* env, int argc } static ERL_NIF_TERM zstd_nif_flush_compression_stream(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - size_t ret; ErlNifBinary bin; ZSTD_CStream **pzcs; /* extract the stream */ - if (!(enif_get_resource(env, argv[0], zstd_compression_stream_type, (void **)&pzcs))) + if (!(enif_get_resource(env, argv[0], zstd_compression_stream_type, (void **)&pzcs)) && *pzcs != NULL) return enif_make_tuple2(env, zstd_atom_error, zstd_atom_invalid); /* allocate binary buffer */ if (!(enif_alloc_binary(ZSTD_CStreamOutSize(), &bin))) return enif_make_tuple2(env, zstd_atom_error, zstd_atom_enomem); - /* output buffer */ - ZSTD_outBuffer outbuf = { - .pos = 0, - .dst = bin.data, - .size = bin.size, - }; - - /* reset the stream */ - if (ZSTD_isError(ret = ZSTD_endStream(*pzcs, &outbuf))) - { - enif_release_binary(&bin); - return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(ret), ERL_NIF_LATIN1)); - } + int finished; + size_t offset = 0; + ZSTD_outBuffer outbuf; + do { + /* output buffer */ + outbuf.pos = 0; + outbuf.dst = bin.data + offset; + outbuf.size = bin.size - offset; + + /* ends the stream */ + size_t const remaining = ZSTD_endStream(*pzcs, &outbuf); + if (ZSTD_isError(remaining)) + { + enif_release_binary(&bin); + return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(remaining), ERL_NIF_LATIN1)); + } + finished = remaining == 0; + if(!finished) { + offset += ZSTD_CStreamOutSize(); + if(!enif_realloc_binary(&bin, bin.size + ZSTD_CStreamOutSize())) { + enif_release_binary(&bin); + return enif_make_tuple2(env, zstd_atom_error, zstd_atom_enomem); + } + } + } while (!finished); /* transfer to binary object */ ERL_NIF_TERM binary = enif_make_binary(env, &bin); @@ -232,7 +250,7 @@ static ERL_NIF_TERM zstd_nif_flush_compression_stream(ErlNifEnv* env, int argc, /* remove unused spaces */ if (outbuf.pos < outbuf.size) - result = enif_make_sub_binary(env, binary, 0, outbuf.pos); + result = enif_make_sub_binary(env, binary, 0, bin.size - (outbuf.size - outbuf.pos)); /* construct the result tuple */ return enif_make_tuple2(env, zstd_atom_ok, result); @@ -245,12 +263,17 @@ static ERL_NIF_TERM zstd_nif_compress_stream(ErlNifEnv* env, int argc, const ERL ZSTD_CStream **pzcs; /* extract the stream */ - if (!(enif_get_resource(env, argv[0], zstd_compression_stream_type, (void **)&pzcs)) || + if ((!(enif_get_resource(env, argv[0], zstd_compression_stream_type, (void **)&pzcs)) && (*pzcs) != NULL) || !(enif_inspect_iolist_as_binary(env, argv[1], &in))) return enif_make_tuple2(env, zstd_atom_error, zstd_atom_invalid); /* all output binary buffer */ - if (!(enif_alloc_binary(ZSTD_compressBound(in.size), &out))) { + size_t buffer_size = ZSTD_compressBound(in.size); + if(ZSTD_isError(buffer_size)) { + enif_release_binary(&in); + return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(buffer_size), ERL_NIF_LATIN1)); + } + if (!(enif_alloc_binary(buffer_size, &out))) { enif_release_binary(&in); return enif_make_tuple2(env, zstd_atom_error, zstd_atom_enomem); } @@ -363,11 +386,13 @@ static ERL_NIF_TERM zstd_nif_decompress_stream(ErlNifEnv* env, int argc, const E static void zstd_compression_stream_destructor(ErlNifEnv *env, void *stream) { ZSTD_CStream **handle = stream; ZSTD_freeCStream(*handle); + *handle = NULL; } static void zstd_decompression_stream_destructor(ErlNifEnv *env, void *stream) { ZSTD_DStream **handle = stream; ZSTD_freeDStream(*handle); + *handle = NULL; } static int zstd_init(ErlNifEnv *env) { @@ -431,6 +456,7 @@ static ErlNifFunc nif_funcs[] = { { "compression_stream_init" , 2, zstd_nif_init_compression_stream , ERL_DIRTY_JOB_CPU_BOUND }, { "compression_stream_init" , 3, zstd_nif_init_compression_stream , ERL_DIRTY_JOB_CPU_BOUND }, { "compression_stream_init" , 4, zstd_nif_init_compression_stream , ERL_DIRTY_JOB_CPU_BOUND }, + { "compression_stream_init" , 5, zstd_nif_init_compression_stream , ERL_DIRTY_JOB_CPU_BOUND }, { "decompression_stream_init" , 1, zstd_nif_init_decompression_stream , ERL_DIRTY_JOB_CPU_BOUND }, { "compression_stream_reset" , 2, zstd_nif_reset_compression_stream }, diff --git a/rebar.config b/rebar.config index 4f3cef4..1a0ed8f 100644 --- a/rebar.config +++ b/rebar.config @@ -6,9 +6,11 @@ {deps, []}. {pre_hooks, - [{"(linux|darwin|solaris)", compile, "make MOREFLAGS=-fPIC -C priv/zstd/lib libzstd.a"}, + [{"(linux|darwin|solaris)", + compile, + "make MOREFLAGS=-fPIC -C priv/zstd/lib libzstd.a-mt"}, {"(linux|darwin|solaris)", compile, "make -C c_src"}, - {"(freebsd)", compile, "gmake MOREFLAGS=-fPIC -C priv/zstd/lib libzstd.a"}, + {"(freebsd)", compile, "gmake MOREFLAGS=-fPIC -C priv/zstd/lib libzstd.a-mt"}, {"(freebsd)", compile, "gmake -C c_src"}]}. {post_hooks, diff --git a/src/zstd.erl b/src/zstd.erl index 44e968f..ee31ba9 100644 --- a/src/zstd.erl +++ b/src/zstd.erl @@ -3,9 +3,10 @@ -export([compress/1, compress/2]). -export([decompress/1]). -export([new_compression_stream/0, compression_stream_init/1, compression_stream_init/2, - compression_stream_init/3, compression_stream_init/4, compression_stream_reset/1, - compression_stream_reset/2, stream_compress/2, stream_flush/1, new_decompression_stream/0, - decompression_stream_init/1, decompression_stream_reset/1, stream_decompress/2]). + compression_stream_init/3, compression_stream_init/4, compression_stream_init/5, + compression_stream_reset/1, compression_stream_reset/2, stream_compress/2, stream_flush/1, + new_decompression_stream/0, decompression_stream_init/1, decompression_stream_reset/1, + stream_decompress/2]). -on_load init/0. @@ -51,6 +52,15 @@ compression_stream_init(_Ref, _Level, _WindowLog) -> compression_stream_init(_Ref, _Level, _WindowLog, _EnableLongDistanceMatching) -> erlang:nif_error(?LINE). +-spec compression_stream_init(reference(), 0..22, integer(), integer(), integer()) -> + ok | {error, invalid | string()}. +compression_stream_init(_Ref, + _Level, + _WindowLog, + _EnableLongDistanceMatching, + _NumWorkers) -> + erlang:nif_error(?LINE). + -spec decompression_stream_init(reference()) -> ok | {error, invalid | string()}. decompression_stream_init(_Ref) -> erlang:nif_error(?LINE). diff --git a/test/zstd_tests.erl b/test/zstd_tests.erl index 51b78b8..4620b7e 100644 --- a/test/zstd_tests.erl +++ b/test/zstd_tests.erl @@ -3,6 +3,7 @@ -define(COMPRESSION_LEVEL, 5). -define(WINDOW_LOG, 23). -define(ENABLE_LONG_DISTANCE_MATCHING, 1). +-define(NUM_WORKERS, 8). -include_lib("eunit/include/eunit.hrl"). @@ -13,18 +14,46 @@ zstd_test() -> zstd:compress(Data))). zstd_stream_test() -> - Bin = << <<"A">> || _ <- lists:seq(1, 1024 * 1024) >>, + Bin = rand:bytes(1000000), CStream = zstd:new_compression_stream(), ok = zstd:compression_stream_init(CStream, ?COMPRESSION_LEVEL, ?WINDOW_LOG, - ?ENABLE_LONG_DISTANCE_MATCHING), + ?ENABLE_LONG_DISTANCE_MATCHING, + ?NUM_WORKERS), {ok, CompressionBin} = zstd:stream_compress(CStream, Bin), - {ok, FlushBin} = zstd:stream_flush(CStream), + {ok, LastBin} = zstd:stream_flush(CStream), DStream = zstd:new_decompression_stream(), ok = zstd:decompression_stream_init(DStream), {ok, DBin1} = zstd:stream_decompress(DStream, CompressionBin), - {ok, DBin2} = zstd:stream_decompress(DStream, FlushBin), - ?assertEqual(Bin, <>). + {ok, DBin2} = zstd:stream_decompress(DStream, LastBin), + DecompressBin = <>, + ?assertEqual(size(Bin), size(DecompressBin)), + ?assertEqual(Bin, DecompressBin). + +zstd_parallel_test() -> + Bin = rand:bytes(1_000_000), + CStream = zstd:new_compression_stream(), + ok = + zstd:compression_stream_init(CStream, + ?COMPRESSION_LEVEL, + ?WINDOW_LOG, + ?ENABLE_LONG_DISTANCE_MATCHING, + ?NUM_WORKERS), + Parent = self(), + F = fun() -> + zstd:stream_compress(CStream, Bin), + Parent ! done + end, + lists:foreach(fun(_) -> spawn(F) end, lists:seq(1, 100)), + lists:foreach(fun(_) -> + receive + done -> + ok + end + end, + lists:seq(1, 100)), + {ok, _} = zstd:stream_flush(CStream), + ok.