1- import psycopg2
2- from datetime import datetime
3- from typing import List , Optional , Tuple
1+ """
2+ PostgreSQL Database Interface for Smack Messaging Service.
3+
4+ This module provides a robust database interface for storing and retrieving messages
5+ with proper connection management, error handling, and logging.
6+ """
47import os
58import atexit
9+ import logging
10+ from datetime import datetime
11+ from typing import List , Optional , Tuple , Dict , Any , Union
12+ import time
13+
14+ import psycopg2
15+ from psycopg2 import pool
16+
17+ # Configure logging
18+ logger = logging .getLogger (__name__ )
619
720class MessageDB :
8- """PostgreSQL database interface for message storage with a single persistent connection."""
9- _connection = None
21+ """PostgreSQL database interface for message storage with connection pooling."""
22+
23+ # Class-level connection pool
24+ _connection_pool = None
25+ _pool_min_conn = 1
26+ _pool_max_conn = 10
1027
1128 def __init__ (self ):
12- """Initialize the database connection if not already initialized."""
29+ """Initialize the database connection pool if not already initialized."""
1330 if not hasattr (self , '_initialized' ):
14- self ._initialize_connection ()
31+ self ._initialize_connection_pool ()
1532 # Register cleanup function
1633 atexit .register (self .close_connection )
1734 self ._initialized = True
1835
19- def _initialize_connection (self ):
20- """Initialize the database connection."""
21- # Get connection string from environment variable with fallback to individual parameters
36+ def _initialize_connection_pool (self ) -> None :
37+ """Initialize the database connection pool with retry logic ."""
38+ # Get connection parameters from environment with secure defaults
2239 database_url = os .environ .get ('DATABASE_URL' )
40+ max_retries = int (os .environ .get ('DB_MAX_RETRIES' , '5' ))
41+ retry_delay = int (os .environ .get ('DB_RETRY_DELAY' , '2' ))
2342
24- try :
25- if database_url :
26- # Use the complete DATABASE_URL if available
27- self ._connection = psycopg2 .connect (database_url )
28- print (f"PostgreSQL connection established using DATABASE_URL" )
29- else :
30- # Fallback to individual parameters
31- self .db_host = os .environ .get ('DB_HOST' , 'postgres' )
32- self .db_name = os .environ .get ('DB_NAME' , 'smack' )
33- self .db_user = os .environ .get ('DB_USER' , 'postgres' )
34- self .db_password = os .environ .get ('DB_PASSWORD' , 'postgres' )
35- self .db_port = os .environ .get ('DB_PORT' , '5432' )
43+ # Set pool size from environment or use defaults
44+ self ._pool_min_conn = int (os .environ .get ('DB_MIN_CONNECTIONS' , '1' ))
45+ self ._pool_max_conn = int (os .environ .get ('DB_MAX_CONNECTIONS' , '10' ))
46+
47+ retry_count = 0
48+ last_error = None
49+
50+ while retry_count < max_retries :
51+ try :
52+ if database_url :
53+ # Use the complete DATABASE_URL if available
54+ self ._connection_pool = pool .ThreadedConnectionPool (
55+ self ._pool_min_conn ,
56+ self ._pool_max_conn ,
57+ database_url
58+ )
59+ logger .info ("PostgreSQL connection pool established using DATABASE_URL" )
60+ else :
61+ # Fallback to individual parameters with secure defaults
62+ self .db_host = os .environ .get ('DB_HOST' , 'localhost' )
63+ self .db_name = os .environ .get ('DB_NAME' , 'smack' )
64+ self .db_user = os .environ .get ('DB_USER' , 'postgres' )
65+ self .db_password = os .environ .get ('DB_PASSWORD' , '' )
66+ self .db_port = os .environ .get ('DB_PORT' , '5432' )
67+
68+ self ._connection_pool = pool .ThreadedConnectionPool (
69+ self ._pool_min_conn ,
70+ self ._pool_max_conn ,
71+ host = self .db_host ,
72+ database = self .db_name ,
73+ user = self .db_user ,
74+ password = self .db_password ,
75+ port = self .db_port
76+ )
77+ logger .info (f"PostgreSQL connection pool established to { self .db_host } :{ self .db_port } /{ self .db_name } " )
3678
37- self ._connection = psycopg2 .connect (
38- host = self .db_host ,
39- database = self .db_name ,
40- user = self .db_user ,
41- password = self .db_password ,
42- port = self .db_port
43- )
44- print (f"PostgreSQL connection established to { self .db_host } :{ self .db_port } /{ self .db_name } " )
45-
46- self ._init_db ()
47- except Exception as e :
48- print (f"Error connecting to PostgreSQL: { e } " )
49- self ._connection = None
79+ # Initialize database schema
80+ self ._init_db ()
81+ return
82+
83+ except Exception as e :
84+ last_error = e
85+ retry_count += 1
86+ logger .warning (f"Connection attempt { retry_count } /{ max_retries } failed: { e } " )
87+ if retry_count < max_retries :
88+ logger .info (f"Retrying in { retry_delay } seconds..." )
89+ time .sleep (retry_delay )
90+
91+ # If we get here, all retries failed
92+ logger .error (f"Failed to establish database connection after { max_retries } attempts: { last_error } " )
93+ raise ConnectionError (f"Could not connect to database: { last_error } " )
5094
5195 def _init_db (self ) -> None :
52- """Initialize the database and create the messages table if it doesn't exist."""
96+ """Initialize the database schema if it doesn't exist."""
97+ connection = None
5398 try :
54- cursor = self ._connection .cursor ()
99+ connection = self ._get_connection ()
100+ cursor = connection .cursor ()
101+
102+ # Create messages table with proper indexing
55103 cursor .execute ('''
56104 CREATE TABLE IF NOT EXISTS messages (
57105 id SERIAL PRIMARY KEY,
@@ -60,38 +108,58 @@ def _init_db(self) -> None:
60108 timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
61109 )
62110 ''' )
63- self . _connection .commit ()
111+ connection .commit ()
64112 cursor .close ()
113+ logger .info ("Database schema initialized successfully" )
65114 except Exception as e :
66- print (f"Error initializing database: { e } " )
67- if self ._connection :
68- self ._connection .rollback ()
115+ logger .error (f"Error initializing database schema: { e } " )
116+ if connection :
117+ connection .rollback ()
118+ raise
119+ finally :
120+ self ._return_connection (connection )
69121
70- def _ensure_connection (self ):
71- """Ensure the connection is active, reconnecting if necessary ."""
72- if self . _connection is None or self ._connection . closed :
73- self . _initialize_connection ( )
74-
75- # Test connection with a simple query
122+ def _get_connection (self ):
123+ """Get a connection from the pool with validation ."""
124+ if not self ._connection_pool :
125+ logger . warning ( "Connection pool not initialized, attempting to reconnect" )
126+ self . _initialize_connection_pool ()
127+
76128 try :
77- cursor = self ._connection .cursor ()
129+ connection = self ._connection_pool .getconn ()
130+ # Test connection with a simple query
131+ cursor = connection .cursor ()
78132 cursor .execute ("SELECT 1" )
79133 cursor .close ()
134+ return connection
80135 except Exception as e :
81- print (f"Connection test failed, reconnecting: { e } " )
82- self ._initialize_connection ()
83-
84- if self ._connection is None or self ._connection .closed :
85- raise Exception ("Failed to establish database connection" )
86-
87- return self ._connection
136+ logger .error (f"Failed to get valid connection from pool: { e } " )
137+ # Try to reinitialize the pool
138+ self ._initialize_connection_pool ()
139+ return self ._connection_pool .getconn ()
140+
141+ def _return_connection (self , connection ):
142+ """Return a connection to the pool safely."""
143+ if connection and self ._connection_pool :
144+ try :
145+ self ._connection_pool .putconn (connection )
146+ except Exception as e :
147+ logger .warning (f"Error returning connection to pool: { e } " )
148+ # Try to close it directly if returning fails
149+ try :
150+ connection .close ()
151+ except :
152+ pass
88153
89- def close_connection (self ):
90- """Close the database connection."""
91- if self ._connection :
92- self ._connection .close ()
93- print ("PostgreSQL connection closed" )
94- self ._connection = None
154+ async def close_connection (self ):
155+ """Close the database connection pool."""
156+ if self ._connection_pool :
157+ try :
158+ self ._connection_pool .closeall ()
159+ logger .info ("PostgreSQL connection pool closed" )
160+ self ._connection_pool = None
161+ except Exception as e :
162+ logger .error (f"Error closing connection pool: { e } " )
95163
96164 def add_message (self , sender : str , content : str ) -> bool :
97165 """
@@ -104,41 +172,63 @@ def add_message(self, sender: str, content: str) -> bool:
104172 Returns:
105173 bool: True if message was added successfully, False otherwise
106174 """
175+ connection = None
107176 try :
108- conn = self ._ensure_connection ()
109- cursor = conn .cursor ()
177+ # Input validation
178+ if not sender or not sender .strip ():
179+ logger .warning ("Attempted to add message with empty sender" )
180+ return False
181+
182+ if not content or not content .strip ():
183+ logger .warning (f"Attempted to add empty message from { sender } " )
184+ return False
185+
186+ connection = self ._get_connection ()
187+ cursor = connection .cursor ()
110188 cursor .execute (
111- 'INSERT INTO messages (sender, content) VALUES (%s, %s)' ,
189+ 'INSERT INTO messages (sender, content) VALUES (%s, %s) RETURNING id ' ,
112190 (sender , content )
113191 )
114- conn .commit ()
192+ message_id = cursor .fetchone ()[0 ]
193+ connection .commit ()
115194 cursor .close ()
195+ logger .info (f"Message added successfully with ID { message_id } " )
116196 return True
117197 except Exception as e :
118- print (f"Error adding message to database: { e } " )
119- if self . _connection :
120- self . _connection .rollback ()
198+ logger . error (f"Error adding message to database: { e } " )
199+ if connection :
200+ connection .rollback ()
121201 return False
202+ finally :
203+ self ._return_connection (connection )
122204
123- def get_all_messages (self ) -> List [Tuple [int , str , str , str ]]:
205+ def get_all_messages (self , limit : int = 100 ) -> List [Tuple [int , str , str , str ]]:
124206 """
125- Retrieve all messages from the database.
207+ Retrieve messages from the database with pagination .
126208
209+ Args:
210+ limit: Maximum number of messages to retrieve (default: 100)
211+
127212 Returns:
128213 List of tuples containing (id, sender, content, timestamp)
129214 """
215+ connection = None
130216 try :
131- conn = self ._ensure_connection ()
132- cursor = conn .cursor ()
217+ connection = self ._get_connection ()
218+ cursor = connection .cursor ()
133219 cursor .execute (
134- 'SELECT id, sender, content, timestamp FROM messages ORDER BY timestamp DESC'
220+ 'SELECT id, sender, content, timestamp FROM messages ORDER BY timestamp DESC LIMIT %s' ,
221+ (limit ,)
135222 )
136223 messages = cursor .fetchall ()
137224 cursor .close ()
225+ logger .info (f"Retrieved { len (messages )} messages successfully" )
138226 return messages
139227 except Exception as e :
140- print (f"Error retrieving messages from database: { e } " )
228+ logger . error (f"Error retrieving messages from database: { e } " )
141229 return []
230+ finally :
231+ self ._return_connection (connection )
142232
143233 def get_message_by_id (self , message_id : int ) -> Optional [Tuple [int , str , str , str ]]:
144234 """
@@ -150,19 +240,32 @@ def get_message_by_id(self, message_id: int) -> Optional[Tuple[int, str, str, st
150240 Returns:
151241 Tuple containing (id, sender, content, timestamp) or None if not found
152242 """
243+ connection = None
153244 try :
154- conn = self ._ensure_connection ()
155- cursor = conn .cursor ()
245+ if not isinstance (message_id , int ) or message_id <= 0 :
246+ logger .warning (f"Invalid message ID: { message_id } " )
247+ return None
248+
249+ connection = self ._get_connection ()
250+ cursor = connection .cursor ()
156251 cursor .execute (
157252 'SELECT id, sender, content, timestamp FROM messages WHERE id = %s' ,
158253 (message_id ,)
159254 )
160255 message = cursor .fetchone ()
161256 cursor .close ()
257+
258+ if message :
259+ logger .info (f"Retrieved message with ID { message_id } " )
260+ else :
261+ logger .info (f"No message found with ID { message_id } " )
262+
162263 return message
163264 except Exception as e :
164- print (f"Error retrieving message from database: { e } " )
265+ logger . error (f"Error retrieving message from database: { e } " )
165266 return None
267+ finally :
268+ self ._return_connection (connection )
166269
167270 def delete_message (self , message_id : int ) -> bool :
168271 """
@@ -174,17 +277,30 @@ def delete_message(self, message_id: int) -> bool:
174277 Returns:
175278 bool: True if message was deleted successfully, False otherwise
176279 """
280+ connection = None
177281 try :
178- conn = self ._ensure_connection ()
179- cursor = conn .cursor ()
282+ if not isinstance (message_id , int ) or message_id <= 0 :
283+ logger .warning (f"Invalid message ID for deletion: { message_id } " )
284+ return False
285+
286+ connection = self ._get_connection ()
287+ cursor = connection .cursor ()
180288 cursor .execute ('DELETE FROM messages WHERE id = %s' , (message_id ,))
181289 deleted = cursor .rowcount > 0
182- conn .commit ()
290+ connection .commit ()
183291 cursor .close ()
292+
293+ if deleted :
294+ logger .info (f"Message with ID { message_id } deleted successfully" )
295+ else :
296+ logger .info (f"No message found with ID { message_id } for deletion" )
297+
184298 return deleted
185299 except Exception as e :
186- print (f"Error deleting message from database: { e } " )
187- if self . _connection :
188- self . _connection .rollback ()
300+ logger . error (f"Error deleting message from database: { e } " )
301+ if connection :
302+ connection .rollback ()
189303 return False
304+ finally :
305+ self ._return_connection (connection )
190306
0 commit comments