Skip to content

Commit

Permalink
Enhance robustness (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
dawnkisser authored Aug 21, 2024
1 parent add7e7c commit 261c56a
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions src/turtle_agent/scripts/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ def get_llm():
"""A helper function to get the LLM instance."""
dotenv.load_dotenv(dotenv.find_dotenv())

APIM_SUBSCRIPTION_KEY = os.getenv("APIM_SUBSCRIPTION_KEY")
APIM_SUBSCRIPTION_KEY = get_env_variable("APIM_SUBSCRIPTION_KEY")
default_headers = {}
if APIM_SUBSCRIPTION_KEY != None:
if APIM_SUBSCRIPTION_KEY is not None:
# only set this if the APIM API requires a subscription...
default_headers["Ocp-Apim-Subscription-Key"] = APIM_SUBSCRIPTION_KEY

# Set up authority and credentials for Azure authentication
credential = ClientSecretCredential(
tenant_id=os.getenv("AZURE_TENANT_ID"),
client_id=os.getenv("AZURE_CLIENT_ID"),
client_secret=os.getenv("AZURE_CLIENT_SECRET"),
tenant_id=get_env_variable("AZURE_TENANT_ID"),
client_id=get_env_variable("AZURE_CLIENT_ID"),
client_secret=get_env_variable("AZURE_CLIENT_SECRET"),
authority="https://login.microsoftonline.com",
)

Expand All @@ -42,12 +42,36 @@ def get_llm():
)

llm = AzureChatOpenAI(
azure_deployment=os.getenv("DEPLOYMENT_ID"),
azure_deployment=get_env_variable("DEPLOYMENT_ID"),
azure_ad_token_provider=token_provider,
openai_api_type="azure_ad",
api_version=os.getenv("API_VERSION"),
azure_endpoint=os.getenv("API_ENDPOINT"),
api_version=get_env_variable("API_VERSION"),
azure_endpoint=get_env_variable("API_ENDPOINT"),
default_headers=default_headers,
)

return llm


def get_env_variable(var_name):
"""
Retrieves the value of the specified environment variable.
Args:
var_name (str): The name of the environment variable to retrieve.
Returns:
str: The value of the environment variable.
Raises:
ValueError: If the environment variable is not set.
This function provides a consistent and safe way to retrieve environment variables.
By using this function, we ensure that all required environment variables are present
before proceeding with any operations. If a variable is not set, the function will
raise a ValueError, making it easier to debug configuration issues.
"""
value = os.getenv(var_name)
if value is None:
raise ValueError(f"Environment variable {var_name} is not set.")
return value

0 comments on commit 261c56a

Please sign in to comment.