Skip to content

Commit 1654a93

Browse files
hteeyeoh14pankajmadhuri-rai07
authored
ChatQnA fixes (#553)
Signed-off-by: Yeoh, Hoong Tee <hoong.tee.yeoh@intel.com> Co-authored-by: 14pankaj <pankaj.kumar.singh@intel.com> Co-authored-by: Pankaj Kumar Singh <97222471+14pankaj@users.noreply.github.com> Co-authored-by: Madhuri Kumari <madhuri.rai07@gmail.com>
1 parent aa19e39 commit 1654a93

64 files changed

Lines changed: 1580 additions & 1154 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
*.pyc
2+
__pycache__/
3+
.*_cache/
4+
**/charts
5+
.vscode
6+
.venv
7+
coverage
8+
.coverage
9+
.coverage-report/
10+
*.lock
11+
!poetry.lock
12+
.vscode

sample-applications/chat-question-and-answer-core/app/.env

Lines changed: 0 additions & 8 deletions
This file was deleted.

sample-applications/chat-question-and-answer-core/app/chain.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .config import Settings
1+
from .config import config
22
from .utils import login_to_huggingface, download_huggingface_model, convert_model
33
from .document import load_file_document
44
from .logger import logger
@@ -14,73 +14,55 @@
1414
import os
1515
import pandas as pd
1616

17-
config = Settings()
1817
vectorstore = None
1918

2019
# The RUN_TEST flag is used to bypass the model download and conversion steps during pytest unit testing.
21-
# By default, the flag is set to 'false', enabling the model download and conversion process in a normal run.
22-
# To skip these steps, set the flag to 'true'.
23-
# Check environment flag
24-
RUN_TEST = os.getenv('RUN_TEST', False)
25-
26-
if not RUN_TEST:
20+
# If RUN_TEST is set to "True", the model download and conversion steps are skipped.
21+
# This flag is set in the conftest.py file before running the tests.
22+
if os.getenv("RUN_TEST", "").lower() != "true":
2723
# login huggingface
2824
login_to_huggingface(config.HF_ACCESS_TOKEN)
2925

3026
# Download convert the model to openvino optimized
31-
download_huggingface_model(config.EMBEDDING_MODEL_ID, config.CACHE_DIR)
32-
download_huggingface_model(config.RERANKER_MODEL_ID, config.CACHE_DIR)
33-
download_huggingface_model(config.LLM_MODEL_ID, config.CACHE_DIR)
27+
download_huggingface_model(config.EMBEDDING_MODEL_ID, config._CACHE_DIR)
28+
download_huggingface_model(config.RERANKER_MODEL_ID, config._CACHE_DIR)
29+
download_huggingface_model(config.LLM_MODEL_ID, config._CACHE_DIR)
3430

3531
# Convert to openvino IR
36-
convert_model(config.EMBEDDING_MODEL_ID, config.CACHE_DIR, "embedding")
37-
convert_model(config.RERANKER_MODEL_ID, config.CACHE_DIR, "reranker")
38-
convert_model(config.LLM_MODEL_ID, config.CACHE_DIR, "llm")
39-
40-
# Define RAG prompt
41-
template = """
42-
Use the following pieces of context from retrieved
43-
dataset to answer the question. Do not make up an answer if there is no
44-
context provided to help answer it.
32+
convert_model(config.EMBEDDING_MODEL_ID, config._CACHE_DIR, "embedding")
33+
convert_model(config.RERANKER_MODEL_ID, config._CACHE_DIR, "reranker")
34+
convert_model(config.LLM_MODEL_ID, config._CACHE_DIR, "llm")
4535

46-
Context:
47-
---------
48-
{context}
4936

50-
---------
51-
Question: {question}
52-
---------
53-
54-
Answer:
55-
"""
37+
template = config.PROMPT_TEMPLATE
5638

5739
prompt = ChatPromptTemplate.from_template(template)
5840

5941
# Initialize Embedding Model
6042
embedding = OpenVINOBgeEmbeddings(
61-
model_name_or_path=f"{config.CACHE_DIR}/{config.EMBEDDING_MODEL_ID}",
43+
model_name_or_path=f"{config._CACHE_DIR}/{config.EMBEDDING_MODEL_ID}",
6244
model_kwargs={"device": config.EMBEDDING_DEVICE, "compile": False},
6345
)
6446
embedding.ov_model.compile()
6547

6648
# Initialize Reranker Model
6749
reranker = OpenVINOReranker(
68-
model_name_or_path=f"{config.CACHE_DIR}/{config.RERANKER_MODEL_ID}",
50+
model_name_or_path=f"{config._CACHE_DIR}/{config.RERANKER_MODEL_ID}",
6951
model_kwargs={"device": config.RERANKER_DEVICE},
7052
top_n=2,
7153
)
7254

7355
# Initialize LLM
7456
llm = HuggingFacePipeline.from_model_id(
75-
model_id=f"{config.CACHE_DIR}/{config.LLM_MODEL_ID}",
57+
model_id=f"{config._CACHE_DIR}/{config.LLM_MODEL_ID}",
7658
task="text-generation",
7759
backend="openvino",
7860
model_kwargs={
7961
"device": config.LLM_DEVICE,
8062
"ov_config": {
8163
"PERFORMANCE_HINT": "LATENCY",
8264
"NUM_STREAMS": "1",
83-
"CACHE_DIR": f"{config.CACHE_DIR}/{config.LLM_MODEL_ID}/model_cache",
65+
"CACHE_DIR": f"{config._CACHE_DIR}/{config.LLM_MODEL_ID}/model_cache",
8466
},
8567
"trust_remote_code": True,
8668
},
@@ -287,4 +269,4 @@ def delete_embedding_from_vectordb(document: str = "", delete_all: bool = False)
287269

288270
vectorstore.delete(chunk_list)
289271

290-
return True
272+
return True
Lines changed: 98 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,114 @@
1+
from pydantic import PrivateAttr
12
from pydantic_settings import BaseSettings
2-
from os.path import dirname, abspath, join
3-
3+
from os.path import dirname, abspath
4+
from .prompt import get_prompt_template
5+
import os
6+
import yaml
47

58
class Settings(BaseSettings):
69
"""
7-
Settings for the Chatqna-Core application.
10+
Settings class for configuring the Chatqna-Core application.
11+
This class manages application-wide configuration, including model settings, device preferences,
12+
supported file formats, and paths for caching and configuration files. It loads additional
13+
configuration from a YAML file if provided, and updates its attributes accordingly.
814
915
Attributes:
10-
APP_DISPLAY_NAME (str): The display name of the application.
11-
BASE_DIR (str): The base directory of the application.
12-
SUPPORTED_FORMATS (set): A set of supported file formats.
16+
APP_DISPLAY_NAME (str): Display name of the application.
17+
BASE_DIR (str): Base directory of the application.
18+
SUPPORTED_FORMATS (set): Supported document file formats.
1319
DEBUG (bool): Flag to enable or disable debug mode.
14-
TMP_FILE_PATH (str): The temporary file path for documents.
15-
HF_ACCESS_TOKEN (str): The Hugging Face access token.
16-
EMBEDDING_MODEL_ID (str): The ID of the embedding model.
17-
RERANKER_MODEL_ID (str): The ID of the reranker model.
18-
LLM_MODEL_ID (str): The ID of the large language model.
19-
EMBEDDING_DEVICE (str): The device used for embedding.
20-
RERANKER_DEVICE (str): The device used for reranker.
21-
LLM_DEVICE (str): The device used for LLM inferencing.
22-
CACHE_DIR (str): The directory used for caching.
23-
HF_DATASETS_CACHE (str): The cache directory for Hugging Face datasets.
24-
MAX_TOKENS (int): The maximum number of output tokens.
20+
HF_ACCESS_TOKEN (str): Hugging Face access token for model downloads.
21+
EMBEDDING_MODEL_ID (str): Model ID for embeddings.
22+
RERANKER_MODEL_ID (str): Model ID for reranker.
23+
LLM_MODEL_ID (str): Model ID for large language model.
24+
PROMPT_TEMPLATE (str): Prompt template for the LLM.
25+
EMBEDDING_DEVICE (str): Device to run embedding model on.
26+
RERANKER_DEVICE (str): Device to run reranker model on.
27+
LLM_DEVICE (str): Device to run LLM on.
28+
MAX_TOKENS (int): Maximum number of tokens for LLM responses.
2529
ENABLE_RERANK (bool): Flag to enable or disable reranking.
30+
_CACHE_DIR (str): Directory for model cache (private).
31+
_HF_DATASETS_CACHE (str): Directory for Hugging Face datasets cache (private).
32+
_TMP_FILE_PATH (str): Temporary file path for documents (private).
33+
_DEFAULT_MODEL_CONFIG (str): Path to default model configuration YAML (private).
34+
_MODEL_CONFIG_PATH (str): Path to user-provided model configuration YAML (private).
2635
27-
Config:
28-
env_file (str): The path to the environment file.
36+
Methods:
37+
__init__(**kwargs): Initializes the Settings object, loads configuration from YAML file,
38+
and updates attributes accordingly.
2939
"""
3040

3141
APP_DISPLAY_NAME: str = "Chatqna-Core"
3242
BASE_DIR: str = dirname(dirname(abspath(__file__)))
3343
SUPPORTED_FORMATS: set = {".pdf", ".txt", ".docx"}
3444
DEBUG: bool = False
3545

36-
HF_ACCESS_TOKEN: str = ...
37-
EMBEDDING_MODEL_ID: str = ...
38-
RERANKER_MODEL_ID: str = ...
39-
LLM_MODEL_ID: str = ...
40-
EMBEDDING_DEVICE: str = ...
41-
RERANKER_DEVICE: str = ...
42-
LLM_DEVICE: str = ...
43-
CACHE_DIR: str = ...
44-
HF_DATASETS_CACHE: str = ...
45-
MAX_TOKENS: int = ...
46-
ENABLE_RERANK: bool = ...
47-
TMP_FILE_PATH: str = ...
48-
49-
class Config:
50-
env_file = join(dirname(abspath(__file__)), ".env")
46+
HF_ACCESS_TOKEN: str = ""
47+
EMBEDDING_MODEL_ID: str = ""
48+
RERANKER_MODEL_ID: str = ""
49+
LLM_MODEL_ID: str = ""
50+
PROMPT_TEMPLATE: str = ""
51+
EMBEDDING_DEVICE: str = "CPU"
52+
RERANKER_DEVICE: str = "CPU"
53+
LLM_DEVICE: str = "CPU"
54+
MAX_TOKENS: int = 1024
55+
ENABLE_RERANK: bool = True
56+
57+
# These fields will not be affected by environment variables
58+
_CACHE_DIR: str = PrivateAttr("/tmp/model_cache")
59+
_HF_DATASETS_CACHE: str = PrivateAttr("/tmp/model_cache")
60+
_TMP_FILE_PATH: str = PrivateAttr("/tmp/chatqna/documents")
61+
_DEFAULT_MODEL_CONFIG: str = PrivateAttr("/tmp/model_config/default_model.yaml")
62+
_MODEL_CONFIG_PATH: str = PrivateAttr("/tmp/model_config/config.yaml")
63+
64+
65+
def __init__(self, **kwargs):
66+
super().__init__(**kwargs)
67+
68+
# The RUN_TEST flag is used to bypass the model config loading during pytest unit testing.
69+
# If RUN_TEST is set to "True", the model config loading is skipped.
70+
# This flag is set in the conftest.py file before running the tests.
71+
if os.getenv("RUN_TEST", "").lower() == "true":
72+
print("INFO - Skipping model config loading in test mode.")
73+
return
74+
75+
config_file = self._MODEL_CONFIG_PATH if os.path.isfile(self._MODEL_CONFIG_PATH) else self._DEFAULT_MODEL_CONFIG
76+
77+
if config_file == self._MODEL_CONFIG_PATH:
78+
print(f"INFO - Model configuration yaml from user found in {config_file}. Loading configuration from {config_file}")
79+
80+
else:
81+
print("WARNING - User did not provide model configuration yaml file via MODEL_CONFIG_PATH.")
82+
print(f"INFO - Proceeding with default settings from {config_file}")
83+
84+
with open(config_file, 'r') as f:
85+
config = yaml.safe_load(f)
86+
87+
for section in ("model_settings", "device_settings"):
88+
for key, value in config.get(section, {}).items():
89+
if hasattr(self, key):
90+
setattr(self, key, value)
91+
92+
self._validate_model_ids()
93+
94+
self._check_and_validate_prompt_template()
95+
96+
def _validate_model_ids(self):
97+
for model_name in ["EMBEDDING_MODEL_ID", "RERANKER_MODEL_ID", "LLM_MODEL_ID"]:
98+
model_id = getattr(self, model_name)
99+
if not model_id:
100+
raise ValueError(f"{model_name} must not be an empty string.")
101+
102+
def _check_and_validate_prompt_template(self):
103+
if not self.PROMPT_TEMPLATE:
104+
print("INFO - PROMPT_TEMPLATE is not set. Getting default prompt_template.")
105+
self.PROMPT_TEMPLATE = get_prompt_template(self.LLM_MODEL_ID)
106+
107+
# Validate PROMPT_TEMPLATE
108+
required_placeholders = ["{context}", "{question}"]
109+
for placeholder in required_placeholders:
110+
if placeholder not in self.PROMPT_TEMPLATE:
111+
raise ValueError(f"PROMPT_TEMPLATE must include the placeholder {placeholder}.")
112+
113+
114+
config = Settings()

sample-applications/chat-question-and-answer-core/app/document.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from .config import Settings
2+
from .config import config
33
from .logger import logger
44
from pathlib import Path
55
from fastapi import UploadFile
@@ -9,8 +9,6 @@
99
TextLoader
1010
)
1111

12-
config = Settings()
13-
1412

1513
def validate_document(file_object: UploadFile):
1614
"""
@@ -45,7 +43,7 @@ async def save_document(file_object: UploadFile):
4543
If the file is saved successfully, the error will be None. If an error occurs, the path will be None.
4644
"""
4745

48-
tmp_path = Path(config.TMP_FILE_PATH) / file_object.filename
46+
tmp_path = Path(config._TMP_FILE_PATH) / file_object.filename
4947
if not tmp_path.parent.exists():
5048
tmp_path.parent.mkdir(parents=True, exist_ok=True)
5149

@@ -85,4 +83,4 @@ def load_file_document(file_path):
8583

8684
docs = loader.load()
8785

88-
return docs
86+
return docs

sample-applications/chat-question-and-answer-core/app/logger.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
from .config import Settings
1+
from .config import config
22
from typing import Optional
33
import logging
44
import sys
55

6-
config = Settings()
76

87
def initialize_logger(name: Optional[str] = None) -> logging.Logger:
98
"""

0 commit comments

Comments
 (0)