Skip to content

Commit ba2c727

Browse files
committed
Gracefully handle missing/broken lz4/zstd codecs to improve portability
1 parent 2108c38 commit ba2c727

File tree

4 files changed

+123
-9
lines changed

4 files changed

+123
-9
lines changed

tests/unit/test_client.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from trino.client import _retry_with
5858
from trino.client import _RetryWithExponentialBackoff
5959
from trino.client import ClientSession
60+
from trino.client import CompressedQueryDataDecoderFactory
6061
from trino.client import TrinoQuery
6162
from trino.client import TrinoRequest
6263
from trino.client import TrinoResult
@@ -1450,3 +1451,51 @@ def delete_password(self, servicename, username):
14501451
return None
14511452

14521453
os.remove(file_path)
1454+
1455+
1456+
def test_trino_request_headers_encoding_default_behavior():
1457+
session = ClientSession(user="test", encoding=None)
1458+
1459+
# Case 1: Both available -> No header
1460+
with mock.patch("trino.client.CODECS_UNAVAILABLE", {}):
1461+
req = TrinoRequest("host", 8080, session)
1462+
assert constants.HEADER_ENCODING not in req.http_headers
1463+
1464+
# Case 2: Zstd missing -> Header set with json+lz4,json
1465+
with mock.patch("trino.client.CODECS_UNAVAILABLE", {"zstd": "Not installed"}):
1466+
req = TrinoRequest("host", 8080, session)
1467+
assert req.http_headers[constants.HEADER_ENCODING] == "json+lz4,json"
1468+
1469+
# Case 3: Lz4 missing -> Header set with json+zstd,json
1470+
with mock.patch("trino.client.CODECS_UNAVAILABLE", {"lz4": "Not installed"}):
1471+
req = TrinoRequest("host", 8080, session)
1472+
assert req.http_headers[constants.HEADER_ENCODING] == "json+zstd,json"
1473+
1474+
# Case 4: Both missing -> Header set with json
1475+
with mock.patch("trino.client.CODECS_UNAVAILABLE", {"lz4": "Not installed", "zstd": "Not installed"}):
1476+
req = TrinoRequest("host", 8080, session)
1477+
assert req.http_headers[constants.HEADER_ENCODING] == "json"
1478+
1479+
1480+
def test_decoder_factory_raises_with_message_on_missing_zstd():
1481+
mapper = mock.Mock()
1482+
factory = CompressedQueryDataDecoderFactory(mapper)
1483+
error_message = "No module named 'zstandard'"
1484+
with mock.patch("trino.client.CODECS_UNAVAILABLE", {"zstd": error_message}):
1485+
with pytest.raises(
1486+
ValueError,
1487+
match=f"zstd is not installed so json\\+zstd encoding is not supported: {error_message}"
1488+
):
1489+
factory.create("json+zstd")
1490+
1491+
1492+
def test_decoder_factory_raises_with_message_on_missing_lz4():
1493+
mapper = mock.Mock()
1494+
factory = CompressedQueryDataDecoderFactory(mapper)
1495+
error_message = "No module named 'lz4.block'"
1496+
with mock.patch("trino.client.CODECS_UNAVAILABLE", {"lz4": error_message}):
1497+
with pytest.raises(
1498+
ValueError,
1499+
match=f"lz4 is not installed so json\\+lz4 encoding is not supported: {error_message}"
1500+
):
1501+
factory.create("json+lz4")

tests/unit/test_dbapi.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,27 @@ def test_description_is_none_when_cursor_is_not_executed():
338338
def test_setting_http_scheme(host, port, http_scheme_input_argument, http_scheme_set):
339339
connection = Connection(host, port, http_scheme=http_scheme_input_argument)
340340
assert connection.http_scheme == http_scheme_set
341+
342+
343+
@patch("trino.client.CODECS_UNAVAILABLE", {"lz4": "Not installed", "zstd": "Not installed"})
344+
def test_default_encoding_no_compression():
345+
connection = Connection("host", 8080, user="test")
346+
assert connection._client_session.encoding == ["json"]
347+
348+
349+
@patch("trino.client.CODECS_UNAVAILABLE", {"zstd": "Not installed"})
350+
def test_default_encoding_lz4():
351+
connection = Connection("host", 8080, user="test")
352+
assert connection._client_session.encoding == ["json+lz4", "json"]
353+
354+
355+
@patch("trino.client.CODECS_UNAVAILABLE", {"lz4": "Not installed"})
356+
def test_default_encoding_zstd():
357+
connection = Connection("host", 8080, user="test")
358+
assert connection._client_session.encoding == ["json+zstd", "json"]
359+
360+
361+
@patch("trino.client.CODECS_UNAVAILABLE", {})
362+
def test_default_encoding_all():
363+
connection = Connection("host", 8080, user="test")
364+
assert connection._client_session.encoding == ["json+zstd", "json+lz4", "json"]

trino/client.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,31 @@
6464
from typing import Union
6565
from zoneinfo import ZoneInfo
6666

67-
import lz4.block
67+
try:
68+
import lz4.block
69+
except ImportError as err:
70+
_LZ4_ERROR = str(err)
71+
else:
72+
_LZ4_ERROR = None
73+
6874
try:
6975
import orjson as json
7076
except ImportError:
7177
import json
7278

7379
import requests
74-
import zstandard
7580
from requests import Response
7681
from requests import Session
7782
from requests.structures import CaseInsensitiveDict
7883

84+
try:
85+
import zstandard
86+
except ImportError as err:
87+
_ZSTD_ERROR = str(err)
88+
else:
89+
_ZSTD_ERROR = None
90+
91+
7992
import trino.logging
8093
from trino import constants
8194
from trino import exceptions
@@ -87,6 +100,7 @@
87100
from trino.mapper import RowMapper
88101
from trino.mapper import RowMapperFactory
89102

103+
90104
__all__ = [
91105
"ClientSession",
92106
"TrinoQuery",
@@ -117,6 +131,13 @@ def close_executor():
117131

118132
_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$')
119133

134+
ENCODINGS = ["json+zstd", "json+lz4", "json"]
135+
CODECS_UNAVAILABLE = {}
136+
if _LZ4_ERROR:
137+
CODECS_UNAVAILABLE["lz4"] = _LZ4_ERROR
138+
if _ZSTD_ERROR:
139+
CODECS_UNAVAILABLE["zstd"] = _ZSTD_ERROR
140+
120141
ROLE_PATTERN = re.compile(r"^ROLE\{(.*)\}$")
121142

122143

@@ -290,7 +311,7 @@ def timezone(self) -> str:
290311
return self._timezone
291312

292313
@property
293-
def encoding(self) -> Union[str, List[str]]:
314+
def encoding(self) -> Optional[Union[str, List[str]]]:
294315
with self._object_lock:
295316
return self._encoding
296317

@@ -524,7 +545,15 @@ def http_headers(self) -> CaseInsensitiveDict[str]:
524545
headers[constants.HEADER_USER] = self._client_session.user
525546
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
526547
if self._client_session.encoding is None:
527-
pass
548+
if not CODECS_UNAVAILABLE:
549+
pass
550+
else:
551+
encoding = [
552+
enc
553+
for enc in ENCODINGS
554+
if (enc.split("+")[1] if "+" in enc else None) not in CODECS_UNAVAILABLE
555+
]
556+
headers[constants.HEADER_ENCODING] = ",".join(encoding)
528557
elif isinstance(self._client_session.encoding, list):
529558
headers[constants.HEADER_ENCODING] = ",".join(self._client_session.encoding)
530559
elif isinstance(self._client_session.encoding, str):
@@ -1271,8 +1300,16 @@ def __init__(self, mapper: RowMapper) -> None:
12711300

12721301
def create(self, encoding: str) -> QueryDataDecoder:
12731302
if encoding == "json+zstd":
1303+
if "zstd" in CODECS_UNAVAILABLE:
1304+
raise ValueError(
1305+
f"zstd is not installed so json+zstd encoding is not supported: {CODECS_UNAVAILABLE['zstd']}"
1306+
)
12741307
return ZStdQueryDataDecoder(JsonQueryDataDecoder(self._mapper))
12751308
elif encoding == "json+lz4":
1309+
if "lz4" in CODECS_UNAVAILABLE:
1310+
raise ValueError(
1311+
f"lz4 is not installed so json+lz4 encoding is not supported: {CODECS_UNAVAILABLE['lz4']}"
1312+
)
12761313
return Lz4QueryDataDecoder(JsonQueryDataDecoder(self._mapper))
12771314
elif encoding == "json":
12781315
return JsonQueryDataDecoder(self._mapper)
@@ -1322,10 +1359,14 @@ def decode(self, data: bytes, metadata: _SegmentMetadataTO) -> List[List[Any]]:
13221359

13231360

13241361
class ZStdQueryDataDecoder(CompressedQueryDataDecoder):
1325-
zstd_decompressor = zstandard.ZstdDecompressor()
1362+
def __init__(self, delegate: QueryDataDecoder) -> None:
1363+
super().__init__(delegate)
1364+
self._decompressor = None
13261365

13271366
def decompress(self, data: bytes, metadata: _SegmentMetadataTO) -> bytes:
1328-
return ZStdQueryDataDecoder.zstd_decompressor.decompress(data)
1367+
if self._decompressor is None:
1368+
self._decompressor = zstandard.ZstdDecompressor()
1369+
return self._decompressor.decompress(data)
13291370

13301371

13311372
class Lz4QueryDataDecoder(CompressedQueryDataDecoder):

trino/dbapi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,9 @@ def __init__(
170170

171171
if encoding is _USE_DEFAULT_ENCODING:
172172
encoding = [
173-
"json+zstd",
174-
"json+lz4",
175-
"json",
173+
enc
174+
for enc in trino.client.ENCODINGS
175+
if (enc.split("+")[1] if "+" in enc else None) not in trino.client.CODECS_UNAVAILABLE
176176
]
177177

178178
self.host = host if parsed_host.hostname is None else parsed_host.hostname + parsed_host.path

0 commit comments

Comments
 (0)