|
3 | 3 | import uuid
|
4 | 4 | from datetime import date
|
5 | 5 | from enum import Enum
|
6 |
| -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union |
| 6 | +from typing import ( |
| 7 | + TYPE_CHECKING, |
| 8 | + Any, |
| 9 | + Callable, |
| 10 | + Iterable, |
| 11 | + Iterator, |
| 12 | + List, |
| 13 | + Optional, |
| 14 | + Sequence, |
| 15 | + Set, |
| 16 | + Tuple, |
| 17 | + Type, |
| 18 | + TypeVar, |
| 19 | + Union, |
| 20 | +) |
7 | 21 |
|
8 | 22 | from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order
|
9 | 23 | from pypika.utils import (
|
@@ -288,57 +302,111 @@ def get_sql(self, **kwargs: Any) -> str:
|
288 | 302 | raise NotImplementedError()
|
289 | 303 |
|
290 | 304 |
|
| 305 | +def idx_placeholder_gen(idx: int) -> str: |
| 306 | + return str(idx + 1) |
| 307 | + |
| 308 | + |
| 309 | +def named_placeholder_gen(idx: int) -> str: |
| 310 | + return f'param{idx + 1}' |
| 311 | + |
| 312 | + |
291 | 313 | class Parameter(Term):
|
292 | 314 | is_aggregate = None
|
293 | 315 |
|
294 | 316 | def __init__(self, placeholder: Union[str, int]) -> None:
|
295 | 317 | super().__init__()
|
296 |
| - self.placeholder = placeholder |
| 318 | + self._placeholder = placeholder |
| 319 | + |
| 320 | + @property |
| 321 | + def placeholder(self): |
| 322 | + return self._placeholder |
297 | 323 |
|
298 | 324 | def get_sql(self, **kwargs: Any) -> str:
|
299 | 325 | return str(self.placeholder)
|
300 | 326 |
|
| 327 | + def update_parameters(self, param_key: Any, param_value: Any, **kwargs): |
| 328 | + pass |
301 | 329 |
|
302 |
| -class QmarkParameter(Parameter): |
303 |
| - """Question mark style, e.g. ...WHERE name=?""" |
| 330 | + def get_param_key(self, placeholder: Any, **kwargs): |
| 331 | + return placeholder |
304 | 332 |
|
305 |
| - def __init__(self) -> None: |
306 |
| - pass |
307 | 333 |
|
308 |
| - def get_sql(self, **kwargs: Any) -> str: |
309 |
| - return "?" |
| 334 | +class ListParameter(Parameter): |
| 335 | + def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None: |
| 336 | + super().__init__(placeholder=placeholder) |
| 337 | + self._parameters = list() |
310 | 338 |
|
| 339 | + @property |
| 340 | + def placeholder(self) -> str: |
| 341 | + if callable(self._placeholder): |
| 342 | + return self._placeholder(len(self._parameters)) |
311 | 343 |
|
312 |
| -class NumericParameter(Parameter): |
313 |
| - """Numeric, positional style, e.g. ...WHERE name=:1""" |
| 344 | + return str(self._placeholder) |
314 | 345 |
|
315 |
| - def get_sql(self, **kwargs: Any) -> str: |
316 |
| - return ":{placeholder}".format(placeholder=self.placeholder) |
| 346 | + def get_parameters(self, **kwargs): |
| 347 | + return self._parameters |
317 | 348 |
|
| 349 | + def update_parameters(self, value: Any, **kwargs): |
| 350 | + self._parameters.append(value) |
318 | 351 |
|
319 |
| -class NamedParameter(Parameter): |
320 |
| - """Named style, e.g. ...WHERE name=:name""" |
| 352 | + |
| 353 | +class DictParameter(Parameter): |
| 354 | + def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen) -> None: |
| 355 | + super().__init__(placeholder=placeholder) |
| 356 | + self._parameters = dict() |
| 357 | + |
| 358 | + @property |
| 359 | + def placeholder(self) -> str: |
| 360 | + if callable(self._placeholder): |
| 361 | + return self._placeholder(len(self._parameters)) |
| 362 | + |
| 363 | + return str(self._placeholder) |
| 364 | + |
| 365 | + def get_parameters(self, **kwargs): |
| 366 | + return self._parameters |
| 367 | + |
| 368 | + def get_param_key(self, placeholder: Any, **kwargs): |
| 369 | + return placeholder[1:] |
| 370 | + |
| 371 | + def update_parameters(self, param_key: Any, value: Any, **kwargs): |
| 372 | + self._parameters[param_key] = value |
| 373 | + |
| 374 | + |
| 375 | +class QmarkParameter(ListParameter): |
| 376 | + def get_sql(self, **kwargs): |
| 377 | + return '?' |
| 378 | + |
| 379 | + |
| 380 | +class NumericParameter(ListParameter): |
| 381 | + """Numeric, positional style, e.g. ...WHERE name=:1""" |
321 | 382 |
|
322 | 383 | def get_sql(self, **kwargs: Any) -> str:
|
323 | 384 | return ":{placeholder}".format(placeholder=self.placeholder)
|
324 | 385 |
|
325 | 386 |
|
326 |
| -class FormatParameter(Parameter): |
| 387 | +class FormatParameter(ListParameter): |
327 | 388 | """ANSI C printf format codes, e.g. ...WHERE name=%s"""
|
328 | 389 |
|
329 |
| - def __init__(self) -> None: |
330 |
| - pass |
331 |
| - |
332 | 390 | def get_sql(self, **kwargs: Any) -> str:
|
333 | 391 | return "%s"
|
334 | 392 |
|
335 | 393 |
|
336 |
| -class PyformatParameter(Parameter): |
| 394 | +class NamedParameter(DictParameter): |
| 395 | + """Named style, e.g. ...WHERE name=:name""" |
| 396 | + |
| 397 | + def get_sql(self, **kwargs: Any) -> str: |
| 398 | + return ":{placeholder}".format(placeholder=self.placeholder) |
| 399 | + |
| 400 | + |
| 401 | +class PyformatParameter(DictParameter): |
337 | 402 | """Python extended format codes, e.g. ...WHERE name=%(name)s"""
|
338 | 403 |
|
339 | 404 | def get_sql(self, **kwargs: Any) -> str:
|
340 | 405 | return "%({placeholder})s".format(placeholder=self.placeholder)
|
341 | 406 |
|
| 407 | + def get_param_key(self, placeholder: Any, **kwargs): |
| 408 | + return placeholder[2:-2] |
| 409 | + |
342 | 410 |
|
343 | 411 | class Negative(Term):
|
344 | 412 | def __init__(self, term: Term) -> None:
|
@@ -385,9 +453,44 @@ def get_formatted_value(cls, value: Any, **kwargs):
|
385 | 453 | return "null"
|
386 | 454 | return str(value)
|
387 | 455 |
|
388 |
| - def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any) -> str: |
389 |
| - sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) |
390 |
| - return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) |
| 456 | + def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: |
| 457 | + param_sql = parameter.get_sql(**kwargs) |
| 458 | + param_key = parameter.get_param_key(placeholder=param_sql) |
| 459 | + |
| 460 | + return param_sql, param_key |
| 461 | + |
| 462 | + def get_sql( |
| 463 | + self, |
| 464 | + quote_char: Optional[str] = None, |
| 465 | + secondary_quote_char: str = "'", |
| 466 | + parameter: Parameter = None, |
| 467 | + **kwargs: Any, |
| 468 | + ) -> str: |
| 469 | + if parameter is None: |
| 470 | + sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) |
| 471 | + return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) |
| 472 | + |
| 473 | + # Don't stringify numbers when using a parameter |
| 474 | + if isinstance(self.value, (int, float)): |
| 475 | + value_sql = self.value |
| 476 | + else: |
| 477 | + value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) |
| 478 | + param_sql, param_key = self._get_param_data(parameter, **kwargs) |
| 479 | + parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs) |
| 480 | + |
| 481 | + return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) |
| 482 | + |
| 483 | + |
| 484 | +class ParameterValueWrapper(ValueWrapper): |
| 485 | + def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None: |
| 486 | + super().__init__(value, alias) |
| 487 | + self._parameter = parameter |
| 488 | + |
| 489 | + def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: |
| 490 | + param_sql = self._parameter.get_sql(**kwargs) |
| 491 | + param_key = self._parameter.get_param_key(placeholder=param_sql) |
| 492 | + |
| 493 | + return param_sql, param_key |
391 | 494 |
|
392 | 495 |
|
393 | 496 | class JSON(Term):
|
@@ -551,6 +654,7 @@ def __init__(
|
551 | 654 | if isinstance(table, str):
|
552 | 655 | # avoid circular import at load time
|
553 | 656 | from pypika.queries import Table
|
| 657 | + |
554 | 658 | table = Table(table)
|
555 | 659 | self.table = table
|
556 | 660 |
|
|
0 commit comments