Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/huge-snails-shine.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"chatbot-evaluate": patch
---

align settings file
17 changes: 5 additions & 12 deletions apps/chatbot-evaluate/src/modules/settings.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,23 @@
import boto3
import os
import yaml
from pathlib import Path
from pydantic_settings import BaseSettings

from src.modules.utils import get_ssm_parameter


CWF = Path(__file__)
ROOT = CWF.parent.parent.parent.absolute().__str__()
PROMPTS = yaml.safe_load(open(os.path.join(ROOT, "config", "prompts.yaml"), "r"))
AWS_SESSION = boto3.Session()


class ChatbotSettings(BaseSettings):
"""Settings for the chatbot evaluation."""

# api keys
aws_access_key_id: str = os.getenv(
"AWS_ACCESS_KEY_ID", os.getenv("CHB_AWS_ACCESS_KEY_ID")
)
aws_default_region: str = os.getenv(
"AWS_REGION", os.getenv("CHB_AWS_DEFAULT_REGION")
)
aws_endpoint_url: str | None = os.getenv("CHB_AWS_SSM_ENDPOINT_URL")
aws_secret_access_key: str = os.getenv(
"AWS_SECRET_ACCESS_KEY", os.getenv("CHB_AWS_SECRET_ACCESS_KEY")
)
# api
aws_region: str = os.getenv("AWS_REGION", "us-east-1")
aws_endpoint_url: str = os.getenv("AWS_ENDPOINT_URL")
google_api_key: str = get_ssm_parameter(
os.getenv("CHB_AWS_SSM_GOOGLE_API_KEY"),
os.getenv("CHB_AWS_GOOGLE_API_KEY"),
Expand Down
8 changes: 1 addition & 7 deletions apps/chatbot-evaluate/src/modules/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import boto3
import json

from src.modules.logger import get_logger
from src.modules.models import get_llm, get_embed_model
Expand All @@ -16,11 +14,7 @@
def test_aws_credentials() -> None:
identity = None
try:
session = boto3.Session(
aws_access_key_id=os.getenv("CHB_AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("CHB_AWS_SECRET_ACCESS_KEY"),
)
sts = session.client("sts")
sts = SETTINGS.boto3_session.client("sts")
identity = sts.get_caller_identity()

except Exception as e:
Expand Down
20 changes: 6 additions & 14 deletions apps/chatbot-evaluate/src/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,8 @@

from src.modules.logger import get_logger


LOGGER = get_logger(__name__)
AWS_ACCESS_KEY_ID = os.getenv("CHB_AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("CHB_AWS_SECRET_ACCESS_KEY")
AWS_DEFAULT_REGION = os.getenv("CHB_AWS_DEFAULT_REGION")
AWS_ENDPOINT_URL = os.getenv("CHB_AWS_SSM_ENDPOINT_URL", None)
SSM_CLIENT = boto3.client(
"ssm",
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
region_name=AWS_DEFAULT_REGION,
endpoint_url=AWS_ENDPOINT_URL,
)
SSM_CLIENT = boto3.client("ssm")


def get_ssm_parameter(name: str | None, default: str | None = None) -> str | None:
Expand All @@ -27,17 +16,20 @@ def get_ssm_parameter(name: str | None, default: str | None = None) -> str | Non
:return: The value of the requested parameter.
"""

LOGGER.info(f"get_ssm_parameter {name}...")

if name is None:
name = "none-params-in-ssm"
try:
# Get the requested parameter
response = SSM_CLIENT.get_parameter(Name=name, WithDecryption=True)
value = response["Parameter"]["Value"]
except SSM_CLIENT.exceptions.ParameterNotFound:
LOGGER.info(f"Parameter {name} not found in SSM, returning default: {default}")
LOGGER.warning(
f"Parameter {name} not found in SSM, returning default: {default}"
)
return default

LOGGER.info(f"SSM Parameter {name} retrieved.")
return value


Expand Down
2 changes: 1 addition & 1 deletion apps/chatbot/src/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_ssm_parameter(name: str | None, default: str | None = None) -> str | Non
LOGGER.info(f"get_ssm_parameter {name}...")

if name is None:
name = "/none/param"
name = "none-params-in-ssm"
try:
# Get the requested parameter
response = SSM_CLIENT.get_parameter(Name=name, WithDecryption=True)
Expand Down
Loading