Skip to content

Commit 70bd2d2

Browse files
authored
add feature to create tables from pyarrow objects (#597)
* working addition for creating tables from pyarrow objects * update the changelog * address PR review * more pr fixes * ran linting + updated docstring in ddl function * linting again
1 parent d18b4bf commit 70bd2d2

File tree

4 files changed

+352
-0
lines changed

4 files changed

+352
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ The supported method of passing ClickHouse server settings is to prefix such arg
2828

2929
### Improvements
3030
- Add support for QBit data type. Closes [#570](https://github.com/ClickHouse/clickhouse-connect/issues/570)
31+
- Add the ability to create table from PyArrow objects. Addresses [#588](https://github.com/ClickHouse/clickhouse-connect/issues/588)
3132

3233
## 0.10.0, 2025-11-14
3334

clickhouse_connect/driver/ddl.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import NamedTuple, Sequence
22

33
from clickhouse_connect.datatypes.base import ClickHouseType
4+
from clickhouse_connect.driver.options import check_arrow
45

56

67
class TableColumnDef(NamedTuple):
@@ -26,3 +27,111 @@ def create_table(table_name: str, columns: Sequence[TableColumnDef], engine: str
2627
for key, value in engine_params.items():
2728
stmt += f' {key} {value}'
2829
return stmt
30+
31+
32+
def _arrow_type_to_ch(arrow_type: "pa.DataType") -> str: # pylint: disable=too-many-return-statements
33+
"""
34+
Best-effort mapping from common PyArrow types to ClickHouse type names.
35+
36+
Covers core scalar types. For anything unknown, we raise so the
37+
caller is aware that the automatic mapping is not implemented for that Arrow type.
38+
"""
39+
pa = check_arrow()
40+
41+
pat = pa.types
42+
43+
# Signed ints
44+
if pat.is_int8(arrow_type):
45+
return 'Int8'
46+
if pat.is_int16(arrow_type):
47+
return 'Int16'
48+
if pat.is_int32(arrow_type):
49+
return 'Int32'
50+
if pat.is_int64(arrow_type):
51+
return 'Int64'
52+
53+
# Unsigned ints
54+
if pat.is_uint8(arrow_type):
55+
return 'UInt8'
56+
if pat.is_uint16(arrow_type):
57+
return 'UInt16'
58+
if pat.is_uint32(arrow_type):
59+
return 'UInt32'
60+
if pat.is_uint64(arrow_type):
61+
return 'UInt64'
62+
63+
# Floats
64+
if pat.is_float16(arrow_type) or pat.is_float32(arrow_type):
65+
return 'Float32'
66+
if pat.is_float64(arrow_type):
67+
return 'Float64'
68+
69+
# Boolean
70+
if pat.is_boolean(arrow_type):
71+
return 'Bool'
72+
73+
# Strings (this covers pa.string(), pa.large_string())
74+
if pat.is_string(arrow_type) or pat.is_large_string(arrow_type):
75+
return 'String'
76+
77+
# for any currently unsupported type, we raise so it’s clear that
78+
# this Arrow type isn’t supported by the helper yet.
79+
raise TypeError(f'Unsupported Arrow type for automatic mapping: {arrow_type!r}')
80+
81+
82+
class _DDLType:
83+
"""
84+
Minimal helper used to satisfy TableColumnDef.ch_type.
85+
86+
create_table() only needs ch_type.name when building the DDL string,
87+
so we'll wrap the ClickHouse type name in this tiny object instead of
88+
constructing full ClickHouseType instances here.
89+
"""
90+
def __init__(self, name: str):
91+
self.name = name
92+
93+
94+
def arrow_schema_to_column_defs(schema: "pa.Schema") -> list[TableColumnDef]:
95+
"""
96+
Convert a PyArrow Schema into a list of TableColumnDef objects.
97+
98+
This helper uses an *optimistic non-null* strategy: it always produces
99+
non-nullable ClickHouse types, even though Arrow fields are nullable by
100+
default.
101+
102+
Note that if the user inserts a table with nulls into a non-Nullable column,
103+
ClickHouse will silently convert those nulls to default values due to the default
104+
server setting input_format_null_as_default=1 and current lack of client-side
105+
validation on arrow inserts.
106+
"""
107+
pa = check_arrow()
108+
109+
if not isinstance(schema, pa.Schema):
110+
raise TypeError(f'Expected pyarrow.Schema, got {type(schema)!r}')
111+
112+
col_defs: list[TableColumnDef] = []
113+
for field in schema:
114+
ch_type_name = _arrow_type_to_ch(field.type)
115+
col_defs.append(
116+
TableColumnDef(
117+
name=field.name,
118+
ch_type=_DDLType(ch_type_name),
119+
)
120+
)
121+
return col_defs
122+
123+
124+
def create_table_from_arrow_schema(
125+
table_name: str,
126+
schema: "pa.Schema",
127+
engine: str,
128+
engine_params: dict,
129+
) -> str:
130+
"""
131+
Helper function to build a CREATE TABLE statement from a PyArrow Schema.
132+
133+
Internally:
134+
schema -> arrow_schema_to_column_defs -> create_table(...)
135+
"""
136+
col_defs = arrow_schema_to_column_defs(schema)
137+
return create_table(table_name, col_defs, engine, engine_params)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import pytest
2+
import pyarrow as pa
3+
4+
from clickhouse_connect.driver import Client
5+
from clickhouse_connect.driver.ddl import (
6+
arrow_schema_to_column_defs,
7+
create_table,
8+
create_table_from_arrow_schema,
9+
)
10+
11+
pytest.importorskip("pyarrow")
12+
13+
14+
def test_arrow_create_table_and_insert(test_client: Client):
15+
if not test_client.min_version("20"):
16+
pytest.skip(
17+
f"Not supported server version {test_client.server_version}"
18+
)
19+
20+
table_name = "test_arrow_basic_integration"
21+
22+
test_client.command(f"DROP TABLE IF EXISTS {table_name}")
23+
24+
schema = pa.schema(
25+
[
26+
("id", pa.int64()),
27+
("name", pa.string()),
28+
("score", pa.float32()),
29+
("flag", pa.bool_()),
30+
]
31+
)
32+
33+
ddl = create_table_from_arrow_schema(
34+
table_name=table_name,
35+
schema=schema,
36+
engine="MergeTree",
37+
engine_params={"ORDER BY": "id"},
38+
)
39+
test_client.command(ddl)
40+
41+
arrow_table = pa.table(
42+
{
43+
"id": [1, 2],
44+
"name": ["a", "b"],
45+
"score": [1.5, 2.5],
46+
"flag": [True, False],
47+
},
48+
schema=schema,
49+
)
50+
51+
test_client.insert_arrow(table=table_name, arrow_table=arrow_table)
52+
53+
result = test_client.query(
54+
f"SELECT id, name, score, flag FROM {table_name} ORDER BY id"
55+
)
56+
assert result.result_rows == [
57+
(1, "a", 1.5, True),
58+
(2, "b", 2.5, False),
59+
]
60+
61+
test_client.command(f"DROP TABLE IF EXISTS {table_name}")
62+
63+
64+
def test_arrow_schema_to_column_defs(test_client: Client):
65+
table_name = "test_arrow_manual_integration"
66+
67+
test_client.command(f"DROP TABLE IF EXISTS {table_name}")
68+
69+
schema = pa.schema(
70+
[
71+
("id", pa.int64()),
72+
("name", pa.string()),
73+
]
74+
)
75+
76+
# check using the explicit helper path.
77+
col_defs = arrow_schema_to_column_defs(schema)
78+
79+
ddl = create_table(
80+
table_name=table_name,
81+
columns=col_defs,
82+
engine="MergeTree",
83+
engine_params={"ORDER BY": "id"},
84+
)
85+
test_client.command(ddl)
86+
87+
arrow_table = pa.table(
88+
{
89+
"id": [10, 20],
90+
"name": ["x", "y"],
91+
},
92+
schema=schema,
93+
)
94+
95+
test_client.insert_arrow(table=table_name, arrow_table=arrow_table)
96+
97+
result = test_client.query(f"SELECT id, name FROM {table_name} ORDER BY id")
98+
assert result.result_rows == [
99+
(10, "x"),
100+
(20, "y"),
101+
]
102+
103+
test_client.command(f"DROP TABLE IF EXISTS {table_name}")
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# pylint: disable=duplicate-code
2+
3+
import pytest
4+
import pyarrow as pa
5+
6+
from clickhouse_connect.driver.ddl import (
7+
arrow_schema_to_column_defs,
8+
create_table,
9+
create_table_from_arrow_schema,
10+
)
11+
12+
pytest.importorskip("pyarrow")
13+
14+
15+
def test_arrow_schema_to_column_defs_basic_mappings():
16+
schema = pa.schema(
17+
[
18+
("i8", pa.int8()),
19+
("i16", pa.int16()),
20+
("i32", pa.int32()),
21+
("i64", pa.int64()),
22+
("u8", pa.uint8()),
23+
("u16", pa.uint16()),
24+
("u32", pa.uint32()),
25+
("u64", pa.uint64()),
26+
("f16", pa.float16()),
27+
("f32", pa.float32()),
28+
("f64", pa.float64()),
29+
("s", pa.string()),
30+
("ls", pa.large_string()),
31+
("b", pa.bool_()),
32+
]
33+
)
34+
35+
col_defs = arrow_schema_to_column_defs(schema)
36+
37+
assert [c.name for c in col_defs] == [
38+
"i8",
39+
"i16",
40+
"i32",
41+
"i64",
42+
"u8",
43+
"u16",
44+
"u32",
45+
"u64",
46+
"f16",
47+
"f32",
48+
"f64",
49+
"s",
50+
"ls",
51+
"b",
52+
]
53+
54+
type_names = [c.ch_type.name for c in col_defs]
55+
56+
assert type_names == [
57+
"Int8",
58+
"Int16",
59+
"Int32",
60+
"Int64",
61+
"UInt8",
62+
"UInt16",
63+
"UInt32",
64+
"UInt64",
65+
"Float32",
66+
"Float32",
67+
"Float64",
68+
"String",
69+
"String",
70+
"Bool",
71+
]
72+
73+
74+
def test_arrow_schema_to_column_defs_unsupported_type_raises():
75+
schema = pa.schema(
76+
[
77+
("ts", pa.timestamp("ms")),
78+
]
79+
)
80+
81+
with pytest.raises(TypeError, match="Unsupported Arrow type"):
82+
arrow_schema_to_column_defs(schema)
83+
84+
85+
def test_arrow_schema_to_column_defs_invalid_input_type():
86+
with pytest.raises(TypeError, match="Expected pyarrow.Schema"):
87+
arrow_schema_to_column_defs("not a schema")
88+
89+
90+
def test_create_table_from_arrow_schema_builds_expected_ddl():
91+
schema = pa.schema(
92+
[
93+
("id", pa.int64()),
94+
("name", pa.string()),
95+
("score", pa.float32()),
96+
("flag", pa.bool_()),
97+
]
98+
)
99+
100+
ddl = create_table_from_arrow_schema(
101+
table_name="arrow_basic_test",
102+
schema=schema,
103+
engine="MergeTree",
104+
engine_params={"ORDER BY": "id"},
105+
)
106+
107+
assert (
108+
ddl
109+
== "CREATE TABLE arrow_basic_test "
110+
"(id Int64, name String, score Float32, flag Bool) "
111+
"ENGINE MergeTree ORDER BY id"
112+
)
113+
114+
115+
def test_create_table_from_arrow_schema_matches_manual_create_table():
116+
schema = pa.schema(
117+
[
118+
("id", pa.int64()),
119+
("name", pa.string()),
120+
]
121+
)
122+
123+
col_defs = arrow_schema_to_column_defs(schema)
124+
125+
ddl_manual = create_table(
126+
table_name="arrow_compare_test",
127+
columns=col_defs,
128+
engine="MergeTree",
129+
engine_params={"ORDER BY": "id"},
130+
)
131+
132+
ddl_wrapper = create_table_from_arrow_schema(
133+
table_name="arrow_compare_test",
134+
schema=schema,
135+
engine="MergeTree",
136+
engine_params={"ORDER BY": "id"},
137+
)
138+
139+
assert ddl_manual == ddl_wrapper

0 commit comments

Comments
 (0)