Skip to content

Commit d57f025

Browse files
author
Bharat Sinha
committed
use include block when loading config from bucket
1 parent e865a4c commit d57f025

File tree

4 files changed

+127
-6
lines changed

4 files changed

+127
-6
lines changed

litellm/proxy/common_utils/load_config_utils.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from typing import Callable
2+
13
import yaml
24

35
from litellm._logging import verbose_proxy_logger
46

57

6-
def get_file_contents_from_s3(bucket_name, object_key):
8+
async def get_file_contents_from_s3(bucket_name, object_key):
79
try:
810
# v0 rely on boto3 for authentication - allowing boto3 to handle IAM credentials etc
911
import tempfile
@@ -20,26 +22,33 @@ def get_file_contents_from_s3(bucket_name, object_key):
2022
aws_secret_access_key=credentials.secret_key,
2123
aws_session_token=credentials.token, # Optional, if using temporary credentials
2224
)
23-
verbose_proxy_logger.debug(
25+
verbose_proxy_logger.error(
2426
f"Retrieving {object_key} from S3 bucket: {bucket_name}"
2527
)
2628
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
27-
verbose_proxy_logger.debug(f"Response: {response}")
29+
verbose_proxy_logger.error(f"Response: {response}")
2830

2931
# Read the file contents
3032
file_contents = response["Body"].read().decode("utf-8")
31-
verbose_proxy_logger.debug("File contents retrieved from S3")
33+
verbose_proxy_logger.error("File contents retrieved from S3")
3234

3335
# Create a temporary file with YAML extension
3436
with tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") as temp_file:
3537
temp_file.write(file_contents.encode("utf-8"))
3638
temp_file_path = temp_file.name
37-
verbose_proxy_logger.debug(f"File stored temporarily at: {temp_file_path}")
39+
verbose_proxy_logger.error(f"File stored temporarily at: {temp_file_path}")
3840

3941
# Load the YAML file content
4042
with open(temp_file_path, "r") as yaml_file:
4143
config = yaml.safe_load(yaml_file)
4244

45+
# include file config
46+
config = await process_includes_from_bucket(
47+
config=config,
48+
get_file_method=get_file_contents_from_s3,
49+
bucket_name=bucket_name,
50+
)
51+
4352
return config
4453
except ImportError as e:
4554
# this is most likely if a user is not using the litellm docker container
@@ -64,13 +73,57 @@ async def get_config_file_contents_from_gcs(bucket_name, object_key):
6473
file_contents = file_contents.decode("utf-8")
6574
# convert to yaml
6675
config = yaml.safe_load(file_contents)
76+
# include file config
77+
config = await process_includes_from_bucket(
78+
config=config,
79+
get_file_method=get_config_file_contents_from_gcs,
80+
bucket_name=bucket_name,
81+
)
6782
return config
6883

6984
except Exception as e:
7085
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
7186
return None
7287

7388

89+
async def process_includes_from_bucket(
90+
config: dict, get_file_method: Callable, bucket_name: str
91+
) -> dict:
92+
"""
93+
Process includes by appending their contents to the main config
94+
95+
Handles nested config.yamls with `include` section
96+
97+
Example config: This will get the contents from files in `include` and append it
98+
```yaml
99+
include:
100+
- /path/to/key/model_config.yaml
101+
102+
litellm_settings:
103+
callbacks: ["prometheus"]
104+
```
105+
"""
106+
if "include" not in config:
107+
return config
108+
109+
if not isinstance(config["include"], list):
110+
raise ValueError("'include' must be a list of file paths")
111+
112+
# Load and append all included files
113+
for include_file in config["include"]:
114+
included_config = await get_file_method(bucket_name, include_file)
115+
# Simply update/extend the main config with included config
116+
for key, value in included_config.items():
117+
if isinstance(value, list) and key in config:
118+
config[key].extend(value)
119+
else:
120+
config[key] = value
121+
122+
# Remove the include directive
123+
del config["include"]
124+
return config
125+
126+
74127
# # Example usage
75128
# bucket_name = 'litellm-proxy'
76129
# object_key = 'litellm_proxy_config.yaml'

litellm/proxy/proxy_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1489,7 +1489,7 @@ async def get_config(self, config_file_path: Optional[str] = None) -> dict:
14891489
bucket_name=bucket_name, object_key=object_key
14901490
)
14911491
else:
1492-
config = get_file_contents_from_s3(
1492+
config = await get_file_contents_from_s3(
14931493
bucket_name=bucket_name, object_key=object_key
14941494
)
14951495

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ types-PyYAML = "*"
117117
opentelemetry-api = "1.25.0"
118118
opentelemetry-sdk = "1.25.0"
119119
opentelemetry-exporter-otlp = "1.25.0"
120+
moto = "5.0.24"
120121

121122
[tool.poetry.group.proxy-dev.dependencies]
122123
prisma = "0.11.0"

tests/proxy_unit_tests/test_proxy_config_unit_test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pytest
66

77
from dotenv import load_dotenv
8+
from moto import mock_aws
9+
import boto3
810

911
import litellm.proxy
1012
import litellm.proxy.proxy_server
@@ -264,3 +266,68 @@ def test_add_callbacks_invalid_input():
264266
# Cleanup
265267
litellm.success_callback = []
266268
litellm.failure_callback = []
269+
270+
271+
@pytest.mark.asyncio
272+
async def test_reading_configs_with_includes_from_s3():
273+
"""
274+
Test that the config is read correctly from the S3 and read includes
275+
"""
276+
BUCKET_NAME = "config-bucket"
277+
config_files = ("config_with_multiple_includes.yaml", "models_file_1.yaml", "models_file_2.yaml")
278+
os.environ["LITELLM_CONFIG_BUCKET_NAME"] = BUCKET_NAME
279+
os.environ["LITELLM_CONFIG_BUCKET_OBJECT_KEY"] = config_files[0]
280+
os.environ["LITELLM_CONFIG_BUCKET_TYPE"] = "S3"
281+
with mock_aws():
282+
# setup s3 bucket and put config files
283+
current_path = os.path.dirname(os.path.abspath(__file__))
284+
s3_client = boto3.client("s3")
285+
s3_client.create_bucket(Bucket=BUCKET_NAME)
286+
for file_name in config_files:
287+
config_path = os.path.join(
288+
current_path, "example_config_yaml", file_name,
289+
)
290+
s3_client.put_object(
291+
Bucket="config-bucket", Key=file_name, Body=open(config_path, 'rb')
292+
)
293+
proxy_config_instance = ProxyConfig()
294+
config = await proxy_config_instance.get_config()
295+
296+
assert config == {
297+
'model_list': [
298+
{'model_name': 'included-model-1', 'litellm_params': {'model': 'gpt-4'}},
299+
{'model_name': 'included-model-2', 'litellm_params': {'model': 'gpt-3.5-turbo'}}
300+
],
301+
'litellm_settings': {'callbacks': ['prometheus']},
302+
}
303+
304+
# unset the env variable to avoid side-effects on other tests
305+
del os.environ["LITELLM_CONFIG_BUCKET_NAME"]
306+
del os.environ["LITELLM_CONFIG_BUCKET_OBJECT_KEY"]
307+
del os.environ["LITELLM_CONFIG_BUCKET_TYPE"]
308+
309+
310+
@pytest.mark.asyncio
311+
async def test_reading_configs_from_s3_file_not_found():
312+
"""
313+
Test that the config is not present S3
314+
"""
315+
BUCKET_NAME = "config-bucket"
316+
config_files = ("config_with_multiple_includes.yaml", "models_file_1.yaml", "models_file_2.yaml")
317+
os.environ["LITELLM_CONFIG_BUCKET_NAME"] = BUCKET_NAME
318+
os.environ["LITELLM_CONFIG_BUCKET_OBJECT_KEY"] = config_files[0]
319+
os.environ["LITELLM_CONFIG_BUCKET_TYPE"] = "S3"
320+
with mock_aws():
321+
# setup s3 bucket but do not put file
322+
s3_client = boto3.client("s3")
323+
s3_client.create_bucket(Bucket=BUCKET_NAME)
324+
proxy_config_instance = ProxyConfig()
325+
with pytest.raises(Exception) as ex:
326+
await proxy_config_instance.get_config()
327+
328+
assert str(ex.value) == "Unable to load config from given source."
329+
330+
# unset the env variable to avoid side-effects on other tests
331+
del os.environ["LITELLM_CONFIG_BUCKET_NAME"]
332+
del os.environ["LITELLM_CONFIG_BUCKET_OBJECT_KEY"]
333+
del os.environ["LITELLM_CONFIG_BUCKET_TYPE"]

0 commit comments

Comments
 (0)