Skip to content

Commit d009d20

Browse files
enterprisify it
1 parent 67de6dd commit d009d20

File tree

2 files changed

+303
-107
lines changed

2 files changed

+303
-107
lines changed

examples/smack/postgresDB.py

Lines changed: 199 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,105 @@
1-
import psycopg2
2-
from datetime import datetime
3-
from typing import List, Optional, Tuple
1+
"""
2+
PostgreSQL Database Interface for Smack Messaging Service.
3+
4+
This module provides a robust database interface for storing and retrieving messages
5+
with proper connection management, error handling, and logging.
6+
"""
47
import os
58
import atexit
9+
import logging
10+
from datetime import datetime
11+
from typing import List, Optional, Tuple, Dict, Any, Union
12+
import time
13+
14+
import psycopg2
15+
from psycopg2 import pool
16+
17+
# Configure logging
18+
logger = logging.getLogger(__name__)
619

720
class MessageDB:
8-
"""PostgreSQL database interface for message storage with a single persistent connection."""
9-
_connection = None
21+
"""PostgreSQL database interface for message storage with connection pooling."""
22+
23+
# Class-level connection pool
24+
_connection_pool = None
25+
_pool_min_conn = 1
26+
_pool_max_conn = 10
1027

1128
def __init__(self):
12-
"""Initialize the database connection if not already initialized."""
29+
"""Initialize the database connection pool if not already initialized."""
1330
if not hasattr(self, '_initialized'):
14-
self._initialize_connection()
31+
self._initialize_connection_pool()
1532
# Register cleanup function
1633
atexit.register(self.close_connection)
1734
self._initialized = True
1835

19-
def _initialize_connection(self):
20-
"""Initialize the database connection."""
21-
# Get connection string from environment variable with fallback to individual parameters
36+
def _initialize_connection_pool(self) -> None:
37+
"""Initialize the database connection pool with retry logic."""
38+
# Get connection parameters from environment with secure defaults
2239
database_url = os.environ.get('DATABASE_URL')
40+
max_retries = int(os.environ.get('DB_MAX_RETRIES', '5'))
41+
retry_delay = int(os.environ.get('DB_RETRY_DELAY', '2'))
2342

24-
try:
25-
if database_url:
26-
# Use the complete DATABASE_URL if available
27-
self._connection = psycopg2.connect(database_url)
28-
print(f"PostgreSQL connection established using DATABASE_URL")
29-
else:
30-
# Fallback to individual parameters
31-
self.db_host = os.environ.get('DB_HOST', 'postgres')
32-
self.db_name = os.environ.get('DB_NAME', 'smack')
33-
self.db_user = os.environ.get('DB_USER', 'postgres')
34-
self.db_password = os.environ.get('DB_PASSWORD', 'postgres')
35-
self.db_port = os.environ.get('DB_PORT', '5432')
43+
# Set pool size from environment or use defaults
44+
self._pool_min_conn = int(os.environ.get('DB_MIN_CONNECTIONS', '1'))
45+
self._pool_max_conn = int(os.environ.get('DB_MAX_CONNECTIONS', '10'))
46+
47+
retry_count = 0
48+
last_error = None
49+
50+
while retry_count < max_retries:
51+
try:
52+
if database_url:
53+
# Use the complete DATABASE_URL if available
54+
self._connection_pool = pool.ThreadedConnectionPool(
55+
self._pool_min_conn,
56+
self._pool_max_conn,
57+
database_url
58+
)
59+
logger.info("PostgreSQL connection pool established using DATABASE_URL")
60+
else:
61+
# Fallback to individual parameters with secure defaults
62+
self.db_host = os.environ.get('DB_HOST', 'localhost')
63+
self.db_name = os.environ.get('DB_NAME', 'smack')
64+
self.db_user = os.environ.get('DB_USER', 'postgres')
65+
self.db_password = os.environ.get('DB_PASSWORD', '')
66+
self.db_port = os.environ.get('DB_PORT', '5432')
67+
68+
self._connection_pool = pool.ThreadedConnectionPool(
69+
self._pool_min_conn,
70+
self._pool_max_conn,
71+
host=self.db_host,
72+
database=self.db_name,
73+
user=self.db_user,
74+
password=self.db_password,
75+
port=self.db_port
76+
)
77+
logger.info(f"PostgreSQL connection pool established to {self.db_host}:{self.db_port}/{self.db_name}")
3678

37-
self._connection = psycopg2.connect(
38-
host=self.db_host,
39-
database=self.db_name,
40-
user=self.db_user,
41-
password=self.db_password,
42-
port=self.db_port
43-
)
44-
print(f"PostgreSQL connection established to {self.db_host}:{self.db_port}/{self.db_name}")
45-
46-
self._init_db()
47-
except Exception as e:
48-
print(f"Error connecting to PostgreSQL: {e}")
49-
self._connection = None
79+
# Initialize database schema
80+
self._init_db()
81+
return
82+
83+
except Exception as e:
84+
last_error = e
85+
retry_count += 1
86+
logger.warning(f"Connection attempt {retry_count}/{max_retries} failed: {e}")
87+
if retry_count < max_retries:
88+
logger.info(f"Retrying in {retry_delay} seconds...")
89+
time.sleep(retry_delay)
90+
91+
# If we get here, all retries failed
92+
logger.error(f"Failed to establish database connection after {max_retries} attempts: {last_error}")
93+
raise ConnectionError(f"Could not connect to database: {last_error}")
5094

5195
def _init_db(self) -> None:
52-
"""Initialize the database and create the messages table if it doesn't exist."""
96+
"""Initialize the database schema if it doesn't exist."""
97+
connection = None
5398
try:
54-
cursor = self._connection.cursor()
99+
connection = self._get_connection()
100+
cursor = connection.cursor()
101+
102+
# Create messages table with proper indexing
55103
cursor.execute('''
56104
CREATE TABLE IF NOT EXISTS messages (
57105
id SERIAL PRIMARY KEY,
@@ -60,38 +108,58 @@ def _init_db(self) -> None:
60108
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
61109
)
62110
''')
63-
self._connection.commit()
111+
connection.commit()
64112
cursor.close()
113+
logger.info("Database schema initialized successfully")
65114
except Exception as e:
66-
print(f"Error initializing database: {e}")
67-
if self._connection:
68-
self._connection.rollback()
115+
logger.error(f"Error initializing database schema: {e}")
116+
if connection:
117+
connection.rollback()
118+
raise
119+
finally:
120+
self._return_connection(connection)
69121

70-
def _ensure_connection(self):
71-
"""Ensure the connection is active, reconnecting if necessary."""
72-
if self._connection is None or self._connection.closed:
73-
self._initialize_connection()
74-
75-
# Test connection with a simple query
122+
def _get_connection(self):
123+
"""Get a connection from the pool with validation."""
124+
if not self._connection_pool:
125+
logger.warning("Connection pool not initialized, attempting to reconnect")
126+
self._initialize_connection_pool()
127+
76128
try:
77-
cursor = self._connection.cursor()
129+
connection = self._connection_pool.getconn()
130+
# Test connection with a simple query
131+
cursor = connection.cursor()
78132
cursor.execute("SELECT 1")
79133
cursor.close()
134+
return connection
80135
except Exception as e:
81-
print(f"Connection test failed, reconnecting: {e}")
82-
self._initialize_connection()
83-
84-
if self._connection is None or self._connection.closed:
85-
raise Exception("Failed to establish database connection")
86-
87-
return self._connection
136+
logger.error(f"Failed to get valid connection from pool: {e}")
137+
# Try to reinitialize the pool
138+
self._initialize_connection_pool()
139+
return self._connection_pool.getconn()
140+
141+
def _return_connection(self, connection):
142+
"""Return a connection to the pool safely."""
143+
if connection and self._connection_pool:
144+
try:
145+
self._connection_pool.putconn(connection)
146+
except Exception as e:
147+
logger.warning(f"Error returning connection to pool: {e}")
148+
# Try to close it directly if returning fails
149+
try:
150+
connection.close()
151+
except:
152+
pass
88153

89-
def close_connection(self):
90-
"""Close the database connection."""
91-
if self._connection:
92-
self._connection.close()
93-
print("PostgreSQL connection closed")
94-
self._connection = None
154+
async def close_connection(self):
155+
"""Close the database connection pool."""
156+
if self._connection_pool:
157+
try:
158+
self._connection_pool.closeall()
159+
logger.info("PostgreSQL connection pool closed")
160+
self._connection_pool = None
161+
except Exception as e:
162+
logger.error(f"Error closing connection pool: {e}")
95163

96164
def add_message(self, sender: str, content: str) -> bool:
97165
"""
@@ -104,41 +172,63 @@ def add_message(self, sender: str, content: str) -> bool:
104172
Returns:
105173
bool: True if message was added successfully, False otherwise
106174
"""
175+
connection = None
107176
try:
108-
conn = self._ensure_connection()
109-
cursor = conn.cursor()
177+
# Input validation
178+
if not sender or not sender.strip():
179+
logger.warning("Attempted to add message with empty sender")
180+
return False
181+
182+
if not content or not content.strip():
183+
logger.warning(f"Attempted to add empty message from {sender}")
184+
return False
185+
186+
connection = self._get_connection()
187+
cursor = connection.cursor()
110188
cursor.execute(
111-
'INSERT INTO messages (sender, content) VALUES (%s, %s)',
189+
'INSERT INTO messages (sender, content) VALUES (%s, %s) RETURNING id',
112190
(sender, content)
113191
)
114-
conn.commit()
192+
message_id = cursor.fetchone()[0]
193+
connection.commit()
115194
cursor.close()
195+
logger.info(f"Message added successfully with ID {message_id}")
116196
return True
117197
except Exception as e:
118-
print(f"Error adding message to database: {e}")
119-
if self._connection:
120-
self._connection.rollback()
198+
logger.error(f"Error adding message to database: {e}")
199+
if connection:
200+
connection.rollback()
121201
return False
202+
finally:
203+
self._return_connection(connection)
122204

123-
def get_all_messages(self) -> List[Tuple[int, str, str, str]]:
205+
def get_all_messages(self, limit: int = 100) -> List[Tuple[int, str, str, str]]:
124206
"""
125-
Retrieve all messages from the database.
207+
Retrieve messages from the database with pagination.
126208
209+
Args:
210+
limit: Maximum number of messages to retrieve (default: 100)
211+
127212
Returns:
128213
List of tuples containing (id, sender, content, timestamp)
129214
"""
215+
connection = None
130216
try:
131-
conn = self._ensure_connection()
132-
cursor = conn.cursor()
217+
connection = self._get_connection()
218+
cursor = connection.cursor()
133219
cursor.execute(
134-
'SELECT id, sender, content, timestamp FROM messages ORDER BY timestamp DESC'
220+
'SELECT id, sender, content, timestamp FROM messages ORDER BY timestamp DESC LIMIT %s',
221+
(limit,)
135222
)
136223
messages = cursor.fetchall()
137224
cursor.close()
225+
logger.info(f"Retrieved {len(messages)} messages successfully")
138226
return messages
139227
except Exception as e:
140-
print(f"Error retrieving messages from database: {e}")
228+
logger.error(f"Error retrieving messages from database: {e}")
141229
return []
230+
finally:
231+
self._return_connection(connection)
142232

143233
def get_message_by_id(self, message_id: int) -> Optional[Tuple[int, str, str, str]]:
144234
"""
@@ -150,19 +240,32 @@ def get_message_by_id(self, message_id: int) -> Optional[Tuple[int, str, str, st
150240
Returns:
151241
Tuple containing (id, sender, content, timestamp) or None if not found
152242
"""
243+
connection = None
153244
try:
154-
conn = self._ensure_connection()
155-
cursor = conn.cursor()
245+
if not isinstance(message_id, int) or message_id <= 0:
246+
logger.warning(f"Invalid message ID: {message_id}")
247+
return None
248+
249+
connection = self._get_connection()
250+
cursor = connection.cursor()
156251
cursor.execute(
157252
'SELECT id, sender, content, timestamp FROM messages WHERE id = %s',
158253
(message_id,)
159254
)
160255
message = cursor.fetchone()
161256
cursor.close()
257+
258+
if message:
259+
logger.info(f"Retrieved message with ID {message_id}")
260+
else:
261+
logger.info(f"No message found with ID {message_id}")
262+
162263
return message
163264
except Exception as e:
164-
print(f"Error retrieving message from database: {e}")
265+
logger.error(f"Error retrieving message from database: {e}")
165266
return None
267+
finally:
268+
self._return_connection(connection)
166269

167270
def delete_message(self, message_id: int) -> bool:
168271
"""
@@ -174,17 +277,30 @@ def delete_message(self, message_id: int) -> bool:
174277
Returns:
175278
bool: True if message was deleted successfully, False otherwise
176279
"""
280+
connection = None
177281
try:
178-
conn = self._ensure_connection()
179-
cursor = conn.cursor()
282+
if not isinstance(message_id, int) or message_id <= 0:
283+
logger.warning(f"Invalid message ID for deletion: {message_id}")
284+
return False
285+
286+
connection = self._get_connection()
287+
cursor = connection.cursor()
180288
cursor.execute('DELETE FROM messages WHERE id = %s', (message_id,))
181289
deleted = cursor.rowcount > 0
182-
conn.commit()
290+
connection.commit()
183291
cursor.close()
292+
293+
if deleted:
294+
logger.info(f"Message with ID {message_id} deleted successfully")
295+
else:
296+
logger.info(f"No message found with ID {message_id} for deletion")
297+
184298
return deleted
185299
except Exception as e:
186-
print(f"Error deleting message from database: {e}")
187-
if self._connection:
188-
self._connection.rollback()
300+
logger.error(f"Error deleting message from database: {e}")
301+
if connection:
302+
connection.rollback()
189303
return False
304+
finally:
305+
self._return_connection(connection)
190306

0 commit comments

Comments
 (0)