1
+ import copy
1
2
import logging
2
3
import importlib
3
4
4
5
from django .db import connections
5
6
7
+ from aws_xray_sdk .core import xray_recorder
6
8
from aws_xray_sdk .ext .dbapi2 import XRayTracedCursor
7
9
8
10
log = logging .getLogger (__name__ )
9
11
10
12
11
13
def patch_db ():
12
-
13
14
for conn in connections .all ():
14
15
module = importlib .import_module (conn .__module__ )
15
16
_patch_conn (getattr (module , conn .__class__ .__name__ ))
16
17
17
18
18
- def _patch_conn (conn ):
19
-
20
- attr = '_xray_original_cursor'
19
+ class DjangoXRayTracedCursor (XRayTracedCursor ):
20
+ def execute (self , query , * args , ** kwargs ):
21
+ if xray_recorder .stream_sql :
22
+ _previous_meta = copy .copy (self ._xray_meta )
23
+ self ._xray_meta ['sanitized_query' ] = query
24
+ result = super (DjangoXRayTracedCursor , self ).execute (query , * args , ** kwargs )
25
+ if xray_recorder .stream_sql :
26
+ self ._xray_meta = _previous_meta
27
+ return result
28
+
29
+ def executemany (self , query , * args , ** kwargs ):
30
+ if xray_recorder .stream_sql :
31
+ _previous_meta = copy .copy (self ._xray_meta )
32
+ self ._xray_meta ['sanitized_query' ] = query
33
+ result = super (DjangoXRayTracedCursor , self ).executemany (query , * args , ** kwargs )
34
+ if xray_recorder .stream_sql :
35
+ self ._xray_meta = _previous_meta
36
+ return result
37
+
38
+ def callproc (self , proc , args ):
39
+ if xray_recorder .stream_sql :
40
+ _previous_meta = copy .copy (self ._xray_meta )
41
+ self ._xray_meta ['sanitized_query' ] = proc
42
+ result = super (DjangoXRayTracedCursor , self ).callproc (proc , args )
43
+ if xray_recorder .stream_sql :
44
+ self ._xray_meta = _previous_meta
45
+ return result
46
+
47
+
48
+ def _patch_cursor (cursor_name , conn ):
49
+ attr = '_xray_original_{}' .format (cursor_name )
21
50
22
51
if hasattr (conn , attr ):
23
- log .debug ('django built-in db already patched' )
52
+ log .debug ('django built-in db {} already patched' .format (cursor_name ))
53
+ return
54
+
55
+ if not hasattr (conn , cursor_name ):
56
+ log .debug ('django built-in db does not have {}' .format (cursor_name ))
24
57
return
25
58
26
- setattr (conn , attr , conn . cursor )
59
+ setattr (conn , attr , getattr ( conn , cursor_name ) )
27
60
28
61
meta = {}
29
62
@@ -45,7 +78,12 @@ def cursor(self, *args, **kwargs):
45
78
if user :
46
79
meta ['user' ] = user
47
80
48
- return XRayTracedCursor (
49
- self ._xray_original_cursor (* args , ** kwargs ), meta )
81
+ original_cursor = getattr (self , attr )(* args , ** kwargs )
82
+ return DjangoXRayTracedCursor (original_cursor , meta )
83
+
84
+ setattr (conn , cursor_name , cursor )
50
85
51
- conn .cursor = cursor
86
+
87
+ def _patch_conn (conn ):
88
+ _patch_cursor ('cursor' , conn )
89
+ _patch_cursor ('chunked_cursor' , conn )
0 commit comments