Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/huge-snails-shine.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"chatbot-evaluate": patch
---

align settings file
20 changes: 8 additions & 12 deletions apps/chatbot-evaluate/src/modules/settings.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,23 @@
import boto3
import os
import yaml
from pathlib import Path
from pydantic_settings import BaseSettings

from src.modules.utils import get_ssm_parameter


CWF = Path(__file__)
ROOT = CWF.parent.parent.parent.absolute().__str__()
PROMPTS = yaml.safe_load(open(os.path.join(ROOT, "config", "prompts.yaml"), "r"))
AWS_SESSION = boto3.Session()


class ChatbotSettings(BaseSettings):
"""Settings for the chatbot evaluation."""

# api keys
aws_access_key_id: str = os.getenv(
"AWS_ACCESS_KEY_ID", os.getenv("CHB_AWS_ACCESS_KEY_ID")
)
aws_default_region: str = os.getenv(
"AWS_REGION", os.getenv("CHB_AWS_DEFAULT_REGION")
)
aws_endpoint_url: str | None = os.getenv("CHB_AWS_SSM_ENDPOINT_URL")
aws_secret_access_key: str = os.getenv(
"AWS_SECRET_ACCESS_KEY", os.getenv("CHB_AWS_SECRET_ACCESS_KEY")
)
# api
aws_region: str = os.getenv("AWS_REGION", "us-east-1")
aws_endpoint_url: str = os.getenv("AWS_ENDPOINT_URL")
google_api_key: str = get_ssm_parameter(
os.getenv("CHB_AWS_SSM_GOOGLE_API_KEY"),
os.getenv("CHB_AWS_GOOGLE_API_KEY"),
Expand Down Expand Up @@ -52,5 +45,8 @@ class ChatbotSettings(BaseSettings):
# prompts
condense_prompt_str: str = PROMPTS["condense_prompt_str"]

# urls
website_url: str = os.getenv("CHB_WEBSITE_URL")


SETTINGS = ChatbotSettings()
14 changes: 3 additions & 11 deletions apps/chatbot-evaluate/src/modules/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
import os
import boto3
import json

from src.modules.logger import get_logger
from src.modules.models import get_llm, get_embed_model
from src.modules.judge import Judge
from src.modules.monitor import LANGFUSE_CLIENT
from src.modules.settings import SETTINGS
from src.modules.settings import AWS_SESSION, SETTINGS

LOGGER = get_logger(__name__)
JUDGE = Judge()
WEBSITE_URL = os.getenv("CHB_WEBSITE_URL")
WEBSITE_URL = SETTINGS.website_url


def test_aws_credentials() -> None:
identity = None
try:
session = boto3.Session(
aws_access_key_id=os.getenv("CHB_AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("CHB_AWS_SECRET_ACCESS_KEY"),
)
sts = session.client("sts")
sts = AWS_SESSION.client("sts")
identity = sts.get_caller_identity()

except Exception as e:
Expand Down
20 changes: 6 additions & 14 deletions apps/chatbot-evaluate/src/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,8 @@

from src.modules.logger import get_logger


LOGGER = get_logger(__name__)
AWS_ACCESS_KEY_ID = os.getenv("CHB_AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("CHB_AWS_SECRET_ACCESS_KEY")
AWS_DEFAULT_REGION = os.getenv("CHB_AWS_DEFAULT_REGION")
AWS_ENDPOINT_URL = os.getenv("CHB_AWS_SSM_ENDPOINT_URL", None)
SSM_CLIENT = boto3.client(
"ssm",
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
region_name=AWS_DEFAULT_REGION,
endpoint_url=AWS_ENDPOINT_URL,
)
SSM_CLIENT = boto3.client("ssm")


def get_ssm_parameter(name: str | None, default: str | None = None) -> str | None:
Expand All @@ -27,17 +16,20 @@ def get_ssm_parameter(name: str | None, default: str | None = None) -> str | Non
:return: The value of the requested parameter.
"""

LOGGER.info(f"get_ssm_parameter {name}...")

if name is None:
name = "none-params-in-ssm"
try:
# Get the requested parameter
response = SSM_CLIENT.get_parameter(Name=name, WithDecryption=True)
value = response["Parameter"]["Value"]
except SSM_CLIENT.exceptions.ParameterNotFound:
LOGGER.info(f"Parameter {name} not found in SSM, returning default: {default}")
LOGGER.warning(
f"Parameter {name} not found in SSM, returning default: {default}"
)
return default

LOGGER.info(f"SSM Parameter {name} retrieved.")
return value


Expand Down