|
| 1 | +import logging |
| 2 | +from math import ceil, nan |
| 3 | +from struct import pack, unpack |
| 4 | +from typing import Any, Sequence |
| 5 | + |
| 6 | +from clickhouse_connect.datatypes.base import ClickHouseType, TypeDef |
| 7 | +from clickhouse_connect.datatypes.registry import get_from_name |
| 8 | +from clickhouse_connect.driver.ctypes import data_conv |
| 9 | +from clickhouse_connect.driver.insert import InsertContext |
| 10 | +from clickhouse_connect.driver.options import np |
| 11 | +from clickhouse_connect.driver.query import QueryContext |
| 12 | +from clickhouse_connect.driver.types import ByteSource |
| 13 | + |
| 14 | +logger = logging.getLogger(__name__) |
| 15 | + |
| 16 | +if np is None: |
| 17 | + logger.info("NumPy not detected. Install NumPy to see an order of magnitude performance gain with QBit columns.") |
| 18 | + |
| 19 | + |
| 20 | +class QBit(ClickHouseType): |
| 21 | + """ |
| 22 | + QBit type - represents bit-transposed vectors for efficient vector search operations. |
| 23 | +
|
| 24 | + Syntax: QBit(element_type, dimension) |
| 25 | + - element_type: BFloat16, Float32, or Float64 |
| 26 | + - dimension: Number of elements per vector |
| 27 | +
|
| 28 | + Over the Native protocol, ClickHouse transmits QBit columns as bit-transposed Tuples. |
| 29 | +
|
| 30 | + Requires: |
| 31 | + - SET allow_experimental_qbit_type = 1 |
| 32 | + - Server version >=25.10 |
| 33 | + """ |
| 34 | + |
| 35 | + __slots__ = ( |
| 36 | + "element_type", |
| 37 | + "dimension", |
| 38 | + "_bits_per_element", |
| 39 | + "_bytes_per_fixedstring", |
| 40 | + "_tuple_type", |
| 41 | + ) |
| 42 | + |
| 43 | + python_type = list |
| 44 | + _BIT_SHIFTS = [1 << i for i in range(8)] |
| 45 | + _ELEMENT_BITS = {"BFloat16": 16, "Float32": 32, "Float64": 64} |
| 46 | + |
| 47 | + def __init__(self, type_def: TypeDef): |
| 48 | + super().__init__(type_def) |
| 49 | + |
| 50 | + self.element_type = type_def.values[0] |
| 51 | + if self.element_type not in self._ELEMENT_BITS: |
| 52 | + raise ValueError(f"Unsupported QBit element type '{self.element_type}'. Supported types: BFloat16, Float32, Float64.") |
| 53 | + |
| 54 | + self.dimension = type_def.values[1] |
| 55 | + if self.dimension <= 0: |
| 56 | + raise ValueError(f"QBit dimension must be greater than 0. Got: {self.dimension}.") |
| 57 | + |
| 58 | + self._name_suffix = f"({self.element_type}, {self.dimension})" |
| 59 | + self._bits_per_element = self._ELEMENT_BITS.get(self.element_type, 32) |
| 60 | + self._bytes_per_fixedstring = ceil(self.dimension / 8) |
| 61 | + |
| 62 | + # Create the underlying Tuple type for bit-transposed representation |
| 63 | + # E.g., for Float32 with dim=8: Tuple(FixedString(1), FixedString(1), ... x32) |
| 64 | + fixedstring_type = f"FixedString({self._bytes_per_fixedstring})" |
| 65 | + tuple_types = ", ".join([fixedstring_type] * self._bits_per_element) |
| 66 | + tuple_type_name = f"Tuple({tuple_types})" |
| 67 | + self._tuple_type = get_from_name(tuple_type_name) |
| 68 | + self.byte_size = self._bits_per_element * self._bytes_per_fixedstring |
| 69 | + |
| 70 | + def read_column_prefix(self, source: ByteSource, ctx: QueryContext): |
| 71 | + return self._tuple_type.read_column_prefix(source, ctx) |
| 72 | + |
| 73 | + def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any) -> Sequence: |
| 74 | + """Read bit-transposed Tuple data and convert to flat float vectors.""" |
| 75 | + if num_rows == 0: |
| 76 | + return [] |
| 77 | + |
| 78 | + null_map = None |
| 79 | + if self.nullable: |
| 80 | + null_map = source.read_bytes(num_rows) |
| 81 | + |
| 82 | + tuple_data = self._tuple_type.read_column_data(source, num_rows, ctx, read_state) |
| 83 | + vectors = [self._untranspose_row(t) for t in tuple_data] |
| 84 | + if self.nullable: |
| 85 | + return data_conv.build_nullable_column(vectors, null_map, self._active_null(ctx)) |
| 86 | + return vectors |
| 87 | + |
| 88 | + def write_column_prefix(self, dest: bytearray): |
| 89 | + self._tuple_type.write_column_prefix(dest) |
| 90 | + |
| 91 | + def write_column_data(self, column: Sequence, dest: bytearray, ctx: InsertContext): |
| 92 | + """Convert flat float vectors to bit-transposed Tuple data and write.""" |
| 93 | + if len(column) == 0: |
| 94 | + return |
| 95 | + |
| 96 | + if self.nullable: |
| 97 | + dest += bytes([1 if x is None else 0 for x in column]) |
| 98 | + |
| 99 | + null_tuple = tuple(b"\x00" * self._bytes_per_fixedstring for _ in range(self._bits_per_element)) |
| 100 | + tuple_column = [null_tuple if row is None else self._transpose_row(row) for row in column] |
| 101 | + |
| 102 | + self._tuple_type.write_column_data(tuple_column, dest, ctx) |
| 103 | + |
| 104 | + def _active_null(self, ctx: QueryContext): |
| 105 | + """Return context-appropriate null value for nullable QBit columns.""" |
| 106 | + if ctx.use_none: |
| 107 | + return None |
| 108 | + if ctx.use_extended_dtypes: |
| 109 | + return nan |
| 110 | + return None |
| 111 | + |
| 112 | + def _values_to_words(self, values: list[float]) -> Sequence[int]: |
| 113 | + """Convert float values to integer words using batch struct processing.""" |
| 114 | + count = len(values) |
| 115 | + |
| 116 | + if self.element_type == "BFloat16": |
| 117 | + # BFloat16 is the top 16 bits of a Float32 (truncate mantissa) |
| 118 | + raw_ints = unpack(f"<{count}I", pack(f"<{count}f", *values)) |
| 119 | + return [(x >> 16) & 0xFFFF for x in raw_ints] |
| 120 | + |
| 121 | + fmt_char = "I" if self.element_type == "Float32" else "Q" |
| 122 | + float_char = "f" if self.element_type == "Float32" else "d" |
| 123 | + |
| 124 | + return unpack(f"<{count}{fmt_char}", pack(f"<{count}{float_char}", *values)) |
| 125 | + |
| 126 | + def _words_to_values(self, words: list[int]) -> list[float]: |
| 127 | + """Convert integer words to float values using batch unpacking.""" |
| 128 | + count = len(words) |
| 129 | + |
| 130 | + if self.element_type == "BFloat16": |
| 131 | + # Pad BFloat16 words with zeros to reconstruct valid Float32s |
| 132 | + shifted_words = [(w & 0xFFFF) << 16 for w in words] |
| 133 | + return list(unpack(f"<{count}f", pack(f"<{count}I", *shifted_words))) |
| 134 | + |
| 135 | + if self.element_type == "Float32": |
| 136 | + return list(unpack(f"<{count}f", pack(f"<{count}I", *words))) |
| 137 | + |
| 138 | + # Float64 |
| 139 | + return list(unpack(f"<{count}d", pack(f"<{count}Q", *words))) |
| 140 | + |
| 141 | + def _untranspose_row(self, bit_planes: tuple): |
| 142 | + """Convert bit-transposed tuple to flat float vector.""" |
| 143 | + if np is not None: |
| 144 | + return self._untranspose_row_numpy(bit_planes) |
| 145 | + |
| 146 | + words = [0] * self.dimension |
| 147 | + bit_shifts = self._BIT_SHIFTS |
| 148 | + dim = self.dimension |
| 149 | + |
| 150 | + # Iterate Planes (MSB -> LSB) |
| 151 | + for bit_idx, bit_plane_bytes in enumerate(bit_planes): |
| 152 | + bit_pos = self._bits_per_element - 1 - bit_idx |
| 153 | + mask = 1 << bit_pos |
| 154 | + |
| 155 | + # Iterate Bytes in Plane |
| 156 | + for byte_idx, byte_val in enumerate(bit_plane_bytes): |
| 157 | + # if byte is 0, skip processing 8 bits |
| 158 | + if byte_val == 0: |
| 159 | + continue |
| 160 | + |
| 161 | + base_elem_idx = byte_idx << 3 # Each byte encodes 8 elements |
| 162 | + |
| 163 | + # Extract set bits from this byte |
| 164 | + for bit_in_byte in range(8): |
| 165 | + if byte_val & bit_shifts[bit_in_byte]: |
| 166 | + elem_idx = base_elem_idx + bit_in_byte |
| 167 | + if elem_idx < dim: |
| 168 | + words[elem_idx] |= mask # Accumulate bit at position bit_pos |
| 169 | + |
| 170 | + return self._words_to_values(words) |
| 171 | + |
| 172 | + def _untranspose_row_numpy(self, bit_planes: tuple) -> list[float]: |
| 173 | + """Vectorized numpy operations version of _untranspose_row""" |
| 174 | + # 1. Convert tuple of bytes to a single uint8 array |
| 175 | + total_bytes = b"".join(bit_planes) |
| 176 | + planes_uint8 = np.frombuffer(total_bytes, dtype=np.uint8) |
| 177 | + planes_uint8 = planes_uint8.reshape(self._bits_per_element, -1) |
| 178 | + |
| 179 | + # 2. Unpack bits to get the boolean/integer matrix |
| 180 | + bits_matrix: "np.ndarray" = np.unpackbits(planes_uint8, axis=1, bitorder="little") |
| 181 | + |
| 182 | + # 3. Trim padding if necessary |
| 183 | + if bits_matrix.shape[1] != self.dimension: # pylint: disable=no-member |
| 184 | + bits_matrix = bits_matrix[:, : self.dimension] # pylint: disable=invalid-sequence-index |
| 185 | + |
| 186 | + # 4. Reconstruct the integer words |
| 187 | + if self.element_type == "Float64": |
| 188 | + int_dtype = np.uint64 |
| 189 | + final_dtype = np.float64 |
| 190 | + else: |
| 191 | + # Float32 and BFloat16 use 32-bit containers |
| 192 | + int_dtype = np.uint32 |
| 193 | + final_dtype = np.float32 |
| 194 | + |
| 195 | + # Accumulate bits into integers |
| 196 | + words = np.zeros(self.dimension, dtype=int_dtype) |
| 197 | + |
| 198 | + for i in range(self._bits_per_element): |
| 199 | + # MSB is at index 0 |
| 200 | + shift = self._bits_per_element - 1 - i |
| 201 | + |
| 202 | + # If the bit row is 1, add 2^shift to the word |
| 203 | + # Cast bits to the target int type before shifting to avoid overflow |
| 204 | + words |= bits_matrix[i].astype(int_dtype) << shift |
| 205 | + |
| 206 | + # 5. Interpret as Floats |
| 207 | + if self.element_type == "BFloat16": |
| 208 | + # Shift back up to the top 16 bits of a Float32 |
| 209 | + # Cast to uint32 first to ensure safe shifting |
| 210 | + words = words.astype(np.uint32) << 16 |
| 211 | + return words.view(np.float32).tolist() |
| 212 | + |
| 213 | + return words.view(final_dtype).tolist() |
| 214 | + |
| 215 | + def _transpose_row(self, values: list[float]) -> tuple: |
| 216 | + """Convert flat float vector to bit-transposed tuple.""" |
| 217 | + if len(values) != self.dimension: |
| 218 | + raise ValueError(f"Vector dimension mismatch: expected {self.dimension}, got {len(values)}") |
| 219 | + |
| 220 | + # If numpy is available, use the fast path |
| 221 | + if np is not None: |
| 222 | + if isinstance(values, np.ndarray): |
| 223 | + return self._transpose_row_numpy(values) |
| 224 | + |
| 225 | + # If numpy is available but user supplied python list, convert to np array anyway for |
| 226 | + # huge performance gains. |
| 227 | + dtype = np.float64 if self.element_type == "Float64" else np.float32 |
| 228 | + return self._transpose_row_numpy(np.array(values, dtype=dtype)) |
| 229 | + |
| 230 | + words = self._values_to_words(values) |
| 231 | + bit_planes = [] |
| 232 | + bit_shifts = self._BIT_SHIFTS |
| 233 | + bytes_per_fs = self._bytes_per_fixedstring |
| 234 | + |
| 235 | + for bit_idx in range(self._bits_per_element): |
| 236 | + bit_pos = self._bits_per_element - 1 - bit_idx |
| 237 | + mask = 1 << bit_pos |
| 238 | + plane = bytearray(bytes_per_fs) |
| 239 | + |
| 240 | + for elem_idx, word in enumerate(words): |
| 241 | + if word & mask: |
| 242 | + plane[elem_idx >> 3] |= bit_shifts[elem_idx & 7] |
| 243 | + |
| 244 | + bit_planes.append(bytes(plane)) |
| 245 | + |
| 246 | + return tuple(bit_planes) |
| 247 | + |
| 248 | + def _transpose_row_numpy(self, vector: "np.ndarray") -> tuple: |
| 249 | + """Fast path for numpy arrays using vectorized operations.""" |
| 250 | + # Cast to int view |
| 251 | + if self.element_type == "BFloat16": |
| 252 | + # Numpy doesn't have bfloat16. Input is Float32 so just |
| 253 | + # discard the bottom 16 bits. |
| 254 | + v_float = vector.astype(np.float32, copy=False) |
| 255 | + # View as uint32, shift right 16, cast to uint16 |
| 256 | + v_int = (v_float.view(np.uint32) >> 16).astype(np.uint16) |
| 257 | + |
| 258 | + elif self.element_type == "Float32": |
| 259 | + # Ensure it is 32-bit float first (handles float64->32 downcast safely) |
| 260 | + v_float = vector.astype(np.float32, copy=False) |
| 261 | + v_int = v_float.view(np.uint32) |
| 262 | + |
| 263 | + else: # Float64 |
| 264 | + v_float = vector.astype(np.float64, copy=False) |
| 265 | + v_int = v_float.view(np.uint64) |
| 266 | + |
| 267 | + bits = self._bits_per_element |
| 268 | + masks = (1 << np.arange(bits - 1, -1, -1, dtype=v_int.dtype)).reshape(-1, 1) |
| 269 | + |
| 270 | + # Extract bits: (Bits, Dim) |
| 271 | + # v_int broadcasted to (1, Dim) |
| 272 | + bits_extracted = (v_int & masks) != 0 |
| 273 | + |
| 274 | + packed = np.packbits(bits_extracted.view(np.uint8), axis=1, bitorder="little") |
| 275 | + |
| 276 | + return tuple(row.tobytes() for row in packed) |
0 commit comments