diff --git a/c_src/zstd_nif.c b/c_src/zstd_nif.c index 6997c61..7efaf03 100644 --- a/c_src/zstd_nif.c +++ b/c_src/zstd_nif.c @@ -111,6 +111,7 @@ static ERL_NIF_TERM zstd_nif_new_decompression_stream(ErlNifEnv* env, int argc, static ERL_NIF_TERM zstd_nif_init_compression_stream(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { int level = ZSTD_CLEVEL_DEFAULT; + int window_log = 0; size_t ret; ZSTD_CStream **pzcs; @@ -119,12 +120,20 @@ static ERL_NIF_TERM zstd_nif_init_compression_stream(ErlNifEnv* env, int argc, c return enif_make_tuple2(env, zstd_atom_error, zstd_atom_invalid); /* extract the compression level if any */ - if ((argc == 2) && !(enif_get_int(env, argv[1], &level))) + if ((argc >= 2) && !(enif_get_int(env, argv[1], &level))) + return enif_make_badarg(env); + + /* extract the window log if any */ + if ((argc == 3) && !(enif_get_int(env, argv[2], &window_log))) return enif_make_badarg(env); /* initialize the stream */ if (ZSTD_isError(ret = ZSTD_initCStream(*pzcs, level))) return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(ret), ERL_NIF_LATIN1)); + if (ZSTD_isError(ret = ZSTD_CCtx_setParameter(*pzcs, ZSTD_c_windowLog, window_log))) + return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(ret), ERL_NIF_LATIN1)); + if (ZSTD_isError(ret = ZSTD_CCtx_setParameter(*pzcs, ZSTD_c_checksumFlag, 1))) + return enif_make_tuple2(env, zstd_atom_error, enif_make_string(env, ZSTD_getErrorName(ret), ERL_NIF_LATIN1)); /* stream initialization successful */ return zstd_atom_ok; @@ -413,6 +422,7 @@ static ErlNifFunc nif_funcs[] = { { "compression_stream_init" , 1, zstd_nif_init_compression_stream , ERL_DIRTY_JOB_CPU_BOUND }, { "compression_stream_init" , 2, zstd_nif_init_compression_stream , ERL_DIRTY_JOB_CPU_BOUND }, + { "compression_stream_init" , 3, zstd_nif_init_compression_stream , ERL_DIRTY_JOB_CPU_BOUND }, { "decompression_stream_init" , 1, zstd_nif_init_decompression_stream , ERL_DIRTY_JOB_CPU_BOUND }, { "compression_stream_reset" , 2, zstd_nif_reset_compression_stream }, diff --git a/src/zstd.erl b/src/zstd.erl index 1080bec..b8f924b 100644 --- a/src/zstd.erl +++ b/src/zstd.erl @@ -2,10 +2,10 @@ -export([compress/1, compress/2]). -export([decompress/1]). --export([new_compression_stream/0, new_decompression_stream/0, compression_stream_init/1, - compression_stream_init/2, decompression_stream_init/1, compression_stream_reset/2, - compression_stream_reset/1, decompression_stream_reset/1, stream_flush/1, - stream_compress/2, stream_decompress/2]). +-export([new_compression_stream/0, compression_stream_init/1, compression_stream_init/2, + compression_stream_init/3, compression_stream_reset/1, compression_stream_reset/2, + stream_compress/2, stream_flush/1, new_decompression_stream/0, + decompression_stream_init/1, decompression_stream_reset/1, stream_decompress/2]). -on_load init/0. @@ -41,6 +41,11 @@ compression_stream_init(_Ref) -> compression_stream_init(_Ref, _Level) -> erlang:nif_error(?LINE). +-spec compression_stream_init(reference(), 0..22, integer()) -> + ok | {error, invalid | string()}. +compression_stream_init(_Ref, _Level, _WindowLog) -> + erlang:nif_error(?LINE). + -spec decompression_stream_init(reference()) -> ok | {error, invalid | string()}. decompression_stream_init(_Ref) -> erlang:nif_error(?LINE). diff --git a/test/zstd_tests.erl b/test/zstd_tests.erl index a67ea1b..d413b02 100644 --- a/test/zstd_tests.erl +++ b/test/zstd_tests.erl @@ -11,7 +11,7 @@ zstd_test() -> zstd_stream_test() -> Bin = << <<"A">> || _ <- lists:seq(1, 1024 * 1024) >>, CStream = zstd:new_compression_stream(), - ok = zstd:compression_stream_init(CStream), + ok = zstd:compression_stream_init(CStream, 5, 23), {ok, CompressionBin} = zstd:stream_compress(CStream, Bin), {ok, FlushBin} = zstd:stream_flush(CStream),