Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add inference component support to SageMaker Endpoint #229

Merged
merged 4 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions libs/aws/langchain_aws/llms/sagemaker_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ class SagemakerEndpoint(LLM):
region_name=region_name,
credentials_profile_name=credentials_profile_name
)

#Use with SageMaker Endpoint and Inference Component
se = SagemakerEndpoint(
endpoint_name=endpoint_name,
inference_component_name=inference_component_name,
region_name=region_name,
credentials_profile_name=credentials_profile_name
)
3coins marked this conversation as resolved.
Show resolved Hide resolved

#Use with boto3 client
client = boto3.client(
Expand All @@ -203,6 +211,10 @@ class SagemakerEndpoint(LLM):
"""The name of the endpoint from the deployed Sagemaker model.
Must be unique within an AWS Region."""

inference_component_name: Optional[str] = None
"""Optional name of the inference component to invoke
if specified with endpoint name."""

region_name: str = ""
"""The aws region where the Sagemaker model is deployed, eg. `us-west-2`."""

Expand Down Expand Up @@ -296,6 +308,7 @@ def _identifying_params(self) -> Mapping[str, Any]:
_model_kwargs = self.model_kwargs or {}
return {
**{"endpoint_name": self.endpoint_name},
**{"inference_component_name": self.inference_component_name},
**{"model_kwargs": _model_kwargs},
}

Expand All @@ -315,13 +328,19 @@ def _stream(
_model_kwargs = {**_model_kwargs, **kwargs}
_endpoint_kwargs = self.endpoint_kwargs or {}

invocation_params = {
"EndpointName": self.endpoint_name,
"Body": self.content_handler.transform_input(prompt, _model_kwargs),
"ContentType": self.content_handler.content_type,
**_endpoint_kwargs,
}

# If inference_component_name is specified, append it to invocation_params
if self.inference_component_name:
invocation_params["InferenceComponentName"] = self.inference_component_name

try:
resp = self.client.invoke_endpoint_with_response_stream(
EndpointName=self.endpoint_name,
Body=self.content_handler.transform_input(prompt, _model_kwargs),
ContentType=self.content_handler.content_type,
**_endpoint_kwargs,
)
resp = self.client.invoke_endpoint_with_response_stream(**invocation_params)
iterator = LineIterator(resp["Body"])

for line in iterator:
Expand Down Expand Up @@ -349,7 +368,8 @@ def _call(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Sagemaker inference endpoint.
"""Call out to SageMaker inference endpoint or inference component
of SageMaker inference endpoint.

Args:
prompt: The prompt to pass into the model.
Expand All @@ -371,20 +391,26 @@ def _call(
content_type = self.content_handler.content_type
accepts = self.content_handler.accepts

invocation_params = {
"EndpointName": self.endpoint_name,
"Body": body,
"ContentType": content_type,
"Accept": accepts,
**_endpoint_kwargs,
}

# If inference_compoent_name is specified, append it to invocation_params
if self.inference_component_name:
invocation_params["InferenceComponentName"] = self.inference_component_name

if self.streaming and run_manager:
completion: str = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion

try:
response = self.client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=body,
ContentType=content_type,
Accept=accepts,
**_endpoint_kwargs,
)
response = self.client.invoke_endpoint(**invocation_params)
except Exception as e:
logging.error(f"Error raised by inference endpoint: {e}")
if run_manager is not None:
Expand Down
85 changes: 85 additions & 0 deletions libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,39 @@ def test_sagemaker_endpoint_invoke() -> None:
)


def test_sagemaker_endpoint_inference_component_invoke() -> None:
client = Mock()
response = {
"ContentType": "application/json",
"Body": b'[{"generated_text": "SageMaker Endpoint"}]',
}
client.invoke_endpoint.return_value = response

llm = SagemakerEndpoint(
endpoint_name="my-endpoint",
inference_component_name="my-inference-component",
region_name="us-west-2",
content_handler=DefaultHandler(),
model_kwargs={
"parameters": {
"max_new_tokens": 50,
}
},
client=client,
)

service_response = llm.invoke("What is Sagemaker endpoints?")

assert service_response == "SageMaker Endpoint"
client.invoke_endpoint.assert_called_once_with(
EndpointName="my-endpoint",
Body=b"What is Sagemaker endpoints?",
ContentType="application/json",
Accept="application/json",
InferenceComponentName="my-inference-component",
)


def test_sagemaker_endpoint_stream() -> None:
class ContentHandler(LLMContentHandler):
accepts = "application/json"
Expand Down Expand Up @@ -97,3 +130,55 @@ def transform_output(self, output: bytes) -> str:
Body=expected_body,
ContentType="application/json",
)


def test_sagemaker_endpoint_inference_component_stream() -> None:
class ContentHandler(LLMContentHandler):
accepts = "application/json"
content_type = "application/json"

def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
body = json.dumps({"inputs": prompt, **model_kwargs})
return body.encode()

def transform_output(self, output: bytes) -> str:
body = json.loads(output)
return body.get("outputs")[0]

body = (
{"PayloadPart": {"Bytes": b'{"outputs": ["S"]}\n'}},
{"PayloadPart": {"Bytes": b'{"outputs": ["age"]}\n'}},
{"PayloadPart": {"Bytes": b'{"outputs": ["Maker"]}\n'}},
)

response = {"ContentType": "application/json", "Body": body}

client = Mock()
client.invoke_endpoint_with_response_stream.return_value = response

llm = SagemakerEndpoint(
endpoint_name="my-endpoint",
inference_component_name="my_inference_component",
region_name="us-west-2",
content_handler=ContentHandler(),
client=client,
model_kwargs={"parameters": {"max_new_tokens": 50}},
)

expected_body = json.dumps(
{"inputs": "What is Sagemaker endpoints?", "parameters": {"max_new_tokens": 50}}
).encode()

chunks = ["S", "age", "Maker"]
service_chunks = []

for chunk in llm.stream("What is Sagemaker endpoints?"):
service_chunks.append(chunk)

assert service_chunks == chunks
client.invoke_endpoint_with_response_stream.assert_called_once_with(
EndpointName="my-endpoint",
Body=expected_body,
ContentType="application/json",
InferenceComponentName="my_inference_component",
)
Loading