Skip to content

Commit 94b7a45

Browse files
author
r0BIT
committed
fix(cache): thread-safe SQLite with per-thread connections (B5)
- Use threading.local() for per-thread SQLite connections - Each worker thread gets its own connection to shared DB file - WAL mode ensures concurrent access works correctly - Track all connections in _connections list for cleanup - Add _get_conn() method for lazy connection creation Tests: - Update tests to use _get_conn() instead of self.conn - Add test_concurrent_persistent_access for thread safety - All 1340 tests pass Fixes 'SQLite objects created in a thread' errors when using --threads
1 parent 44eb1d5 commit 94b7a45

File tree

2 files changed

+163
-56
lines changed

2 files changed

+163
-56
lines changed

taskhound/utils/cache_manager.py

Lines changed: 105 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
99
Thread-safety:
1010
- Session cache protected by RLock for concurrent access
11-
- SQLite uses WAL mode for concurrent reads/writes
11+
- SQLite uses per-thread connections via threading.local() for thread safety
12+
- SQLite uses WAL mode for concurrent reads/writes across threads
1213
"""
1314

1415
import contextlib
@@ -28,7 +29,8 @@ class CacheManager:
2829
2930
Thread-safe:
3031
- Session cache (in-memory dict) protected by RLock
31-
- SQLite uses WAL mode for concurrent access
32+
- SQLite uses per-thread connections (threading.local) for thread affinity
33+
- SQLite uses WAL mode for concurrent access across threads
3234
"""
3335

3436
def __init__(self, cache_file: Optional[Path] = None, ttl_hours: int = 24, enabled: bool = True):
@@ -42,7 +44,9 @@ def __init__(self, cache_file: Optional[Path] = None, ttl_hours: int = 24, enabl
4244
"""
4345
self.ttl_hours = ttl_hours
4446
self.persistent_enabled = enabled
45-
self.conn = None
47+
48+
# Thread-local storage for per-thread SQLite connections
49+
self._local = threading.local()
4650

4751
# Thread-safety: RLock for session cache access
4852
self._session_lock = threading.RLock()
@@ -68,20 +72,61 @@ def __init__(self, cache_file: Optional[Path] = None, ttl_hours: int = 24, enabl
6872
"expired": 0,
6973
}
7074

71-
# Initialize persistent cache if enabled
75+
# Track all connections for cleanup (protected by _session_lock)
76+
self._connections: list[sqlite3.Connection] = []
77+
78+
# Initialize persistent cache if enabled (creates schema)
7279
if self.persistent_enabled:
7380
self._init_db()
7481

82+
def _get_conn(self) -> Optional[sqlite3.Connection]:
83+
"""
84+
Get thread-local SQLite connection, creating one if needed.
85+
86+
SQLite connections have thread affinity - they can only be used in the
87+
thread that created them. This method ensures each thread gets its own
88+
connection to the shared database file.
89+
90+
Returns:
91+
SQLite connection for current thread, or None if disabled/failed
92+
"""
93+
if not self.persistent_enabled:
94+
return None
95+
96+
# Check for existing connection in this thread
97+
conn = getattr(self._local, "conn", None)
98+
if conn is not None:
99+
return conn
100+
101+
# Create new connection for this thread
102+
try:
103+
conn = sqlite3.connect(self.cache_file, timeout=10.0)
104+
# WAL mode allows concurrent reads while writing
105+
conn.execute("PRAGMA journal_mode=WAL")
106+
conn.execute("PRAGMA synchronous=NORMAL")
107+
self._local.conn = conn
108+
109+
# Track for cleanup
110+
with self._session_lock:
111+
self._connections.append(conn)
112+
113+
debug(f"Created new cache connection for thread {threading.current_thread().name}")
114+
return conn
115+
except Exception as e:
116+
debug(f"Failed to create cache connection: {e}")
117+
return None
118+
75119
def _init_db(self):
76-
"""Initialize SQLite database and schema."""
120+
"""Initialize SQLite database schema (called once at startup)."""
77121
try:
78-
self.conn = sqlite3.connect(self.cache_file, timeout=10.0)
79-
# Enable WAL mode for better concurrency
80-
self.conn.execute("PRAGMA journal_mode=WAL")
81-
self.conn.execute("PRAGMA synchronous=NORMAL")
122+
# Get connection for main thread (also creates schema)
123+
conn = self._get_conn()
124+
if not conn:
125+
self.persistent_enabled = False
126+
return
82127

83128
# Create table if not exists
84-
self.conn.execute("""
129+
conn.execute("""
85130
CREATE TABLE IF NOT EXISTS cache (
86131
category TEXT,
87132
key TEXT,
@@ -92,14 +137,13 @@ def _init_db(self):
92137
""")
93138

94139
# Create index for expiration cleanup
95-
self.conn.execute("""
140+
conn.execute("""
96141
CREATE INDEX IF NOT EXISTS idx_expires_at ON cache(expires_at)
97142
""")
98143

99-
self.conn.commit()
144+
conn.commit()
100145

101-
# Opportunistic cleanup of expired entries (1% chance on init to avoid thundering herd)
102-
# Or just do it always since it's fast with an index
146+
# Opportunistic cleanup of expired entries
103147
self._prune_expired()
104148

105149
except Exception as e:
@@ -108,16 +152,18 @@ def _init_db(self):
108152

109153
def _prune_expired(self):
110154
"""Remove expired entries from database."""
111-
if not self.conn:
155+
conn = self._get_conn()
156+
if not conn:
112157
return
113158

114159
try:
115160
now = time.time()
116-
cursor = self.conn.execute("DELETE FROM cache WHERE expires_at < ?", (now,))
161+
cursor = conn.execute("DELETE FROM cache WHERE expires_at < ?", (now,))
117162
if cursor.rowcount > 0:
118163
debug(f"Pruned {cursor.rowcount} expired cache entries")
119-
self.stats["expired"] += cursor.rowcount
120-
self.conn.commit()
164+
with self._session_lock:
165+
self.stats["expired"] += cursor.rowcount
166+
conn.commit()
121167
except Exception as e:
122168
debug(f"Error pruning cache: {e}")
123169

@@ -144,11 +190,12 @@ def get(self, category: str, key: str) -> Optional[Any]:
144190
return self.session[session_key]
145191
self.stats["session_misses"] += 1
146192

147-
# Tier 2: Check persistent cache (SQLite handles concurrency via WAL)
148-
if self.persistent_enabled and self.conn:
193+
# Tier 2: Check persistent cache (thread-local connection)
194+
conn = self._get_conn()
195+
if conn:
149196
try:
150197
now = time.time()
151-
cursor = self.conn.execute(
198+
cursor = conn.execute(
152199
"SELECT value, expires_at FROM cache WHERE category=? AND key=?", (category, key)
153200
)
154201
row = cursor.fetchone()
@@ -162,8 +209,8 @@ def get(self, category: str, key: str) -> Optional[Any]:
162209
with self._session_lock:
163210
self.stats["expired"] += 1
164211
# Lazy delete
165-
self.conn.execute("DELETE FROM cache WHERE category=? AND key=?", (category, key))
166-
self.conn.commit()
212+
conn.execute("DELETE FROM cache WHERE category=? AND key=?", (category, key))
213+
conn.commit()
167214
return None
168215

169216
# Valid hit
@@ -189,7 +236,7 @@ def set(self, category: str, key: str, value: Any, ttl_hours: Optional[int] = No
189236
"""
190237
Store value in both session and persistent caches.
191238
192-
Thread-safe: Uses RLock for session cache, SQLite handles persistence.
239+
Thread-safe: Uses RLock for session cache, per-thread SQLite connection.
193240
194241
Args:
195242
category: Cache category ("computers", "users", "sids")
@@ -203,18 +250,19 @@ def set(self, category: str, key: str, value: Any, ttl_hours: Optional[int] = No
203250
with self._session_lock:
204251
self.session[session_key] = value
205252

206-
# Store in persistent cache if enabled (SQLite handles concurrency)
207-
if self.persistent_enabled and self.conn:
253+
# Store in persistent cache (thread-local connection)
254+
conn = self._get_conn()
255+
if conn:
208256
try:
209257
ttl = ttl_hours if ttl_hours is not None else self.ttl_hours
210258
expires_at = time.time() + (ttl * 3600)
211259
value_json = json.dumps(value)
212260

213-
self.conn.execute(
261+
conn.execute(
214262
"INSERT OR REPLACE INTO cache (category, key, value, expires_at) VALUES (?, ?, ?, ?)",
215263
(category, key, value_json, expires_at),
216264
)
217-
self.conn.commit()
265+
conn.commit()
218266
debug(f"Cache store: {category}:{key}")
219267
except Exception as e:
220268
debug(f"Cache write error: {e}")
@@ -223,7 +271,7 @@ def delete(self, category: str, key: str):
223271
"""
224272
Remove value from both session and persistent caches.
225273
226-
Thread-safe: Uses RLock for session cache.
274+
Thread-safe: Uses RLock for session cache, per-thread SQLite connection.
227275
228276
Args:
229277
category: Cache category
@@ -236,11 +284,12 @@ def delete(self, category: str, key: str):
236284
if session_key in self.session:
237285
del self.session[session_key]
238286

239-
# Remove from persistent cache (SQLite handles concurrency)
240-
if self.persistent_enabled and self.conn:
287+
# Remove from persistent cache (thread-local connection)
288+
conn = self._get_conn()
289+
if conn:
241290
try:
242-
self.conn.execute("DELETE FROM cache WHERE category=? AND key=?", (category, key))
243-
self.conn.commit()
291+
conn.execute("DELETE FROM cache WHERE category=? AND key=?", (category, key))
292+
conn.commit()
244293
except Exception as e:
245294
debug(f"Cache delete error: {e}")
246295

@@ -256,12 +305,13 @@ def get_all(self, category: str) -> Dict[str, Any]:
256305
"""
257306
result: Dict[str, Any] = {}
258307

259-
if not self.persistent_enabled or not self.conn:
308+
conn = self._get_conn()
309+
if not conn:
260310
return result
261311

262312
try:
263313
now = time.time()
264-
cursor = self.conn.execute(
314+
cursor = conn.execute(
265315
"SELECT key, value, expires_at FROM cache WHERE category=?", (category,)
266316
)
267317

@@ -284,10 +334,10 @@ def get_all(self, category: str) -> Dict[str, Any]:
284334
# Clean up expired entries
285335
if expired_keys:
286336
for key in expired_keys:
287-
self.conn.execute(
337+
conn.execute(
288338
"DELETE FROM cache WHERE category=? AND key=?", (category, key)
289339
)
290-
self.conn.commit()
340+
conn.commit()
291341
debug(f"Cleaned up {len(expired_keys)} expired {category} cache entries")
292342

293343
except Exception as e:
@@ -313,29 +363,33 @@ def invalidate(self, category: Optional[str] = None, key: Optional[str] = None):
313363
session_key = f"{category}:{key}"
314364
self.session.pop(session_key, None)
315365

316-
# Clear persistent cache
317-
if self.persistent_enabled and self.conn:
366+
# Clear persistent cache (thread-local connection)
367+
conn = self._get_conn()
368+
if conn:
318369
try:
319370
if category is None and key is None:
320-
self.conn.execute("DELETE FROM cache")
371+
conn.execute("DELETE FROM cache")
321372
info("Cache cleared (all entries)")
322373
elif category is not None and key is None:
323-
self.conn.execute("DELETE FROM cache WHERE category=?", (category,))
374+
conn.execute("DELETE FROM cache WHERE category=?", (category,))
324375
info(f"Cache cleared (category: {category})")
325376
elif category is not None and key is not None:
326-
self.conn.execute("DELETE FROM cache WHERE category=? AND key=?", (category, key))
377+
conn.execute("DELETE FROM cache WHERE category=? AND key=?", (category, key))
327378
debug(f"Cache invalidated: {category}:{key}")
328379

329-
self.conn.commit()
380+
conn.commit()
330381
except Exception as e:
331382
warn(f"Cache invalidation error: {e}")
332383

333384
def close(self):
334-
"""Close database connection."""
335-
if self.conn:
336-
with contextlib.suppress(Exception):
337-
self.conn.close()
338-
self.conn = None
385+
"""Close all database connections (one per thread)."""
386+
with self._session_lock:
387+
for conn in self._connections:
388+
with contextlib.suppress(Exception):
389+
conn.close()
390+
self._connections.clear()
391+
# Clear thread-local connection reference
392+
self._local.conn = None
339393

340394
# ==========================================
341395
# Host Deduplication (Session-only)
@@ -385,9 +439,10 @@ def print_stats(self):
385439
info(f" Misses: {self.stats['persistent_misses']}")
386440
info(f" Expired: {self.stats['expired']}")
387441

388-
if self.persistent_enabled and self.conn:
442+
conn = self._get_conn()
443+
if conn:
389444
try:
390-
cursor = self.conn.execute("SELECT COUNT(*) FROM cache")
445+
cursor = conn.execute("SELECT COUNT(*) FROM cache")
391446
total_cached = cursor.fetchone()[0]
392447
info(f" Persistent cache size: {total_cached} entries")
393448
except Exception:

0 commit comments

Comments
 (0)