Skip to content

Commit 0697774

Browse files
add array join and final in sqa
1 parent def9a93 commit 0697774

File tree

7 files changed

+427
-7
lines changed

7 files changed

+427
-7
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from clickhouse_connect import driver_name
22
from clickhouse_connect.cc_sqlalchemy.datatypes.base import schema_types
3+
from clickhouse_connect.cc_sqlalchemy.sql import final
4+
from clickhouse_connect.cc_sqlalchemy.sql.clauses import array_join, ArrayJoin
35

46
# pylint: disable=invalid-name
57
dialect_name = driver_name
68
ischema_names = schema_types
9+
10+
__all__ = ['dialect_name', 'ischema_names', 'array_join', 'ArrayJoin', 'final']

clickhouse_connect/cc_sqlalchemy/sql/__init__.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional
22

33
from sqlalchemy import Table
4+
from sqlalchemy.sql.selectable import FromClause, Select
45

56
from clickhouse_connect.driver.binding import quote_identifier
67

@@ -13,3 +14,46 @@ def full_table(table_name: str, schema: Optional[str] = None) -> str:
1314

1415
def format_table(table: Table):
1516
return full_table(table.name, table.schema)
17+
18+
19+
def final(select_stmt: Select, table: Optional[FromClause] = None) -> Select:
20+
"""
21+
Apply the ClickHouse FINAL modifier to a select statement.
22+
23+
Args:
24+
select_stmt: The SQLAlchemy Select statement to modify.
25+
table: Optional explicit table/alias to apply FINAL to. When omitted the
26+
method will use the single FROM element present on the select. A
27+
ValueError is raised if the statement has no FROMs or more than one
28+
FROM element and table is not provided.
29+
30+
Returns:
31+
A new Select that renders the FINAL modifier for the target table.
32+
"""
33+
if not isinstance(select_stmt, Select):
34+
raise TypeError("final() expects a SQLAlchemy Select instance")
35+
36+
target = table
37+
if target is None:
38+
froms = select_stmt.get_final_froms()
39+
if not froms:
40+
raise ValueError("final() requires a table to apply the FINAL modifier.")
41+
if len(froms) > 1:
42+
raise ValueError("final() is ambiguous for statements with multiple FROM clauses. Specify the table explicitly.")
43+
target = froms[0]
44+
45+
if not isinstance(target, FromClause):
46+
raise TypeError("table must be a SQLAlchemy FromClause when provided")
47+
48+
return select_stmt.with_hint(target, "FINAL")
49+
50+
51+
def _select_final(self: Select, table: Optional[FromClause] = None) -> Select:
52+
"""
53+
Select.final() convenience wrapper around the module-level final() helper.
54+
"""
55+
return final(self, table=table)
56+
57+
58+
# Monkey-patch the Select class to add the .final() convenience method
59+
Select.final = _select_final
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from sqlalchemy.sql.base import Immutable
2+
from sqlalchemy.sql.selectable import FromClause
3+
4+
5+
# pylint: disable=protected-access,too-many-ancestors,abstract-method,unused-argument
6+
class ArrayJoin(Immutable, FromClause):
7+
"""Represents ClickHouse ARRAY JOIN clause"""
8+
9+
__visit_name__ = "array_join"
10+
_is_from_container = True
11+
named_with_column = False
12+
_is_join = True
13+
14+
def __init__(self, left, array_column, alias=None, is_left=False):
15+
"""Initialize ARRAY JOIN clause
16+
17+
Args:
18+
left: The left side (table or subquery)
19+
array_column: The array column to join
20+
alias: Optional alias for the joined array elements
21+
is_left: If True, use LEFT ARRAY JOIN instead of ARRAY JOIN
22+
"""
23+
super().__init__()
24+
self.left = left
25+
self.array_column = array_column
26+
self.alias = alias
27+
self.is_left = is_left
28+
self._is_clone_of = None
29+
30+
@property
31+
def selectable(self):
32+
"""Return the selectable for this clause"""
33+
return self.left
34+
35+
@property
36+
def _hide_froms(self):
37+
"""Hide the left table from the FROM clause since it's part of the ARRAY JOIN"""
38+
return [self.left]
39+
40+
@property
41+
def _from_objects(self):
42+
"""Return all FROM objects referenced by this construct"""
43+
return self.left._from_objects
44+
45+
def _clone(self, **kw):
46+
"""Return a copy of this ArrayJoin"""
47+
c = self.__class__.__new__(self.__class__)
48+
c.__dict__ = self.__dict__.copy()
49+
c._is_clone_of = self
50+
return c
51+
52+
def _copy_internals(self, clone=None, **kw):
53+
"""Copy internal state for cloning
54+
55+
This ensures that when queries are cloned (e.g., for subqueries, unions, or CTEs),
56+
the left FromClause and array_column references are properly deep-cloned.
57+
"""
58+
def _default_clone(elem, **kwargs):
59+
return elem
60+
61+
if clone is None:
62+
clone = _default_clone
63+
64+
# Clone the left FromClause and array column to ensure proper
65+
# reference handling in complex query scenarios
66+
self.left = clone(self.left, **kw)
67+
self.array_column = clone(self.array_column, **kw)
68+
69+
70+
def array_join(left, array_column, alias=None, is_left=False):
71+
"""Create an ARRAY JOIN clause
72+
73+
Args:
74+
left: The left side (table or subquery)
75+
array_column: The array column to join
76+
alias: Optional alias for the joined array elements
77+
is_left: If True, use LEFT ARRAY JOIN instead of ARRAY JOIN
78+
79+
Returns:
80+
ArrayJoin: An ArrayJoin clause element
81+
82+
Example:
83+
from clickhouse_connect.cc_sqlalchemy.sql.clauses import array_join
84+
85+
# Basic ARRAY JOIN
86+
query = select(table).select_from(array_join(table, table.c.tags))
87+
88+
# LEFT ARRAY JOIN with alias
89+
query = select(table).select_from(
90+
array_join(table, table.c.tags, alias='tag', is_left=True)
91+
)
92+
"""
93+
return ArrayJoin(left, array_column, alias, is_left)

clickhouse_connect/cc_sqlalchemy/sql/compiler.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from sqlalchemy.exc import CompileError
22
from sqlalchemy.sql.compiler import SQLCompiler
33

4+
from clickhouse_connect.cc_sqlalchemy import ArrayJoin
45
from clickhouse_connect.cc_sqlalchemy.sql import format_table
56

67

78
# pylint: disable=arguments-differ
89
class ChStatementCompiler(SQLCompiler):
910

10-
# pylint: disable=attribute-defined-outside-init
11+
# pylint: disable=attribute-defined-outside-init,unused-argument
1112
def visit_delete(self, delete_stmt, visiting_cte=None, **kw):
1213
table = delete_stmt.table
1314
text = f"DELETE FROM {format_table(table)}"
@@ -23,10 +24,20 @@ def visit_delete(self, delete_stmt, visiting_cte=None, **kw):
2324

2425
return text
2526

26-
def visit_select(self, select_stmt, **kw):
27-
return super().visit_select(select_stmt, **kw)
27+
def visit_array_join(self, array_join_clause, asfrom=False, from_linter=None, **kw):
28+
left = self.process(array_join_clause.left, asfrom=True, from_linter=from_linter, **kw)
29+
array_col = self.process(array_join_clause.array_column, **kw)
30+
join_type = "LEFT ARRAY JOIN" if array_join_clause.is_left else "ARRAY JOIN"
31+
text = f"{left} {join_type} {array_col}"
32+
if array_join_clause.alias:
33+
text += f" AS {self.preparer.quote(array_join_clause.alias)}"
34+
35+
return text
2836

2937
def visit_join(self, join, **kw):
38+
if isinstance(join, ArrayJoin):
39+
return self.visit_array_join(join, **kw)
40+
3041
left = self.process(join.left, **kw)
3142
right = self.process(join.right, **kw)
3243
onclause = join.onclause
@@ -60,7 +71,7 @@ def visit_column(self, column, add_to_result_map=None, include_table=True, resul
6071
)
6172

6273
# Abstract methods required by SQLCompiler
63-
def delete_extra_from_clause(self, delete_stmt, from_table, extra_froms, from_hints, **kw):
74+
def delete_extra_from_clause(self, update_stmt, from_table, extra_froms, from_hints, **kw):
6475
raise NotImplementedError("ClickHouse doesn't support DELETE with extra FROM clause")
6576

6677
def update_from_clause(self, update_stmt, from_table, extra_froms, from_hints, **kw):
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from sqlalchemy import Column, MetaData, Table, literal_column, select
2+
from sqlalchemy.engine.base import Engine
3+
from sqlalchemy.types import Integer, String
4+
5+
from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import Array
6+
from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import String as ChString
7+
from clickhouse_connect.cc_sqlalchemy.ddl.tableengine import MergeTree
8+
from clickhouse_connect.cc_sqlalchemy.sql.clauses import array_join
9+
10+
11+
def test_array_join(test_engine: Engine, test_db: str):
12+
"""Test ARRAY JOIN clause"""
13+
with test_engine.begin() as conn:
14+
metadata = MetaData(schema=test_db)
15+
16+
test_table = Table(
17+
"test_array_join",
18+
metadata,
19+
Column("id", Integer),
20+
Column("name", String),
21+
Column("tags", Array(ChString)),
22+
MergeTree(order_by="id"),
23+
)
24+
25+
test_table.drop(conn, checkfirst=True)
26+
test_table.create(conn)
27+
28+
conn.execute(
29+
test_table.insert(),
30+
[
31+
{"id": 1, "name": "Alice", "tags": ["python", "sql", "clickhouse"]},
32+
{"id": 2, "name": "Bob", "tags": ["java", "sql"]},
33+
{"id": 3, "name": "Joe", "tags": ["python", "javascript"]},
34+
{"id": 4, "name": "Charlie", "tags": []},
35+
],
36+
)
37+
38+
query = (
39+
select(test_table.c.id, test_table.c.name, test_table.c.tags)
40+
.select_from(array_join(test_table, test_table.c.tags))
41+
.order_by(test_table.c.id)
42+
.order_by(test_table.c.tags)
43+
)
44+
45+
compiled = query.compile(dialect=test_engine.dialect)
46+
assert "ARRAY JOIN" in str(compiled).upper()
47+
48+
result = conn.execute(query)
49+
rows = result.fetchall()
50+
assert len(rows) == 7
51+
assert rows[0].id == 1
52+
assert rows[0].name == "Alice"
53+
assert rows[0].tags == "clickhouse"
54+
# ARRAY JOIN should not contain items with empty lists
55+
assert "Charlie" not in [row.name for row in rows]
56+
57+
test_table.drop(conn)
58+
59+
60+
def test_left_array_join_with_alias(test_engine: Engine, test_db: str):
61+
"""Test LEFT ARRAY JOIN with alias"""
62+
with test_engine.begin() as conn:
63+
metadata = MetaData(schema=test_db)
64+
65+
test_table = Table(
66+
"test_left_array_join",
67+
metadata,
68+
Column("id", Integer),
69+
Column("name", String),
70+
Column("tags", Array(ChString)),
71+
MergeTree(order_by="id"),
72+
)
73+
74+
test_table.drop(conn, checkfirst=True)
75+
test_table.create(conn)
76+
77+
conn.execute(
78+
test_table.insert(),
79+
[
80+
{"id": 1, "name": "Alice", "tags": ["python", "sql", "clickhouse"]},
81+
{"id": 2, "name": "Bob", "tags": ["java", "sql"]},
82+
{"id": 3, "name": "Joe", "tags": ["python", "javascript"]},
83+
{"id": 4, "name": "Charlie", "tags": []},
84+
],
85+
)
86+
87+
query = (
88+
select(
89+
test_table.c.id,
90+
test_table.c.name,
91+
literal_column("tag"), # Needed when using alias
92+
)
93+
.select_from(array_join(test_table, test_table.c.tags, alias="tag", is_left=True))
94+
.order_by(test_table.c.id)
95+
.order_by(literal_column("tag"))
96+
)
97+
98+
compiled = query.compile(dialect=test_engine.dialect)
99+
compiled_str = str(compiled).upper()
100+
assert "LEFT ARRAY JOIN" in compiled_str
101+
assert "AS" in compiled_str
102+
103+
result = conn.execute(query)
104+
rows = result.fetchall()
105+
assert len(rows) == 8
106+
107+
alice_tags = [row.tag for row in rows if row.name == "Alice"]
108+
assert len(alice_tags) == 3
109+
assert alice_tags == sorted(["python", "sql", "clickhouse"])
110+
111+
bob_tags = [row.tag for row in rows if row.name == "Bob"]
112+
assert len(bob_tags) == 2
113+
assert bob_tags == sorted(["java", "sql"])
114+
115+
charlie_rows = [row for row in rows if row.name == "Charlie"]
116+
assert len(charlie_rows) == 1
117+
assert charlie_rows[0].tag == ""
118+
119+
test_table.drop(conn)

0 commit comments

Comments
 (0)