11import datetime
2+ import hashlib
23import json
34
45import sqlparse
56import websocket # Using websocket-client library for synchronous operations
7+ from django .conf import settings
68from django .db import IntegrityError , DatabaseError
79from django .db .backends .sqlite3 .base import DatabaseWrapper as SQLiteDatabaseWrapper
810from django .db .backends .sqlite3 .client import DatabaseClient as SQLiteDatabaseClient
1416from django .utils import timezone
1517from sqlparse .sql import IdentifierList , Identifier
1618from sqlparse .tokens import DML
19+ from django .core .cache import cache
1720
1821
1922class DatabaseFeatures (SQLiteDatabaseFeatures ):
2023 supports_transactions = True
2124 supports_savepoints = False
25+ max_query_params = 100
2226
2327
2428class DatabaseOperations (SQLiteDatabaseOperations ):
@@ -140,6 +144,13 @@ def _convert_results(self, results):
140144 converted_results .append (tuple (converted_row ))
141145 return converted_results
142146
147+ def generate_cache_key (self , sql_query : str ) -> str :
148+ # Create a SHA256 hash of the combined string
149+ hash_object = hashlib .sha256 (sql_query .encode ())
150+
151+ # Return the hexadecimal representation of the hash
152+ return f"workers_dbms_{ hash_object .hexdigest ()} "
153+
143154 def raw_query (self , websocket , query , params = None ):
144155 if params == None :
145156 if query .strip () == 'PRAGMA foreign_keys = OFF' :
@@ -167,7 +178,7 @@ def raw_query(self, websocket, query, params=None):
167178 if params :
168179 sql , params = self ._format_params (sql , params )
169180
170- websocket . send ( json .dumps ({
181+ socket_input = json .dumps ({
171182 "type" : "request" ,
172183 "request" : {
173184 "type" : "execute" ,
@@ -176,28 +187,55 @@ def raw_query(self, websocket, query, params=None):
176187 "query" : sql
177188 }
178189 }
179- }))
190+ })
180191
181- if self .connection .debug is True :
182- print (sql )
183- print (params )
192+ should_cache = False
193+ response = None
194+ if self .connection .cache is True :
195+ upper_query = query .upper ()
196+ user_model = settings .AUTH_USER_MODEL .replace ('.' , '_' ).upper ()
197+ try :
198+ if 'FROM "DJANGO_SESSION"' in upper_query or f'FROM "{ user_model } "' in upper_query :
199+ if "UPDATE" not in upper_query and "DELETE" not in upper_query and "INSERT" not in upper_query :
200+ cache_key = self .generate_cache_key (socket_input )
201+ response = cache .get (cache_key )
184202
185- response = websocket .recv ()
186- parsed_response = json .loads (response )
203+ if not response :
204+ should_cache = True
205+ else :
206+ pass # TODO: clear cache
207+ # cache.clear(prefix="workers_dbms_")
208+ except TypeError :
209+ pass
187210
188- if self .connection .debug is True :
189- print (parsed_response )
190- print ('---' )
211+ if not response :
212+ websocket .send (socket_input )
213+
214+ if self .connection .debug is True :
215+ print (sql )
216+ print (params )
217+
218+ response = websocket .recv ()
219+
220+ parsed_response = json .loads (response )
191221
192222 if parsed_response ["type" ] == "response_error" :
193223 if "unique constraint failed" in parsed_response ["error" ].lower ():
194224 raise IntegrityError (parsed_response ["error" ])
195225
196226 raise DatabaseError (parsed_response ["error" ] + "\n " + sql )
197227
228+ if self .connection .cache is True and should_cache is True :
229+ cache_key = self .generate_cache_key (socket_input )
230+ cache .set (cache_key , response , 300 )
231+
198232 results = self ._convert_results (list (tuple (row ) for row in parsed_response ["result" ]["results" ]))
199233 meta = parsed_response ["result" ].get ("meta" )
200234
235+ if self .connection .debug is True :
236+ print (results )
237+ print ('---' )
238+
201239 return results , meta
202240
203241 def quote_name (self , name ):
@@ -259,6 +297,7 @@ def bulk_insert_sql(self, fields, placeholder_rows):
259297class DatabaseWrapper (SQLiteDatabaseWrapper ):
260298 vendor = 'websocket'
261299 debug = False
300+ cache = False
262301
263302 def __init__ (self , * args , ** kwargs ):
264303 super ().__init__ (* args , ** kwargs )
@@ -277,11 +316,14 @@ def get_connection_params(self):
277316 'endpoint_url' : settings_dict ['WORKERS_DBMS_ENDPOINT' ],
278317 'access_id' : settings_dict .get ('WORKERS_DBMS_ACCESS_ID' ),
279318 'access_secret' : settings_dict .get ('WORKERS_DBMS_ACCESS_SECRET' ),
319+ 'cache' : settings_dict .get ('WORKERS_DBMS_CACHE' , True ),
280320 'debug' : settings_dict .get ('WORKERS_DBMS_DEBUG' ),
281321 }
282322
283323 def get_new_connection (self , conn_params ):
284324 headers = []
325+ if conn_params ['cache' ]:
326+ self .cache = conn_params ['cache' ]
285327 if conn_params ['debug' ]:
286328 self .debug = conn_params ['debug' ] is True
287329
@@ -347,19 +389,23 @@ def close(self):
347389 # self.websocket.close()
348390
349391 def execute (self , sql , params = None ):
350- result , meta = self .ops .raw_query (self .websocket , sql , params )
392+ try :
393+ result , meta = self .ops .raw_query (self .websocket , sql , params )
351394
352- self .results = result
395+ self .results = result
353396
354- # Update rowcount based on the operation type
355- if meta :
356- if "INSERT" in sql .upper ():
357- self .rowcount = meta .get ("rows_written" , 0 )
358- # self.connection.ops.last_insert_id = meta.get("last_insert_id") # TODO: implement last insert id
359- elif "UPDATE" in sql .upper () or "DELETE" in sql .upper ():
360- self .rowcount = meta .get ("rows_written" , 0 )
361- else :
362- self .rowcount = meta .get ("rows_read" , 0 )
397+ # Update rowcount based on the operation type
398+ if meta :
399+ if "INSERT" in sql .upper ():
400+ self .rowcount = meta .get ("rows_written" , 0 )
401+ # self.connection.ops.last_insert_id = meta.get("last_insert_id") # TODO: implement last insert id
402+ elif "UPDATE" in sql .upper () or "DELETE" in sql .upper ():
403+ self .rowcount = meta .get ("rows_written" , 0 )
404+ else :
405+ self .rowcount = meta .get ("rows_read" , 0 )
406+ except Exception as e :
407+ self .results = []
408+ raise DatabaseError (str (e ))
363409
364410 return self
365411
@@ -380,12 +426,13 @@ def fetchmany(self, size=None):
380426 def fetchall (self ):
381427 if self .results :
382428 results = self .results
383- self .results = None
429+ self .results = [] # Clear the results after fetching
384430 return results
385431 return []
386432
387433 def __iter__ (self ):
388- return iter (self .fetchall ())
434+ while self .results :
435+ yield self .fetchone ()
389436
390437 @property
391438 def rowcount (self ):
0 commit comments