Skip to content

Commit 1cd0a99

Browse files
committed
code upated with model lodaers
1 parent 4dbaa1f commit 1cd0a99

File tree

4 files changed

+140
-74
lines changed

4 files changed

+140
-74
lines changed

.github/workflows/task_definition.json

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,7 @@
2828
],
2929
"secrets": [
3030
{
31-
"name": "GROQ_API_KEY",
32-
"valueFrom": "arn:aws:secretsmanager:ap-southeast-2:459497895986:secret:api_keys-nZTtj8"
33-
},
34-
35-
{
36-
"name": "GOOGLE_API_KEY",
31+
"name": "API_KEYS",
3732
"valueFrom": "arn:aws:secretsmanager:ap-southeast-2:459497895986:secret:api_keys-nZTtj8"
3833
}
3934

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ logs/
77
*.pyc
88
faiss_index/
99
main_archive/
10-
data/
10+
data/
11+
archive/

test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,58 @@
241241
# # secret = get_secret_value_response['SecretString']
242242

243243
# # # Your code goes here.
244+
245+
246+
247+
# {
248+
# "family": "documentportaltd",
249+
# "networkMode": "awsvpc",
250+
# "executionRoleArn": "arn:aws:iam::459497895986:role/ecsTaskExecutionRole",
251+
# "requiresCompatibilities": ["FARGATE"],
252+
# "cpu": "1024",
253+
# "memory": "8192",
254+
# "containerDefinitions": [
255+
# {
256+
# "name": "document-portal-container",
257+
# "image": "459497895986.dkr.ecr.ap-southeast-2.amazonaws.com/documentportalliveclass",
258+
# "cpu": 1024,
259+
# "essential": true,
260+
# "portMappings": [
261+
# {
262+
# "containerPort": 8080,
263+
# "hostPort": 8080,
264+
# "protocol": "tcp",
265+
# "name": "document-portal-container-8080-tcp",
266+
# "appProtocol": "http"
267+
# }
268+
# ],
269+
# "environment": [
270+
# {
271+
# "name": "ENV",
272+
# "value": "production"
273+
# }
274+
# ],
275+
# "secrets": [
276+
# {
277+
# "name": "GROQ_API_KEY",
278+
# "valueFrom": "arn:aws:secretsmanager:ap-southeast-2:459497895986:secret:api_keys-nZTtj8"
279+
# },
280+
281+
# {
282+
# "name": "GOOGLE_API_KEY",
283+
# "valueFrom": "arn:aws:secretsmanager:ap-southeast-2:459497895986:secret:api_keys-nZTtj8"
284+
# }
285+
286+
# ],
287+
# "logConfiguration": {
288+
# "logDriver": "awslogs",
289+
# "options": {
290+
# "awslogs-group": "/ecs/documentportaltd",
291+
# "awslogs-region": "ap-southeast-2",
292+
# "awslogs-stream-prefix": "ecs",
293+
# "awslogs-create-group": "true"
294+
# }
295+
# }
296+
# }
297+
# ]
298+
# }

utils/model_loader.py

Lines changed: 82 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,127 +1,142 @@
1-
21
import os
32
import sys
3+
import json
44
from dotenv import load_dotenv
55
from utils.config_loader import load_config
6-
from .config_loader import load_config
7-
from langchain_google_genai import GoogleGenerativeAIEmbeddings
8-
from langchain_google_genai import ChatGoogleGenerativeAI
6+
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
97
from langchain_groq import ChatGroq
10-
#from langchain_openai import ChatOpenAI
118
from logger import GLOBAL_LOGGER as log
129
from exception.custom_exception import DocumentPortalException
1310

11+
12+
class ApiKeyManager:
13+
REQUIRED_KEYS = ["GROQ_API_KEY", "GOOGLE_API_KEY"]
14+
15+
def __init__(self):
16+
self.api_keys = {}
17+
raw = os.getenv("API_KEYS")
18+
19+
if raw:
20+
try:
21+
parsed = json.loads(raw)
22+
if not isinstance(parsed, dict):
23+
raise ValueError("API_KEYS is not a valid JSON object")
24+
self.api_keys = parsed
25+
log.info("Loaded API_KEYS from ECS secret")
26+
except Exception as e:
27+
log.warning("Failed to parse API_KEYS as JSON", error=str(e))
28+
29+
# Fallback to individual env vars
30+
for key in self.REQUIRED_KEYS:
31+
if not self.api_keys.get(key):
32+
env_val = os.getenv(key)
33+
if env_val:
34+
self.api_keys[key] = env_val
35+
log.info(f"Loaded {key} from individual env var")
36+
37+
# Final check
38+
missing = [k for k in self.REQUIRED_KEYS if not self.api_keys.get(k)]
39+
if missing:
40+
log.error("Missing required API keys", missing_keys=missing)
41+
raise DocumentPortalException("Missing API keys", sys)
42+
43+
log.info("API keys loaded", keys={k: v[:6] + "..." for k, v in self.api_keys.items()})
44+
45+
46+
def get(self, key: str) -> str:
47+
val = self.api_keys.get(key)
48+
if not val:
49+
raise KeyError(f"API key for {key} is missing")
50+
return val
51+
52+
1453
class ModelLoader:
15-
1654
"""
17-
A utility class to load embedding models and LLM models.
55+
Loads embedding models and LLMs based on config and environment.
1856
"""
19-
57+
2058
def __init__(self):
21-
2259
if os.getenv("ENV", "local").lower() != "production":
2360
load_dotenv()
24-
log.info("Running in LOCAL mode: .env file loaded")
61+
log.info("Running in LOCAL mode: .env loaded")
2562
else:
26-
log.info("Running in PRODUCTION mode: .env not loaded")
27-
self._validate_env()
28-
self.config=load_config()
29-
log.info("Configuration loaded successfully", config_keys=list(self.config.keys()))
30-
31-
def _validate_env(self):
32-
"""
33-
Validate necessary environment variables.
34-
Ensure API keys exist.
35-
"""
36-
required_vars=["GOOGLE_API_KEY","GROQ_API_KEY"]
37-
self.api_keys={key:os.getenv(key) for key in required_vars}
38-
missing = [k for k, v in self.api_keys.items() if not v]
39-
if missing:
40-
log.error("Missing environment variables", missing_vars=missing)
41-
raise DocumentPortalException("Missing environment variables", sys)
42-
log.info("Environment variables validated", available_keys=[k for k in self.api_keys if self.api_keys[k]])
43-
log.info("Environment variables validated", available_keys={k: v[:30] + "..." if v else None for k, v in self.api_keys.items()})
63+
log.info("Running in PRODUCTION mode")
64+
65+
self.api_key_mgr = ApiKeyManager()
66+
self.config = load_config()
67+
log.info("YAML config loaded", config_keys=list(self.config.keys()))
4468

45-
4669
def load_embeddings(self):
4770
"""
48-
Load and return the embedding model.
71+
Load and return embedding model from Google Generative AI.
4972
"""
5073
try:
51-
log.info("Loading embedding model...")
5274
model_name = self.config["embedding_model"]["model_name"]
53-
return GoogleGenerativeAIEmbeddings(model=model_name)
75+
log.info("Loading embedding model", model=model_name)
76+
return GoogleGenerativeAIEmbeddings(model=model_name,
77+
google_api_key=self.api_key_mgr.get("GOOGLE_API_KEY")) #type: ignore
5478
except Exception as e:
5579
log.error("Error loading embedding model", error=str(e))
5680
raise DocumentPortalException("Failed to load embedding model", sys)
57-
81+
5882
def load_llm(self):
5983
"""
60-
Load and return the LLM model.
84+
Load and return the configured LLM model.
6185
"""
62-
"""Load LLM dynamically based on provider in config."""
63-
6486
llm_block = self.config["llm"]
87+
provider_key = os.getenv("LLM_PROVIDER", "google")
6588

66-
log.info("Loading LLM...")
67-
68-
provider_key = os.getenv("LLM_PROVIDER", "google") # Default google
6989
if provider_key not in llm_block:
70-
log.error("LLM provider not found in config", provider_key=provider_key)
71-
raise ValueError(f"Provider '{provider_key}' not found in config")
90+
log.error("LLM provider not found in config", provider=provider_key)
91+
raise ValueError(f"LLM provider '{provider_key}' not found in config")
7292

7393
llm_config = llm_block[provider_key]
7494
provider = llm_config.get("provider")
7595
model_name = llm_config.get("model_name")
7696
temperature = llm_config.get("temperature", 0.2)
7797
max_tokens = llm_config.get("max_output_tokens", 2048)
78-
79-
log.info("Loading LLM", provider=provider, model=model_name, temperature=temperature, max_tokens=max_tokens)
98+
99+
log.info("Loading LLM", provider=provider, model=model_name)
80100

81101
if provider == "google":
82-
llm=ChatGoogleGenerativeAI(
102+
return ChatGoogleGenerativeAI(
83103
model=model_name,
104+
google_api_key=self.api_key_mgr.get("GOOGLE_API_KEY"),
84105
temperature=temperature,
85106
max_output_tokens=max_tokens
86107
)
87-
return llm
88108

89109
elif provider == "groq":
90-
llm=ChatGroq(
110+
return ChatGroq(
91111
model=model_name,
92-
api_key=self.api_keys["GROQ_API_KEY"], #type: ignore
112+
api_key=self.api_key_mgr.get("GROQ_API_KEY"), #type: ignore
93113
temperature=temperature,
94114
)
95-
return llm
96-
115+
97116
# elif provider == "openai":
98117
# return ChatOpenAI(
99118
# model=model_name,
100-
# api_key=self.api_keys["OPENAI_API_KEY"],
119+
# api_key=self.api_key_mgr.get("OPENAI_API_KEY"),
101120
# temperature=temperature,
102121
# max_tokens=max_tokens
103122
# )
123+
104124
else:
105125
log.error("Unsupported LLM provider", provider=provider)
106126
raise ValueError(f"Unsupported LLM provider: {provider}")
107-
108-
109-
127+
128+
110129
if __name__ == "__main__":
111130
loader = ModelLoader()
112-
113-
# Test embedding model loading
131+
132+
# Test Embedding
114133
embeddings = loader.load_embeddings()
115-
print(f"Embedding Model Loaded: {embeddings}")
116-
117-
# Test the ModelLoader
118-
result=embeddings.embed_query("Hello, how are you?")
119-
print(f"Embedding Result: {result}")
120-
121-
# Test LLM loading based on YAML config
134+
print(f"✅ Embedding Model Loaded: {embeddings}")
135+
result = embeddings.embed_query("Hello, how are you?")
136+
print(f"✅ Embedding Result: {result}")
137+
138+
# Test LLM
122139
llm = loader.load_llm()
123-
print(f"LLM Loaded: {llm}")
124-
125-
# Test the ModelLoader
126-
result=llm.invoke("Hello, how are you?")
127-
print(f"LLM Result: {result.content}")
140+
print(f"✅ LLM Loaded: {llm}")
141+
result = llm.invoke("Hello, how are you?")
142+
print(f"✅ LLM Result: {result.content}")

0 commit comments

Comments
 (0)