Skip to content

Commit 4072bfb

Browse files
⭐ Improve parameterized query support - fixes #793 (#794)
* Add parameterized query support * Revert base Parameter constructor back to it's original signature * Fix a few typehints and make code more DRY * add test for PyformatParameter * fix linting issues --------- Co-authored-by: Lars Schwegmann <[email protected]>
1 parent 53a77eb commit 4072bfb

File tree

3 files changed

+239
-23
lines changed

3 files changed

+239
-23
lines changed

pypika/terms.py

+126-22
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,21 @@
33
import uuid
44
from datetime import date
55
from enum import Enum
6-
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union
6+
from typing import (
7+
TYPE_CHECKING,
8+
Any,
9+
Callable,
10+
Iterable,
11+
Iterator,
12+
List,
13+
Optional,
14+
Sequence,
15+
Set,
16+
Tuple,
17+
Type,
18+
TypeVar,
19+
Union,
20+
)
721

822
from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order
923
from pypika.utils import (
@@ -288,57 +302,111 @@ def get_sql(self, **kwargs: Any) -> str:
288302
raise NotImplementedError()
289303

290304

305+
def idx_placeholder_gen(idx: int) -> str:
306+
return str(idx + 1)
307+
308+
309+
def named_placeholder_gen(idx: int) -> str:
310+
return f'param{idx + 1}'
311+
312+
291313
class Parameter(Term):
292314
is_aggregate = None
293315

294316
def __init__(self, placeholder: Union[str, int]) -> None:
295317
super().__init__()
296-
self.placeholder = placeholder
318+
self._placeholder = placeholder
319+
320+
@property
321+
def placeholder(self):
322+
return self._placeholder
297323

298324
def get_sql(self, **kwargs: Any) -> str:
299325
return str(self.placeholder)
300326

327+
def update_parameters(self, param_key: Any, param_value: Any, **kwargs):
328+
pass
301329

302-
class QmarkParameter(Parameter):
303-
"""Question mark style, e.g. ...WHERE name=?"""
330+
def get_param_key(self, placeholder: Any, **kwargs):
331+
return placeholder
304332

305-
def __init__(self) -> None:
306-
pass
307333

308-
def get_sql(self, **kwargs: Any) -> str:
309-
return "?"
334+
class ListParameter(Parameter):
335+
def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None:
336+
super().__init__(placeholder=placeholder)
337+
self._parameters = list()
310338

339+
@property
340+
def placeholder(self) -> str:
341+
if callable(self._placeholder):
342+
return self._placeholder(len(self._parameters))
311343

312-
class NumericParameter(Parameter):
313-
"""Numeric, positional style, e.g. ...WHERE name=:1"""
344+
return str(self._placeholder)
314345

315-
def get_sql(self, **kwargs: Any) -> str:
316-
return ":{placeholder}".format(placeholder=self.placeholder)
346+
def get_parameters(self, **kwargs):
347+
return self._parameters
317348

349+
def update_parameters(self, value: Any, **kwargs):
350+
self._parameters.append(value)
318351

319-
class NamedParameter(Parameter):
320-
"""Named style, e.g. ...WHERE name=:name"""
352+
353+
class DictParameter(Parameter):
354+
def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen) -> None:
355+
super().__init__(placeholder=placeholder)
356+
self._parameters = dict()
357+
358+
@property
359+
def placeholder(self) -> str:
360+
if callable(self._placeholder):
361+
return self._placeholder(len(self._parameters))
362+
363+
return str(self._placeholder)
364+
365+
def get_parameters(self, **kwargs):
366+
return self._parameters
367+
368+
def get_param_key(self, placeholder: Any, **kwargs):
369+
return placeholder[1:]
370+
371+
def update_parameters(self, param_key: Any, value: Any, **kwargs):
372+
self._parameters[param_key] = value
373+
374+
375+
class QmarkParameter(ListParameter):
376+
def get_sql(self, **kwargs):
377+
return '?'
378+
379+
380+
class NumericParameter(ListParameter):
381+
"""Numeric, positional style, e.g. ...WHERE name=:1"""
321382

322383
def get_sql(self, **kwargs: Any) -> str:
323384
return ":{placeholder}".format(placeholder=self.placeholder)
324385

325386

326-
class FormatParameter(Parameter):
387+
class FormatParameter(ListParameter):
327388
"""ANSI C printf format codes, e.g. ...WHERE name=%s"""
328389

329-
def __init__(self) -> None:
330-
pass
331-
332390
def get_sql(self, **kwargs: Any) -> str:
333391
return "%s"
334392

335393

336-
class PyformatParameter(Parameter):
394+
class NamedParameter(DictParameter):
395+
"""Named style, e.g. ...WHERE name=:name"""
396+
397+
def get_sql(self, **kwargs: Any) -> str:
398+
return ":{placeholder}".format(placeholder=self.placeholder)
399+
400+
401+
class PyformatParameter(DictParameter):
337402
"""Python extended format codes, e.g. ...WHERE name=%(name)s"""
338403

339404
def get_sql(self, **kwargs: Any) -> str:
340405
return "%({placeholder})s".format(placeholder=self.placeholder)
341406

407+
def get_param_key(self, placeholder: Any, **kwargs):
408+
return placeholder[2:-2]
409+
342410

343411
class Negative(Term):
344412
def __init__(self, term: Term) -> None:
@@ -385,9 +453,44 @@ def get_formatted_value(cls, value: Any, **kwargs):
385453
return "null"
386454
return str(value)
387455

388-
def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any) -> str:
389-
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
390-
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
456+
def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
457+
param_sql = parameter.get_sql(**kwargs)
458+
param_key = parameter.get_param_key(placeholder=param_sql)
459+
460+
return param_sql, param_key
461+
462+
def get_sql(
463+
self,
464+
quote_char: Optional[str] = None,
465+
secondary_quote_char: str = "'",
466+
parameter: Parameter = None,
467+
**kwargs: Any,
468+
) -> str:
469+
if parameter is None:
470+
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
471+
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
472+
473+
# Don't stringify numbers when using a parameter
474+
if isinstance(self.value, (int, float)):
475+
value_sql = self.value
476+
else:
477+
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs)
478+
param_sql, param_key = self._get_param_data(parameter, **kwargs)
479+
parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs)
480+
481+
return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs)
482+
483+
484+
class ParameterValueWrapper(ValueWrapper):
485+
def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None:
486+
super().__init__(value, alias)
487+
self._parameter = parameter
488+
489+
def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
490+
param_sql = self._parameter.get_sql(**kwargs)
491+
param_key = self._parameter.get_param_key(placeholder=param_sql)
492+
493+
return param_sql, param_key
391494

392495

393496
class JSON(Term):
@@ -551,6 +654,7 @@ def __init__(
551654
if isinstance(table, str):
552655
# avoid circular import at load time
553656
from pypika.queries import Table
657+
554658
table = Table(table)
555659
self.table = table
556660

pypika/tests/test_parameter.py

+112
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from datetime import date
23

34
from pypika import (
45
FormatParameter,
@@ -10,6 +11,7 @@
1011
Query,
1112
Tables,
1213
)
14+
from pypika.terms import ListParameter, ParameterValueWrapper
1315

1416

1517
class ParametrizedTests(unittest.TestCase):
@@ -92,3 +94,113 @@ def test_format_parameter(self):
9294

9395
def test_pyformat_parameter(self):
9496
self.assertEqual('%(buz)s', PyformatParameter('buz').get_sql())
97+
98+
99+
class ParametrizedTestsWithValues(unittest.TestCase):
100+
table_abc, table_efg = Tables("abc", "efg")
101+
102+
def test_param_insert(self):
103+
q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, 'foo')
104+
105+
parameter = QmarkParameter()
106+
sql = q.get_sql(parameter=parameter)
107+
self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (?,?,?)', sql)
108+
self.assertEqual([1, 2.2, 'foo'], parameter.get_parameters())
109+
110+
def test_param_select_join(self):
111+
q = (
112+
Query.from_(self.table_abc)
113+
.select("*")
114+
.where(self.table_abc.category == 'foobar')
115+
.join(self.table_efg)
116+
.on(self.table_abc.id == self.table_efg.abc_id)
117+
.where(self.table_efg.date >= date(2024, 2, 22))
118+
.limit(10)
119+
)
120+
121+
parameter = FormatParameter()
122+
sql = q.get_sql(parameter=parameter)
123+
self.assertEqual(
124+
'SELECT * FROM "abc" JOIN "efg" ON "abc"."id"="efg"."abc_id" WHERE "abc"."category"=%s AND "efg"."date">=%s LIMIT 10',
125+
sql,
126+
)
127+
self.assertEqual(['foobar', '2024-02-22'], parameter.get_parameters())
128+
129+
def test_param_select_subquery(self):
130+
q = (
131+
Query.from_(self.table_abc)
132+
.select("*")
133+
.where(self.table_abc.category == 'foobar')
134+
.where(
135+
self.table_abc.id.isin(
136+
Query.from_(self.table_efg)
137+
.select(self.table_efg.abc_id)
138+
.where(self.table_efg.date >= date(2024, 2, 22))
139+
)
140+
)
141+
.limit(10)
142+
)
143+
144+
parameter = ListParameter(placeholder=lambda idx: f'&{idx+1}')
145+
sql = q.get_sql(parameter=parameter)
146+
self.assertEqual(
147+
'SELECT * FROM "abc" WHERE "category"=&1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=&2) LIMIT 10',
148+
sql,
149+
)
150+
self.assertEqual(['foobar', '2024-02-22'], parameter.get_parameters())
151+
152+
def test_join(self):
153+
subquery = (
154+
Query.from_(self.table_efg)
155+
.select(self.table_efg.fiz, self.table_efg.buz)
156+
.where(self.table_efg.buz == 'buz')
157+
)
158+
159+
q = (
160+
Query.from_(self.table_abc)
161+
.join(subquery)
162+
.on(self.table_abc.bar == subquery.buz)
163+
.select(self.table_abc.foo, subquery.fiz)
164+
.where(self.table_abc.bar == 'bar')
165+
)
166+
167+
parameter = NamedParameter()
168+
sql = q.get_sql(parameter=parameter)
169+
self.assertEqual(
170+
'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:param1)'
171+
' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:param2',
172+
sql,
173+
)
174+
self.assertEqual({'param1': 'buz', 'param2': 'bar'}, parameter.get_parameters())
175+
176+
def test_join_with_parameter_value_wrapper(self):
177+
subquery = (
178+
Query.from_(self.table_efg)
179+
.select(self.table_efg.fiz, self.table_efg.buz)
180+
.where(self.table_efg.buz == ParameterValueWrapper(Parameter(':buz'), 'buz'))
181+
)
182+
183+
q = (
184+
Query.from_(self.table_abc)
185+
.join(subquery)
186+
.on(self.table_abc.bar == subquery.buz)
187+
.select(self.table_abc.foo, subquery.fiz)
188+
.where(self.table_abc.bar == ParameterValueWrapper(NamedParameter('bar'), 'bar'))
189+
)
190+
191+
parameter = NamedParameter()
192+
sql = q.get_sql(parameter=parameter)
193+
self.assertEqual(
194+
'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:buz)'
195+
' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:bar',
196+
sql,
197+
)
198+
self.assertEqual({':buz': 'buz', 'bar': 'bar'}, parameter.get_parameters())
199+
200+
def test_pyformat_parameter(self):
201+
q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, 'foo')
202+
203+
parameter = PyformatParameter()
204+
sql = q.get_sql(parameter=parameter)
205+
self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (%(param1)s,%(param2)s,%(param3)s)', sql)
206+
self.assertEqual({"param1": 1, "param2": 2.2, "param3": "foo"}, parameter.get_parameters())

pypika/tests/test_terms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_init_with_str_table(self):
2020
test_table_name = "test_table"
2121
field = Field(name="name", table=test_table_name)
2222
self.assertEqual(field.table, Table(name=test_table_name))
23-
23+
2424

2525
class FieldHashingTests(TestCase):
2626
def test_tabled_eq_fields_equally_hashed(self):

0 commit comments

Comments
 (0)