Skip to content

Commit a50b367

Browse files
authored
fix some types when downstream enabled this package in mypy (#815)
* fix: select arg type * fix some type error when mypy enable * revert * revert * fix
1 parent c1a2fe1 commit a50b367

File tree

3 files changed

+88
-28
lines changed

3 files changed

+88
-28
lines changed

pypika/__init__.py

+54
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,63 @@
112112
FunctionException,
113113
)
114114

115+
115116
__author__ = "Timothy Heys"
116117
__email__ = "[email protected]"
117118
__version__ = "0.48.9"
118119

119120
NULL = NullValue()
120121
SYSTEM_TIME = SystemTimeValue()
122+
123+
__all__ = (
124+
'ClickHouseQuery',
125+
'Dialects',
126+
'MSSQLQuery',
127+
'MySQLQuery',
128+
'OracleQuery',
129+
'PostgreSQLQuery',
130+
'RedshiftQuery',
131+
'SQLLiteQuery',
132+
'VerticaQuery',
133+
'DatePart',
134+
'JoinType',
135+
'Order',
136+
'AliasedQuery',
137+
'Query',
138+
'Schema',
139+
'Table',
140+
'Column',
141+
'Database',
142+
'Tables',
143+
'Columns',
144+
'Array',
145+
'Bracket',
146+
'Case',
147+
'Criterion',
148+
'EmptyCriterion',
149+
'Field',
150+
'Index',
151+
'Interval',
152+
'JSON',
153+
'Not',
154+
'NullValue',
155+
'SystemTimeValue',
156+
'Parameter',
157+
'QmarkParameter',
158+
'NumericParameter',
159+
'NamedParameter',
160+
'FormatParameter',
161+
'PyformatParameter',
162+
'Rollup',
163+
'Tuple',
164+
'CustomFunction',
165+
'CaseException',
166+
'GroupingException',
167+
'JoinException',
168+
'QueryException',
169+
'RollupException',
170+
'SetOperationException',
171+
'FunctionException',
172+
'NULL',
173+
'SYSTEM_TIME',
174+
)

pypika/functions.py

+33-27
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
"""
22
Package for SQL functions wrappers
33
"""
4+
from __future__ import annotations
5+
6+
from typing import Optional
7+
8+
from pypika import Field
49
from pypika.enums import SqlTypes
510
from pypika.terms import (
611
AggregateFunction,
@@ -10,6 +15,7 @@
1015
)
1116
from pypika.utils import builder
1217

18+
1319
__author__ = "Timothy Heys"
1420
__email__ = "[email protected]"
1521

@@ -34,64 +40,64 @@ def distinct(self):
3440

3541

3642
class Count(DistinctOptionFunction):
37-
def __init__(self, param, alias=None):
43+
def __init__(self, param: str | Field, alias: Optional[str] = None) -> None:
3844
is_star = isinstance(param, str) and "*" == param
3945
super(Count, self).__init__("COUNT", Star() if is_star else param, alias=alias)
4046

4147

4248
# Arithmetic Functions
4349
class Sum(DistinctOptionFunction):
44-
def __init__(self, term, alias=None):
50+
def __init__(self, term: str | Field, alias: Optional[str] = None):
4551
super(Sum, self).__init__("SUM", term, alias=alias)
4652

4753

4854
class Avg(AggregateFunction):
49-
def __init__(self, term, alias=None):
55+
def __init__(self, term: str | Field, alias: Optional[str] = None):
5056
super(Avg, self).__init__("AVG", term, alias=alias)
5157

5258

5359
class Min(AggregateFunction):
54-
def __init__(self, term, alias=None):
60+
def __init__(self, term: str | Field, alias: Optional[str] = None):
5561
super(Min, self).__init__("MIN", term, alias=alias)
5662

5763

5864
class Max(AggregateFunction):
59-
def __init__(self, term, alias=None):
65+
def __init__(self, term: str | Field, alias: Optional[str] = None):
6066
super(Max, self).__init__("MAX", term, alias=alias)
6167

6268

6369
class Std(AggregateFunction):
64-
def __init__(self, term, alias=None):
70+
def __init__(self, term: str | Field, alias: Optional[str] = None):
6571
super(Std, self).__init__("STD", term, alias=alias)
6672

6773

6874
class StdDev(AggregateFunction):
69-
def __init__(self, term, alias=None):
75+
def __init__(self, term: str | Field, alias: Optional[str] = None):
7076
super(StdDev, self).__init__("STDDEV", term, alias=alias)
7177

7278

7379
class Abs(AggregateFunction):
74-
def __init__(self, term, alias=None):
80+
def __init__(self, term: str | Field, alias: Optional[str] = None):
7581
super(Abs, self).__init__("ABS", term, alias=alias)
7682

7783

7884
class First(AggregateFunction):
79-
def __init__(self, term, alias=None):
85+
def __init__(self, term: str | Field, alias: Optional[str] = None):
8086
super(First, self).__init__("FIRST", term, alias=alias)
8187

8288

8389
class Last(AggregateFunction):
84-
def __init__(self, term, alias=None):
90+
def __init__(self, term: str | Field, alias: Optional[str] = None):
8591
super(Last, self).__init__("LAST", term, alias=alias)
8692

8793

8894
class Sqrt(Function):
89-
def __init__(self, term, alias=None):
95+
def __init__(self, term: str | Field, alias: Optional[str] = None):
9096
super(Sqrt, self).__init__("SQRT", term, alias=alias)
9197

9298

9399
class Floor(Function):
94-
def __init__(self, term, alias=None):
100+
def __init__(self, term: str | Field, alias: Optional[str] = None):
95101
super(Floor, self).__init__("FLOOR", term, alias=alias)
96102

97103

@@ -131,17 +137,17 @@ def __init__(self, term, as_type, alias=None):
131137

132138

133139
class Signed(Cast):
134-
def __init__(self, term, alias=None):
140+
def __init__(self, term: str | Field, alias: Optional[str] = None):
135141
super(Signed, self).__init__(term, SqlTypes.SIGNED, alias=alias)
136142

137143

138144
class Unsigned(Cast):
139-
def __init__(self, term, alias=None):
145+
def __init__(self, term: str | Field, alias: Optional[str] = None):
140146
super(Unsigned, self).__init__(term, SqlTypes.UNSIGNED, alias=alias)
141147

142148

143149
class Date(Function):
144-
def __init__(self, term, alias=None):
150+
def __init__(self, term: str | Field, alias: Optional[str] = None):
145151
super(Date, self).__init__("DATE", term, alias=alias)
146152

147153

@@ -156,7 +162,7 @@ def __init__(self, start_time, end_time, alias=None):
156162

157163

158164
class DateAdd(Function):
159-
def __init__(self, date_part, interval, term, alias=None):
165+
def __init__(self, date_part, interval, term: str, alias: Optional[str] = None):
160166
date_part = getattr(date_part, "value", date_part)
161167
super(DateAdd, self).__init__("DATE_ADD", LiteralValue(date_part), interval, term, alias=alias)
162168

@@ -167,19 +173,19 @@ def __init__(self, value, format_mask, alias=None):
167173

168174

169175
class Timestamp(Function):
170-
def __init__(self, term, alias=None):
176+
def __init__(self, term: str | Field, alias: Optional[str] = None):
171177
super(Timestamp, self).__init__("TIMESTAMP", term, alias=alias)
172178

173179

174180
class TimestampAdd(Function):
175-
def __init__(self, date_part, interval, term, alias=None):
181+
def __init__(self, date_part, interval, term: str, alias: Optional[str] = None):
176182
date_part = getattr(date_part, 'value', date_part)
177183
super(TimestampAdd, self).__init__("TIMESTAMPADD", LiteralValue(date_part), interval, term, alias=alias)
178184

179185

180186
# String Functions
181187
class Ascii(Function):
182-
def __init__(self, term, alias=None):
188+
def __init__(self, term: str | Field, alias: Optional[str] = None):
183189
super(Ascii, self).__init__("ASCII", term, alias=alias)
184190

185191

@@ -189,7 +195,7 @@ def __init__(self, term, condition, **kwargs):
189195

190196

191197
class Bin(Function):
192-
def __init__(self, term, alias=None):
198+
def __init__(self, term: str | Field, alias: Optional[str] = None):
193199
super(Bin, self).__init__("BIN", term, alias=alias)
194200

195201

@@ -205,17 +211,17 @@ def __init__(self, term, start, stop, subterm, alias=None):
205211

206212

207213
class Length(Function):
208-
def __init__(self, term, alias=None):
214+
def __init__(self, term: str | Field, alias: Optional[str] = None):
209215
super(Length, self).__init__("LENGTH", term, alias=alias)
210216

211217

212218
class Upper(Function):
213-
def __init__(self, term, alias=None):
219+
def __init__(self, term: str | Field, alias: Optional[str] = None):
214220
super(Upper, self).__init__("UPPER", term, alias=alias)
215221

216222

217223
class Lower(Function):
218-
def __init__(self, term, alias=None):
224+
def __init__(self, term: str | Field, alias: Optional[str] = None):
219225
super(Lower, self).__init__("LOWER", term, alias=alias)
220226

221227

@@ -225,12 +231,12 @@ def __init__(self, term, start, stop, alias=None):
225231

226232

227233
class Reverse(Function):
228-
def __init__(self, term, alias=None):
234+
def __init__(self, term: str | Field, alias: Optional[str] = None):
229235
super(Reverse, self).__init__("REVERSE", term, alias=alias)
230236

231237

232238
class Trim(Function):
233-
def __init__(self, term, alias=None):
239+
def __init__(self, term: str | Field, alias: Optional[str] = None):
234240
super(Trim, self).__init__("TRIM", term, alias=alias)
235241

236242

@@ -297,7 +303,7 @@ def get_special_params_sql(self, **kwargs):
297303

298304
# Null Functions
299305
class IsNull(Function):
300-
def __init__(self, term, alias=None):
306+
def __init__(self, term: str | Field, alias: Optional[str] = None):
301307
super(IsNull, self).__init__("ISNULL", term, alias=alias)
302308

303309

@@ -312,5 +318,5 @@ def __init__(self, condition, term, **kwargs):
312318

313319

314320
class NVL(Function):
315-
def __init__(self, condition, term, alias=None):
321+
def __init__(self, condition, term: str, alias: Optional[str] = None):
316322
super(NVL, self).__init__("NVL", condition, term, alias=alias)

pypika/queries.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def __ne__(self, other: Any) -> bool:
206206
def __hash__(self) -> int:
207207
return hash(str(self))
208208

209-
def select(self, *terms: Sequence[Union[int, float, str, bool, Term, Field]]) -> "QueryBuilder":
209+
def select(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilder":
210210
"""
211211
Perform a SELECT operation on the current table
212212

0 commit comments

Comments
 (0)