|
64 | 64 | from typing import Union |
65 | 65 | from zoneinfo import ZoneInfo |
66 | 66 |
|
67 | | -import lz4.block |
| 67 | +try: |
| 68 | + import lz4.block |
| 69 | +except ImportError as err: |
| 70 | + _LZ4_ERROR = str(err) |
| 71 | +else: |
| 72 | + _LZ4_ERROR = None |
| 73 | + |
68 | 74 | try: |
69 | 75 | import orjson as json |
70 | 76 | except ImportError: |
71 | 77 | import json |
72 | 78 |
|
73 | 79 | import requests |
74 | | -import zstandard |
75 | 80 | from requests import Response |
76 | 81 | from requests import Session |
77 | 82 | from requests.structures import CaseInsensitiveDict |
78 | 83 |
|
| 84 | +try: |
| 85 | + import zstandard |
| 86 | +except ImportError as err: |
| 87 | + _ZSTD_ERROR = str(err) |
| 88 | +else: |
| 89 | + _ZSTD_ERROR = None |
| 90 | + |
| 91 | + |
79 | 92 | import trino.logging |
80 | 93 | from trino import constants |
81 | 94 | from trino import exceptions |
|
87 | 100 | from trino.mapper import RowMapper |
88 | 101 | from trino.mapper import RowMapperFactory |
89 | 102 |
|
| 103 | + |
90 | 104 | __all__ = [ |
91 | 105 | "ClientSession", |
92 | 106 | "TrinoQuery", |
@@ -117,6 +131,13 @@ def close_executor(): |
117 | 131 |
|
118 | 132 | _HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$') |
119 | 133 |
|
| 134 | +ENCODINGS = ["json+zstd", "json+lz4", "json"] |
| 135 | +CODECS_UNAVAILABLE = {} |
| 136 | +if _LZ4_ERROR: |
| 137 | + CODECS_UNAVAILABLE["lz4"] = _LZ4_ERROR |
| 138 | +if _ZSTD_ERROR: |
| 139 | + CODECS_UNAVAILABLE["zstd"] = _ZSTD_ERROR |
| 140 | + |
120 | 141 | ROLE_PATTERN = re.compile(r"^ROLE\{(.*)\}$") |
121 | 142 |
|
122 | 143 |
|
@@ -290,7 +311,7 @@ def timezone(self) -> str: |
290 | 311 | return self._timezone |
291 | 312 |
|
292 | 313 | @property |
293 | | - def encoding(self) -> Union[str, List[str]]: |
| 314 | + def encoding(self) -> Optional[Union[str, List[str]]]: |
294 | 315 | with self._object_lock: |
295 | 316 | return self._encoding |
296 | 317 |
|
@@ -524,7 +545,15 @@ def http_headers(self) -> CaseInsensitiveDict[str]: |
524 | 545 | headers[constants.HEADER_USER] = self._client_session.user |
525 | 546 | headers[constants.HEADER_TIMEZONE] = self._client_session.timezone |
526 | 547 | if self._client_session.encoding is None: |
527 | | - pass |
| 548 | + if not CODECS_UNAVAILABLE: |
| 549 | + pass |
| 550 | + else: |
| 551 | + encoding = [ |
| 552 | + enc |
| 553 | + for enc in ENCODINGS |
| 554 | + if (enc.split("+")[1] if "+" in enc else None) not in CODECS_UNAVAILABLE |
| 555 | + ] |
| 556 | + headers[constants.HEADER_ENCODING] = ",".join(encoding) |
528 | 557 | elif isinstance(self._client_session.encoding, list): |
529 | 558 | headers[constants.HEADER_ENCODING] = ",".join(self._client_session.encoding) |
530 | 559 | elif isinstance(self._client_session.encoding, str): |
@@ -1271,8 +1300,16 @@ def __init__(self, mapper: RowMapper) -> None: |
1271 | 1300 |
|
1272 | 1301 | def create(self, encoding: str) -> QueryDataDecoder: |
1273 | 1302 | if encoding == "json+zstd": |
| 1303 | + if "zstd" in CODECS_UNAVAILABLE: |
| 1304 | + raise ValueError( |
| 1305 | + f"zstd is not installed so json+zstd encoding is not supported: {CODECS_UNAVAILABLE['zstd']}" |
| 1306 | + ) |
1274 | 1307 | return ZStdQueryDataDecoder(JsonQueryDataDecoder(self._mapper)) |
1275 | 1308 | elif encoding == "json+lz4": |
| 1309 | + if "lz4" in CODECS_UNAVAILABLE: |
| 1310 | + raise ValueError( |
| 1311 | + f"lz4 is not installed so json+lz4 encoding is not supported: {CODECS_UNAVAILABLE['lz4']}" |
| 1312 | + ) |
1276 | 1313 | return Lz4QueryDataDecoder(JsonQueryDataDecoder(self._mapper)) |
1277 | 1314 | elif encoding == "json": |
1278 | 1315 | return JsonQueryDataDecoder(self._mapper) |
@@ -1322,10 +1359,14 @@ def decode(self, data: bytes, metadata: _SegmentMetadataTO) -> List[List[Any]]: |
1322 | 1359 |
|
1323 | 1360 |
|
1324 | 1361 | class ZStdQueryDataDecoder(CompressedQueryDataDecoder): |
1325 | | - zstd_decompressor = zstandard.ZstdDecompressor() |
| 1362 | + def __init__(self, delegate: QueryDataDecoder) -> None: |
| 1363 | + super().__init__(delegate) |
| 1364 | + self._decompressor = None |
1326 | 1365 |
|
1327 | 1366 | def decompress(self, data: bytes, metadata: _SegmentMetadataTO) -> bytes: |
1328 | | - return ZStdQueryDataDecoder.zstd_decompressor.decompress(data) |
| 1367 | + if self._decompressor is None: |
| 1368 | + self._decompressor = zstandard.ZstdDecompressor() |
| 1369 | + return self._decompressor.decompress(data) |
1329 | 1370 |
|
1330 | 1371 |
|
1331 | 1372 | class Lz4QueryDataDecoder(CompressedQueryDataDecoder): |
|
0 commit comments