Skip to content

Commit 6853f8d

Browse files
author
Siba Rajendran
committed
prompt tuner changes
1 parent 79d1ef2 commit 6853f8d

24 files changed

+1241
-97
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from deepeval.models import DeepEvalBaseLLM
2+
from langchain_core.messages import HumanMessage
3+
from pydantic import BaseModel
4+
from fmcore.experimental.llm.base_llm import BaseLLM
5+
from langchain_core.messages import BaseMessage
6+
7+
8+
class DeepEvalLLMAdapter(DeepEvalBaseLLM):
9+
"""
10+
Adapter class that bridges BaseLLM implementations with DeepEval's evaluation framework.
11+
12+
This adapter implements DeepEval's base LLM interface, allowing any BaseLLM instance
13+
to be used with DeepEval's evaluation tools. It handles the conversion between
14+
different message formats and provides both synchronous and asynchronous generation capabilities.
15+
16+
Note: This adapter currently supports only text generation metrics. Multimodal metrics support will be added soon.
17+
18+
Attributes:
19+
llm (BaseLLM): The underlying language model implementation to be adapted.
20+
"""
21+
22+
llm: BaseLLM
23+
24+
def __init__(self, llm: BaseLLM):
25+
"""
26+
Initialize the adapter with a BaseLLM instance.
27+
28+
Args:
29+
llm (BaseLLM): The language model implementation to be wrapped.
30+
"""
31+
self.llm = llm
32+
33+
def load_model(self):
34+
"""
35+
Provide access to the underlying LLM instance.
36+
37+
This method is required by the DeepEval interface to access the model implementation.
38+
39+
Returns:
40+
BaseLLM: The wrapped language model instance.
41+
"""
42+
return self.llm
43+
44+
def generate(self, prompt: str, schema: BaseModel, **kwargs) -> BaseModel:
45+
"""
46+
Synchronously generate a response for the given prompt.
47+
48+
This method converts the string prompt into a response by invoking the underlying
49+
LLM and extracting the content from the returned message.
50+
51+
Args:
52+
prompt (str): The input text to generate a response for.
53+
54+
Returns:
55+
str: The generated response text.
56+
57+
Note:
58+
This method handles the conversion from DeepEval's string-based interface
59+
to BaseLLM's message-based interface.
60+
"""
61+
messages = [HumanMessage(content=prompt)]
62+
response: BaseMessage = self.llm.invoke(messages=messages)
63+
return schema.model_validate_json(response.content)
64+
65+
async def a_generate(self, prompt: str, schema: BaseModel, **kwargs) -> BaseModel:
66+
"""
67+
Asynchronously generate a response for the given prompt.
68+
69+
This method provides an asynchronous interface for generating responses,
70+
useful for high-throughput or I/O-bound evaluation scenarios.
71+
72+
Args:
73+
prompt (str): The input text to generate a response for.
74+
75+
Returns:
76+
str: The generated response text.
77+
78+
Note:
79+
This method handles the conversion from DeepEval's string-based interface
80+
to BaseLLM's message-based interface in an asynchronous context.
81+
"""
82+
messages = [HumanMessage(content=prompt)]
83+
response: BaseMessage = await self.llm.ainvoke(messages=messages)
84+
return schema.model_validate_json(response.content)
85+
86+
def get_model_name(self):
87+
"""
88+
Retrieve the identifier of the underlying language model.
89+
90+
Returns:
91+
str: The model identifier as specified in the LLM's configuration.
92+
"""
93+
return self.llm.config.model_id
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import List, Optional
2+
import dspy
3+
from fmcore.experimental.llm.base_llm import BaseLLM
4+
from fmcore.experimental.types.llm_types import LLMConfig
5+
from langchain_core.messages import BaseMessage
6+
7+
8+
class DSPyLLMAdapter(dspy.LM):
9+
10+
def __init__(
11+
self,
12+
llm_config: LLMConfig,
13+
**kwargs,
14+
):
15+
super().__init__(model=llm_config.model_id, **kwargs)
16+
self.llm: BaseLLM = BaseLLM.of(llm_config=llm_config)
17+
self.history = []
18+
19+
def __call__(
20+
self, prompt: Optional[str] = None, messages: Optional[List[BaseMessage]] = None, **kwargs
21+
) -> List[str]:
22+
"""
23+
Executes inference with either a text prompt or predefined list of messages.
24+
25+
If a prompt is provided, it is converted into a list of HumanMessage objects.
26+
27+
Args:
28+
prompt (str, optional): The input prompt to generate messages for.
29+
messages (List[BaseMessage], optional): Predefined list of messages for inference.
30+
31+
Returns:
32+
List[str]: The generated responses from the model.
33+
34+
Raises:
35+
ValueError: If both prompt and messages are provided.
36+
37+
Example:
38+
predictions = claude_model(prompt="the sky is blue")
39+
print(predictions)
40+
"""
41+
if prompt and messages:
42+
raise ValueError("You can only provide either a 'prompt' or 'messages', not both.")
43+
44+
if prompt:
45+
messages = [{"role": "user", "content": prompt}]
46+
47+
response = self.llm.invoke(messages)
48+
result = [response.content]
49+
50+
# Updating LMs history using DSPy constructs, which currently support only dictionaries
51+
entry = {
52+
"messages": messages,
53+
"outputs": result,
54+
"kwargs": kwargs,
55+
}
56+
self.history.append(entry)
57+
self.update_global_history(entry)
58+
return result

src/fmcore/experimental/factory/bedrock_factory.py

+45-10
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,61 @@
88
from fmcore.experimental.proxy.rate_limit_proxy import RateLimitedProxy
99
from fmcore.experimental.types.provider_types import BedrockAccountConfig
1010

11-
# Clearer alias
1211
BedrockClientProxy: TypeAlias = RateLimitedProxy[ChatBedrockConverse]
1312

1413

1514
class BedrockFactory:
16-
"""Factory class for creating Bedrock clients with additional functionalities like rate limiting."""
15+
"""Factory class for creating Bedrock clients with additional functionalities like rate limiting.
16+
17+
This class provides static methods to create and configure Amazon Bedrock clients
18+
with built-in rate limiting capabilities. It handles the creation of both single
19+
and multiple clients based on provided configurations.
20+
21+
The factory supports multiple AWS accounts and automatically configures rate limiting
22+
based on account-specific parameters.
23+
"""
1724

1825
@staticmethod
19-
def create_bedrock_clients(*, llm_config: LLMConfig) -> List[BedrockClientProxy]:
20-
"""Creates multiple Bedrock clients based on the provided configuration."""
26+
def create_bedrock_clients(llm_config: LLMConfig) -> List[BedrockClientProxy]:
27+
"""Creates multiple Bedrock clients based on the provided configuration.
28+
29+
Args:
30+
llm_config (LLMConfig): Configuration object containing LLM settings and provider parameters,
31+
including account configurations and model parameters.
32+
33+
Returns:
34+
List[BedrockClientProxy]: A list of rate-limited Bedrock client proxies, one for each
35+
account specified in the configuration.
36+
37+
Example:
38+
llm_config = LLMConfig(...)
39+
clients = BedrockFactory.create_bedrock_clients(llm_config)
40+
"""
2141
return [
22-
BedrockFactory._create_bedrock_client_with_converse(account, llm_config)
42+
BedrockFactory._create_bedrock_client_with_converse(
43+
account_config=account, llm_config=llm_config
44+
)
2345
for account in llm_config.provider_params.accounts
2446
]
2547

2648
@staticmethod
2749
def _create_bedrock_client_with_converse(
28-
account_config: BedrockAccountConfig, llm_config: LLMConfig
50+
account_config: BedrockAccountConfig, llm_config: LLMConfig
2951
) -> BedrockClientProxy:
30-
"""Helper method to create a single Bedrock client with rate limiting."""
52+
"""Creates a single Bedrock client with rate limiting capabilities.
53+
54+
Args:
55+
account_config (BedrockAccountConfig): Configuration for a specific AWS account,
56+
including region, role ARN, and rate limits.
57+
llm_config (LLMConfig): Configuration containing model settings and parameters.
58+
59+
Returns:
60+
BedrockClientProxy: A rate-limited proxy wrapper around the Bedrock client.
61+
62+
Note:
63+
The method configures rate limiting based on the account's specified rate limit
64+
and wraps the ChatBedrockConverse client in a proxy for controlled access.
65+
"""
3166
boto_client = BotoFactory.get_client(
3267
service_name="bedrock-runtime",
3368
region=account_config.region,
@@ -37,11 +72,11 @@ def _create_bedrock_client_with_converse(
3772
converse_client = ChatBedrockConverse(
3873
model_id=llm_config.model_id,
3974
client=boto_client,
40-
**llm_config.model_params.dict(exclude_none=True),
75+
**llm_config.model_params.model_dump(exclude_none=True),
4176
)
4277

43-
# Currently, We are using the off the shelf rate limiters provided by aiolimiter
44-
# TODO Implement custom rate limiters
78+
# Create rate limiter based on account config
4579
rate_limiter = AsyncLimiter(max_rate=account_config.rate_limit)
4680

81+
# Create proxy without weight
4782
return BedrockClientProxy(client=converse_client, rate_limiter=rate_limiter)

src/fmcore/experimental/factory/boto_factory.py

+6-18
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,15 @@ def __get_refreshable_session(
3333

3434
def refresh() -> dict:
3535
"""Refreshes credentials by assuming the specified role."""
36-
sts_client = boto3.client(
37-
AWSConstants.AWS_SERVICE_STS, region_name=region_name
38-
)
39-
response = sts_client.assume_role(
40-
RoleArn=role_arn, RoleSessionName=session_name
41-
)
36+
sts_client = boto3.client(AWSConstants.AWS_SERVICE_STS, region_name=region_name)
37+
response = sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)
4238
credentials = response[AWSConstants.CREDENTIALS]
4339
return {
44-
AWSConstants.AWS_CREDENTIALS_ACCESS_KEY: credentials[
45-
AWSConstants.ACCESS_KEY_ID
46-
],
40+
AWSConstants.AWS_CREDENTIALS_ACCESS_KEY: credentials[AWSConstants.ACCESS_KEY_ID],
4741
AWSConstants.AWS_CREDENTIALS_SECRET_KEY: credentials[
4842
AWSConstants.SECRET_ACCESS_KEY
4943
],
50-
AWSConstants.AWS_CREDENTIALS_TOKEN: credentials[
51-
AWSConstants.SESSION_TOKEN
52-
],
44+
AWSConstants.AWS_CREDENTIALS_TOKEN: credentials[AWSConstants.SESSION_TOKEN],
5345
AWSConstants.AWS_CREDENTIALS_EXPIRY_TIME: credentials[
5446
AWSConstants.EXPIRATION
5547
].isoformat(),
@@ -82,17 +74,13 @@ def __create_session(cls, *, region: str, role_arn: str) -> boto3.Session:
8274
boto3.Session: A configured Boto3 session.
8375
"""
8476
return (
85-
cls.__get_refreshable_session(
86-
role_arn=role_arn, session_name="BedrockRealtime"
87-
)
77+
cls.__get_refreshable_session(role_arn=role_arn, session_name="BedrockRealtime")
8878
if role_arn
8979
else boto3.Session(region_name=region)
9080
)
9181

9282
@classmethod
93-
def get_client(
94-
cls, *, service_name: str, region: str, role_arn: str
95-
) -> boto3.client:
83+
def get_client(cls, *, service_name: str, region: str, role_arn: str) -> boto3.client:
9684
"""
9785
Retrieves a cached Boto3 client or creates a new one.
9886
+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from fmcore.experimental.llm.base_llm import BaseLLM
2-
from fmcore.experimental.llm.bedrock_llm import BedrockLLM
2+
from fmcore.experimental.llm.bedrock_llm import BedrockLLM

src/fmcore/experimental/llm/base_llm.py

+1-29
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def of(cls, llm_config: LLMConfig):
7373
- Acts as a factory by retrieving the necessary constructor parameters and instantiating the subclass.
7474
7575
Args:
76-
config (LLMConfig): The configuration containing provider details.
76+
llm_config (LLMConfig): The configuration containing provider details.
7777
7878
Returns:
7979
BaseLLM: An instance of the corresponding LLM subclass.
@@ -133,31 +133,3 @@ def astream(self, messages: List[BaseMessage]) -> Iterator[BaseMessageChunk]:
133133
Iterator[BaseMessageChunk]: A stream of LLM response chunks.
134134
"""
135135
pass
136-
137-
@abstractmethod
138-
def batch(self, message_batches: List[List[BaseMessage]]) -> List[BaseMessage]:
139-
"""
140-
Synchronously processes multiple batches of messages.
141-
142-
Args:
143-
message_batches (List[List[BaseMessage]]): A list of message batches.
144-
145-
Returns:
146-
List[BaseMessage]: The responses for each batch.
147-
"""
148-
pass
149-
150-
@abstractmethod
151-
async def abatch(
152-
self, message_batches: List[List[BaseMessage]]
153-
) -> List[BaseMessage]:
154-
"""
155-
Asynchronously processes multiple batches of messages.
156-
157-
Args:
158-
message_batches (List[List[BaseMessage]]): A list of message batches.
159-
160-
Returns:
161-
List[BaseMessage]: The responses for each batch.
162-
"""
163-
pass

0 commit comments

Comments
 (0)