Skip to content

Commit 0967cb1

Browse files
authored
Merge branch 'master' into master
2 parents 98cbf46 + 4072bfb commit 0967cb1

File tree

7 files changed

+314
-49
lines changed

7 files changed

+314
-49
lines changed

pypika/dialects.py

+27-14
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import warnings
23
from copy import copy
34
from typing import Any, Optional, Union, Tuple as TypedTuple
45

@@ -87,7 +88,7 @@ class MySQLQueryBuilder(QueryBuilder):
8788
QUERY_CLS = MySQLQuery
8889

8990
def __init__(self, **kwargs: Any) -> None:
90-
super().__init__(dialect=Dialects.MYSQL, wrap_set_operation_queries=False, **kwargs)
91+
super().__init__(dialect=Dialects.MYSQL, **kwargs)
9192
self._duplicate_updates = []
9293
self._ignore_duplicates = False
9394
self._modifiers = []
@@ -347,6 +348,19 @@ def __str__(self) -> str:
347348
return self.get_sql()
348349

349350

351+
class FetchNextAndOffsetRowsQueryBuilder(QueryBuilder):
352+
def _limit_sql(self) -> str:
353+
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit)
354+
355+
def _offset_sql(self) -> str:
356+
return " OFFSET {offset} ROWS".format(offset=self._offset or 0)
357+
358+
@builder
359+
def fetch_next(self, limit: int):
360+
warnings.warn("`fetch_next` is deprecated - please use the `limit` method", DeprecationWarning)
361+
self._limit = limit
362+
363+
350364
class OracleQuery(Query):
351365
"""
352366
Defines a query class for use with Oracle.
@@ -357,7 +371,7 @@ def _builder(cls, **kwargs: Any) -> "OracleQueryBuilder":
357371
return OracleQueryBuilder(**kwargs)
358372

359373

360-
class OracleQueryBuilder(QueryBuilder):
374+
class OracleQueryBuilder(FetchNextAndOffsetRowsQueryBuilder):
361375
QUOTE_CHAR = None
362376
QUERY_CLS = OracleQuery
363377

@@ -370,6 +384,16 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str:
370384
kwargs['groupby_alias'] = False
371385
return super().get_sql(*args, **kwargs)
372386

387+
def _apply_pagination(self, querystring: str) -> str:
388+
# Note: Overridden as Oracle specifies offset before the fetch next limit
389+
if self._offset:
390+
querystring += self._offset_sql()
391+
392+
if self._limit is not None:
393+
querystring += self._limit_sql()
394+
395+
return querystring
396+
373397

374398
class PostgreSQLQuery(Query):
375399
"""
@@ -670,7 +694,7 @@ def _builder(cls, **kwargs: Any) -> "MSSQLQueryBuilder":
670694
return MSSQLQueryBuilder(**kwargs)
671695

672696

673-
class MSSQLQueryBuilder(QueryBuilder):
697+
class MSSQLQueryBuilder(FetchNextAndOffsetRowsQueryBuilder):
674698
QUERY_CLS = MSSQLQuery
675699

676700
def __init__(self, **kwargs: Any) -> None:
@@ -695,17 +719,6 @@ def top(self, value: Union[str, int], percent: bool = False, with_ties: bool = F
695719
self._top_percent: bool = percent
696720
self._top_with_ties: bool = with_ties
697721

698-
@builder
699-
def fetch_next(self, limit: int) -> "MSSQLQueryBuilder":
700-
# Overridden to provide a more domain-specific API for T-SQL users
701-
self._limit = limit
702-
703-
def _offset_sql(self) -> str:
704-
return " OFFSET {offset} ROWS".format(offset=self._offset or 0)
705-
706-
def _limit_sql(self) -> str:
707-
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit)
708-
709722
def _apply_pagination(self, querystring: str) -> str:
710723
# Note: Overridden as MSSQL specifies offset before the fetch next limit
711724
if self._limit is not None or self._offset:

pypika/terms.py

+130-22
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,21 @@
33
import uuid
44
from datetime import date
55
from enum import Enum
6-
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union
6+
from typing import (
7+
TYPE_CHECKING,
8+
Any,
9+
Callable,
10+
Iterable,
11+
Iterator,
12+
List,
13+
Optional,
14+
Sequence,
15+
Set,
16+
Tuple,
17+
Type,
18+
TypeVar,
19+
Union,
20+
)
721

822
from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order
923
from pypika.utils import (
@@ -288,57 +302,111 @@ def get_sql(self, **kwargs: Any) -> str:
288302
raise NotImplementedError()
289303

290304

305+
def idx_placeholder_gen(idx: int) -> str:
306+
return str(idx + 1)
307+
308+
309+
def named_placeholder_gen(idx: int) -> str:
310+
return f'param{idx + 1}'
311+
312+
291313
class Parameter(Term):
292314
is_aggregate = None
293315

294316
def __init__(self, placeholder: Union[str, int]) -> None:
295317
super().__init__()
296-
self.placeholder = placeholder
318+
self._placeholder = placeholder
319+
320+
@property
321+
def placeholder(self):
322+
return self._placeholder
297323

298324
def get_sql(self, **kwargs: Any) -> str:
299325
return str(self.placeholder)
300326

327+
def update_parameters(self, param_key: Any, param_value: Any, **kwargs):
328+
pass
301329

302-
class QmarkParameter(Parameter):
303-
"""Question mark style, e.g. ...WHERE name=?"""
330+
def get_param_key(self, placeholder: Any, **kwargs):
331+
return placeholder
304332

305-
def __init__(self) -> None:
306-
pass
307333

308-
def get_sql(self, **kwargs: Any) -> str:
309-
return "?"
334+
class ListParameter(Parameter):
335+
def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None:
336+
super().__init__(placeholder=placeholder)
337+
self._parameters = list()
310338

339+
@property
340+
def placeholder(self) -> str:
341+
if callable(self._placeholder):
342+
return self._placeholder(len(self._parameters))
311343

312-
class NumericParameter(Parameter):
313-
"""Numeric, positional style, e.g. ...WHERE name=:1"""
344+
return str(self._placeholder)
314345

315-
def get_sql(self, **kwargs: Any) -> str:
316-
return ":{placeholder}".format(placeholder=self.placeholder)
346+
def get_parameters(self, **kwargs):
347+
return self._parameters
317348

349+
def update_parameters(self, value: Any, **kwargs):
350+
self._parameters.append(value)
318351

319-
class NamedParameter(Parameter):
320-
"""Named style, e.g. ...WHERE name=:name"""
352+
353+
class DictParameter(Parameter):
354+
def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen) -> None:
355+
super().__init__(placeholder=placeholder)
356+
self._parameters = dict()
357+
358+
@property
359+
def placeholder(self) -> str:
360+
if callable(self._placeholder):
361+
return self._placeholder(len(self._parameters))
362+
363+
return str(self._placeholder)
364+
365+
def get_parameters(self, **kwargs):
366+
return self._parameters
367+
368+
def get_param_key(self, placeholder: Any, **kwargs):
369+
return placeholder[1:]
370+
371+
def update_parameters(self, param_key: Any, value: Any, **kwargs):
372+
self._parameters[param_key] = value
373+
374+
375+
class QmarkParameter(ListParameter):
376+
def get_sql(self, **kwargs):
377+
return '?'
378+
379+
380+
class NumericParameter(ListParameter):
381+
"""Numeric, positional style, e.g. ...WHERE name=:1"""
321382

322383
def get_sql(self, **kwargs: Any) -> str:
323384
return ":{placeholder}".format(placeholder=self.placeholder)
324385

325386

326-
class FormatParameter(Parameter):
387+
class FormatParameter(ListParameter):
327388
"""ANSI C printf format codes, e.g. ...WHERE name=%s"""
328389

329-
def __init__(self) -> None:
330-
pass
331-
332390
def get_sql(self, **kwargs: Any) -> str:
333391
return "%s"
334392

335393

336-
class PyformatParameter(Parameter):
394+
class NamedParameter(DictParameter):
395+
"""Named style, e.g. ...WHERE name=:name"""
396+
397+
def get_sql(self, **kwargs: Any) -> str:
398+
return ":{placeholder}".format(placeholder=self.placeholder)
399+
400+
401+
class PyformatParameter(DictParameter):
337402
"""Python extended format codes, e.g. ...WHERE name=%(name)s"""
338403

339404
def get_sql(self, **kwargs: Any) -> str:
340405
return "%({placeholder})s".format(placeholder=self.placeholder)
341406

407+
def get_param_key(self, placeholder: Any, **kwargs):
408+
return placeholder[2:-2]
409+
342410

343411
class Negative(Term):
344412
def __init__(self, term: Term) -> None:
@@ -385,9 +453,44 @@ def get_formatted_value(cls, value: Any, **kwargs):
385453
return "null"
386454
return str(value)
387455

388-
def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any) -> str:
389-
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
390-
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
456+
def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
457+
param_sql = parameter.get_sql(**kwargs)
458+
param_key = parameter.get_param_key(placeholder=param_sql)
459+
460+
return param_sql, param_key
461+
462+
def get_sql(
463+
self,
464+
quote_char: Optional[str] = None,
465+
secondary_quote_char: str = "'",
466+
parameter: Parameter = None,
467+
**kwargs: Any,
468+
) -> str:
469+
if parameter is None:
470+
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
471+
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
472+
473+
# Don't stringify numbers when using a parameter
474+
if isinstance(self.value, (int, float)):
475+
value_sql = self.value
476+
else:
477+
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs)
478+
param_sql, param_key = self._get_param_data(parameter, **kwargs)
479+
parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs)
480+
481+
return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs)
482+
483+
484+
class ParameterValueWrapper(ValueWrapper):
485+
def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None:
486+
super().__init__(value, alias)
487+
self._parameter = parameter
488+
489+
def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
490+
param_sql = self._parameter.get_sql(**kwargs)
491+
param_key = self._parameter.get_param_key(placeholder=param_sql)
492+
493+
return param_sql, param_key
391494

392495

393496
class JSON(Term):
@@ -548,6 +651,11 @@ def __init__(
548651
) -> None:
549652
super().__init__(alias=alias)
550653
self.name = name
654+
if isinstance(table, str):
655+
# avoid circular import at load time
656+
from pypika.queries import Table
657+
658+
table = Table(table)
551659
self.table = table
552660

553661
def nodes_(self) -> Iterator[NodeT]:

pypika/tests/dialects/test_mssql.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,13 @@ def test_limit(self):
5353

5454
self.assertEqual('SELECT "def" FROM "abc" ORDER BY "def" OFFSET 0 ROWS FETCH NEXT 10 ROWS ONLY', str(q))
5555

56-
def test_fetch_next(self):
57-
q = MSSQLQuery.from_("abc").select("def").orderby("def").fetch_next(10)
58-
59-
self.assertEqual('SELECT "def" FROM "abc" ORDER BY "def" OFFSET 0 ROWS FETCH NEXT 10 ROWS ONLY', str(q))
60-
6156
def test_offset(self):
6257
q = MSSQLQuery.from_("abc").select("def").orderby("def").offset(10)
6358

6459
self.assertEqual('SELECT "def" FROM "abc" ORDER BY "def" OFFSET 10 ROWS', str(q))
6560

66-
def test_fetch_next_with_offset(self):
67-
q = MSSQLQuery.from_("abc").select("def").orderby("def").fetch_next(10).offset(10)
61+
def test_limit_with_offset(self):
62+
q = MSSQLQuery.from_("abc").select("def").orderby("def").limit(10).offset(10)
6863

6964
self.assertEqual('SELECT "def" FROM "abc" ORDER BY "def" OFFSET 10 ROWS FETCH NEXT 10 ROWS ONLY', str(q))
7065

pypika/tests/dialects/test_oracle.py

+30
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,33 @@ def test_groupby_alias_False_does_not_group_by_alias_when_subqueries_are_present
1919
q = OracleQuery.from_(subquery).select(col, Count('*')).groupby(col)
2020

2121
self.assertEqual('SELECT sq0.abc a,COUNT(\'*\') FROM (SELECT abc FROM table1) sq0 GROUP BY sq0.abc', str(q))
22+
23+
def test_limit_query(self):
24+
t = Table('table1')
25+
limit = 5
26+
q = OracleQuery.from_(t).select(t.test).limit(limit)
27+
28+
self.assertEqual(f'SELECT test FROM table1 FETCH NEXT {limit} ROWS ONLY', str(q))
29+
30+
def test_offset_query(self):
31+
t = Table('table1')
32+
offset = 5
33+
q = OracleQuery.from_(t).select(t.test).offset(offset)
34+
35+
self.assertEqual(f'SELECT test FROM table1 OFFSET {offset} ROWS', str(q))
36+
37+
def test_limit_offset_query(self):
38+
t = Table('table1')
39+
limit = 5
40+
offset = 5
41+
q = OracleQuery.from_(t).select(t.test).limit(limit).offset(offset)
42+
43+
self.assertEqual(f'SELECT test FROM table1 OFFSET {offset} ROWS FETCH NEXT {limit} ROWS ONLY', str(q))
44+
45+
def test_fetch_next_method_deprecated(self):
46+
with self.assertWarns(DeprecationWarning):
47+
t = Table('table1')
48+
limit = 5
49+
q = OracleQuery.from_(t).select(t.test).fetch_next(limit)
50+
51+
self.assertEqual(f'SELECT test FROM table1 FETCH NEXT {limit} ROWS ONLY', str(q))

0 commit comments

Comments
 (0)