Skip to content

Commit 982d74e

Browse files
authored
feat(aws): support model profiles (#768)
1 parent 7f97389 commit 982d74e

File tree

9 files changed

+787
-6
lines changed

9 files changed

+787
-6
lines changed

libs/aws/langchain_aws/chat_models/bedrock.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
BaseChatModel,
2525
LangSmithParams,
2626
LanguageModelInput,
27+
ModelProfile,
28+
ModelProfileRegistry,
2729
)
2830
from langchain_core.language_models.chat_models import generate_from_stream
2931
from langchain_core.messages import (
@@ -49,9 +51,11 @@
4951
from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass
5052
from langchain_core.utils.utils import _build_model_kwargs
5153
from pydantic import BaseModel, ConfigDict, Field, model_validator
54+
from typing_extensions import Self
5255

5356
from langchain_aws.chat_models._compat import _convert_from_v1_to_anthropic
5457
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
58+
from langchain_aws.data._profiles import _PROFILES
5559
from langchain_aws.function_calling import (
5660
AnthropicTool,
5761
ToolsOutputParser,
@@ -77,6 +81,14 @@
7781
logger = logging.getLogger(__name__)
7882

7983

84+
_MODEL_PROFILES = cast("ModelProfileRegistry", _PROFILES)
85+
86+
87+
def _get_default_model_profile(model_name: str) -> ModelProfile:
88+
default = _MODEL_PROFILES.get(model_name) or {}
89+
return default.copy()
90+
91+
8092
def _convert_one_message_to_text_llama(message: BaseMessage) -> str:
8193
if isinstance(message, ChatMessage):
8294
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
@@ -844,6 +856,14 @@ def build_extra(cls, values: dict[str, Any]) -> Any:
844856
}
845857
return values
846858

859+
@model_validator(mode="after")
860+
def _set_model_profile(self) -> Self:
861+
"""Set model profile if not overridden."""
862+
if self.profile is None:
863+
model_id = re.sub(r"^[A-Za-z]{2}\.", "", self.model_id)
864+
self.profile = _get_default_model_profile(model_id)
865+
return self
866+
847867
@property
848868
def lc_attributes(self) -> Dict[str, Any]:
849869
attributes: Dict[str, Any] = {}

libs/aws/langchain_aws/chat_models/bedrock_converse.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424

2525
from langchain_core.callbacks import CallbackManagerForLLMRun
2626
from langchain_core.exceptions import OutputParserException
27-
from langchain_core.language_models import BaseChatModel, LanguageModelInput
27+
from langchain_core.language_models import (
28+
BaseChatModel,
29+
LanguageModelInput,
30+
ModelProfile,
31+
ModelProfileRegistry,
32+
)
2833
from langchain_core.language_models.base import LangSmithParams
2934
from langchain_core.messages import (
3035
AIMessage,
@@ -58,6 +63,7 @@
5863
from typing_extensions import Self
5964

6065
from langchain_aws.chat_models._compat import _convert_from_v1_to_converse
66+
from langchain_aws.data._profiles import _PROFILES
6167
from langchain_aws.function_calling import ToolsOutputParser
6268
from langchain_aws.utils import (
6369
count_tokens_api_supported_for_model,
@@ -66,6 +72,16 @@
6672
)
6773

6874
logger = logging.getLogger(__name__)
75+
76+
77+
_MODEL_PROFILES = cast("ModelProfileRegistry", _PROFILES)
78+
79+
80+
def _get_default_model_profile(model_name: str) -> ModelProfile:
81+
default = _MODEL_PROFILES.get(model_name) or {}
82+
return default.copy()
83+
84+
6985
_BM = TypeVar("_BM", bound=BaseModel)
7086

7187
EMPTY_CONTENT = "."
@@ -837,6 +853,14 @@ def validate_environment(self) -> Self:
837853

838854
return self
839855

856+
@model_validator(mode="after")
857+
def _set_model_profile(self) -> Self:
858+
"""Set model profile if not overridden."""
859+
if self.profile is None:
860+
model_id = re.sub(r"^[A-Za-z]{2}\.", "", self.model_id)
861+
self.profile = _get_default_model_profile(model_id)
862+
return self
863+
840864
def _get_base_model(self) -> str:
841865
"""Return base model id, stripping any regional prefix."""
842866

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Model profile data. All edits should be made in profile_augmentations.toml."""

0 commit comments

Comments
 (0)