|
| 1 | +import struct |
| 2 | + |
| 3 | +from clickhouse_connect.driver.types import ByteSource |
| 4 | + |
| 5 | + |
| 6 | +class ByteArraySource(ByteSource): |
| 7 | + """ |
| 8 | + ByteSource implementation for in-memory byte arrays. |
| 9 | +
|
| 10 | + This class wraps a byte array and provides the ByteSource interface, |
| 11 | + allowing ClickHouse type decoders to read from in-memory data instead |
| 12 | + of a network stream. |
| 13 | +
|
| 14 | + Used primarily for decoding variant-encoded values from JSON shared data |
| 15 | + where each value is a complete serialized type instance. |
| 16 | + """ |
| 17 | + |
| 18 | + def __init__(self, data: bytes, encoding: str = "utf-8"): |
| 19 | + """ |
| 20 | + Initialize ByteArraySource with byte array data. |
| 21 | +
|
| 22 | + :param data: The byte array to read from |
| 23 | + :param encoding: Character encoding for string operations (default: utf-8) |
| 24 | + """ |
| 25 | + self.data = data |
| 26 | + self.pos = 0 |
| 27 | + self.encoding = encoding |
| 28 | + |
| 29 | + def read_byte(self) -> int: |
| 30 | + """Read a single byte and advance position.""" |
| 31 | + if self.pos >= len(self.data): |
| 32 | + raise EOFError("Attempted to read past end of byte array") |
| 33 | + b = self.data[self.pos] |
| 34 | + self.pos += 1 |
| 35 | + return b |
| 36 | + |
| 37 | + def read_bytes(self, sz: int) -> bytes: |
| 38 | + """Read specified number of bytes and advance position.""" |
| 39 | + if self.pos + sz > len(self.data): |
| 40 | + raise EOFError(f"Attempted to read {sz} bytes, only {len(self.data) - self.pos} available") |
| 41 | + result = self.data[self.pos : self.pos + sz] |
| 42 | + self.pos += sz |
| 43 | + return result |
| 44 | + |
| 45 | + def read_leb128(self) -> int: |
| 46 | + """Read a LEB128 (variable-length) encoded integer.""" |
| 47 | + sz = 0 |
| 48 | + shift = 0 |
| 49 | + while self.pos < len(self.data): |
| 50 | + b = self.read_byte() |
| 51 | + sz += (b & 0x7F) << shift |
| 52 | + if (b & 0x80) == 0: |
| 53 | + return sz |
| 54 | + shift += 7 |
| 55 | + raise EOFError("Unexpected end while reading LEB128") |
| 56 | + |
| 57 | + def read_leb128_str(self) -> str: |
| 58 | + """Read a LEB128 length-prefixed string.""" |
| 59 | + sz = self.read_leb128() |
| 60 | + return self.read_bytes(sz).decode(self.encoding) |
| 61 | + |
| 62 | + def read_uint64(self) -> int: |
| 63 | + """Read an unsigned 64-bit integer (little-endian).""" |
| 64 | + return int.from_bytes(self.read_bytes(8), "little", signed=False) |
| 65 | + |
| 66 | + def read_int64(self) -> int: |
| 67 | + """Read a signed 64-bit integer (little-endian).""" |
| 68 | + return int.from_bytes(self.read_bytes(8), "little", signed=True) |
| 69 | + |
| 70 | + def read_uint32(self) -> int: |
| 71 | + """Read an unsigned 32-bit integer (little-endian).""" |
| 72 | + return int.from_bytes(self.read_bytes(4), "little", signed=False) |
| 73 | + |
| 74 | + def read_int32(self) -> int: |
| 75 | + """Read a signed 32-bit integer (little-endian).""" |
| 76 | + return int.from_bytes(self.read_bytes(4), "little", signed=True) |
| 77 | + |
| 78 | + def read_uint16(self) -> int: |
| 79 | + """Read an unsigned 16-bit integer (little-endian).""" |
| 80 | + return int.from_bytes(self.read_bytes(2), "little", signed=False) |
| 81 | + |
| 82 | + def read_int16(self) -> int: |
| 83 | + """Read a signed 16-bit integer (little-endian).""" |
| 84 | + return int.from_bytes(self.read_bytes(2), "little", signed=True) |
| 85 | + |
| 86 | + def read_float32(self) -> float: |
| 87 | + """Read a 32-bit float (little-endian).""" |
| 88 | + return struct.unpack("<f", self.read_bytes(4))[0] |
| 89 | + |
| 90 | + def read_float64(self) -> float: |
| 91 | + """Read a 64-bit float (double, little-endian).""" |
| 92 | + return struct.unpack("<d", self.read_bytes(8))[0] |
| 93 | + |
| 94 | + # pylint: disable=too-many-return-statements |
| 95 | + def read_array(self, array_type: str, num_rows: int): # type: ignore |
| 96 | + """ |
| 97 | + Limited implementation of array reading for basic types. |
| 98 | +
|
| 99 | + Args: |
| 100 | + array_type: Python struct format character |
| 101 | + 'B' = UInt8, 'H' = UInt16, 'I' = UInt32, 'Q' = UInt64 |
| 102 | + 'b' = Int8, 'h' = Int16, 'i' = Int32, 'q' = Int64 |
| 103 | + 'f' = Float32, 'd' = Float64 |
| 104 | + num_rows: Number of elements to read |
| 105 | +
|
| 106 | + Returns: |
| 107 | + List of values |
| 108 | + """ |
| 109 | + if array_type == "B": |
| 110 | + return [self.read_byte() for _ in range(num_rows)] |
| 111 | + elif array_type == "H": |
| 112 | + return [self.read_uint16() for _ in range(num_rows)] |
| 113 | + elif array_type == "I": |
| 114 | + return [self.read_uint32() for _ in range(num_rows)] |
| 115 | + elif array_type == "Q": |
| 116 | + return [self.read_uint64() for _ in range(num_rows)] |
| 117 | + elif array_type == "b": |
| 118 | + return [int.from_bytes([self.read_byte()], "little", signed=True) for _ in range(num_rows)] |
| 119 | + elif array_type == "h": |
| 120 | + return [self.read_int16() for _ in range(num_rows)] |
| 121 | + elif array_type == "i": |
| 122 | + return [self.read_int32() for _ in range(num_rows)] |
| 123 | + elif array_type == "q": |
| 124 | + return [self.read_int64() for _ in range(num_rows)] |
| 125 | + elif array_type == "f": |
| 126 | + return [self.read_float32() for _ in range(num_rows)] |
| 127 | + elif array_type == "d": |
| 128 | + return [self.read_float64() for _ in range(num_rows)] |
| 129 | + else: |
| 130 | + raise NotImplementedError(f"Array type {array_type} not implemented for ByteArraySource") |
| 131 | + |
| 132 | + # Minimal implementations for other ByteSource methods that aren't needed |
| 133 | + # for single-value decoding but are required by the interface |
| 134 | + |
| 135 | + def read_str_col(self, num_rows, encoding, nullable=False, null_obj=None): # type: ignore |
| 136 | + """ |
| 137 | + Read a column of strings. |
| 138 | + For single-value decoding (num_rows=1), read one LEB128 length-prefixed string. |
| 139 | + """ |
| 140 | + if num_rows != 1: |
| 141 | + raise NotImplementedError("read_str_col only supports num_rows=1 for single-value decoding") |
| 142 | + |
| 143 | + length = self.read_leb128() |
| 144 | + string_bytes = self.read_bytes(length) |
| 145 | + |
| 146 | + if encoding is None: |
| 147 | + return [string_bytes] |
| 148 | + |
| 149 | + return [string_bytes.decode(encoding)] |
| 150 | + |
| 151 | + def read_bytes_col(self, sz, num_rows): |
| 152 | + """Not used for single-value decoding.""" |
| 153 | + raise NotImplementedError("read_bytes_col not needed for single-value decoding") |
| 154 | + |
| 155 | + def read_fixed_str_col(self, sz, num_rows, encoding): |
| 156 | + """Not used for single-value decoding.""" |
| 157 | + raise NotImplementedError("read_fixed_str_col not needed for single-value decoding") |
| 158 | + |
| 159 | + def close(self): |
| 160 | + """No cleanup needed for byte arrays.""" |
0 commit comments