From 261c56a2d02f7712de3da98c07ec619d2b704ee9 Mon Sep 17 00:00:00 2001 From: Kejun Liu <119113065+dawnkisser@users.noreply.github.com> Date: Thu, 22 Aug 2024 03:31:53 +0800 Subject: [PATCH] Enhance robustness (#11) --- src/turtle_agent/scripts/llm.py | 40 ++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/src/turtle_agent/scripts/llm.py b/src/turtle_agent/scripts/llm.py index f19e5e7..2c41536 100644 --- a/src/turtle_agent/scripts/llm.py +++ b/src/turtle_agent/scripts/llm.py @@ -23,17 +23,17 @@ def get_llm(): """A helper function to get the LLM instance.""" dotenv.load_dotenv(dotenv.find_dotenv()) - APIM_SUBSCRIPTION_KEY = os.getenv("APIM_SUBSCRIPTION_KEY") + APIM_SUBSCRIPTION_KEY = get_env_variable("APIM_SUBSCRIPTION_KEY") default_headers = {} - if APIM_SUBSCRIPTION_KEY != None: + if APIM_SUBSCRIPTION_KEY is not None: # only set this if the APIM API requires a subscription... default_headers["Ocp-Apim-Subscription-Key"] = APIM_SUBSCRIPTION_KEY # Set up authority and credentials for Azure authentication credential = ClientSecretCredential( - tenant_id=os.getenv("AZURE_TENANT_ID"), - client_id=os.getenv("AZURE_CLIENT_ID"), - client_secret=os.getenv("AZURE_CLIENT_SECRET"), + tenant_id=get_env_variable("AZURE_TENANT_ID"), + client_id=get_env_variable("AZURE_CLIENT_ID"), + client_secret=get_env_variable("AZURE_CLIENT_SECRET"), authority="https://login.microsoftonline.com", ) @@ -42,12 +42,36 @@ def get_llm(): ) llm = AzureChatOpenAI( - azure_deployment=os.getenv("DEPLOYMENT_ID"), + azure_deployment=get_env_variable("DEPLOYMENT_ID"), azure_ad_token_provider=token_provider, openai_api_type="azure_ad", - api_version=os.getenv("API_VERSION"), - azure_endpoint=os.getenv("API_ENDPOINT"), + api_version=get_env_variable("API_VERSION"), + azure_endpoint=get_env_variable("API_ENDPOINT"), default_headers=default_headers, ) return llm + + +def get_env_variable(var_name): + """ + Retrieves the value of the specified environment variable. + + Args: + var_name (str): The name of the environment variable to retrieve. + + Returns: + str: The value of the environment variable. + + Raises: + ValueError: If the environment variable is not set. + + This function provides a consistent and safe way to retrieve environment variables. + By using this function, we ensure that all required environment variables are present + before proceeding with any operations. If a variable is not set, the function will + raise a ValueError, making it easier to debug configuration issues. + """ + value = os.getenv(var_name) + if value is None: + raise ValueError(f"Environment variable {var_name} is not set.") + return value