10
10
11
11
import base64
12
12
import datetime
13
+ import io
14
+ import itertools
15
+ import numpy as np
16
+ import pyarrow as pa
17
+ import pyarrow .json
13
18
import re
14
19
from decimal import Decimal
15
20
from ssl import CERT_NONE , CERT_OPTIONAL , CERT_REQUIRED , create_default_context
40
45
41
46
_logger = logging .getLogger (__name__ )
42
47
43
- _TIMESTAMP_PATTERN = re .compile (r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)' )
48
+ _TIMESTAMP_PATTERN = re .compile (r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,9})?)' )
49
+ _INTERVAL_DAY_TIME_PATTERN = re .compile (r'(\d+) (\d+):(\d+):(\d+(?:.\d+)?)' )
44
50
45
51
ssl_cert_parameter_map = {
46
52
"none" : CERT_NONE ,
@@ -106,9 +112,36 @@ def _parse_timestamp(value):
106
112
value = None
107
113
return value
108
114
115
+ def _parse_date (value ):
116
+ if value :
117
+ format = '%Y-%m-%d'
118
+ value = datetime .datetime .strptime (value , format ).date ()
119
+ else :
120
+ value = None
121
+ return value
109
122
110
- TYPES_CONVERTER = {"DECIMAL_TYPE" : Decimal ,
111
- "TIMESTAMP_TYPE" : _parse_timestamp }
123
+ def _parse_interval_day_time (value ):
124
+ if value :
125
+ match = _INTERVAL_DAY_TIME_PATTERN .match (value )
126
+ if match :
127
+ days = int (match .group (1 ))
128
+ hours = int (match .group (2 ))
129
+ minutes = int (match .group (3 ))
130
+ seconds = float (match .group (4 ))
131
+ value = datetime .timedelta (days = days , hours = hours , minutes = minutes , seconds = seconds )
132
+ else :
133
+ raise Exception (
134
+ 'Cannot convert "{}" into an interval_day_time' .format (value ))
135
+ else :
136
+ value = None
137
+ return value
138
+
139
+ TYPES_CONVERTER = {
140
+ "DECIMAL_TYPE" : Decimal ,
141
+ "TIMESTAMP_TYPE" : _parse_timestamp ,
142
+ "DATE_TYPE" : _parse_date ,
143
+ "INTERVAL_DAY_TIME_TYPE" : _parse_interval_day_time ,
144
+ }
112
145
113
146
114
147
class HiveParamEscaper (common .ParamEscaper ):
@@ -488,7 +521,50 @@ def cancel(self):
488
521
response = self ._connection .client .CancelOperation (req )
489
522
_check_status (response )
490
523
491
- def _fetch_more (self ):
524
+ def fetchone (self , schema = []):
525
+ return self .fetchmany (1 , schema )
526
+
527
+ def fetchall (self , schema = []):
528
+ return self .fetchmany (- 1 , schema )
529
+
530
+ def fetchmany (self , size = None , schema = []):
531
+ if size is None :
532
+ size = self .arraysize
533
+
534
+ if self ._state == self ._STATE_NONE :
535
+ raise exc .ProgrammingError ("No query yet" )
536
+
537
+ if size == - 1 :
538
+ # Fetch everything
539
+ self ._fetch_while (lambda : self ._state != self ._STATE_FINISHED , schema )
540
+ else :
541
+ self ._fetch_while (lambda :
542
+ (self ._state != self ._STATE_FINISHED ) and
543
+ (self ._data is None or self ._data .num_rows < size ),
544
+ schema
545
+ )
546
+
547
+ if not self ._data :
548
+ return None
549
+
550
+ if size == - 1 :
551
+ # Fetch everything
552
+ size = self ._data .num_rows
553
+ else :
554
+ size = min (size , self ._data .num_rows )
555
+
556
+ self ._rownumber += size
557
+ rows = self ._data [:size ]
558
+
559
+ if size == self ._data .num_rows :
560
+ # Fetch everything
561
+ self ._data = None
562
+ else :
563
+ self ._data = self ._data [size :]
564
+
565
+ return rows
566
+
567
+ def _fetch_more (self , ext_schema ):
492
568
"""Send another TFetchResultsReq and update state"""
493
569
assert (self ._state == self ._STATE_RUNNING ), "Should be running when in _fetch_more"
494
570
assert (self ._operationHandle is not None ), "Should have an op handle in _fetch_more"
@@ -503,15 +579,21 @@ def _fetch_more(self):
503
579
_check_status (response )
504
580
schema = self .description
505
581
assert not response .results .rows , 'expected data in columnar format'
506
- columns = [_unwrap_column (col , col_schema [1 ]) for col , col_schema in
507
- zip (response .results .columns , schema )]
508
- new_data = list ( zip ( * columns ))
509
- self . _data += new_data
582
+ columns = [_unwrap_column (col , col_schema [1 ], e_schema ) for col , col_schema , e_schema in
583
+ itertools . zip_longest (response .results .columns , schema , ext_schema )]
584
+ names = [ col [ 0 ] for col in schema ]
585
+ new_data = pa . Table . from_batches ([ pa . RecordBatch . from_arrays ( columns , names = names )])
510
586
# response.hasMoreRows seems to always be False, so we instead check the number of rows
511
587
# https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678
512
588
# if not response.hasMoreRows:
513
- if not new_data :
589
+ if new_data . num_rows == 0 :
514
590
self ._state = self ._STATE_FINISHED
591
+ return
592
+
593
+ if self ._data is None :
594
+ self ._data = new_data
595
+ else :
596
+ self ._data = pa .concat_tables ([self ._data , new_data ])
515
597
516
598
def poll (self , get_progress_update = True ):
517
599
"""Poll for and return the raw status data provided by the Hive Thrift REST API.
@@ -585,21 +667,55 @@ def fetch_logs(self):
585
667
#
586
668
587
669
588
- def _unwrap_column (col , type_ = None ):
670
+ def _unwrap_column (col , type_ = None , schema = None ):
589
671
"""Return a list of raw values from a TColumn instance."""
590
672
for attr , wrapper in iteritems (col .__dict__ ):
591
673
if wrapper is not None :
592
- result = wrapper .values
593
- nulls = wrapper .nulls # bit set describing what's null
594
- assert isinstance (nulls , bytes )
595
- for i , char in enumerate (nulls ):
596
- byte = ord (char ) if sys .version_info [0 ] == 2 else char
597
- for b in range (8 ):
598
- if byte & (1 << b ):
599
- result [i * 8 + b ] = None
600
- converter = TYPES_CONVERTER .get (type_ , None )
601
- if converter and type_ :
602
- result = [converter (row ) if row else row for row in result ]
674
+ if attr in ['boolVal' , 'byteVal' , 'i16Val' , 'i32Val' , 'i64Val' , 'doubleVal' ]:
675
+ values = wrapper .values
676
+ # unpack nulls as a byte array
677
+ nulls = np .unpackbits (np .frombuffer (wrapper .nulls , dtype = 'uint8' )).view (bool )
678
+ # override a full mask as trailing False values are not sent
679
+ mask = np .zeros (values .shape , dtype = '?' )
680
+ end = min (len (mask ), len (nulls ))
681
+ mask [:end ] = nulls [:end ]
682
+
683
+ # float values are transferred as double
684
+ if type_ == 'FLOAT_TYPE' :
685
+ values = values .astype ('>f4' )
686
+
687
+ result = pa .array (values .byteswap ().view (values .dtype .newbyteorder ()), mask = mask )
688
+
689
+ else :
690
+ result = wrapper .values
691
+ nulls = wrapper .nulls # bit set describing what's null
692
+ if len (result ) == 0 :
693
+ return pa .array ([])
694
+ assert isinstance (nulls , bytes )
695
+ for i , char in enumerate (nulls ):
696
+ byte = ord (char ) if sys .version_info [0 ] == 2 else char
697
+ for b in range (8 ):
698
+ if byte & (1 << b ):
699
+ result [i * 8 + b ] = None
700
+ converter = TYPES_CONVERTER .get (type_ , None )
701
+ if converter and type_ :
702
+ result = [converter (row ) if row else row for row in result ]
703
+
704
+ if type_ in ['ARRAY_TYPE' , 'MAP_TYPE' , 'STRUCT_TYPE' ]:
705
+ fd = io .BytesIO ()
706
+ for row in result :
707
+ if row is None :
708
+ row = 'null'
709
+ fd .write (f'{{"c":{ row } }}\n ' .encode ('utf8' ))
710
+ fd .seek (0 )
711
+
712
+ if schema == None :
713
+ # NOTE: JSON map conversion (from the original struct) is not supported
714
+ result = pa .json .read_json (fd , parse_options = None )[0 ].combine_chunks ()
715
+ else :
716
+ sch = pa .schema ([('c' , schema )])
717
+ opts = pa .json .ParseOptions (explicit_schema = sch )
718
+ result = pa .json .read_json (fd , parse_options = opts )[0 ].combine_chunks ()
603
719
return result
604
720
raise DataError ("Got empty column value {}" .format (col )) # pragma: no cover
605
721
0 commit comments