1-
21import os
32import sys
3+ import json
44from dotenv import load_dotenv
55from utils .config_loader import load_config
6- from .config_loader import load_config
7- from langchain_google_genai import GoogleGenerativeAIEmbeddings
8- from langchain_google_genai import ChatGoogleGenerativeAI
6+ from langchain_google_genai import GoogleGenerativeAIEmbeddings , ChatGoogleGenerativeAI
97from langchain_groq import ChatGroq
10- #from langchain_openai import ChatOpenAI
118from logger import GLOBAL_LOGGER as log
129from exception .custom_exception import DocumentPortalException
1310
11+
12+ class ApiKeyManager :
13+ REQUIRED_KEYS = ["GROQ_API_KEY" , "GOOGLE_API_KEY" ]
14+
15+ def __init__ (self ):
16+ self .api_keys = {}
17+ raw = os .getenv ("API_KEYS" )
18+
19+ if raw :
20+ try :
21+ parsed = json .loads (raw )
22+ if not isinstance (parsed , dict ):
23+ raise ValueError ("API_KEYS is not a valid JSON object" )
24+ self .api_keys = parsed
25+ log .info ("Loaded API_KEYS from ECS secret" )
26+ except Exception as e :
27+ log .warning ("Failed to parse API_KEYS as JSON" , error = str (e ))
28+
29+ # Fallback to individual env vars
30+ for key in self .REQUIRED_KEYS :
31+ if not self .api_keys .get (key ):
32+ env_val = os .getenv (key )
33+ if env_val :
34+ self .api_keys [key ] = env_val
35+ log .info (f"Loaded { key } from individual env var" )
36+
37+ # Final check
38+ missing = [k for k in self .REQUIRED_KEYS if not self .api_keys .get (k )]
39+ if missing :
40+ log .error ("Missing required API keys" , missing_keys = missing )
41+ raise DocumentPortalException ("Missing API keys" , sys )
42+
43+ log .info ("API keys loaded" , keys = {k : v [:6 ] + "..." for k , v in self .api_keys .items ()})
44+
45+
46+ def get (self , key : str ) -> str :
47+ val = self .api_keys .get (key )
48+ if not val :
49+ raise KeyError (f"API key for { key } is missing" )
50+ return val
51+
52+
1453class ModelLoader :
15-
1654 """
17- A utility class to load embedding models and LLM models .
55+ Loads embedding models and LLMs based on config and environment .
1856 """
19-
57+
2058 def __init__ (self ):
21-
2259 if os .getenv ("ENV" , "local" ).lower () != "production" :
2360 load_dotenv ()
24- log .info ("Running in LOCAL mode: .env file loaded" )
61+ log .info ("Running in LOCAL mode: .env loaded" )
2562 else :
26- log .info ("Running in PRODUCTION mode: .env not loaded" )
27- self ._validate_env ()
28- self .config = load_config ()
29- log .info ("Configuration loaded successfully" , config_keys = list (self .config .keys ()))
30-
31- def _validate_env (self ):
32- """
33- Validate necessary environment variables.
34- Ensure API keys exist.
35- """
36- required_vars = ["GOOGLE_API_KEY" ,"GROQ_API_KEY" ]
37- self .api_keys = {key :os .getenv (key ) for key in required_vars }
38- missing = [k for k , v in self .api_keys .items () if not v ]
39- if missing :
40- log .error ("Missing environment variables" , missing_vars = missing )
41- raise DocumentPortalException ("Missing environment variables" , sys )
42- log .info ("Environment variables validated" , available_keys = [k for k in self .api_keys if self .api_keys [k ]])
43- log .info ("Environment variables validated" , available_keys = {k : v [:30 ] + "..." if v else None for k , v in self .api_keys .items ()})
63+ log .info ("Running in PRODUCTION mode" )
64+
65+ self .api_key_mgr = ApiKeyManager ()
66+ self .config = load_config ()
67+ log .info ("YAML config loaded" , config_keys = list (self .config .keys ()))
4468
45-
4669 def load_embeddings (self ):
4770 """
48- Load and return the embedding model.
71+ Load and return embedding model from Google Generative AI .
4972 """
5073 try :
51- log .info ("Loading embedding model..." )
5274 model_name = self .config ["embedding_model" ]["model_name" ]
53- return GoogleGenerativeAIEmbeddings (model = model_name )
75+ log .info ("Loading embedding model" , model = model_name )
76+ return GoogleGenerativeAIEmbeddings (model = model_name ,
77+ google_api_key = self .api_key_mgr .get ("GOOGLE_API_KEY" )) #type: ignore
5478 except Exception as e :
5579 log .error ("Error loading embedding model" , error = str (e ))
5680 raise DocumentPortalException ("Failed to load embedding model" , sys )
57-
81+
5882 def load_llm (self ):
5983 """
60- Load and return the LLM model.
84+ Load and return the configured LLM model.
6185 """
62- """Load LLM dynamically based on provider in config."""
63-
6486 llm_block = self .config ["llm" ]
87+ provider_key = os .getenv ("LLM_PROVIDER" , "google" )
6588
66- log .info ("Loading LLM..." )
67-
68- provider_key = os .getenv ("LLM_PROVIDER" , "google" ) # Default google
6989 if provider_key not in llm_block :
70- log .error ("LLM provider not found in config" , provider_key = provider_key )
71- raise ValueError (f"Provider '{ provider_key } ' not found in config" )
90+ log .error ("LLM provider not found in config" , provider = provider_key )
91+ raise ValueError (f"LLM provider '{ provider_key } ' not found in config" )
7292
7393 llm_config = llm_block [provider_key ]
7494 provider = llm_config .get ("provider" )
7595 model_name = llm_config .get ("model_name" )
7696 temperature = llm_config .get ("temperature" , 0.2 )
7797 max_tokens = llm_config .get ("max_output_tokens" , 2048 )
78-
79- log .info ("Loading LLM" , provider = provider , model = model_name , temperature = temperature , max_tokens = max_tokens )
98+
99+ log .info ("Loading LLM" , provider = provider , model = model_name )
80100
81101 if provider == "google" :
82- llm = ChatGoogleGenerativeAI (
102+ return ChatGoogleGenerativeAI (
83103 model = model_name ,
104+ google_api_key = self .api_key_mgr .get ("GOOGLE_API_KEY" ),
84105 temperature = temperature ,
85106 max_output_tokens = max_tokens
86107 )
87- return llm
88108
89109 elif provider == "groq" :
90- llm = ChatGroq (
110+ return ChatGroq (
91111 model = model_name ,
92- api_key = self .api_keys [ "GROQ_API_KEY" ] , #type: ignore
112+ api_key = self .api_key_mgr . get ( "GROQ_API_KEY" ) , #type: ignore
93113 temperature = temperature ,
94114 )
95- return llm
96-
115+
97116 # elif provider == "openai":
98117 # return ChatOpenAI(
99118 # model=model_name,
100- # api_key=self.api_keys[ "OPENAI_API_KEY"] ,
119+ # api_key=self.api_key_mgr.get( "OPENAI_API_KEY") ,
101120 # temperature=temperature,
102121 # max_tokens=max_tokens
103122 # )
123+
104124 else :
105125 log .error ("Unsupported LLM provider" , provider = provider )
106126 raise ValueError (f"Unsupported LLM provider: { provider } " )
107-
108-
109-
127+
128+
110129if __name__ == "__main__" :
111130 loader = ModelLoader ()
112-
113- # Test embedding model loading
131+
132+ # Test Embedding
114133 embeddings = loader .load_embeddings ()
115- print (f"Embedding Model Loaded: { embeddings } " )
116-
117- # Test the ModelLoader
118- result = embeddings .embed_query ("Hello, how are you?" )
119- print (f"Embedding Result: { result } " )
120-
121- # Test LLM loading based on YAML config
134+ print (f"✅ Embedding Model Loaded: { embeddings } " )
135+ result = embeddings .embed_query ("Hello, how are you?" )
136+ print (f"✅ Embedding Result: { result } " )
137+
138+ # Test LLM
122139 llm = loader .load_llm ()
123- print (f"LLM Loaded: { llm } " )
124-
125- # Test the ModelLoader
126- result = llm .invoke ("Hello, how are you?" )
127- print (f"LLM Result: { result .content } " )
140+ print (f"✅ LLM Loaded: { llm } " )
141+ result = llm .invoke ("Hello, how are you?" )
142+ print (f"✅ LLM Result: { result .content } " )
0 commit comments