Skip to content

Commit 6944a6c

Browse files
OutSquareCapitalCopilotgeorgesittas
committed
refactor!: Improve typing annotations for planner, schema, serde, and transforms modules (#7579)
* refactor: improve planner typing annotations * refactor: improve planner, schema, and serde annotations Co-authored-by: Copilot <copilot@github.com> * refactor: added lazy annotations import to time module Co-authored-by: Copilot <copilot@github.com> * refactor: improved transforms annotations Co-authored-by: Copilot <copilot@github.com> * fix: make `StackVal` `type alias compatible with python 3.9 Co-authored-by: Copilot <copilot@github.com> * fix: since ast tree is mutated in `transforms::eliminate_qualify`, we indeed need to collect the Iterator in a list first Co-authored-by: Copilot <copilot@github.com> * fix: mypc is buggy with `object::__module__` access, so we can't narrow to a precise type the `Expr` path in serde::dump` * refactor: revert `type` -> `isinstance` usage in `serde::dump` function body * refactor: revert `nodes` list type in `serde::load` * refactor: ignore `node` type in `serde::load` body to avoid errors * refactor: Apply suggestions from code review Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com> * fix: use `Any` for the key type of the trie mapping in `time::format_time` * refactor: move comment of joins_ons in `transforms::eliminate_join_marks` above the line * refactor: change the function body of `trnasforms::move_schema_columns_to_partitioned_by` into something more type safe * fix: revert instance checks in `serde::dump` * fix: revert type hint of `node` in `serde::load` * refactor: make the `_sql_handler` variable in `transforms::preprocess::_to_sql` a Protocol --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>
1 parent 6f68a9f commit 6944a6c

5 files changed

Lines changed: 125 additions & 116 deletions

File tree

sqlglot/planner.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
class Plan:
1313
def __init__(self, expression: exp.Expr) -> None:
14-
self.expression = expression.copy()
15-
self.root = Step.from_expression(self.expression)
14+
self.expression: exp.Expr = expression.copy()
15+
self.root: Step = Step.from_expression(self.expression)
1616
self._dag: dict[Step, set[Step]] = {}
1717

1818
@property
@@ -93,10 +93,10 @@ def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = No
9393
"""
9494
ctes = ctes or {}
9595
expression = expression.unnest()
96-
with_ = expression.args.get("with_")
96+
with_: exp.With | None = expression.args.get("with_")
9797

9898
# CTEs break the mold of scope and introduce themselves to all in the context.
99-
if with_:
99+
if with_ is not None:
100100
ctes = ctes.copy()
101101
for cte in with_.expressions:
102102
step = Step.from_expression(cte.this, ctes)
@@ -112,23 +112,22 @@ def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = No
112112
else:
113113
step = Scan()
114114

115-
joins = expression.args.get("joins")
115+
joins: list[exp.Join] | None = expression.args.get("joins")
116116

117-
if joins:
117+
if joins is not None:
118118
join = Join.from_joins(joins, ctes)
119119
join.name = step.name
120120
join.source_name = step.name
121121
join.add_dependency(step)
122122
step = join
123-
124-
projections: list[
125-
exp.Expr
126-
] = [] # final selects in this chain of steps representing a select
127-
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
128-
aggregations = {}
123+
# final selects in this chain of steps representing a select
124+
projections: list[exp.Expr] = []
125+
# intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
126+
operands: dict[exp.Expr, str] = {}
127+
aggregations: dict[exp.Expr, None] = {}
129128
next_operand_name = name_sequence("_a_")
130129

131-
def extract_agg_operands(expression):
130+
def extract_agg_operands(expression: exp.Expr) -> bool:
132131
agg_funcs = tuple(expression.find_all(exp.AggFunc))
133132
if agg_funcs:
134133
aggregations[expression] = None
@@ -144,7 +143,7 @@ def extract_agg_operands(expression):
144143

145144
return bool(agg_funcs)
146145

147-
def set_ops_and_aggs(step):
146+
def set_ops_and_aggs(step) -> None:
148147
step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
149148
step.aggregations = list(aggregations)
150149

@@ -155,21 +154,21 @@ def set_ops_and_aggs(step):
155154
else:
156155
projections.append(e)
157156

158-
where = expression.args.get("where")
157+
where: exp.Where | None = expression.args.get("where")
159158

160-
if where:
159+
if where is not None:
161160
step.condition = where.this
162161

163-
group = expression.args.get("group")
162+
group: exp.Group | None = expression.args.get("group")
164163

165-
if group or aggregations:
164+
if group is not None or aggregations:
166165
aggregate = Aggregate()
167166
aggregate.source = step.name
168167
aggregate.name = step.name
169168

170-
having = expression.args.get("having")
169+
having: exp.Having | None = expression.args.get("having")
171170

172-
if having:
171+
if having is not None:
173172
if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
174173
aggregate.condition = exp.column("_h", step.name, quoted=True)
175174
else:
@@ -205,10 +204,10 @@ def set_ops_and_aggs(step):
205204
else:
206205
aggregate = None
207206

208-
order = expression.args.get("order")
207+
order: exp.Order | None = expression.args.get("order")
209208

210-
if order:
211-
if aggregate and isinstance(step, Aggregate):
209+
if order is not None:
210+
if aggregate is not None and isinstance(step, Aggregate):
212211
for i, ordered in enumerate(order.expressions):
213212
if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
214213
ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
@@ -234,9 +233,9 @@ def set_ops_and_aggs(step):
234233
distinct.add_dependency(step)
235234
step = distinct
236235

237-
limit = expression.args.get("limit")
236+
limit: exp.Limit | None = expression.args.get("limit")
238237

239-
if limit:
238+
if limit is not None:
240239
step.limit = int(limit.text("expression"))
241240

242241
return step
@@ -304,7 +303,7 @@ def _to_s(self, _indent: str) -> list[str]:
304303
class Scan(Step):
305304
@classmethod
306305
def from_expression(cls, expression: exp.Expr, ctes: dict[str, Step] | None = None) -> Step:
307-
table = expression
306+
table: exp.Expr = expression
308307
alias_ = expression.alias_or_name
309308

310309
if isinstance(expression, exp.Subquery):
@@ -356,7 +355,7 @@ def _to_s(self, indent: str) -> list[str]:
356355
lines = [f"{indent}Source: {self.source_name or self.name}"]
357356
for name, join in self.joins.items():
358357
lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
359-
join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
358+
join_key = ", ".join(str(key) for key in t.cast(list[str], join.get("join_key") or []))
360359
if join_key:
361360
lines.append(f"{indent}Key: {join_key}")
362361
if join.get("condition"):
@@ -396,7 +395,7 @@ def _to_s(self, indent: str) -> list[str]:
396395
class Sort(Step):
397396
def __init__(self) -> None:
398397
super().__init__()
399-
self.key = None
398+
self.key: list[exp.Expr] | None = None
400399

401400
def _to_s(self, indent: str) -> list[str]:
402401
lines = [f"{indent}Key:"]
@@ -408,18 +407,12 @@ def _to_s(self, indent: str) -> list[str]:
408407

409408

410409
class SetOperation(Step):
411-
def __init__(
412-
self,
413-
op: type[exp.Expr],
414-
left: str | None,
415-
right: str | None,
416-
distinct: bool = False,
417-
) -> None:
410+
def __init__(self, op: type[exp.Expr], left: str, right: str, distinct: bool = False) -> None:
418411
super().__init__()
419-
self.op = op
420-
self.left = left
421-
self.right = right
422-
self.distinct = distinct
412+
self.op: type[exp.Expr] = op
413+
self.left: str = left
414+
self.right: str = right
415+
self.distinct: bool = distinct
423416

424417
@classmethod
425418
def from_expression(
@@ -442,15 +435,15 @@ def from_expression(
442435
step.add_dependency(left)
443436
step.add_dependency(right)
444437

445-
limit = expression.args.get("limit")
438+
limit: exp.Limit | None = expression.args.get("limit")
446439

447-
if limit:
440+
if limit is not None:
448441
step.limit = int(limit.text("expression"))
449442

450443
return step
451444

452445
def _to_s(self, indent: str) -> list[str]:
453-
lines = []
446+
lines: list[str] = []
454447
if self.distinct:
455448
lines.append(f"{indent}Distinct: {self.distinct}")
456449
return lines

sqlglot/schema.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from collections.abc import Sequence
1919
from typing_extensions import Unpack
2020

21-
ColumnMapping = t.Union[dict, str, list]
21+
ColumnMapping = t.Union[dict[str, t.Any], str, list[str]]
2222

2323

2424
@trait
@@ -344,7 +344,7 @@ def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
344344
def find(
345345
self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
346346
) -> t.Any | None:
347-
schema = super().find(
347+
schema: dict[str, object] | None = super().find(
348348
table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
349349
)
350350
if ensure_data_types and isinstance(schema, dict):
@@ -417,7 +417,7 @@ def column_names(
417417
) -> list[str]:
418418
normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
419419

420-
schema = self.find(normalized_table)
420+
schema: dict[str, object] | None = self.find(normalized_table)
421421
if schema is None:
422422
return []
423423

@@ -440,7 +440,7 @@ def get_column_type(
440440
column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
441441
)
442442

443-
table_schema = self.find(normalized_table, raise_on_missing=False)
443+
table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False)
444444
if table_schema:
445445
column_type = table_schema.get(normalized_column_name)
446446

@@ -500,7 +500,7 @@ def has_column(
500500
column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
501501
)
502502

503-
table_schema = self.find(normalized_table, raise_on_missing=False)
503+
table_schema: dict[str, object] | None = self.find(normalized_table, raise_on_missing=False)
504504
return normalized_column_name in table_schema if table_schema else False
505505

506506
def _normalize(self, schema: dict[str, object]) -> dict[str, object]:
@@ -708,7 +708,7 @@ def ensure_schema(
708708
return MappingSchema(schema, **kwargs)
709709

710710

711-
def ensure_column_mapping(mapping: ColumnMapping | None) -> dict:
711+
def ensure_column_mapping(mapping: ColumnMapping | None) -> dict[str, t.Any]:
712712
if mapping is None:
713713
return {}
714714
elif isinstance(mapping, dict):

sqlglot/serde.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
import typing as t
44

55
from sqlglot import expressions as exp
6+
from types import ModuleType
7+
8+
9+
StackVal = tuple[t.Any, t.Optional[int], t.Optional[str], bool]
610

711

812
INDEX = "i"
@@ -21,8 +25,8 @@ def dump(expression: exp.Expr) -> list[dict[str, t.Any]]:
2125
Dump an Expr into a JSON serializable List.
2226
"""
2327
i = 0
24-
payloads = []
25-
stack: list[tuple[t.Any, int | None, str | None, bool]] = [(expression, None, None, False)]
28+
payloads: list[dict[str, t.Any]] = []
29+
stack: list[StackVal] = [(expression, None, None, False)]
2630

2731
while stack:
2832
node, index, arg_key, is_array = stack.pop()
@@ -90,8 +94,8 @@ def load(
9094
node = payload[VALUE]
9195

9296
nodes.append(node)
93-
parent = nodes[payload[INDEX]]
94-
arg_key = payload[ARG_KEY]
97+
parent: exp.Expr = nodes[payload[INDEX]]
98+
arg_key: str = payload[ARG_KEY]
9599

96100
if payload.get(IS_ARRAY):
97101
parent.append(arg_key, node)
@@ -102,11 +106,11 @@ def load(
102106

103107

104108
def _load(payload: dict[str, t.Any]) -> exp.Expr | exp.DType:
105-
class_name = payload[CLASS]
109+
class_name: str = payload[CLASS]
106110

107111
if class_name == DATA_TYPE:
108112
return exp.DType(payload[VALUE])
109-
113+
module: ModuleType
110114
if "." in class_name:
111115
module_path, class_name = class_name.rsplit(".", maxsplit=1)
112116
module = __import__(module_path, fromlist=[class_name])

sqlglot/time.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
import typing as t
1+
from __future__ import annotations
22
import datetime
3+
import typing as t
34

45
# The generic time format is based on python time.strftime.
56
# https://docs.python.org/3/library/time.html#time.strftime
67
from sqlglot.trie import TrieResult, in_trie, new_trie
78

89

910
def format_time(
10-
string: str, mapping: dict[str, str], trie: t.Optional[dict] = None
11-
) -> t.Optional[str]:
11+
string: str, mapping: dict[str, str], trie: dict[t.Any, t.Any] | None = None
12+
) -> str | None:
1213
"""
1314
Converts a time string given a mapping.
1415
@@ -31,7 +32,7 @@ def format_time(
3132
size = len(string)
3233
trie = trie or new_trie(mapping)
3334
current = trie
34-
chunks = []
35+
chunks: list[str] = []
3536
sym = None
3637

3738
while end <= size:
@@ -61,7 +62,7 @@ def format_time(
6162
return "".join(mapping.get(chars, chars) for chars in chunks)
6263

6364

64-
TIMEZONES = {
65+
TIMEZONES: set[str] = {
6566
tz.lower()
6667
for tz in (
6768
"Africa/Abidjan",

0 commit comments

Comments
 (0)