Skip to content

Commit 45eb183

Browse files
committed
remade the db and db_utils functions
1 parent 79aa1b6 commit 45eb183

File tree

3 files changed

+127
-50
lines changed

3 files changed

+127
-50
lines changed

app/db.py

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,38 +13,58 @@ def initialize_schema(connection):
1313
try:
1414
with connection.cursor() as cursor:
1515
cursor.execute("""
16-
CREATE TABLE IF NOT EXISTS study (
17-
id SERIAL PRIMARY KEY,
18-
name VARCHAR(50) UNIQUE NOT NULL
19-
);
20-
21-
CREATE TABLE IF NOT EXISTS site (
22-
id SERIAL PRIMARY KEY,
23-
name VARCHAR(50) NOT NULL,
24-
study_id INT REFERENCES study(id) ON DELETE CASCADE
25-
);
26-
27-
CREATE TABLE IF NOT EXISTS subject (
28-
id SERIAL PRIMARY KEY,
29-
name VARCHAR(50) NOT NULL,
30-
site_id INT REFERENCES site(id) ON DELETE CASCADE
31-
);
32-
33-
CREATE TABLE IF NOT EXISTS task (
34-
id SERIAL PRIMARY KEY,
35-
name VARCHAR(50) NOT NULL,
36-
subject_id INT REFERENCES subject(id) ON DELETE CASCADE
37-
);
38-
39-
CREATE TABLE IF NOT EXISTS session (
40-
id SERIAL PRIMARY KEY,
41-
session_name VARCHAR(50) NOT NULL,
42-
category INT NOT NULL,
43-
csv_path TEXT,
44-
plot_paths TEXT[],
45-
task_id INT REFERENCES task(id) ON DELETE CASCADE,
46-
date TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
47-
);
16+
-- Drop existing tables in reverse dependency order.
17+
DROP TABLE IF EXISTS session CASCADE;
18+
DROP TABLE IF EXISTS task CASCADE;
19+
DROP TABLE IF EXISTS subject CASCADE;
20+
DROP TABLE IF EXISTS site CASCADE;
21+
DROP TABLE IF EXISTS study CASCADE;
22+
23+
-- Create table "study"
24+
CREATE TABLE study (
25+
id SERIAL PRIMARY KEY,
26+
name TEXT NOT NULL UNIQUE
27+
);
28+
29+
-- Create table "site"
30+
CREATE TABLE site (
31+
id SERIAL PRIMARY KEY,
32+
name TEXT NOT NULL,
33+
study_id INTEGER NOT NULL,
34+
UNIQUE (name, study_id),
35+
FOREIGN KEY (study_id) REFERENCES study(id) ON DELETE CASCADE
36+
);
37+
38+
-- Create table "subject"
39+
CREATE TABLE subject (
40+
id SERIAL PRIMARY KEY,
41+
name TEXT NOT NULL,
42+
site_id INTEGER NOT NULL,
43+
UNIQUE (name, site_id),
44+
FOREIGN KEY (site_id) REFERENCES site(id) ON DELETE CASCADE
45+
);
46+
47+
-- Create table "task"
48+
CREATE TABLE task (
49+
id SERIAL PRIMARY KEY,
50+
name TEXT NOT NULL,
51+
subject_id INTEGER NOT NULL,
52+
UNIQUE (name, subject_id),
53+
FOREIGN KEY (subject_id) REFERENCES subject(id) ON DELETE CASCADE
54+
);
55+
56+
-- Create table "session"
57+
CREATE TABLE session (
58+
id SERIAL PRIMARY KEY,
59+
session_name TEXT NOT NULL,
60+
category INTEGER NOT NULL,
61+
csv_path TEXT NOT NULL,
62+
task_id INTEGER NOT NULL,
63+
date TIMESTAMP,
64+
plot_paths TEXT[],
65+
FOREIGN KEY (task_id) REFERENCES task(id) ON DELETE CASCADE,
66+
UNIQUE (session_name, category, csv_path, task_id)
67+
);
4868
""")
4969
connection.commit()
5070
except Exception as e:
@@ -131,18 +151,17 @@ def populate_database(connection, data_folder):
131151

132152
# Main entry point
133153
if __name__ == "__main__":
134-
db_name = "boostbeh"
154+
db_name = "boost-beh"
135155
user = "zakg04"
136156
password = "*mIloisfAT23*123*"
137157
data_folder = "../data"
138158
connection = connect_to_db(db_name, user, password)
159+
try:
160+
initialize_schema(connection)
161+
finally:
162+
connection.close()
163+
'''
139164
util_instance = DatabaseUtils(connection, data_folder)
140165
util_instance.update_database()
141166
142-
"""conn = connect_to_db(db_name, user, password)
143-
try:
144-
initialize_schema(conn)
145-
populate_database(conn, data_folder)
146-
print("Database initialized and populated successfully.")
147-
finally:
148-
conn.close()"""
167+
'''
1.43 KB
Binary file not shown.

app/main/update_db.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,100 @@
11
import os
22
import logging
3-
import psycopg
3+
import psycopg # PostgreSQL database adapter
44
import pandas as pd
55
from datetime import datetime
66

77
class DatabaseUtils:
88
def __init__(self, connection, data_folder):
9+
"""
10+
Initializes the DatabaseUtils class with a database connection and a data folder path.
11+
12+
:param connection: PostgreSQL database connection object.
13+
:param data_folder: Path to the directory containing study data.
14+
"""
915
self.connection = connection
1016
self.data_folder = data_folder
1117

1218
def update_database(self):
19+
"""
20+
Iterates through the directory structure and updates the database with study, site, subject,
21+
task, and session information. Commits changes at the end of processing each study.
22+
"""
1323
logging.info("Starting database update.")
1424

25+
# Loop through each study folder in the data directory
1526
for study_name in os.listdir(self.data_folder):
1627
study_path = os.path.join(self.data_folder, study_name)
1728
if not os.path.isdir(study_path):
1829
logging.warning(f"Skipping non-directory: {study_path}")
1930
continue
2031

32+
# Add or retrieve the study ID from the database
2133
study_id = self._add_or_get_id("study", {"name": study_name})
2234

35+
# Loop through each site folder within the study
2336
for site_name in os.listdir(study_path):
2437
site_path = os.path.join(study_path, site_name)
2538
if not os.path.isdir(site_path):
2639
logging.warning(f"Skipping non-directory: {site_path}")
2740
continue
2841

42+
# Add or retrieve the site ID from the database
2943
site_id = self._add_or_get_id("site", {"name": site_name, "study_id": study_id})
3044

45+
# Loop through each subject folder within the site
3146
for subject_name in os.listdir(site_path):
3247
subject_path = os.path.join(site_path, subject_name)
3348
if not os.path.isdir(subject_path):
3449
logging.warning(f"Skipping non-directory: {subject_path}")
3550
continue
3651

52+
# Since subject names are always four-digit numbers, format them accordingly.
53+
try:
54+
# Ensure the subject folder name is numeric and pad with leading zeros if necessary.
55+
int(subject_name)
56+
subject_name = subject_name.zfill(4)
57+
except ValueError:
58+
logging.warning(f"Subject name {subject_name} is not numeric; saving as-is.")
59+
60+
# Add or retrieve the subject ID from the database
3761
subject_id = self._add_or_get_id("subject", {"name": subject_name, "site_id": site_id})
3862

63+
# Loop through each task folder within the subject
3964
for task_name in os.listdir(subject_path):
4065
task_path = os.path.join(subject_path, task_name)
4166
if not os.path.isdir(task_path):
4267
logging.warning(f"Skipping non-directory: {task_path}")
4368
continue
4469

70+
# Add or retrieve the task ID from the database
4571
task_id = self._add_or_get_id("task", {"name": task_name, "subject_id": subject_id})
4672

73+
# Process data files within the task folder
4774
self._process_data_folder(task_path, task_id)
75+
# Process plot images within the task folder
4876
self._process_plot_folder(task_path, task_id)
4977

78+
# Commit all changes for the current study
5079
self.connection.commit()
5180
logging.info("Database committed.")
5281

5382
logging.info("Database update complete.")
5483

5584
def _add_or_get_id(self, table, values):
56-
placeholders = ', '.join([f"{key} = %s" for key in values.keys()])
85+
"""
86+
Adds a new entry to the specified table or retrieves the existing entry's ID.
87+
88+
:param table: The name of the database table.
89+
:param values: A dictionary containing column names and their values.
90+
:return: The ID of the existing or newly inserted row.
91+
"""
92+
# Build a WHERE clause for checking existing rows
93+
placeholders = ' AND '.join([f"{key} = %s" for key in values.keys()])
5794
columns = ', '.join(values.keys())
5895
values_list = list(values.values())
5996

97+
# SQL query to insert a new record, avoiding conflicts
6098
query = f"""
6199
INSERT INTO {table} ({columns})
62100
VALUES ({', '.join(['%s'] * len(values))})
@@ -69,33 +107,42 @@ def _add_or_get_id(self, table, values):
69107
if result:
70108
return int(result[0])
71109

110+
# If no ID is returned, retrieve the existing record's ID.
72111
select_query = f"SELECT id FROM {table} WHERE {placeholders};"
73112
cursor.execute(select_query, values_list)
74113
return int(cursor.fetchone()[0])
75114

76115
def _process_data_folder(self, task_path, task_id):
116+
"""
117+
Processes CSV files in the "data" folder within a task directory and inserts session records into the database.
118+
119+
:param task_path: Path to the task directory.
120+
:param task_id: ID of the corresponding task in the database.
121+
"""
77122
data_folder_path = os.path.join(task_path, "data")
78123
if os.path.exists(data_folder_path):
79124
for file in os.listdir(data_folder_path):
80125
if file.endswith(".csv"):
81126
logging.debug(f"Processing file: {file}")
82127
try:
128+
# Extract session and category information from the filename
83129
parts = file.split("_")
84130
if len(parts) < 3:
85131
raise ValueError(f"Unexpected file format: {file}")
86132

87-
session_name = parts[1].split("-")[1] # Ensure split works correctly
88-
category = int(parts[2].split("-")[1].split(".")[0])
133+
session_name = parts[1].split("-")[1] # Extract session name
134+
category = int(parts[2].split("-")[1].split(".")[0]) # Extract category
89135
csv_path = os.path.join(data_folder_path, file)
90136

91-
# Extract and clean date from CSV if column 'datetime' exists
137+
# Extract and clean the date from the CSV if it contains a 'datetime' column
92138
date = None
93139
df = pd.read_csv(csv_path)
94140
if 'datetime' in df.columns:
95141
raw_date = str(df['datetime'].iloc[0])
96142
date = self._clean_date(raw_date)
97-
del df
143+
del df # Free up memory
98144

145+
# Insert session data into the database
99146
with self.connection.cursor() as cursor:
100147
cursor.execute(
101148
"""
@@ -110,6 +157,12 @@ def _process_data_folder(self, task_path, task_id):
110157
logging.error(f"Error processing file {file}: {e}")
111158

112159
def _process_plot_folder(self, task_path, task_id):
160+
"""
161+
Processes PNG image files in the "plot" folder and updates the session record with plot file paths.
162+
163+
:param task_path: Path to the task directory.
164+
:param task_id: ID of the corresponding task in the database.
165+
"""
113166
plot_folder_path = os.path.join(task_path, "plot")
114167
if os.path.exists(plot_folder_path):
115168
plots = [os.path.join(plot_folder_path, f) for f in os.listdir(plot_folder_path) if f.endswith(".png")]
@@ -126,14 +179,19 @@ def _process_plot_folder(self, task_path, task_id):
126179
logging.debug(f"Plots updated for task {task_id}: {plots}")
127180

128181
def _clean_date(self, raw_date):
182+
"""
183+
Converts a raw date string into a standardized format.
184+
185+
:param raw_date: Date string extracted from a CSV file.
186+
:return: Standardized date string or None if parsing fails.
187+
"""
129188
import re
130-
"""Converts raw date strings into a standardized format."""
131189
try:
132-
# Remove timezone information in parentheses, if any
190+
# Remove timezone information enclosed in parentheses
133191
cleaned_raw_date = re.sub(r"\s\(.*?\)", "", raw_date)
134-
# Parse the cleaned date string
192+
# Parse the cleaned date string into a datetime object
135193
clean_date = datetime.strptime(cleaned_raw_date, "%a %b %d %Y %H:%M:%S %Z%z")
136-
# Standardize to SQL-compatible format
194+
# Convert to SQL-compatible format
137195
return clean_date.strftime("%Y-%m-%d %H:%M:%S")
138196
except ValueError as e:
139197
logging.error(f"Error parsing date: {raw_date} - {e}")

0 commit comments

Comments
 (0)