Skip to content

Commit 6b41431

Browse files
Fix tenant isolation in DROP TABLE migrations
1 parent 402e6e7 commit 6b41431

File tree

2 files changed

+213
-63
lines changed

2 files changed

+213
-63
lines changed

backend/app/routes/migration_routes.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from collections import defaultdict
12
from uuid import UUID
23

34
from fastapi import APIRouter, Depends, HTTPException
5+
from supabase._async.client import AsyncClient
46

57
from app.core.dependencies import get_current_admin
6-
from app.schemas.classification_schemas import Classification
8+
from app.core.supabase import get_async_supabase
9+
from app.schemas.classification_schemas import Classification, ExtractedFile
710
from app.schemas.migration_schemas import Migration, MigrationCreate
811
from app.schemas.relationship_schemas import Relationship
912
from app.services.classification_service import (
@@ -18,7 +21,7 @@
1821
RelationshipService,
1922
get_relationship_service,
2023
)
21-
from app.utils.migrations import create_migrations
24+
from app.utils.migrations import _table_name_for_classification, create_migrations
2225

2326
router = APIRouter(prefix="/migrations", tags=["Migrations"])
2427

@@ -56,7 +59,6 @@ async def generate_migrations(
5659
Then insert the new migrations into the `migrations` table and return them.
5760
"""
5861
try:
59-
# 1) Load current state from DB
6062
classifications: list[
6163
Classification
6264
] = await classification_service.get_classifications(tenant_id)
@@ -72,19 +74,15 @@ async def generate_migrations(
7274
status_code=404, detail="No classifications found for tenant"
7375
)
7476

75-
# 2) Compute *new* migrations (pure function)
76-
# IMPORTANT: this should return list[MigrationCreate]
7777
new_migration_creates: list[MigrationCreate] = create_migrations(
7878
classifications=classifications,
7979
relationships=relationships,
8080
initial_migrations=existing_migrations,
8181
)
8282

8383
if not new_migration_creates:
84-
# Nothing new to add
8584
return []
8685

87-
# 3) Insert into DB and return the created migrations
8886
created: list[Migration] = []
8987
for m in new_migration_creates:
9088
new_id = await migration_service.create_migration(m)
@@ -122,6 +120,78 @@ async def execute_migrations(
122120
raise HTTPException(status_code=500, detail=str(e)) from e
123121

124122

123+
@router.post("/load_data/{tenant_id}")
124+
async def load_data_for_tenant(
125+
tenant_id: UUID,
126+
classification_service: ClassificationService = Depends(get_classification_service),
127+
supabase: AsyncClient = Depends(get_async_supabase),
128+
admin=Depends(get_current_admin),
129+
) -> dict:
130+
"""
131+
Full data sync for a tenant:
132+
133+
- Fetch all extracted files + their classifications
134+
- Group by classification
135+
- For each classification:
136+
* derive table name (same as migrations)
137+
* DELETE existing rows for that tenant
138+
* INSERT rows for each file in that classification
139+
"""
140+
try:
141+
extracted_files: list[
142+
ExtractedFile
143+
] = await classification_service.get_extracted_files(tenant_id)
144+
145+
if not extracted_files:
146+
return {
147+
"status": "ok",
148+
"tables_updated": [],
149+
"message": "No extracted files found",
150+
}
151+
152+
files_by_class_id: dict[UUID, list[ExtractedFile]] = defaultdict(list)
153+
154+
for ef in extracted_files:
155+
if ef.classification is None:
156+
continue
157+
files_by_class_id[ef.classification.classification_id].append(ef)
158+
159+
updated_tables: list[str] = []
160+
161+
for class_files in files_by_class_id.values():
162+
classification = class_files[0].classification
163+
table_name = _table_name_for_classification(classification)
164+
165+
await (
166+
supabase.table(table_name)
167+
.delete()
168+
.eq("tenant_id", str(tenant_id))
169+
.execute()
170+
)
171+
172+
rows = [
173+
{
174+
"id": str(f.extracted_file_id),
175+
"tenant_id": str(tenant_id),
176+
"data": f.extracted_data,
177+
}
178+
for f in class_files
179+
]
180+
181+
if rows:
182+
await supabase.table(table_name).insert(rows).execute()
183+
184+
updated_tables.append(table_name)
185+
186+
return {
187+
"status": "ok",
188+
"tables_updated": updated_tables,
189+
"message": "Data synced from extracted_files into generated tables",
190+
}
191+
except Exception as e:
192+
raise HTTPException(status_code=500, detail=str(e)) from e
193+
194+
125195
@router.get("/connection-url/{tenant_id}")
126196
async def get_tenant_connection_url(
127197
tenant_id: UUID,
@@ -130,15 +200,6 @@ async def get_tenant_connection_url(
130200
) -> dict:
131201
"""
132202
Get a PostgreSQL connection URL for a specific tenant.
133-
134-
This URL is scoped to only show the tenant's generated tables.
135-
136-
Query params:
137-
include_public: If true, also include public schema (for shared tables)
138-
139-
Example:
140-
GET /migrations/connection-url/{tenant_id}
141-
GET /migrations/connection-url/{tenant_id}?include_public=true
142203
"""
143204
from app.utils.tenant_connection import get_schema_name, get_tenant_connection_url
144205

backend/app/utils/migrations.py

Lines changed: 136 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,34 @@ def _get_schema_name(tenant_id) -> str:
2121
return f"tenant_{str(tenant_id).replace('-', '_')}"
2222

2323

24+
def _get_created_tables(migrations: list[Migration], schema_name: str) -> set[str]:
25+
"""
26+
Get all table names that have been created by migrations for this schema.
27+
Returns: set of table names (without schema prefix)
28+
"""
29+
created_tables = set()
30+
prefix = f"create_table_{schema_name}_"
31+
for m in migrations:
32+
if m.name.startswith(prefix):
33+
table_name = m.name.replace(prefix, "")
34+
created_tables.add(table_name)
35+
return created_tables
36+
37+
38+
def _get_dropped_tables(migrations: list[Migration], schema_name: str) -> set[str]:
39+
"""
40+
Get table names that have already been dropped for this schema.
41+
Returns: set of table names (without schema prefix)
42+
"""
43+
dropped = set()
44+
prefix = f"drop_table_{schema_name}_"
45+
for m in migrations:
46+
if m.name.startswith(prefix):
47+
table_name = m.name.replace(prefix, "")
48+
dropped.add(table_name)
49+
return dropped
50+
51+
2452
def create_migrations(
2553
classifications: list[Classification],
2654
relationships: list[Relationship],
@@ -30,16 +58,20 @@ def create_migrations(
3058
PURE FUNCTION.
3159
3260
Given:
33-
- classifications: what tables we conceptually want
61+
- classifications: what tables we conceptually want NOW
3462
- relationships: how those tables relate (1-1, 1-many, many-many)
3563
- initial_migrations: migrations that already exist in DB
3664
3765
Returns:
3866
- list[MigrationCreate] = new migrations to append on top
3967
40-
NOW WITH SCHEMA-PER-TENANT:
41-
- First migration creates the tenant schema
42-
- All tables are created within that schema
68+
This function handles:
69+
1. CREATE SCHEMA for the tenant
70+
2. CREATE TABLE for new classifications
71+
3. DROP TABLE for removed classifications
72+
4. Relationship migrations
73+
74+
All SQL is schema-qualified for tenant isolation.
4375
"""
4476
if not classifications:
4577
return []
@@ -52,11 +84,16 @@ def create_migrations(
5284

5385
new_migrations: list[MigrationCreate] = []
5486

55-
# All classifications belong to the same tenant
56-
tenant_id = classifications[0].tenant_id
57-
schema_name = _get_schema_name(tenant_id)
87+
# Get tenant info and schema name
88+
tenant_id = classifications[0].tenant_id if classifications else None
89+
if not tenant_id:
90+
# If no classifications exist, try to get tenant_id from migrations
91+
if initial_migrations:
92+
tenant_id = initial_migrations[0].tenant_id
5893

59-
# ===== STEP 1: CREATE SCHEMA =====
94+
schema_name = _get_schema_name(tenant_id) if tenant_id else "public"
95+
96+
# ===== STEP 0: CREATE SCHEMA =====
6097
schema_migration_name = f"create_schema_{schema_name}"
6198

6299
if schema_migration_name not in existing_names:
@@ -71,7 +108,47 @@ def create_migrations(
71108
existing_names.add(schema_migration_name)
72109
next_seq += 1
73110

111+
# ===== STEP 1: Handle DROP migrations for removed classifications =====
112+
# Get current state of tables from migrations (passing schema_name)
113+
created_tables = _get_created_tables(initial_migrations, schema_name)
114+
dropped_tables = _get_dropped_tables(initial_migrations, schema_name)
115+
active_tables = created_tables - dropped_tables
116+
117+
# Build current classification table names
118+
current_classification_tables = {
119+
_table_name_for_classification(c) for c in classifications
120+
}
121+
122+
# Tables that were created but no longer in classifications = should be dropped
123+
tables_to_drop = active_tables - current_classification_tables
124+
125+
for table_name in sorted(tables_to_drop):
126+
# Remove schema prefix if present (helper functions might include it)
127+
clean_table_name = (
128+
table_name.split(".")[-1] if "." in table_name else table_name
129+
)
130+
mig_name = f"drop_table_{schema_name}_{clean_table_name}"
131+
132+
if mig_name in existing_names:
133+
continue
134+
135+
# Schema-qualified DROP with CASCADE
136+
sql = f"DROP TABLE IF EXISTS {schema_name}.{clean_table_name} CASCADE;"
137+
138+
if tenant_id:
139+
new_migrations.append(
140+
MigrationCreate(
141+
tenant_id=tenant_id,
142+
name=mig_name,
143+
sql=sql,
144+
sequence=next_seq,
145+
)
146+
)
147+
existing_names.add(mig_name)
148+
next_seq += 1
149+
74150
# ===== STEP 2: CREATE TABLES (in tenant schema) =====
151+
75152
for c in classifications:
76153
table_name = _table_name_for_classification(c)
77154
qualified_table_name = f"{schema_name}.{table_name}"
@@ -80,14 +157,15 @@ def create_migrations(
80157
if mig_name in existing_names:
81158
continue
82159

160+
# Schema-qualified CREATE
83161
sql = f"""
84-
CREATE TABLE IF NOT EXISTS {qualified_table_name} (
85-
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
86-
tenant_id UUID NOT NULL,
87-
data JSONB NOT NULL,
88-
created_at TIMESTAMPTZ DEFAULT NOW()
162+
CREATE TABLE IF NOT EXISTS {qualified_table_name} (
163+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
164+
tenant_id UUID NOT NULL,
165+
data JSONB NOT NULL,
166+
created_at TIMESTAMPTZ DEFAULT NOW()
89167
);
90-
""".strip()
168+
""".strip()
91169

92170
new_migrations.append(
93171
MigrationCreate(
@@ -105,54 +183,65 @@ def create_migrations(
105183
from_table = _table_name_for_classification(rel.from_classification)
106184
to_table = _table_name_for_classification(rel.to_classification)
107185

186+
# Skip relationships where either table doesn't exist anymore
187+
if (
188+
from_table not in current_classification_tables
189+
or to_table not in current_classification_tables
190+
):
191+
continue
192+
108193
qualified_from = f"{schema_name}.{from_table}"
109194
qualified_to = f"{schema_name}.{to_table}"
110195

111196
# Support both Enum and plain string for rel.type
112-
rel_type = getattr(rel.type, "value", rel.type)
197+
raw_type = getattr(rel.type, "value", rel.type)
198+
rel_type_norm = str(raw_type).upper().replace("-", "_")
113199

114-
mig_name = f"rel_{rel_type.lower()}_{schema_name}_{from_table}_{to_table}"
200+
mig_name = f"rel_{rel_type_norm.lower()}_{schema_name}_{from_table}_{to_table}"
115201

116202
if mig_name in existing_names:
117203
continue
118204

119-
if rel_type == "ONE_TO_MANY":
205+
if rel_type_norm == "ONE_TO_MANY":
206+
# Schema-qualified ALTER TABLE for one-to-many
120207
sql = f"""
121-
ALTER TABLE {qualified_from}
122-
ADD COLUMN IF NOT EXISTS {to_table}_id UUID,
123-
ADD CONSTRAINT fk_{schema_name}_{from_table}_{to_table}
124-
FOREIGN KEY ({to_table}_id)
125-
REFERENCES {qualified_to}(id);
126-
""".strip()
127-
128-
elif rel_type == "ONE_TO_ONE":
208+
ALTER TABLE {qualified_from}
209+
ADD COLUMN IF NOT EXISTS {to_table}_id UUID,
210+
ADD CONSTRAINT fk_{schema_name}_{from_table}_{to_table}
211+
FOREIGN KEY ({to_table}_id)
212+
REFERENCES {qualified_to}(id);
213+
""".strip()
214+
215+
elif rel_type_norm == "ONE_TO_ONE":
216+
# Schema-qualified ALTER TABLE for one-to-one
129217
sql = f"""
130-
ALTER TABLE {qualified_from}
131-
ADD COLUMN IF NOT EXISTS {to_table}_id UUID UNIQUE,
132-
ADD CONSTRAINT fk_{schema_name}_{from_table}_{to_table}
133-
FOREIGN KEY ({to_table}_id)
134-
REFERENCES {qualified_to}(id);
135-
""".strip()
136-
137-
elif rel_type == "MANY_TO_MANY":
218+
ALTER TABLE {qualified_from}
219+
ADD COLUMN IF NOT EXISTS {to_table}_id UUID UNIQUE,
220+
ADD CONSTRAINT fk_{schema_name}_{from_table}_{to_table}
221+
FOREIGN KEY ({to_table}_id)
222+
REFERENCES {qualified_to}(id);
223+
""".strip()
224+
225+
elif rel_type_norm == "MANY_TO_MANY":
226+
# Schema-qualified CREATE TABLE for join table
138227
join_table = f"{from_table}_{to_table}_join"
139228
qualified_join = f"{schema_name}.{join_table}"
140229

141230
sql = f"""
142-
CREATE TABLE IF NOT EXISTS {qualified_join} (
143-
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
144-
{from_table}_id UUID NOT NULL,
145-
{to_table}_id UUID NOT NULL,
146-
CONSTRAINT fk_{schema_name}_{join_table}_{from_table}
147-
FOREIGN KEY ({from_table}_id)
148-
REFERENCES {qualified_from}(id),
149-
CONSTRAINT fk_{schema_name}_{join_table}_{to_table}
150-
FOREIGN KEY ({to_table}_id)
151-
REFERENCES {qualified_to}(id),
152-
CONSTRAINT uniq_{schema_name}_{join_table}
153-
UNIQUE ({from_table}_id, {to_table}_id)
154-
);
155-
""".strip()
231+
CREATE TABLE IF NOT EXISTS {qualified_join} (
232+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
233+
{from_table}_id UUID NOT NULL,
234+
{to_table}_id UUID NOT NULL,
235+
CONSTRAINT fk_{schema_name}_{join_table}_{from_table}
236+
FOREIGN KEY ({from_table}_id)
237+
REFERENCES {qualified_from}(id),
238+
CONSTRAINT fk_{schema_name}_{join_table}_{to_table}
239+
FOREIGN KEY ({to_table}_id)
240+
REFERENCES {qualified_to}(id),
241+
CONSTRAINT uniq_{schema_name}_{join_table}
242+
UNIQUE ({from_table}_id, {to_table}_id)
243+
);
244+
""".strip()
156245
else:
157246
sql = f"-- TODO: implement SQL for relationship {mig_name}"
158247

0 commit comments

Comments
 (0)