Skip to content

Commit a82740b

Browse files
committed
fix(main): move qdrant collection_name to retriever_config
1 parent 09690f1 commit a82740b

File tree

5 files changed

+25
-24
lines changed

5 files changed

+25
-24
lines changed

src/flare_ai_rag/main.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def setup_retriever(
3232
qdrant_client: QdrantClient,
3333
input_config: dict,
3434
df_docs: pd.DataFrame,
35-
collection_name: str | None = None,
3635
) -> QdrantRetriever:
3736
"""Initialize the Qdrant retriever."""
3837
# Set up Qdrant config
@@ -41,17 +40,16 @@ def setup_retriever(
4140
# Set up Gemini Embedding client
4241
embedding_client = GeminiEmbedding(settings.gemini_api_key)
4342
# (Re)generate qdrant collection
44-
if collection_name:
45-
generate_collection(
46-
df_docs,
47-
qdrant_client,
48-
retriever_config,
49-
collection_name=collection_name,
50-
embedding_client=embedding_client,
51-
)
52-
logger.info(
53-
"The Qdrant collection has been generated.", collection_name=collection_name
54-
)
43+
generate_collection(
44+
df_docs,
45+
qdrant_client,
46+
retriever_config,
47+
embedding_client=embedding_client,
48+
)
49+
logger.info(
50+
"The Qdrant collection has been generated.",
51+
collection_name=retriever_config.collection_name,
52+
)
5553
# Return retriever
5654
return QdrantRetriever(
5755
client=qdrant_client,
@@ -100,9 +98,7 @@ def main() -> None:
10098
qdrant_client = setup_qdrant(input_config)
10199

102100
# Set up retriever. (Use Gemini Embedding.)
103-
retriever = setup_retriever(
104-
qdrant_client, input_config, df_docs, collection_name="docs_collection"
105-
)
101+
retriever = setup_retriever(qdrant_client, input_config, df_docs)
106102

107103
# Set up responder. (Use Gemini Provider.)
108104
responder = setup_responder(input_config)

src/flare_ai_rag/responder/prompts.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
77
Guidelines:
88
- Use the provided context to support your answer. If applicable,
9-
include citations referring to the context (e.g., "[Document <name>]" or "[Source <name>]").
9+
include citations referring to the context (e.g., "[Document <name>]" or
10+
"[Source <name>]").
1011
- Be clear, factual, and concise. Do not introduce any information that isn't
1112
explicitly supported by the context.
1213
- Maintain a professional tone and ensure that all technical details are accurate.

src/flare_ai_rag/retriever/qdrant_collection.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,16 @@ def generate_collection(
2828
df_docs: pd.DataFrame,
2929
qdrant_client: QdrantClient,
3030
retriever_config: RetrieverConfig,
31-
collection_name: str,
3231
embedding_client: GeminiEmbedding,
3332
) -> None:
3433
"""Routine for generating a Qdrant collection for a specific CSV file type."""
3534
# Create the collection.
36-
_create_collection(qdrant_client, collection_name, retriever_config.vector_size)
37-
logger.info("Created the collection.", collection_name=collection_name)
35+
_create_collection(
36+
qdrant_client, retriever_config.collection_name, retriever_config.vector_size
37+
)
38+
logger.info(
39+
"Created the collection.", collection_name=retriever_config.collection_name
40+
)
3841

3942
# For each document in the CSV, compute its embedding and prepare a Qdrant point.
4043
points = []
@@ -74,10 +77,12 @@ def generate_collection(
7477

7578
if points:
7679
# Upload the points into the Qdrant collection.
77-
qdrant_client.upsert(collection_name=collection_name, points=points)
80+
qdrant_client.upsert(
81+
collection_name=retriever_config.collection_name, points=points
82+
)
7883
logger.info(
7984
"Collection generated and documents inserted into Qdrant successfully.",
80-
collection_name=collection_name,
85+
collection_name=retriever_config.collection_name,
8186
num_points=len(points),
8287
)
8388
else:

tests/test_generate_collection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
def main() -> None:
1515
# Load Qdrant config
1616
config_json = load_json(settings.input_path / "input_parameters.json")
17-
retriever_config = RetrieverConfig.load(config_json["qdrant_config"])
17+
retriever_config = RetrieverConfig.load(config_json["retriever_config"])
1818

1919
# Load the CSV file.
2020
df_docs = pd.read_csv(settings.data_path / "docs.csv", delimiter=",")
@@ -30,7 +30,6 @@ def main() -> None:
3030
df_docs,
3131
client,
3232
retriever_config,
33-
collection_name="docs_collection",
3433
embedding_client=embedding_client,
3534
)
3635

tests/test_qdrant_retriever.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
def main() -> None:
1313
# Load Qdrant config
1414
config_json = load_json(settings.input_path / "input_parameters.json")
15-
retriever_config = RetrieverConfig.load(config_json["qdrant_config"])
15+
retriever_config = RetrieverConfig.load(config_json["retriever_config"])
1616

1717
# Initialize Qdrant client
1818
client = QdrantClient(host=retriever_config.host, port=retriever_config.port)

0 commit comments

Comments
 (0)