Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions protollm/connectors/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ It is also possible to pass additional parameters for the model. Available para
- `top_p` (not available for self-hosted models)
- `max_tokens`

There is a separate parameter, extra_body, which some services use to specify permitted providers. Since not all providers from a given service are available in every region, this parameter allows you to define the providers accessible in your region.
For example:
```codeblock
ALLOWED_PROVIDERS='["google-vertex", "azure"]'
```

Example of how to use the function:
```codeblock
from protollm.connectors.connector_creator import create_llm_connector
Expand Down
3 changes: 2 additions & 1 deletion protollm/connectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .connector_creator import create_llm_connector, CustomChatOpenAI
from .rest_server import ChatRESTServer
from .rest_server import ChatRESTServer
from .utils import get_allowed_providers
9 changes: 8 additions & 1 deletion protollm/connectors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,11 @@ def handle_system_prompt(msgs, sys_prompt):
else:
idx = next((index for index, obj in enumerate(msgs) if isinstance(obj, SystemMessage)), 0)
msgs[idx].content += "\n\n" + sys_prompt
return msgs
return msgs


def get_allowed_providers() -> list | None:
if allowed_providers := os.getenv("ALLOWED_PROVIDERS"):
return json.loads(allowed_providers)
else:
return None
7 changes: 5 additions & 2 deletions protollm/metrics/deepeval_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel
from openai._types import NOT_GIVEN

from ..connectors import create_llm_connector
from ..connectors import create_llm_connector, get_allowed_providers


class DeepEvalConnector(DeepEvalBaseLLM):
Expand All @@ -30,7 +30,10 @@ def __init__(self, sys_prompt: str = "", *args, **kwargs):
@staticmethod
def load_model() -> BaseChatModel:
"""Returns LangChain's ChatModel for requests"""
return create_llm_connector(os.getenv("DEEPEVAL_LLM_URL", "test_model"))
return create_llm_connector(
os.getenv("DEEPEVAL_LLM_URL", "test_model"),
extra_body={"provider": {"only": get_allowed_providers()}}
Comment thread
Nunkyl marked this conversation as resolved.
)

def generate(
self,
Expand Down