Skip to content

Commit ccddc6e

Browse files
fabnemEPFLleagriederMikaelKalajdzic
authored
Fixes rag & retriever API (#93)
* Added retrieving among a set of specified documents * Adapted run_retriever * simplified code in retriever * simplified run_retriever.py * misc style changes * adapted retriever to rag api * cherry picking from RAG interface is now possible by providing document_ids in the json file * Preparation for the retriever api * basic retriever api * Updated retrieval api + documentation Co-authored-by: MikaelKalajdzic <mikael@kalajdzic.ch> Co-authored-by: fabnemEPFL <fabrice.nemo@epfl.ch> * adapted retriever to effectively use the arguments minSimilarity and maxMatches * execution state shut down at the end of processing * import fixes on run_retriever.py * removed useless enumerate * recovered the right version of retriever.py * enhanced rag docs * removed duplicates in run_retriever.py * changed the example for a public version of meditron * small fix in docs * truncate the output of the model in the RAG pipeline * upgraded the requirement of transformers * fix * reformatting using black --------- Co-authored-by: leagrieder <lea@grieder.org> Co-authored-by: MikaelKalajdzic <mikael@kalajdzic.ch>
1 parent f21f62b commit ccddc6e

11 files changed

Lines changed: 237 additions & 15 deletions

File tree

docs/rag.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@
1212

1313
Here is a minimal example to create a RAG pipeline hosted through [LangServe](https://python.langchain.com/docs/langserve/) servers.
1414

15-
1. Create your RAG Inference config file based on the [local example](/examples/rag/config.yaml) or the [API example](/examples/rag/config_api.yaml).
15+
1. Create your RAG Inference config file based on the [local example](/examples/rag/config.yaml) or the [API example](/examples/rag/config_api.yaml). You can check the structure of the configuration file with the dataclass [RAGConfig](/src/mmore/rag/pipeline.py).
1616

1717
2. Start your RAG pipeline using the `run_rag.py` script and your config file
1818
```bash
1919
python3 -m mmore rag --config_file /path/to/config.yaml
2020
```
2121

22-
3. Query the server like any other LangServe server
22+
3. In API mode, query the server like any other LangServe server:
2323
```bash
2424
curl --location --request POST http://localhost:8000/rag/invoke \
2525
-H 'Content-Type: application/json' \
2626
-d '{
2727
"input": {
2828
"input": "What is Meditron?",
29-
"collection_name": "med_docs"
29+
"collection_name": "my_docs"
3030
}
3131
}'
3232
```
@@ -36,6 +36,8 @@ Here is a minimal example to create a RAG pipeline hosted through [LangServe](ht
3636
-H 'Content-Type: application/json'
3737
```
3838

39+
In local mode, the pipeline is run directly with the input data specified in the configuration file and the result is saved at the specified path.
40+
3941
See [`examples/rag`](/examples/rag/) for other use cases.
4042

4143
## :mag: Modules

docs/retriever_api_specs.yaml

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
openapi: 3.1.1
2+
info:
3+
title: mmore Retriever API
4+
description: |
5+
This API is based on the OpenAPI 3.1 specification. You can find out more about Swagger at [https://swagger.io](https://swagger.io).
6+
7+
## Overview
8+
9+
This API defines the retriever API of mmore, handling:
10+
11+
1. **File Operations** - Direct file management within mmore.
12+
2. **Context Retrieval** - Semantic search based on the subset of documents that the user wants.
13+
14+
## API Versioning
15+
16+
All requests must be prefixed with `/v1` for this current version of the API.
17+
18+
## Roadmap & Considerations
19+
- Authorization layer for mmore operations
20+
- Permission control on some operation such as file deletion
21+
- Advanced query for retrieval (send more sophisticated object, for instance representing a whole conversation, rather than one string)
22+
23+
version: 1.0.0
24+
servers:
25+
- url: /v1
26+
description: API server
27+
28+
paths:
29+
/files:
30+
post:
31+
tags:
32+
- File Operations
33+
summary: Upload a file
34+
description: |
35+
Upload a new file
36+
37+
**Requirements**:
38+
- Unique fileId
39+
requestBody:
40+
content:
41+
multipart/form-data:
42+
schema:
43+
type: object
44+
properties:
45+
fileId:
46+
type: string
47+
description: Unique identifier for the file
48+
file:
49+
type: string
50+
format: binary
51+
description: The file content
52+
required:
53+
- fileId
54+
- file
55+
responses:
56+
'201':
57+
description: File successfully uploaded
58+
59+
/files/{id}:
60+
put:
61+
tags:
62+
- File Operations
63+
summary: Update a file
64+
description: Replace an existing file with a new version
65+
parameters:
66+
- name: id
67+
in: path
68+
required: true
69+
schema:
70+
type: string
71+
requestBody:
72+
content:
73+
multipart/form-data:
74+
schema:
75+
type: object
76+
properties:
77+
file:
78+
type: string
79+
format: binary
80+
description: The new file content
81+
required:
82+
- file
83+
responses:
84+
'200':
85+
description: File successfully updated
86+
87+
delete:
88+
tags:
89+
- File Operations
90+
summary: Remove a file
91+
description: |
92+
Delete a file from the system.
93+
94+
**Warning**: Not limited operation.
95+
parameters:
96+
- name: id
97+
in: path
98+
required: true
99+
schema:
100+
type: string
101+
responses:
102+
'200':
103+
description: File successfully deleted
104+
105+
get:
106+
tags:
107+
- File Operations
108+
summary: Download a file
109+
description: Download a file from the system
110+
parameters:
111+
- name: id
112+
in: path
113+
required: true
114+
schema:
115+
type: string
116+
responses:
117+
'200':
118+
description: File content
119+
content:
120+
application/octet-stream:
121+
schema:
122+
type: string
123+
format: binary
124+
125+
/retrieve:
126+
post:
127+
tags:
128+
- Context Retrieval
129+
summary: Retrieve on MMORE
130+
description: |
131+
Search for files based on content similarity.
132+
133+
This request is meant to be called by the Gateway to MMORE service upon chat completions request from Moove.
134+
requestBody:
135+
required: true
136+
content:
137+
application/json:
138+
schema:
139+
type: object
140+
properties:
141+
fileIds:
142+
type: array
143+
items:
144+
type: string
145+
description: List of file IDs to search within
146+
maxMatches:
147+
type: integer
148+
minimum: 1
149+
description: Maximum number of matches to return
150+
minSimilarity:
151+
type: number
152+
format: float
153+
minimum: -1.0
154+
maximum: 1.0
155+
default: 0.0
156+
description: Minimum similarity score for results (-1.0 to 1.0)
157+
query:
158+
type: string
159+
description: Search query
160+
required:
161+
- fileIds
162+
- query
163+
- maxMatches
164+
responses:
165+
'200':
166+
description: List of matching files with related content, sorted by highest similarity first.
167+
content:
168+
application/json:
169+
schema:
170+
type: array
171+
items:
172+
type: object
173+
properties:
174+
fileId:
175+
type: string
176+
content:
177+
type: string
178+
similarity:
179+
type: number
180+
format: float
181+
minimum: -1.0
182+
maximum: 1.0
183+
description: Results sorted by similarity (highest first)

examples/rag/config_api.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
rag:
33
# LLM Config
44
llm:
5-
llm_name: "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # "OpenMeditron/meditron3-8b" # "gpt-4o-mini" # Anything supported
5+
llm_name: "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # "epfl-llm/meditron-70b" # "gpt-4o-mini" # Anything supported
66
max_new_tokens: 100
77
temperature: 0.8
88
# Retriever Config

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ dependencies = [
2727
"numpy==1.26.3",
2828
"pandas==2.2.3",
2929
"datasets==2.19.1",
30-
"transformers==4.47.0",
30+
"transformers==4.52",
3131
"fastapi[standard]",
3232
"fastapi==0.115.5",
3333
"fasteners==0.19",

src/mmore/cli.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ def process(config_file: str):
3333
help="Path to the config file for post-processing.",
3434
)
3535
@click.option(
36-
"--input-data", type=str, required=True, help="Path to the input JSONL file of documents."
36+
"--input-data",
37+
type=str,
38+
required=True,
39+
help="Path to the input JSONL file of documents.",
3740
)
3841
def postprocess(config_file: str, input_data: str):
3942
"""Run the post-processors pipeline.
@@ -59,7 +62,11 @@ def postprocess(config_file: str, input_data: str):
5962
help="Path to the config file for indexing.",
6063
)
6164
@click.option(
62-
"--documents-path", "-f", type=str, required=False, help="Path to the JSONL file of the (post)processed documents."
65+
"--documents-path",
66+
"-f",
67+
type=str,
68+
required=False,
69+
help="Path to the JSONL file of the (post)processed documents.",
6370
)
6471
@click.option(
6572
"--collection-name",
@@ -93,10 +100,18 @@ def index(config_file: str, documents_path: str, collection_name: str):
93100
help="Dispatcher configuration file path.",
94101
)
95102
@click.option(
96-
"--input-file", "-f", type=str, required=True, help="Path to the JSONL file of the input queries."
103+
"--input-file",
104+
"-f",
105+
type=str,
106+
required=True,
107+
help="Path to the JSONL file of the input queries.",
97108
)
98109
@click.option(
99-
"--output-file", "-o", type=str, required=True, help="Path to which save the results of the retriever as a JSON."
110+
"--output-file",
111+
"-o",
112+
type=str,
113+
required=True,
114+
help="Path to which save the results of the retriever as a JSON.",
100115
)
101116
def retrieve(config_file: str, input_file: str, output_file: str):
102117
"""Retrieve documents for specified queries.
@@ -171,8 +186,8 @@ def dashboard_backend(host, port):
171186
"""Run the dashboard backend.
172187
173188
Args:
174-
host:
175-
port:
189+
host:
190+
port:
176191
177192
Returns:
178193

src/mmore/process/dispatcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,8 @@ def batch_list(
342342
else:
343343
results = list(self._dispatch_local(task_lists))
344344

345+
ExecutionState.shutdown()
346+
345347
return results
346348

347349
def __call__(self) -> List[List[MultimodalSample]]:

src/mmore/process/execution_state.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def initialize(distributed_mode=False, client=None):
4747
ExecutionState._local_state = False
4848
logger.info("Execution state initialized (local mode)")
4949

50+
@staticmethod
51+
def shutdown():
52+
ExecutionState._use_dask = None
53+
ExecutionState._dask_var = None
54+
ExecutionState._local_state = False
55+
5056
@staticmethod
5157
def get_should_stop_execution() -> bool:
5258
"""Returns the global execution state (True if it should stop)"""

src/mmore/rag/pipeline.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,15 @@ def _build_chain(retriever, format_docs, prompt, llm) -> Runnable:
8787
validate_input = RunnableLambda(
8888
lambda x: MMOREInput.model_validate(x).model_dump()
8989
)
90-
validate_output = RunnableLambda(
91-
lambda x: MMOREOutput.model_validate(x).model_dump()
92-
)
90+
91+
def make_output(x):
92+
"""Validate the output of the LLM and keep only the actual answer of the assistant"""
93+
res_dict = MMOREOutput.model_validate(x).model_dump()
94+
res_dict["answer"] = res_dict["answer"].split("<|assistant|>\n")[-1]
95+
96+
return res_dict
97+
98+
validate_output = RunnableLambda(make_output)
9399

94100
rag_chain_from_docs = prompt | llm | StrOutputParser()
95101

src/mmore/rag/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ class MMOREInput(BaseModel):
1717
...,
1818
description="The collection",
1919
)
20+
document_ids: List[str] = Field(
21+
default_factory=list, # Set default to an empty list if not provided
22+
description="List of document IDs, defaults to an empty list if not provided.",
23+
)
2024

2125

2226
# ------------------------------- Simple Output ------------------------------ #

src/mmore/run_index_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ async def upload_files(
131131
with tempfile.TemporaryDirectory() as temp_dir:
132132
logging.info(f"Starting to process {len(files)} files with custom IDs")
133133

134-
for i, (file, file_id) in enumerate(zip(files, listIds)):
134+
for file, file_id in zip(files, listIds):
135135
if file.filename is None:
136136
raise HTTPException(
137137
status_code=422,

0 commit comments

Comments
 (0)