Skip to content

Commit 0dad8f6

Browse files
implement qbit type
1 parent f40cd28 commit 0dad8f6

File tree

7 files changed

+1357
-2
lines changed

7 files changed

+1357
-2
lines changed

clickhouse_connect/cc_sqlalchemy/datatypes/sqltypes.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,3 +497,20 @@ def __init__(self, *params, type_def: TypeDef = None):
497497
values += (x,)
498498
type_def = TypeDef(values=values)
499499
super().__init__(type_def)
500+
501+
502+
class QBit(ChSqlaType, UserDefinedType):
503+
python_type = list
504+
505+
def __init__(self, element_type: str = None, dimension: int = None, type_def: TypeDef = None):
506+
"""
507+
QBit constructor for bit-transposed vector types
508+
:param element_type: Element type (BFloat16, Float32, or Float64)
509+
:param dimension: Number of elements in the vector
510+
:param type_def: TypeDef from parse_name function (used during reflection)
511+
"""
512+
if not type_def:
513+
if not element_type or not dimension:
514+
raise ArgumentError("QBit requires element_type and dimension parameters")
515+
type_def = TypeDef(values=(element_type, dimension))
516+
super().__init__(type_def)

clickhouse_connect/datatypes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import clickhouse_connect.datatypes.string
66
import clickhouse_connect.datatypes.temporal
77
import clickhouse_connect.datatypes.geometric
8+
import clickhouse_connect.datatypes.vector
89
import clickhouse_connect.datatypes.dynamic
910
import clickhouse_connect.datatypes.registry
1011
import clickhouse_connect.datatypes.postinit
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import logging
2+
from math import ceil, nan
3+
from struct import pack, unpack
4+
from typing import Any, Sequence
5+
6+
from clickhouse_connect.datatypes.base import ClickHouseType, TypeDef
7+
from clickhouse_connect.datatypes.registry import get_from_name
8+
from clickhouse_connect.driver.ctypes import data_conv
9+
from clickhouse_connect.driver.insert import InsertContext
10+
from clickhouse_connect.driver.options import np
11+
from clickhouse_connect.driver.query import QueryContext
12+
from clickhouse_connect.driver.types import ByteSource
13+
14+
logger = logging.getLogger(__name__)
15+
16+
if np is None:
17+
logger.info("NumPy not detected. Install NumPy to see 10-30x performance gains with QBit columns.")
18+
19+
20+
class QBit(ClickHouseType):
21+
"""
22+
QBit type - represents bit-transposed vectors for efficient vector search operations.
23+
24+
Syntax: QBit(element_type, dimension)
25+
- element_type: BFloat16, Float32, or Float64
26+
- dimension: Number of elements per vector
27+
28+
Over the Native protocol, ClickHouse transmits QBit columns as bit-transposed Tuples.
29+
30+
Requires:
31+
- SET allow_experimental_qbit_type = 1
32+
- Server version >=25.10
33+
"""
34+
35+
__slots__ = (
36+
"element_type",
37+
"dimension",
38+
"_bits_per_element",
39+
"_bytes_per_fixedstring",
40+
"_tuple_type",
41+
)
42+
43+
python_type = list
44+
_BIT_SHIFTS = [1 << i for i in range(8)]
45+
_ELEMENT_BITS = {"BFloat16": 16, "Float32": 32, "Float64": 64}
46+
47+
def __init__(self, type_def: TypeDef):
48+
super().__init__(type_def)
49+
50+
self.element_type = type_def.values[0]
51+
if self.element_type not in self._ELEMENT_BITS:
52+
raise ValueError(f"Unsupported QBit element type '{self.element_type}'. Supported types: BFloat16, Float32, Float64.")
53+
54+
self.dimension = type_def.values[1]
55+
if self.dimension <= 0:
56+
raise ValueError(f"QBit dimension must be greater than 0. Got: {self.dimension}.")
57+
58+
self._name_suffix = f"({self.element_type}, {self.dimension})"
59+
self._bits_per_element = self._ELEMENT_BITS.get(self.element_type, 32)
60+
self._bytes_per_fixedstring = ceil(self.dimension / 8)
61+
62+
# Create the underlying Tuple type for bit-transposed representation
63+
# E.g., for Float32 with dim=8: Tuple(FixedString(1), FixedString(1), ... x32)
64+
fixedstring_type = f"FixedString({self._bytes_per_fixedstring})"
65+
tuple_types = ", ".join([fixedstring_type] * self._bits_per_element)
66+
tuple_type_name = f"Tuple({tuple_types})"
67+
self._tuple_type = get_from_name(tuple_type_name)
68+
self.byte_size = self._bits_per_element * self._bytes_per_fixedstring
69+
70+
def read_column_prefix(self, source: ByteSource, ctx: QueryContext):
71+
return self._tuple_type.read_column_prefix(source, ctx)
72+
73+
def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any) -> Sequence:
74+
"""Read bit-transposed Tuple data and convert to flat float vectors."""
75+
if num_rows == 0:
76+
return []
77+
78+
null_map = None
79+
if self.nullable:
80+
null_map = source.read_bytes(num_rows)
81+
82+
tuple_data = self._tuple_type.read_column_data(source, num_rows, ctx, read_state)
83+
vectors = [self._untranspose_row(t) for t in tuple_data]
84+
if self.nullable:
85+
return data_conv.build_nullable_column(vectors, null_map, self._active_null(ctx))
86+
return vectors
87+
88+
def write_column_prefix(self, dest: bytearray):
89+
self._tuple_type.write_column_prefix(dest)
90+
91+
def write_column_data(self, column: Sequence, dest: bytearray, ctx: InsertContext):
92+
"""Convert flat float vectors to bit-transposed Tuple data and write."""
93+
if len(column) == 0:
94+
return
95+
96+
if self.nullable:
97+
dest += bytes([1 if x is None else 0 for x in column])
98+
99+
null_tuple = tuple(b"\x00" * self._bytes_per_fixedstring for _ in range(self._bits_per_element))
100+
tuple_column = [null_tuple if row is None else self._transpose_row(row) for row in column]
101+
102+
self._tuple_type.write_column_data(tuple_column, dest, ctx)
103+
104+
def _active_null(self, ctx: QueryContext):
105+
"""Return context-appropriate null value for nullable QBit columns."""
106+
if ctx.use_none:
107+
return None
108+
if ctx.use_extended_dtypes:
109+
return nan
110+
return None
111+
112+
def _values_to_words(self, values: list[float]) -> Sequence[int]:
113+
"""Convert float values to integer words using batch struct processing."""
114+
count = len(values)
115+
116+
if self.element_type == "BFloat16":
117+
# BFloat16 is the top 16 bits of a Float32 (truncate mantissa)
118+
raw_ints = unpack(f"<{count}I", pack(f"<{count}f", *values))
119+
return [(x >> 16) & 0xFFFF for x in raw_ints]
120+
121+
fmt_char = "I" if self.element_type == "Float32" else "Q"
122+
float_char = "f" if self.element_type == "Float32" else "d"
123+
124+
return unpack(f"<{count}{fmt_char}", pack(f"<{count}{float_char}", *values))
125+
126+
def _words_to_values(self, words: list[int]) -> list[float]:
127+
"""Convert integer words to float values using batch unpacking."""
128+
count = len(words)
129+
130+
if self.element_type == "BFloat16":
131+
# Pad BFloat16 words with zeros to reconstruct valid Float32s
132+
shifted_words = [(w & 0xFFFF) << 16 for w in words]
133+
return list(unpack(f"<{count}f", pack(f"<{count}I", *shifted_words)))
134+
135+
if self.element_type == "Float32":
136+
return list(unpack(f"<{count}f", pack(f"<{count}I", *words)))
137+
138+
# Float64
139+
return list(unpack(f"<{count}d", pack(f"<{count}Q", *words)))
140+
141+
def _untranspose_row(self, bit_planes: tuple):
142+
"""Convert bit-transposed tuple to flat float vector."""
143+
if np is not None:
144+
return self._untranspose_row_numpy(bit_planes)
145+
146+
words = [0] * self.dimension
147+
bit_shifts = self._BIT_SHIFTS
148+
dim = self.dimension
149+
150+
# Iterate Planes (MSB -> LSB)
151+
for bit_idx, bit_plane_bytes in enumerate(bit_planes):
152+
bit_pos = self._bits_per_element - 1 - bit_idx
153+
mask = 1 << bit_pos
154+
155+
# Iterate Bytes in Plane
156+
for byte_idx, byte_val in enumerate(bit_plane_bytes):
157+
# if byte is 0, skip processing 8 bits
158+
if byte_val == 0:
159+
continue
160+
161+
base_elem_idx = byte_idx << 3 # Each byte encodes 8 elements
162+
163+
# Extract set bits from this byte
164+
for bit_in_byte in range(8):
165+
if byte_val & bit_shifts[bit_in_byte]:
166+
elem_idx = base_elem_idx + bit_in_byte
167+
if elem_idx < dim:
168+
words[elem_idx] |= mask # Accumulate bit at position bit_pos
169+
170+
return self._words_to_values(words)
171+
172+
def _untranspose_row_numpy(self, bit_planes: tuple) -> list[float]:
173+
"""Vectorized numpy operations version of _untranspose_row"""
174+
# 1. Convert tuple of bytes to a single uint8 array
175+
total_bytes = b"".join(bit_planes)
176+
planes_uint8 = np.frombuffer(total_bytes, dtype=np.uint8)
177+
planes_uint8 = planes_uint8.reshape(self._bits_per_element, -1)
178+
179+
# 2. Unpack bits to get the boolean/integer matrix
180+
bits_matrix: "np.ndarray" = np.unpackbits(planes_uint8, axis=1, bitorder="little")
181+
182+
# 3. Trim padding if necessary
183+
if bits_matrix.shape[1] != self.dimension: # pylint: disable=no-member
184+
bits_matrix = bits_matrix[:, : self.dimension] # pylint: disable=invalid-sequence-index
185+
186+
# 4. Reconstruct the integer words
187+
if self.element_type == "Float64":
188+
int_dtype = np.uint64
189+
final_dtype = np.float64
190+
else:
191+
# Float32 and BFloat16 use 32-bit containers
192+
int_dtype = np.uint32
193+
final_dtype = np.float32
194+
195+
# Accumulate bits into integers
196+
words = np.zeros(self.dimension, dtype=int_dtype)
197+
198+
for i in range(self._bits_per_element):
199+
# MSB is at index 0
200+
shift = self._bits_per_element - 1 - i
201+
202+
# If the bit row is 1, add 2^shift to the word
203+
# Cast bits to the target int type before shifting to avoid overflow
204+
words |= bits_matrix[i].astype(int_dtype) << shift
205+
206+
# 5. Interpret as Floats
207+
if self.element_type == "BFloat16":
208+
# Shift back up to the top 16 bits of a Float32
209+
# Cast to uint32 first to ensure safe shifting
210+
words = words.astype(np.uint32) << 16
211+
return words.view(np.float32).tolist()
212+
213+
return words.view(final_dtype).tolist()
214+
215+
def _transpose_row(self, values: list[float]) -> tuple:
216+
"""Convert flat float vector to bit-transposed tuple."""
217+
if len(values) != self.dimension:
218+
raise ValueError(f"Vector dimension mismatch: expected {self.dimension}, got {len(values)}")
219+
220+
# If numpy is available, use the fast path
221+
if np is not None:
222+
if isinstance(values, np.ndarray):
223+
return self._transpose_row_numpy(values)
224+
225+
# If numpy is available but user supplied python list, convert to np array anyway for
226+
# huge performance gains.
227+
dtype = np.float64 if self.element_type == "Float64" else np.float32
228+
return self._transpose_row_numpy(np.array(values, dtype=dtype))
229+
230+
words = self._values_to_words(values)
231+
bit_planes = []
232+
bit_shifts = self._BIT_SHIFTS
233+
bytes_per_fs = self._bytes_per_fixedstring
234+
235+
for bit_idx in range(self._bits_per_element):
236+
bit_pos = self._bits_per_element - 1 - bit_idx
237+
mask = 1 << bit_pos
238+
plane = bytearray(bytes_per_fs)
239+
240+
for elem_idx, word in enumerate(words):
241+
if word & mask:
242+
plane[elem_idx >> 3] |= bit_shifts[elem_idx & 7]
243+
244+
bit_planes.append(bytes(plane))
245+
246+
return tuple(bit_planes)
247+
248+
def _transpose_row_numpy(self, vector: "np.ndarray") -> tuple:
249+
"""Fast path for numpy arrays using vectorized operations."""
250+
# Cast to int view
251+
if self.element_type == "BFloat16":
252+
# Numpy doesn't have bfloat16. Input is Float32 so just
253+
# discard the bottom 16 bits.
254+
v_float = vector.astype(np.float32, copy=False)
255+
# View as uint32, shift right 16, cast to uint16
256+
v_int = (v_float.view(np.uint32) >> 16).astype(np.uint16)
257+
258+
elif self.element_type == "Float32":
259+
# Ensure it is 32-bit float first (handles float64->32 downcast safely)
260+
v_float = vector.astype(np.float32, copy=False)
261+
v_int = v_float.view(np.uint32)
262+
263+
else: # Float64
264+
v_float = vector.astype(np.float64, copy=False)
265+
v_int = v_float.view(np.uint64)
266+
267+
bits = self._bits_per_element
268+
masks = (1 << np.arange(bits - 1, -1, -1, dtype=v_int.dtype)).reshape(-1, 1)
269+
270+
# Extract bits: (Bits, Dim)
271+
# v_int broadcasted to (1, Dim)
272+
bits_extracted = (v_int & masks) != 0
273+
274+
packed = np.packbits(bits_extracted.view(np.uint8), axis=1, bitorder="little")
275+
276+
return tuple(row.tobytes() for row in packed)

tests/integration_tests/test_sqlalchemy/test_ddl.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from clickhouse_connect import common
1212
from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import Int8, UInt16, Decimal, Enum16, Float64, Boolean, \
1313
FixedString, String, UInt64, UUID, DateTime, DateTime64, LowCardinality, Nullable, Array, AggregateFunction, \
14-
UInt32, IPv4
14+
UInt32, IPv4, QBit
1515
from clickhouse_connect.cc_sqlalchemy import final
1616
from clickhouse_connect.cc_sqlalchemy.ddl.custom import CreateDatabase, DropDatabase
1717
from clickhouse_connect.cc_sqlalchemy.ddl.tableengine import engine_map, ReplacingMergeTree
@@ -181,3 +181,32 @@ def test_final_modifier_error_cases(test_engine: Engine, test_db: str):
181181

182182
test_table.drop(conn)
183183
other_table.drop(conn)
184+
185+
186+
def test_qbit_table(test_engine: Engine, test_db: str, test_table_engine: str):
187+
"""Test QBit type DDL and basic operations"""
188+
common.set_setting('invalid_setting_action', 'drop')
189+
with test_engine.begin() as conn:
190+
if not conn.connection.driver_connection.client.min_version('25.10'):
191+
pytest.skip('QBit type requires ClickHouse version 25.10+')
192+
193+
conn.execute(text('SET allow_experimental_qbit_type = 1'))
194+
195+
table_cls = engine_map[test_table_engine]
196+
metadata = MetaData(schema=test_db)
197+
conn.execute(text('DROP TABLE IF EXISTS qbit_test'))
198+
199+
table = db.Table('qbit_test', metadata,
200+
db.Column('id', UInt32),
201+
db.Column('vector', QBit('Float32', 8)),
202+
db.Column('embedding', QBit('Float32', 128)),
203+
table_cls('id'))
204+
table.create(conn)
205+
206+
# Verify table was created
207+
result = conn.execute(text("SHOW CREATE TABLE qbit_test"))
208+
create_sql = result.fetchone()[0]
209+
assert 'QBit(Float32, 8)' in create_sql
210+
assert 'QBit(Float32, 128)' in create_sql
211+
212+
conn.execute(text('DROP TABLE qbit_test'))

0 commit comments

Comments
 (0)