1+ from pydantic import PrivateAttr
12from pydantic_settings import BaseSettings
2- from os .path import dirname , abspath , join
3-
3+ from os .path import dirname , abspath
4+ from .prompt import get_prompt_template
5+ import os
6+ import yaml
47
58class Settings (BaseSettings ):
69 """
7- Settings for the Chatqna-Core application.
10+ Settings class for configuring the Chatqna-Core application.
11+ This class manages application-wide configuration, including model settings, device preferences,
12+ supported file formats, and paths for caching and configuration files. It loads additional
13+ configuration from a YAML file if provided, and updates its attributes accordingly.
814
915 Attributes:
10- APP_DISPLAY_NAME (str): The display name of the application.
11- BASE_DIR (str): The base directory of the application.
12- SUPPORTED_FORMATS (set): A set of supported file formats.
16+ APP_DISPLAY_NAME (str): Display name of the application.
17+ BASE_DIR (str): Base directory of the application.
18+ SUPPORTED_FORMATS (set): Supported document file formats.
1319 DEBUG (bool): Flag to enable or disable debug mode.
14- TMP_FILE_PATH (str): The temporary file path for documents.
15- HF_ACCESS_TOKEN (str): The Hugging Face access token.
16- EMBEDDING_MODEL_ID (str): The ID of the embedding model.
17- RERANKER_MODEL_ID (str): The ID of the reranker model.
18- LLM_MODEL_ID (str): The ID of the large language model.
19- EMBEDDING_DEVICE (str): The device used for embedding.
20- RERANKER_DEVICE (str): The device used for reranker.
21- LLM_DEVICE (str): The device used for LLM inferencing.
22- CACHE_DIR (str): The directory used for caching.
23- HF_DATASETS_CACHE (str): The cache directory for Hugging Face datasets.
24- MAX_TOKENS (int): The maximum number of output tokens.
20+ HF_ACCESS_TOKEN (str): Hugging Face access token for model downloads.
21+ EMBEDDING_MODEL_ID (str): Model ID for embeddings.
22+ RERANKER_MODEL_ID (str): Model ID for reranker.
23+ LLM_MODEL_ID (str): Model ID for large language model.
24+ PROMPT_TEMPLATE (str): Prompt template for the LLM.
25+ EMBEDDING_DEVICE (str): Device to run embedding model on.
26+ RERANKER_DEVICE (str): Device to run reranker model on.
27+ LLM_DEVICE (str): Device to run LLM on.
28+ MAX_TOKENS (int): Maximum number of tokens for LLM responses.
2529 ENABLE_RERANK (bool): Flag to enable or disable reranking.
30+ _CACHE_DIR (str): Directory for model cache (private).
31+ _HF_DATASETS_CACHE (str): Directory for Hugging Face datasets cache (private).
32+ _TMP_FILE_PATH (str): Temporary file path for documents (private).
33+ _DEFAULT_MODEL_CONFIG (str): Path to default model configuration YAML (private).
34+ _MODEL_CONFIG_PATH (str): Path to user-provided model configuration YAML (private).
2635
27- Config:
28- env_file (str): The path to the environment file.
36+ Methods:
37+ __init__(**kwargs): Initializes the Settings object, loads configuration from YAML file,
38+ and updates attributes accordingly.
2939 """
3040
3141 APP_DISPLAY_NAME : str = "Chatqna-Core"
3242 BASE_DIR : str = dirname (dirname (abspath (__file__ )))
3343 SUPPORTED_FORMATS : set = {".pdf" , ".txt" , ".docx" }
3444 DEBUG : bool = False
3545
36- HF_ACCESS_TOKEN : str = ...
37- EMBEDDING_MODEL_ID : str = ...
38- RERANKER_MODEL_ID : str = ...
39- LLM_MODEL_ID : str = ...
40- EMBEDDING_DEVICE : str = ...
41- RERANKER_DEVICE : str = ...
42- LLM_DEVICE : str = ...
43- CACHE_DIR : str = ...
44- HF_DATASETS_CACHE : str = ...
45- MAX_TOKENS : int = ...
46- ENABLE_RERANK : bool = ...
47- TMP_FILE_PATH : str = ...
48-
49- class Config :
50- env_file = join (dirname (abspath (__file__ )), ".env" )
46+ HF_ACCESS_TOKEN : str = ""
47+ EMBEDDING_MODEL_ID : str = ""
48+ RERANKER_MODEL_ID : str = ""
49+ LLM_MODEL_ID : str = ""
50+ PROMPT_TEMPLATE : str = ""
51+ EMBEDDING_DEVICE : str = "CPU"
52+ RERANKER_DEVICE : str = "CPU"
53+ LLM_DEVICE : str = "CPU"
54+ MAX_TOKENS : int = 1024
55+ ENABLE_RERANK : bool = True
56+
57+ # These fields will not be affected by environment variables
58+ _CACHE_DIR : str = PrivateAttr ("/tmp/model_cache" )
59+ _HF_DATASETS_CACHE : str = PrivateAttr ("/tmp/model_cache" )
60+ _TMP_FILE_PATH : str = PrivateAttr ("/tmp/chatqna/documents" )
61+ _DEFAULT_MODEL_CONFIG : str = PrivateAttr ("/tmp/model_config/default_model.yaml" )
62+ _MODEL_CONFIG_PATH : str = PrivateAttr ("/tmp/model_config/config.yaml" )
63+
64+
65+ def __init__ (self , ** kwargs ):
66+ super ().__init__ (** kwargs )
67+
68+ # The RUN_TEST flag is used to bypass the model config loading during pytest unit testing.
69+ # If RUN_TEST is set to "True", the model config loading is skipped.
70+ # This flag is set in the conftest.py file before running the tests.
71+ if os .getenv ("RUN_TEST" , "" ).lower () == "true" :
72+ print ("INFO - Skipping model config loading in test mode." )
73+ return
74+
75+ config_file = self ._MODEL_CONFIG_PATH if os .path .isfile (self ._MODEL_CONFIG_PATH ) else self ._DEFAULT_MODEL_CONFIG
76+
77+ if config_file == self ._MODEL_CONFIG_PATH :
78+ print (f"INFO - Model configuration yaml from user found in { config_file } . Loading configuration from { config_file } " )
79+
80+ else :
81+ print ("WARNING - User did not provide model configuration yaml file via MODEL_CONFIG_PATH." )
82+ print (f"INFO - Proceeding with default settings from { config_file } " )
83+
84+ with open (config_file , 'r' ) as f :
85+ config = yaml .safe_load (f )
86+
87+ for section in ("model_settings" , "device_settings" ):
88+ for key , value in config .get (section , {}).items ():
89+ if hasattr (self , key ):
90+ setattr (self , key , value )
91+
92+ self ._validate_model_ids ()
93+
94+ self ._check_and_validate_prompt_template ()
95+
96+ def _validate_model_ids (self ):
97+ for model_name in ["EMBEDDING_MODEL_ID" , "RERANKER_MODEL_ID" , "LLM_MODEL_ID" ]:
98+ model_id = getattr (self , model_name )
99+ if not model_id :
100+ raise ValueError (f"{ model_name } must not be an empty string." )
101+
102+ def _check_and_validate_prompt_template (self ):
103+ if not self .PROMPT_TEMPLATE :
104+ print ("INFO - PROMPT_TEMPLATE is not set. Getting default prompt_template." )
105+ self .PROMPT_TEMPLATE = get_prompt_template (self .LLM_MODEL_ID )
106+
107+ # Validate PROMPT_TEMPLATE
108+ required_placeholders = ["{context}" , "{question}" ]
109+ for placeholder in required_placeholders :
110+ if placeholder not in self .PROMPT_TEMPLATE :
111+ raise ValueError (f"PROMPT_TEMPLATE must include the placeholder { placeholder } ." )
112+
113+
114+ config = Settings ()
0 commit comments