Skip to content

Commit 793846c

Browse files
committed
redesigned logic and added tests
1 parent 37c147e commit 793846c

File tree

2 files changed

+137
-45
lines changed

2 files changed

+137
-45
lines changed

src/capture_db_queries/wrappers.py

+43-30
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
import json
44
import time
5+
import traceback
56
from typing import TYPE_CHECKING, Any, NamedTuple
67

78
from django.utils.regex_helper import _lazy_re_compile
89

10+
from ._logging import log
911
from .timers import ContextTimer
1012

1113
if TYPE_CHECKING:
@@ -19,33 +21,48 @@
1921

2022

2123
class BaseExecutionWrapper:
22-
def __init__(self, queries_log: QueriesLog, *args: Any, **kwargs: Any) -> None:
23-
self.timer = ContextTimer(time.perf_counter)
24+
def __init__(
25+
self, connection: BaseDatabaseWrapper, queries_log: QueriesLog, *args: Any, **kwargs: Any
26+
) -> None:
27+
log.debug('')
28+
29+
self.connection = connection
2430
self.queries_log = queries_log
31+
self.timer = ContextTimer(time.perf_counter)
32+
33+
# Wrappers for specific databases are stored at addresses:
34+
# django.db.backends.sqlite3.operations.DatabaseOperations
35+
# django.db.backends.postgresql.operations.DatabaseOperations
36+
self.db_operations: BaseDatabaseOperations = self.connection.ops
2537

2638
def __call__(
2739
self,
28-
execute: Callable[..., CursorWrapper | None],
40+
execute: Callable[..., CursorWrapper],
2941
sql: str,
3042
params: tuple[Any, ...],
3143
many: bool,
3244
context: dict[str, Any],
3345
) -> CursorWrapper | None:
3446
"""
35-
Выполняет оригинальный запрос
47+
Executes the original SQL request.
3648
"""
3749
with self.timer as timer:
3850
try:
3951
result = execute(sql, params, many, context)
4052
except Exception as exc:
41-
print('Что-то пошло не так:', exc)
53+
print('Something went wrong:', exc)
4254
return None
4355

44-
# from django.db.backends.utils import CursorDebugWrapper(connection.cursor(), connection)
56+
if not many:
57+
# Get filled SQL with params
58+
sql = self.db_operations.last_executed_query(context['cursor'], sql, params)
59+
4560
query: Query = {'sql': sql, 'time': timer.execution_time}
4661
query = self.update_query(query)
4762
self.queries_log.append(query)
4863

64+
log.trace('Location of SQL Call:\n%s', ''.join(traceback.format_stack()))
65+
4966
return result
5067

5168
def update_query(self, query: Query) -> Query:
@@ -54,12 +71,13 @@ def update_query(self, query: Query) -> Query:
5471

5572
class ExplainExecutionWrapper(BaseExecutionWrapper):
5673
"""
57-
Класс для вызова EXPLAIN на каждом SELECT-запросе.
58-
С сохранением полной функциональности метода QuerySet.explain().
74+
A class for calling EXPLAIN before each original SELECT request.
75+
While maintaining the full functionality of the Query Set.explain() method.
76+
77+
The EXPLAIN call is not fixed in any way and does not affect the measurement results.
5978
60-
https://docs.djangoproject.com/en/5.1/topics/db/instrumentation/#connection-execute-wrapper
6179
https://docs.djangoproject.com/en/5.1/ref/models/querysets/#explain
62-
from django_extensions.management.debug_cursor import monkey_patch_cursordebugwrapper
80+
https://www.postgresql.org/docs/current/sql-explain.html
6381
"""
6482

6583
# Inspired from
@@ -72,20 +90,18 @@ class ExplainInfo(NamedTuple):
7290

7391
def __init__(
7492
self,
75-
queries_log: QueriesLog,
7693
connection: BaseDatabaseWrapper,
94+
queries_log: QueriesLog,
7795
explain_opts: dict[str, Any],
7896
*args: Any,
7997
**kwargs: Any,
8098
) -> None:
81-
# # https://www.postgresql.org/docs/current/sql-explain.html
82-
super().__init__(queries_log, *args, **kwargs)
83-
self.connection = connection
99+
super().__init__(connection, queries_log, *args, **kwargs)
84100
self.explain_info = self.build_explain_info(**explain_opts)
85101

86102
def __call__(
87103
self,
88-
execute: Callable[..., CursorWrapper | None],
104+
execute: Callable[..., CursorWrapper],
89105
sql: str,
90106
params: tuple[Any, ...],
91107
many: bool,
@@ -100,36 +116,31 @@ def update_query(self, query: Query) -> Query:
100116
return query
101117

102118
def _execute_explain(self, sql: str, params: tuple[Any, ...], many: bool) -> str | None:
103-
# Проверяем, является ли запрос SELECT-запросом
119+
# Checking whether the request is a SELECT request
104120
if not sql.strip().lower().startswith('select'):
105121
return None
106122

107-
# Обёртки для конкретных бд хранятся по адресам:
108-
# django.db.backends.sqlite3.operations.DatabaseOperations
109-
# django.db.backends.postgresql.operations.DatabaseOperations
110-
db_operations: BaseDatabaseOperations = self.connection.ops
111-
112-
explain_query = db_operations.explain_query_prefix(
123+
explain_query = self.db_operations.explain_query_prefix(
113124
self.explain_info.format, **self.explain_info.options
114125
)
115126
explain_query = f'{explain_query} {sql}'
116127

117128
try:
118129
raw_explain = self.__execute(explain_query, params, many)
119130
except Exception as exc:
120-
print('Что-то пошло не так:', exc)
131+
print('Something went wrong:', exc)
121132
return None
122133
else:
123134
return '\n'.join(self.format_explain(raw_explain))
124135

125136
def __execute(self, explain_query: str, params: tuple[Any, ...], many: bool) -> Explain:
126137
with self.connection.cursor() as cursor:
127-
# нельзя вызывать execute или executemany
128-
# который вызывает внутри себя обёртку,
129-
# потому что это приведёт к дублированию вызова,
130-
# и вызовет __call__ текущей обёртки
138+
# you cannot call execute or executemany
139+
# which invokes a wrapper inside itself,
140+
# because it will result in duplicate call,
141+
# and will call __call__ of the current wrapper
131142

132-
# данные запросы идут в обход обёрток
143+
# these requests bypass wrappers
133144
if many:
134145
cursor._executemany(explain_query, params)
135146
else:
@@ -139,15 +150,17 @@ def __execute(self, explain_query: str, params: tuple[Any, ...], many: bool) ->
139150

140151
def build_explain_info(self, *, format: str | None = None, **options: dict[str, Any]) -> ExplainInfo: # noqa: A002
141152
"""
142-
Runs an EXPLAIN on the SQL query this QuerySet would perform, and
143-
returns the results.
153+
Validates explain options and build ExplainInfo object.
144154
"""
145155
for option_name in options:
146156
if not self.EXPLAIN_OPTIONS_PATTERN.fullmatch(option_name) or '--' in option_name:
147157
raise ValueError(f'Invalid option name: {option_name!r}.')
148158
return self.ExplainInfo(format, options)
149159

150160
def format_explain(self, result: Explain) -> Generator[str, None, None]:
161+
"""
162+
Splits the explain tuple into its components and collects the final explain string from them.
163+
"""
151164
nested_result = [list(result)]
152165
# Some backends return 1 item tuples with strings, and others return
153166
# tuples with integers and strings. Flatten them out into strings.

tests/test_decorators.py

+94-15
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ def _select(reporter_id: UUID, article_id: UUID) -> None:
4141

4242

4343
class BasicTestsFor3ChoicesOfCaptureQueries:
44-
"""Обязательно должен быть передан аргумент -s при запуске pytest, для вывода output"""
44+
""""""
4545

4646
def setup_method(self, method: Callable[..., Any]) -> None:
4747
self.reporter, self.article = request_to_db()
4848

4949
def call_capture_queries(self, **kwargs: Any) -> CaptureQueries:
5050
raise NotImplementedError
5151

52+
# @pytest.mark.usefixtures('_debug_true')
5253
def test_basic_logic(self) -> None:
5354
obj = self.call_capture_queries()
5455

@@ -74,14 +75,14 @@ def test_basic_logic(self) -> None:
7475
data = obj.queries_log[0]['sql']
7576
assert obj.queries_log[0]['sql'] == (
7677
'SELECT "tests_reporter"."id", "tests_reporter"."full_name" '
77-
'FROM "tests_reporter" WHERE "tests_reporter"."id" = %s'
78+
'FROM "tests_reporter" WHERE "tests_reporter"."id" = %s' % self.reporter.pk
7879
), data
7980

8081
data = obj.queries_log[1]['sql']
8182
assert obj.queries_log[1]['sql'] == (
8283
'SELECT "tests_article"."id", "tests_article"."pub_date", "tests_article"."headline", '
8384
'"tests_article"."content", "tests_article"."reporter_id" '
84-
'FROM "tests_article" WHERE "tests_article"."id" = %s'
85+
'FROM "tests_article" WHERE "tests_article"."id" = %s' % self.article.pk
8586
), data
8687

8788
with pytest.raises(KeyError, match='explain'):
@@ -110,14 +111,14 @@ def test_param__explain(self) -> None:
110111
data = obj.queries_log[0]['sql']
111112
assert data == (
112113
'SELECT "tests_reporter"."id", "tests_reporter"."full_name" '
113-
'FROM "tests_reporter" WHERE "tests_reporter"."id" = %s'
114+
'FROM "tests_reporter" WHERE "tests_reporter"."id" = %s' % self.reporter.pk
114115
), data
115116

116117
data = obj.queries_log[1]['sql']
117118
assert data == (
118119
'SELECT "tests_article"."id", "tests_article"."pub_date", "tests_article"."headline", '
119120
'"tests_article"."content", "tests_article"."reporter_id" '
120-
'FROM "tests_article" WHERE "tests_article"."id" = %s'
121+
'FROM "tests_article" WHERE "tests_article"."id" = %s' % self.article.pk
121122
), data
122123

123124
# data = obj.queries_log[0]['explain']
@@ -134,6 +135,7 @@ def test_param__connection(self) -> None:
134135
class FakeConnection:
135136
vendor = 'fake_vendor'
136137
queries_limit = 4
138+
ops = connection.ops
137139

138140
@contextmanager
139141
def execute_wrapper(
@@ -177,6 +179,31 @@ def test_param__number_runs(self) -> None:
177179
gt, lt = 0.08, 0.13
178180
assert gt < data < lt, f'{gt} < {data} < {lt}'
179181

182+
def test_execute_raw_sql(self) -> None:
183+
reporter_raw_sql = (
184+
'SELECT "tests_reporter"."id", "tests_reporter"."full_name" '
185+
'FROM "tests_reporter" WHERE "tests_reporter"."id" = %s' % self.reporter.pk
186+
)
187+
article_raw_sql = (
188+
'SELECT "tests_article"."id", "tests_article"."pub_date", "tests_article"."headline", '
189+
'"tests_article"."content", "tests_article"."reporter_id" '
190+
'FROM "tests_article" WHERE "tests_article"."id" = %s' % self.article.pk
191+
)
192+
obj = CaptureQueries()
193+
for _ in obj:
194+
list(Reporter.objects.raw(reporter_raw_sql))
195+
list(Article.objects.raw(article_raw_sql))
196+
197+
data = obj.queries_log[0]['sql']
198+
assert data == (reporter_raw_sql), data
199+
200+
data = obj.queries_log[1]['sql']
201+
assert data == (article_raw_sql), data
202+
203+
def test_without_requests(self) -> None:
204+
for _ in CaptureQueries(advanced_verb=True):
205+
pass # no have requests
206+
180207

181208
@pytest.mark.django_db(transaction=True)
182209
class TestDecoratorCaptureQueries(BasicTestsFor3ChoicesOfCaptureQueries):
@@ -230,6 +257,34 @@ def test_param__auto_call_func(self) -> None:
230257
gt, lt = 0.08, 0.13
231258
assert gt < data < lt, f'{gt} < {data} < {lt}'
232259

260+
def test_execute_raw_sql(self) -> None:
261+
reporter_raw_sql = (
262+
'SELECT "tests_reporter"."id", "tests_reporter"."full_name" '
263+
'FROM "tests_reporter" WHERE "tests_reporter"."id" = %s' % self.reporter.pk
264+
)
265+
article_raw_sql = (
266+
'SELECT "tests_article"."id", "tests_article"."pub_date", "tests_article"."headline", '
267+
'"tests_article"."content", "tests_article"."reporter_id" '
268+
'FROM "tests_article" WHERE "tests_article"."id" = %s' % self.article.pk
269+
)
270+
obj = CaptureQueries(auto_call_func=True)
271+
272+
@obj
273+
def _() -> None:
274+
list(Reporter.objects.raw(reporter_raw_sql))
275+
list(Article.objects.raw(article_raw_sql))
276+
277+
data = obj.queries_log[0]['sql']
278+
assert data == (reporter_raw_sql), data
279+
280+
data = obj.queries_log[1]['sql']
281+
assert data == (article_raw_sql), data
282+
283+
def test_without_requests(self) -> None:
284+
@CaptureQueries(advanced_verb=True, auto_call_func=True)
285+
def func() -> None:
286+
pass # no have requests
287+
233288

234289
@pytest.mark.django_db(transaction=True)
235290
class TestContextManagerCaptureQueries(BasicTestsFor3ChoicesOfCaptureQueries):
@@ -238,26 +293,50 @@ class TestContextManagerCaptureQueries(BasicTestsFor3ChoicesOfCaptureQueries):
238293
def call_capture_queries(self, **kwargs: Any) -> CaptureQueries:
239294
with slow_down_execute(0.1): # noqa: SIM117
240295
with CaptureQueries(**kwargs) as obj:
241-
obj.current_iteration = 1 # XXX(Ars): Временный хак для тестов, позже надо переработать
296+
obj.current_iteration = 1 # XXX(Ars): A temporary hack only for tests, reworked later
242297
_select(self.reporter.pk, self.article.pk)
243298
return obj
244299

245-
# @pytest.mark.filterwarnings("ignore::UserWarning") # warn не отображается, и не вызывает ошибки
246-
# @pytest.mark.filterwarnings('default::UserWarning') # warn отображается, и не вызывает ошибки
300+
# @pytest.mark.filterwarnings("ignore::UserWarning") # warn not show, and not raise exc
301+
# @pytest.mark.filterwarnings('default::UserWarning') # warn show, and not raise exc
247302
def test_param__number_runs(self) -> None:
248303
with pytest.raises(
249304
UserWarning,
250305
match=(
251-
'При использовании: CaptureQueries как контекстного менеджера, '
252-
'параметр number_runs > 1 не используеться.'
306+
'When using: CaptureQueries as a context manager,'
307+
' the number_runs > 1 parameter is not used.'
253308
),
254309
):
255310
self.call_capture_queries(number_runs=3)
256311

312+
def test_execute_raw_sql(self) -> None:
313+
reporter_raw_sql = (
314+
'SELECT "tests_reporter"."id", "tests_reporter"."full_name" '
315+
'FROM "tests_reporter" WHERE "tests_reporter"."id" = %s' % self.reporter.pk
316+
)
317+
article_raw_sql = (
318+
'SELECT "tests_article"."id", "tests_article"."pub_date", "tests_article"."headline", '
319+
'"tests_article"."content", "tests_article"."reporter_id" '
320+
'FROM "tests_article" WHERE "tests_article"."id" = %s' % self.article.pk
321+
)
322+
with CaptureQueries() as obj:
323+
list(Reporter.objects.raw(reporter_raw_sql))
324+
list(Article.objects.raw(article_raw_sql))
325+
326+
data = obj.queries_log[0]['sql']
327+
assert data == (reporter_raw_sql), data
328+
329+
data = obj.queries_log[1]['sql']
330+
assert data == (article_raw_sql), data
331+
332+
def test_without_requests(self) -> None:
333+
with CaptureQueries(advanced_verb=True):
334+
pass # no have requests
335+
257336

258337
@pytest.mark.django_db(transaction=True)
259338
class TestOutputCaptureQueries:
260-
"""Обязательно должен быть передан аргумент -s при запуске pytest, для вывода output"""
339+
"""The -s argument must be passed when running py test to output output"""
261340

262341
# @pytest.mark.usefixtures('_debug_true')
263342
def test_capture_queries_loop(self, intercept_output: StringIO) -> None:
@@ -271,7 +350,7 @@ def test_capture_queries_loop(self, intercept_output: StringIO) -> None:
271350
assert re.match(
272351
f'\n\nTests count: 100 | Total queries count: 200 | Total execution time: {ANYNUM}s | Median time one test is: {ANYNUM}s\n', # noqa: E501
273352
output,
274-
), 'incorrect output'
353+
), f'incorrect output = {output}'
275354

276355
# @pytest.mark.usefixtures('_debug_true')
277356
def test_capture_queries_decorator(self, intercept_output: StringIO) -> None:
@@ -286,7 +365,7 @@ def _() -> None:
286365
assert re.match(
287366
f'\n\nTests count: 100 | Total queries count: 200 | Total execution time: {ANYNUM}s | Median time one test is: {ANYNUM}s\n', # noqa: E501
288367
output,
289-
), 'incorrect output'
368+
), f'incorrect output = {output}'
290369

291370
# @pytest.mark.usefixtures('_debug_true')
292371
def test_capture_queries_context_manager(self, intercept_output: StringIO) -> None:
@@ -296,9 +375,9 @@ def test_capture_queries_context_manager(self, intercept_output: StringIO) -> No
296375
output = intercept_output.getvalue()
297376

298377
assert re.match(
299-
f'Queries count: 2 | Execution time: {ANYNUM}s',
378+
f'\nQueries count: 2 | Execution time: {ANYNUM}s | Vendor: sqlite\n\n',
300379
output,
301-
), 'incorrect output'
380+
), f'incorrect output = {output}'
302381

303382

304383
@pytest.mark.django_db(transaction=True)

0 commit comments

Comments
 (0)