Skip to content

Commit 875207b

Browse files
authored
Avoid deep copy on lz4 decompression (#7437)
Speed up deserialization when a. lz4 is installed, and b. the buffer is compressible, and c. the buffer is smaller than 64 MiB (distributed.comm.shard) Note that the default chunk size in dask.array is 128 MiB. Note that this does not prevent a memory flare, as there's an unnecessary deep copy upstream as well: https://github.com/python-lz4/python-lz4/blob/79370987909663d4e6ef743762768ebf970a2383/lz4/block/_block.c#L256
1 parent f3995b5 commit 875207b

File tree

5 files changed

+80
-21
lines changed

5 files changed

+80
-21
lines changed

Diff for: distributed/protocol/compression.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
from collections.abc import Callable
1010
from contextlib import suppress
11+
from functools import partial
1112
from random import randint
1213
from typing import Literal
1314

@@ -64,12 +65,16 @@
6465
if parse_version(lz4.__version__) < parse_version("0.23.1"):
6566
raise ImportError("Need lz4 >= 0.23.1")
6667

67-
from lz4.block import compress as lz4_compress
68-
from lz4.block import decompress as lz4_decompress
68+
import lz4.block
6969

7070
compressions["lz4"] = {
71-
"compress": lz4_compress,
72-
"decompress": lz4_decompress,
71+
"compress": lz4.block.compress,
72+
# Avoid expensive deep copies when deserializing writeable numpy arrays
73+
# See distributed.protocol.numpy.deserialize_numpy_ndarray
74+
# Note that this is only useful for buffers smaller than distributed.comm.shard;
75+
# larger ones are deep-copied between decompression and serialization anyway in
76+
# order to merge them.
77+
"decompress": partial(lz4.block.decompress, return_bytearray=True),
7378
}
7479
default_compression = "lz4"
7580

Diff for: distributed/protocol/numpy.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,12 @@ def deserialize_numpy_ndarray(header, frames):
138138
# This should exclusively happen when the underlying buffer is read-only, e.g.
139139
# a read-only mmap.mmap or a bytes object.
140140
# Specifically, these are the known use cases:
141-
# 1. decompressed output of a buffer that was not sharded
141+
# 1. decompression with a library that does not support output to bytearray
142+
# (lz4 does; snappy, zlib, and zstd don't).
143+
# Note that this only applies to buffers whose uncompressed size was small
144+
# enough that they weren't sharded (distributed.comm.shard); for larger
145+
# buffers the decompressed output is deep-copied beforehand into a bytearray
146+
# in order to merge it.
142147
# 2. unspill with zict <2.3.0 (https://github.com/dask/zict/pull/74)
143148
x = np.require(x, requirements=["W"])
144149

Diff for: distributed/protocol/tests/test_numpy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
serialize,
1818
to_serialize,
1919
)
20-
from distributed.protocol.compression import maybe_compress
20+
from distributed.protocol.compression import default_compression, maybe_compress
2121
from distributed.protocol.numpy import itemsize
2222
from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE
2323
from distributed.system import MEMORY_LIMIT
@@ -216,8 +216,8 @@ def test_itemsize(dt, size):
216216
assert itemsize(np.dtype(dt)) == size
217217

218218

219+
@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
219220
def test_compress_numpy():
220-
pytest.importorskip("lz4")
221221
x = np.ones(10000000, dtype="i4")
222222
frames = dumps({"x": to_serialize(x)})
223223
assert sum(map(nbytes, frames)) < x.nbytes

Diff for: distributed/protocol/tests/test_protocol.py

+61-13
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,25 @@ def test_protocol():
3131
assert loads(dumps(msg)) == msg
3232

3333

34+
def test_default_compression():
35+
"""Test that the default compression algorithm is lz4 -> snappy -> None.
36+
If neither is installed, test that we don't fall back to the very slow zlib.
37+
"""
38+
try:
39+
import lz4 # noqa: F401
40+
41+
assert default_compression == "lz4"
42+
return
43+
except ImportError:
44+
pass
45+
try:
46+
import snappy # noqa: F401
47+
48+
assert default_compression == "snappy"
49+
except ImportError:
50+
assert default_compression is None
51+
52+
3453
@pytest.mark.parametrize(
3554
"config,default",
3655
[
@@ -49,8 +68,8 @@ def test_compression_config(config, default):
4968
assert get_default_compression() == default
5069

5170

71+
@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
5272
def test_compression_1():
53-
pytest.importorskip("lz4")
5473
np = pytest.importorskip("numpy")
5574
x = np.ones(1000000)
5675
b = x.tobytes()
@@ -60,17 +79,17 @@ def test_compression_1():
6079
assert {"x": b} == y
6180

6281

82+
@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
6383
def test_compression_2():
64-
pytest.importorskip("lz4")
6584
np = pytest.importorskip("numpy")
6685
x = np.random.random(10000)
6786
msg = dumps(to_serialize(x.data))
6887
compression = msgpack.loads(msg[1]).get("compression")
6988
assert all(c is None for c in compression)
7089

7190

91+
@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
7292
def test_compression_3():
73-
pytest.importorskip("lz4")
7493
np = pytest.importorskip("numpy")
7594
x = np.ones(1000000)
7695
frames = dumps({"x": Serialize(x.data)})
@@ -79,8 +98,8 @@ def test_compression_3():
7998
assert {"x": x.data} == y
8099

81100

101+
@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
82102
def test_compression_without_deserialization():
83-
pytest.importorskip("lz4")
84103
np = pytest.importorskip("numpy")
85104
x = np.ones(1000000)
86105

@@ -91,6 +110,18 @@ def test_compression_without_deserialization():
91110
assert all(len(frame) < 1000000 for frame in msg["x"].frames)
92111

93112

113+
def test_lz4_decompression_avoids_deep_copy():
114+
"""Test that lz4 output is a bytearray, not bytes, so that numpy deserialization is
115+
not forced to perform a deep copy to obtain a writeable array.
116+
Note that zlib, zstandard, and snappy don't have this option.
117+
"""
118+
pytest.importorskip("lz4")
119+
a = bytearray(1_000_000)
120+
b = compressions["lz4"]["compress"](a)
121+
c = compressions["lz4"]["decompress"](b)
122+
assert isinstance(c, bytearray)
123+
124+
94125
def test_small():
95126
assert sum(map(nbytes, dumps(b""))) < 10
96127
assert sum(map(nbytes, dumps(1))) < 10
@@ -106,7 +137,13 @@ def test_small_and_big():
106137

107138
@pytest.mark.parametrize(
108139
"lib,compression",
109-
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
140+
[
141+
(None, None),
142+
("zlib", "zlib"),
143+
("lz4", "lz4"),
144+
("snappy", "snappy"),
145+
("zstandard", "zstd"),
146+
],
110147
)
111148
def test_maybe_compress(lib, compression):
112149
if lib:
@@ -126,7 +163,13 @@ def test_maybe_compress(lib, compression):
126163

127164
@pytest.mark.parametrize(
128165
"lib,compression",
129-
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
166+
[
167+
(None, None),
168+
("zlib", "zlib"),
169+
("lz4", "lz4"),
170+
("snappy", "snappy"),
171+
("zstandard", "zstd"),
172+
],
130173
)
131174
def test_compression_thread_safety(lib, compression):
132175
if lib:
@@ -164,7 +207,13 @@ def test_compress_decompress(fn):
164207

165208
@pytest.mark.parametrize(
166209
"lib,compression",
167-
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
210+
[
211+
(None, None),
212+
("zlib", "zlib"),
213+
("lz4", "lz4"),
214+
("snappy", "snappy"),
215+
("zstandard", "zstd"),
216+
],
168217
)
169218
def test_maybe_compress_config_default(lib, compression):
170219
if lib:
@@ -183,9 +232,9 @@ def test_maybe_compress_config_default(lib, compression):
183232
assert compressions[rc]["decompress"](rd) == payload
184233

185234

235+
@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
186236
def test_maybe_compress_sample():
187237
np = pytest.importorskip("numpy")
188-
lz4 = pytest.importorskip("lz4")
189238
payload = np.random.randint(0, 255, size=10000).astype("u1").tobytes()
190239
fmt, compressed = maybe_compress(payload)
191240
assert fmt is None
@@ -202,10 +251,9 @@ def test_large_bytes():
202251
assert len(frames[1]) < 1000
203252

204253

205-
@pytest.mark.slow
254+
@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
206255
def test_large_messages():
207256
np = pytest.importorskip("numpy")
208-
pytest.importorskip("lz4")
209257
if MEMORY_LIMIT < 8e9:
210258
pytest.skip("insufficient memory")
211259

@@ -248,8 +296,8 @@ def test_loads_deserialize_False():
248296
assert result == 123
249297

250298

299+
@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
251300
def test_loads_without_deserialization_avoids_compression():
252-
pytest.importorskip("lz4")
253301
b = b"0" * 100000
254302

255303
msg = {"x": 1, "data": to_serialize(b)}
@@ -311,12 +359,12 @@ def test_dumps_loads_Serialized():
311359
assert result == result3
312360

313361

362+
@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
314363
def test_maybe_compress_memoryviews():
315364
np = pytest.importorskip("numpy")
316-
pytest.importorskip("lz4")
317365
x = np.arange(1000000, dtype="int64")
318366
compression, payload = maybe_compress(x.data)
319-
assert compression == "lz4"
367+
assert compression == default_compression
320368
assert len(payload) < x.nbytes * 0.75
321369

322370

Diff for: distributed/protocol/tests/test_serialize.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
serialize_bytes,
3434
to_serialize,
3535
)
36+
from distributed.protocol.compression import default_compression
3637
from distributed.protocol.serialize import check_dask_serializable
3738
from distributed.utils import ensure_memoryview, nbytes
3839
from distributed.utils_test import NO_AMM, gen_test, inc
@@ -279,9 +280,9 @@ def test_serialize_bytes(kwargs):
279280
assert str(x) == str(y)
280281

281282

283+
@pytest.mark.skipif(default_compression is None, reason="requires lz4 or snappy")
282284
@pytest.mark.skipif(np is None, reason="Test needs numpy")
283285
def test_serialize_list_compress():
284-
pytest.importorskip("lz4")
285286
x = np.ones(1000000)
286287
L = serialize_bytelist(x)
287288
assert sum(map(nbytes, L)) < x.nbytes / 2

0 commit comments

Comments
 (0)