Skip to content

Commit 59b2776

Browse files
serenajianghsheth2
authored andcommitted
Support for Presto decimals (dropbox#430)
* Support for Presto decimals * lower
1 parent b95250b commit 59b2776

File tree

4 files changed

+19
-9
lines changed

4 files changed

+19
-9
lines changed

Diff for: pyhive/presto.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from __future__ import unicode_literals
1010

1111
from builtins import object
12+
from decimal import Decimal
13+
1214
from pyhive import common
1315
from pyhive.common import DBAPITypeObject
1416
# Make all exceptions visible in this module per DB-API
@@ -34,6 +36,11 @@
3436

3537
_logger = logging.getLogger(__name__)
3638

39+
TYPES_CONVERTER = {
40+
"decimal": Decimal,
41+
# As of Presto 0.69, binary data is returned as the varbinary type in base64 format
42+
"varbinary": base64.b64decode
43+
}
3744

3845
class PrestoParamEscaper(common.ParamEscaper):
3946
def escape_datetime(self, item, format):
@@ -307,14 +314,13 @@ def _fetch_more(self):
307314
"""Fetch the next URI and update state"""
308315
self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs))
309316

310-
def _decode_binary(self, rows):
311-
# As of Presto 0.69, binary data is returned as the varbinary type in base64 format
312-
# This function decodes base64 data in place
317+
def _process_data(self, rows):
313318
for i, col in enumerate(self.description):
314-
if col[1] == 'varbinary':
319+
col_type = col[1].split("(")[0].lower()
320+
if col_type in TYPES_CONVERTER:
315321
for row in rows:
316322
if row[i] is not None:
317-
row[i] = base64.b64decode(row[i])
323+
row[i] = TYPES_CONVERTER[col_type](row[i])
318324

319325
def _process_response(self, response):
320326
"""Given the JSON response from Presto's REST API, update the internal state with the next
@@ -341,7 +347,7 @@ def _process_response(self, response):
341347
if 'data' in response_json:
342348
assert self._columns
343349
new_data = response_json['data']
344-
self._decode_binary(new_data)
350+
self._process_data(new_data)
345351
self._data += map(tuple, new_data)
346352
if 'nextUri' not in response_json:
347353
self._state = self._STATE_FINISHED

Diff for: pyhive/tests/test_presto.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import contextlib
1111
import os
12+
from decimal import Decimal
13+
1214
import requests
1315

1416
from pyhive import exc
@@ -93,7 +95,7 @@ def test_complex(self, cursor):
9395
{"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON
9496
[1, 2], # struct is returned as a list of elements
9597
# '{0:1}',
96-
'0.1',
98+
Decimal('0.1'),
9799
)]
98100
self.assertEqual(rows, expected)
99101
# catch unicode/str

Diff for: pyhive/tests/test_trino.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import contextlib
1111
import os
12+
from decimal import Decimal
13+
1214
import requests
1315

1416
from pyhive import exc
@@ -89,7 +91,7 @@ def test_complex(self, cursor):
8991
{"1": 2, "3": 4}, # Trino converts all keys to strings so that they're valid JSON
9092
[1, 2], # struct is returned as a list of elements
9193
# '{0:1}',
92-
'0.1',
94+
Decimal('0.1'),
9395
)]
9496
self.assertEqual(rows, expected)
9597
# catch unicode/str

Diff for: pyhive/trino.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _process_response(self, response):
124124
if 'data' in response_json:
125125
assert self._columns
126126
new_data = response_json['data']
127-
self._decode_binary(new_data)
127+
self._process_data(new_data)
128128
self._data += map(tuple, new_data)
129129
if 'nextUri' not in response_json:
130130
self._state = self._STATE_FINISHED

0 commit comments

Comments
 (0)