Skip to content

Commit 91cfe65

Browse files
committed
fix
1 parent c9e547d commit 91cfe65

2 files changed

Lines changed: 53 additions & 11 deletions

File tree

src/mmore/run_index_api.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import shutil
55
import tempfile
66
from pathlib import Path as FilePath
7-
from typing import List, cast
7+
from typing import List
88

99
import uvicorn
1010
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Path, UploadFile
@@ -146,7 +146,12 @@ async def upload_files(
146146
Upload multiple files with custom IDs and index them.
147147
"""
148148
try:
149-
listIds = listIds[0].split(",")
149+
listIds = [
150+
file_id.strip()
151+
for ids in listIds
152+
for file_id in ids.split(",")
153+
if file_id.strip()
154+
]
150155
# Check if IDs and files match in number
151156
if len(listIds) != len(files):
152157
raise HTTPException(
@@ -157,12 +162,17 @@ async def upload_files(
157162
with tempfile.TemporaryDirectory() as temp_dir:
158163
logging.info(f"Starting to process {len(files)} files with custom IDs")
159164

165+
uploaded_files = []
166+
id_by_filename = {}
160167
for file, file_id in zip(files, listIds):
161168
if file.filename is None:
162169
raise HTTPException(
163170
status_code=422,
164171
detail=f"File {file_id} does not have a filename",
165172
)
173+
filename = file.filename
174+
uploaded_files.append({"fileId": file_id, "filename": filename})
175+
id_by_filename[filename] = file_id
166176

167177
# Check if file with this ID already exists
168178
file_storage_path = FilePath(UPLOAD_DIR) / file_id
@@ -173,7 +183,7 @@ async def upload_files(
173183
)
174184

175185
# Save to temp directory
176-
file_name = FilePath(temp_dir) / file.filename
186+
file_name = FilePath(temp_dir) / filename
177187
with file_name.open("wb") as buffer:
178188
shutil.copyfileobj(file.file, buffer)
179189

@@ -187,18 +197,30 @@ async def upload_files(
187197

188198
# Process the documents
189199
file_extensions = [
190-
FilePath(cast(str, file.filename)).suffix.lower() for file in files
200+
FilePath(file_info["filename"]).suffix.lower()
201+
for file_info in uploaded_files
191202
]
192203
documents = process_files_default(
193204
temp_dir, COLLECTION_NAME, file_extensions
194205
)
195206

196207
# Change the IDs to match the ones from the client
197208
modified_documents = []
198-
for doc, docId in zip(documents, listIds):
199-
defDocId = doc.document_id
200-
doc.document_id = docId
201-
doc.id = doc.id.replace(defDocId, docId)
209+
text_by_file_id = {}
210+
chunks_by_file_id = {
211+
file_info["fileId"]: 0 for file_info in uploaded_files
212+
}
213+
for doc in documents:
214+
filename = FilePath(doc.metadata.file_path).name
215+
doc_id = id_by_filename.get(filename)
216+
if doc_id is None:
217+
raise HTTPException(
218+
status_code=500,
219+
detail=f"Could not match processed document {filename} to an uploaded file",
220+
)
221+
_apply_uploaded_file_metadata([doc], doc_id, filename)
222+
text_by_file_id.setdefault(doc_id, doc.text)
223+
chunks_by_file_id[doc_id] += 1
202224
modified_documents.append(doc)
203225

204226
logging.info("Indexing the files")
@@ -215,10 +237,16 @@ async def upload_files(
215237

216238
return {
217239
"status": "success",
218-
"message": f"Successfully processed and indexed {len(modified_documents)} documents",
240+
"message": f"Successfully processed and indexed {len(uploaded_files)} files",
219241
"documents": [
220-
{"fileId": doc.document_id, "text": doc.text[:50] + "..."}
221-
for doc in modified_documents
242+
{
243+
"fileId": file_info["fileId"],
244+
"filename": file_info["filename"],
245+
"text": text_by_file_id.get(file_info["fileId"], "")[:50]
246+
+ "...",
247+
"chunks": chunks_by_file_id[file_info["fileId"]],
248+
}
249+
for file_info in uploaded_files
222250
],
223251
}
224252

tests/test_live_retriever_api.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,13 @@ def test_upload_bulk_files_success(indexer_client):
519519
"mmore.run_index_api.process_files_default",
520520
return_value=[
521521
_fake_doc(fake_path_1, "bulk-1"),
522+
MultimodalSample(
523+
id="bulk-1+1",
524+
document_id="bulk-1",
525+
text="Second chunk from the first bulk document.",
526+
modalities=[],
527+
metadata=DocumentMetadata(file_path=fake_path_1),
528+
),
522529
_fake_doc(fake_path_2, "bulk-2"),
523530
],
524531
):
@@ -532,6 +539,13 @@ def test_upload_bulk_files_success(indexer_client):
532539
)
533540

534541
assert response.status_code == 201
542+
data = response.json()
543+
documents_by_id = {doc["fileId"]: doc for doc in data["documents"]}
544+
assert set(documents_by_id) == {"bulk-1", "bulk-2"}
545+
assert documents_by_id["bulk-1"]["filename"] == "bulk-1.txt"
546+
assert documents_by_id["bulk-1"]["chunks"] == 2
547+
assert documents_by_id["bulk-2"]["filename"] == "bulk-2.txt"
548+
assert documents_by_id["bulk-2"]["chunks"] == 1
535549

536550

537551
def test_upload_bulk_mismatched_ids_returns_400(indexer_client):

0 commit comments

Comments
 (0)