Skip to content

Commit c6baddb

Browse files
author
Diya Kadakia
committed
fix classification route
1 parent 7c5083c commit c6baddb

File tree

3 files changed

+28
-30
lines changed

3 files changed

+28
-30
lines changed

backend/app/routes/classification_routes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async def create_classifications(
9191
)
9292

9393
classification_names: list[str] = await create_classifications_helper(
94-
tenant_id,
94+
extracted_files,
9595
[classification.name for classification in initial_classifications],
9696
)
9797

backend/app/utils/classification/create_classifications.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from app.schemas.classification_schemas import ExtractedFile
2-
from uuid import UUID
31
import hdbscan
42
import numpy as np
5-
from app.core.litellm import EmbeddingModelType, LLMClient
6-
from app.services.classification_service import ClassificationService
73
from sklearn.preprocessing import normalize
84

5+
from app.core.litellm import LLMClient
6+
from app.schemas.classification_schemas import ExtractedFile
7+
98

109
async def create_classifications(
1110
extracted_files: list[ExtractedFile],
@@ -26,19 +25,21 @@ async def create_classifications(
2625
valid_files.append(file)
2726

2827
if len(embeddings) < 3:
29-
print(f"Not enough files for clustering ({len(embeddings)}), returning initial classifications")
28+
print(
29+
f"Not enough files for clustering ({len(embeddings)}), returning initial classifications"
30+
)
3031
return initialClassifications
31-
32+
3233
embeddings_array = np.array(embeddings)
3334

3435
# Normalize embeddings so that cosine similarity ≈ euclidean distance
35-
normalized_embeddings = normalize(embeddings_array) # L2 normalization
36+
normalized_embeddings = normalize(embeddings_array) # L2 normalization
3637

3738
clusterer = hdbscan.HDBSCAN(
38-
min_cluster_size = 2,
39-
min_samples = 1,
40-
metric = 'euclidean',
41-
cluster_selection_method = 'eom'
39+
min_cluster_size=2,
40+
min_samples=1,
41+
metric="euclidean",
42+
cluster_selection_method="eom",
4243
)
4344

4445
cluster_labels = clusterer.fit_predict(normalized_embeddings)
@@ -55,29 +56,28 @@ async def create_classifications(
5556
outliers = clusters.pop(-1, []) # Remove -1 cluster if it exists
5657
print(f"Found {len(clusters)} clusters, {len(outliers)} outliers")
5758

58-
5959
client = LLMClient()
6060
classification_names = []
6161

6262
for cluster_id, files_in_cluster in clusters.items():
6363
print(f"Analyzing cluster {cluster_id} with {len(files_in_cluster)} files...")
64-
64+
6565
# Get sample documents from cluster (up to 5 for context)
6666
sample_texts = []
6767
for file in files_in_cluster[:5]:
6868
text = _extract_text_from_file(file)
6969
sample_texts.append(text[:500]) # Limit text length
70-
70+
7171
# Use LLM to name the cluster
7272
prompt = f"""Analyze these similar documents and provide a single, concise classification name.
7373
7474
Sample documents from this cluster:
7575
76-
{chr(10).join(f"Document {i+1}: {text}" for i, text in enumerate(sample_texts))}
76+
{chr(10).join(f"Document {i + 1}: {text}" for i, text in enumerate(sample_texts))}
7777
78-
What type of documents are these? Respond with ONLY the category name (e.g., "Invoice", "Purchase Order", "Quote").
78+
What type of documents are these? Respond with ONLY the category name.
7979
Do not include any explanation or punctuation."""
80-
80+
8181
try:
8282
response = await client.chat(prompt, temperature=0.3, max_tokens=50)
8383
category_name = response.choices[0].message.content.strip()
@@ -101,7 +101,7 @@ async def create_classifications(
101101
{text}
102102
103103
Respond with ONLY the category name."""
104-
104+
105105
try:
106106
response = await client.chat(prompt, temperature=0.3, max_tokens=50)
107107
category_name = response.choices[0].message.content.strip()
@@ -117,34 +117,32 @@ async def create_classifications(
117117
fallback_name = f"Document Type Outlier {i}"
118118
print(f" → Outlier named: {fallback_name}")
119119

120-
121120
all_classifications = classification_names + initialClassifications
122-
final_classifications = list(set(all_classifications))
123-
121+
final_classifications = list(set(all_classifications))
122+
124123
print(f"Final classifications: {final_classifications}")
125124
return final_classifications
126125

127126

128127
def _extract_text_from_file(file: ExtractedFile) -> str:
129128
"""Convert extracted file to text representation for analysis."""
130129
parts = []
131-
130+
132131
# Add filename
133132
if file.name:
134133
parts.append(f"Filename: {file.name}")
135134

136-
137135
# Add extracted content
138136
if isinstance(file.extracted_data, dict):
139137
for key, value in file.extracted_data.items():
140138
if isinstance(value, (dict, list)):
141-
continue
139+
continue
142140
parts.append(f"{key}: {value}")
143141
elif isinstance(file.extracted_data, list):
144-
parts.append(f"Items: {', '.join(str(item) for item in file.extracted_data[:5])}")
142+
parts.append(
143+
f"Items: {', '.join(str(item) for item in file.extracted_data[:5])}"
144+
)
145145
else:
146146
parts.append(str(file.extracted_data))
147147

148-
149-
150-
return " ".join(parts)
148+
return " ".join(parts)

package-lock.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)