Skip to content

Commit 5efd997

Browse files
committed
Produce rows as slices of pyarrow.Table
1 parent ac09074 commit 5efd997

File tree

5 files changed

+247
-42
lines changed

5 files changed

+247
-42
lines changed

Diff for: TCLIService/ttypes.py

+7-18
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyhive/common.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ def _reset_state(self):
4343

4444
# Internal helper state
4545
self._state = self._STATE_NONE
46-
self._data = collections.deque()
46+
self._data = None
4747
self._columns = None
4848

49-
def _fetch_while(self, fn):
49+
def _fetch_while(self, fn, schema):
5050
while fn():
51-
self._fetch_more()
51+
self._fetch_more(schema)
5252
if fn():
5353
time.sleep(self._poll_interval)
5454

Diff for: pyhive/hive.py

+137-21
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010

1111
import base64
1212
import datetime
13+
import io
14+
import itertools
15+
import numpy as np
16+
import pyarrow as pa
17+
import pyarrow.json
1318
import re
1419
from decimal import Decimal
1520
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context
@@ -40,7 +45,8 @@
4045

4146
_logger = logging.getLogger(__name__)
4247

43-
_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
48+
_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,9})?)')
49+
_INTERVAL_DAY_TIME_PATTERN = re.compile(r'(\d+) (\d+):(\d+):(\d+(?:.\d+)?)')
4450

4551
ssl_cert_parameter_map = {
4652
"none": CERT_NONE,
@@ -106,9 +112,36 @@ def _parse_timestamp(value):
106112
value = None
107113
return value
108114

115+
def _parse_date(value):
116+
if value:
117+
format = '%Y-%m-%d'
118+
value = datetime.datetime.strptime(value, format).date()
119+
else:
120+
value = None
121+
return value
109122

110-
TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal,
111-
"TIMESTAMP_TYPE": _parse_timestamp}
123+
def _parse_interval_day_time(value):
124+
if value:
125+
match = _INTERVAL_DAY_TIME_PATTERN.match(value)
126+
if match:
127+
days = int(match.group(1))
128+
hours = int(match.group(2))
129+
minutes = int(match.group(3))
130+
seconds = float(match.group(4))
131+
value = datetime.timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
132+
else:
133+
raise Exception(
134+
'Cannot convert "{}" into an interval_day_time'.format(value))
135+
else:
136+
value = None
137+
return value
138+
139+
TYPES_CONVERTER = {
140+
"DECIMAL_TYPE": Decimal,
141+
"TIMESTAMP_TYPE": _parse_timestamp,
142+
"DATE_TYPE": _parse_date,
143+
"INTERVAL_DAY_TIME_TYPE": _parse_interval_day_time,
144+
}
112145

113146

114147
class HiveParamEscaper(common.ParamEscaper):
@@ -488,7 +521,50 @@ def cancel(self):
488521
response = self._connection.client.CancelOperation(req)
489522
_check_status(response)
490523

491-
def _fetch_more(self):
524+
def fetchone(self, schema=[]):
525+
return self.fetchmany(1, schema)
526+
527+
def fetchall(self, schema=[]):
528+
return self.fetchmany(-1, schema)
529+
530+
def fetchmany(self, size=None, schema=[]):
531+
if size is None:
532+
size = self.arraysize
533+
534+
if self._state == self._STATE_NONE:
535+
raise exc.ProgrammingError("No query yet")
536+
537+
if size == -1:
538+
# Fetch everything
539+
self._fetch_while(lambda: self._state != self._STATE_FINISHED, schema)
540+
else:
541+
self._fetch_while(lambda:
542+
(self._state != self._STATE_FINISHED) and
543+
(self._data is None or self._data.num_rows < size),
544+
schema
545+
)
546+
547+
if not self._data:
548+
return None
549+
550+
if size == -1:
551+
# Fetch everything
552+
size = self._data.num_rows
553+
else:
554+
size = min(size, self._data.num_rows)
555+
556+
self._rownumber += size
557+
rows = self._data[:size]
558+
559+
if size == self._data.num_rows:
560+
# Fetch everything
561+
self._data = None
562+
else:
563+
self._data = self._data[size:]
564+
565+
return rows
566+
567+
def _fetch_more(self, ext_schema):
492568
"""Send another TFetchResultsReq and update state"""
493569
assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more"
494570
assert(self._operationHandle is not None), "Should have an op handle in _fetch_more"
@@ -503,15 +579,21 @@ def _fetch_more(self):
503579
_check_status(response)
504580
schema = self.description
505581
assert not response.results.rows, 'expected data in columnar format'
506-
columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in
507-
zip(response.results.columns, schema)]
508-
new_data = list(zip(*columns))
509-
self._data += new_data
582+
columns = [_unwrap_column(col, col_schema[1], e_schema) for col, col_schema, e_schema in
583+
itertools.zip_longest(response.results.columns, schema, ext_schema)]
584+
names = [col[0] for col in schema]
585+
new_data = pa.Table.from_batches([pa.RecordBatch.from_arrays(columns, names=names)])
510586
# response.hasMoreRows seems to always be False, so we instead check the number of rows
511587
# https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678
512588
# if not response.hasMoreRows:
513-
if not new_data:
589+
if new_data.num_rows == 0:
514590
self._state = self._STATE_FINISHED
591+
return
592+
593+
if self._data is None:
594+
self._data = new_data
595+
else:
596+
self._data = pa.concat_tables([self._data, new_data])
515597

516598
def poll(self, get_progress_update=True):
517599
"""Poll for and return the raw status data provided by the Hive Thrift REST API.
@@ -585,21 +667,55 @@ def fetch_logs(self):
585667
#
586668

587669

588-
def _unwrap_column(col, type_=None):
670+
def _unwrap_column(col, type_=None, schema=None):
589671
"""Return a list of raw values from a TColumn instance."""
590672
for attr, wrapper in iteritems(col.__dict__):
591673
if wrapper is not None:
592-
result = wrapper.values
593-
nulls = wrapper.nulls # bit set describing what's null
594-
assert isinstance(nulls, bytes)
595-
for i, char in enumerate(nulls):
596-
byte = ord(char) if sys.version_info[0] == 2 else char
597-
for b in range(8):
598-
if byte & (1 << b):
599-
result[i * 8 + b] = None
600-
converter = TYPES_CONVERTER.get(type_, None)
601-
if converter and type_:
602-
result = [converter(row) if row else row for row in result]
674+
if attr in ['boolVal', 'byteVal', 'i16Val', 'i32Val', 'i64Val', 'doubleVal']:
675+
values = wrapper.values
676+
# unpack nulls as a byte array
677+
nulls = np.unpackbits(np.frombuffer(wrapper.nulls, dtype='uint8')).view(bool)
678+
# override a full mask as trailing False values are not sent
679+
mask = np.zeros(values.shape, dtype='?')
680+
end = min(len(mask), len(nulls))
681+
mask[:end] = nulls[:end]
682+
683+
# float values are transferred as double
684+
if type_ == 'FLOAT_TYPE':
685+
values = values.astype('>f4')
686+
687+
result = pa.array(values.byteswap().view(values.dtype.newbyteorder()), mask=mask)
688+
689+
else:
690+
result = wrapper.values
691+
nulls = wrapper.nulls # bit set describing what's null
692+
if len(result) == 0:
693+
return pa.array([])
694+
assert isinstance(nulls, bytes)
695+
for i, char in enumerate(nulls):
696+
byte = ord(char) if sys.version_info[0] == 2 else char
697+
for b in range(8):
698+
if byte & (1 << b):
699+
result[i * 8 + b] = None
700+
converter = TYPES_CONVERTER.get(type_, None)
701+
if converter and type_:
702+
result = [converter(row) if row else row for row in result]
703+
704+
if type_ in ['ARRAY_TYPE', 'MAP_TYPE', 'STRUCT_TYPE']:
705+
fd = io.BytesIO()
706+
for row in result:
707+
if row is None:
708+
row = 'null'
709+
fd.write(f'{{"c":{row}}}\n'.encode('utf8'))
710+
fd.seek(0)
711+
712+
if schema == None:
713+
# NOTE: JSON map conversion (from the original struct) is not supported
714+
result = pa.json.read_json(fd, parse_options=None)[0].combine_chunks()
715+
else:
716+
sch = pa.schema([('c', schema)])
717+
opts = pa.json.ParseOptions(explicit_schema=sch)
718+
result = pa.json.read_json(fd, parse_options=opts)[0].combine_chunks()
603719
return result
604720
raise DataError("Got empty column value {}".format(col)) # pragma: no cover
605721

Diff for: pyhive/schema.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
This module attempts to reconstruct an Arrow schema from the info dumped at the beginning of a Hive query log.
3+
4+
SUPPORTS:
5+
* All primitive types _except_ INTERVAL.
6+
* STRUCT and ARRAY types.
7+
* Composition of any combination of previous types.
8+
9+
LIMITATIONS:
10+
* PyHive does not support INTERVAL types yet. A converter needs to be implemented.
11+
* Hive sends complex types always as strings as something _similar_ to JSON.
12+
* Arrow can parse most of this pseudo-JSON excluding:
13+
* MAP and INTERVAL types
14+
* A custom parser would be needed to implement support for all types and their composition.
15+
"""
16+
17+
import pyparsing as pp
18+
import pyarrow as pa
19+
20+
def a_type(s, loc, toks):
21+
m_basic = {
22+
'tinyint' : pa.int8(),
23+
'smallint' : pa.int16(),
24+
'int' : pa.int32(),
25+
'bigint' : pa.int64(),
26+
'float' : pa.float32(),
27+
'double' : pa.float64(),
28+
'boolean' : pa.bool_(),
29+
'string' : pa.string(),
30+
'char' : pa.string(),
31+
'varchar' : pa.string(),
32+
'binary' : pa.binary(),
33+
'timestamp' : pa.timestamp('ns'),
34+
'date' : pa.date32(),
35+
#'interval_year_month' : pa.month_day_nano_interval(),
36+
#'interval_day_time' : pa.month_day_nano_interval(),
37+
}
38+
39+
typ, args = toks[0], toks[1:]
40+
41+
if typ in m_basic:
42+
return m_basic[typ]
43+
if typ == 'decimal':
44+
return pa.decimal128(*map(int, args))
45+
if typ == 'array':
46+
return pa.list_(args[0])
47+
#if typ == 'map':
48+
# return pa.map_(args[0], args[1])
49+
if typ == 'struct':
50+
return pa.struct(args)
51+
raise NotImplementedError(f"Type {typ} is not supported")
52+
53+
def a_field(s, loc, toks):
54+
return pa.field(toks[0], toks[1])
55+
56+
LB, RB, LP, RP, LT, RT, COMMA, COLON = map(pp.Suppress, "[]()<>,:")
57+
58+
def t_args(n):
59+
return LP + pp.delimitedList(pp.Word(pp.nums), ",", min=n, max=n) + RP
60+
61+
t_basic = pp.one_of(
62+
"tinyint smallint int bigint float double boolean string binary timestamp date decimal",
63+
caseless=True, as_keyword=True
64+
)
65+
t_interval = pp.one_of(
66+
"interval_year_month interval_day_time",
67+
caseless=True, as_keyword=True
68+
)
69+
t_char = pp.one_of("char varchar", caseless=True, as_keyword=True) + t_args(1)
70+
t_decimal = pp.CaselessKeyword("decimal") + t_args(2)
71+
t_primitive = (t_basic ^ t_char ^ t_decimal).set_parse_action(a_type)
72+
73+
t_type = pp.Forward()
74+
75+
t_label = pp.Word(pp.alphas + "_", pp.alphanums + "_")
76+
t_array = pp.CaselessKeyword('array') + LT + t_type + RT
77+
t_map = pp.CaselessKeyword('map') + LT + t_primitive + COMMA + t_type + RT
78+
t_struct = pp.CaselessKeyword('struct') + LT + pp.delimitedList((t_label + COLON + t_type).set_parse_action(a_field), ",") + RT
79+
t_complex = (t_array ^ t_map ^ t_struct).set_parse_action(a_type)
80+
81+
t_type <<= t_primitive ^ t_complex
82+
t_top_type = t_type ^ t_interval
83+
84+
l_schema, l_fieldschemas, l_fieldschema, l_name, l_type, l_comment, l_properties, l_null = map(
85+
lambda x: pp.Keyword(x).suppress(), "Schema fieldSchemas FieldSchema name type comment properties null".split(' ')
86+
)
87+
t_fieldschema = l_fieldschema + LP + l_name + COLON + t_label.suppress() + COMMA + l_type + COLON + t_top_type + COMMA + l_comment + COLON + l_null + RP
88+
t_schema = l_schema + LP + l_fieldschemas + COLON + LB + pp.delimitedList(t_fieldschema, ',') + RB + COMMA + l_properties + COLON + l_null + RP
89+
90+
def parse_schema(logs):
91+
prefix = 'INFO : Returning Hive schema: '
92+
93+
for l in logs:
94+
if l.startswith(prefix):
95+
str_schema = l[len(prefix):]
96+
97+
return t_schema.parse_string(str_schema).as_list()

Diff for: pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[build-system]
2+
requires = ["setuptools"]
3+
build-backend = "setuptools.build_meta"

0 commit comments

Comments
 (0)