Skip to content

Commit

Permalink
add inference component support to SageMaker Endpoint (#229)
Browse files Browse the repository at this point in the history
* add inference component support to SageMaker Endpoint
* add corresponding tests in SageMaker Endpoint integration tests.

---------

Co-authored-by: Pravali Uppugunduri <[email protected]>
Co-authored-by: Piyush Jain <[email protected]>
  • Loading branch information
3 people authored Oct 4, 2024
1 parent ba30daa commit dabdd99
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 14 deletions.
54 changes: 40 additions & 14 deletions libs/aws/langchain_aws/llms/sagemaker_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ class SagemakerEndpoint(LLM):
region_name=region_name,
credentials_profile_name=credentials_profile_name
)
# Usage with Inference Component
se = SagemakerEndpoint(
endpoint_name=endpoint_name,
inference_component_name=inference_component_name,
region_name=region_name,
credentials_profile_name=credentials_profile_name
)
#Use with boto3 client
client = boto3.client(
Expand All @@ -203,6 +211,10 @@ class SagemakerEndpoint(LLM):
"""The name of the endpoint from the deployed Sagemaker model.
Must be unique within an AWS Region."""

inference_component_name: Optional[str] = None
"""Optional name of the inference component to invoke
if specified with endpoint name."""

region_name: str = ""
"""The aws region where the Sagemaker model is deployed, eg. `us-west-2`."""

Expand Down Expand Up @@ -296,6 +308,7 @@ def _identifying_params(self) -> Mapping[str, Any]:
_model_kwargs = self.model_kwargs or {}
return {
**{"endpoint_name": self.endpoint_name},
**{"inference_component_name": self.inference_component_name},
**{"model_kwargs": _model_kwargs},
}

Expand All @@ -315,13 +328,19 @@ def _stream(
_model_kwargs = {**_model_kwargs, **kwargs}
_endpoint_kwargs = self.endpoint_kwargs or {}

invocation_params = {
"EndpointName": self.endpoint_name,
"Body": self.content_handler.transform_input(prompt, _model_kwargs),
"ContentType": self.content_handler.content_type,
**_endpoint_kwargs,
}

# If inference_component_name is specified, append it to invocation_params
if self.inference_component_name:
invocation_params["InferenceComponentName"] = self.inference_component_name

try:
resp = self.client.invoke_endpoint_with_response_stream(
EndpointName=self.endpoint_name,
Body=self.content_handler.transform_input(prompt, _model_kwargs),
ContentType=self.content_handler.content_type,
**_endpoint_kwargs,
)
resp = self.client.invoke_endpoint_with_response_stream(**invocation_params)
iterator = LineIterator(resp["Body"])

for line in iterator:
Expand Down Expand Up @@ -349,7 +368,8 @@ def _call(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Sagemaker inference endpoint.
"""Call out to SageMaker inference endpoint or inference component
of SageMaker inference endpoint.
Args:
prompt: The prompt to pass into the model.
Expand All @@ -371,20 +391,26 @@ def _call(
content_type = self.content_handler.content_type
accepts = self.content_handler.accepts

invocation_params = {
"EndpointName": self.endpoint_name,
"Body": body,
"ContentType": content_type,
"Accept": accepts,
**_endpoint_kwargs,
}

# If inference_compoent_name is specified, append it to invocation_params
if self.inference_component_name:
invocation_params["InferenceComponentName"] = self.inference_component_name

if self.streaming and run_manager:
completion: str = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion

try:
response = self.client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=body,
ContentType=content_type,
Accept=accepts,
**_endpoint_kwargs,
)
response = self.client.invoke_endpoint(**invocation_params)
except Exception as e:
logging.error(f"Error raised by inference endpoint: {e}")
if run_manager is not None:
Expand Down
85 changes: 85 additions & 0 deletions libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,39 @@ def test_sagemaker_endpoint_invoke() -> None:
)


def test_sagemaker_endpoint_inference_component_invoke() -> None:
client = Mock()
response = {
"ContentType": "application/json",
"Body": b'[{"generated_text": "SageMaker Endpoint"}]',
}
client.invoke_endpoint.return_value = response

llm = SagemakerEndpoint(
endpoint_name="my-endpoint",
inference_component_name="my-inference-component",
region_name="us-west-2",
content_handler=DefaultHandler(),
model_kwargs={
"parameters": {
"max_new_tokens": 50,
}
},
client=client,
)

service_response = llm.invoke("What is Sagemaker endpoints?")

assert service_response == "SageMaker Endpoint"
client.invoke_endpoint.assert_called_once_with(
EndpointName="my-endpoint",
Body=b"What is Sagemaker endpoints?",
ContentType="application/json",
Accept="application/json",
InferenceComponentName="my-inference-component",
)


def test_sagemaker_endpoint_stream() -> None:
class ContentHandler(LLMContentHandler):
accepts = "application/json"
Expand Down Expand Up @@ -97,3 +130,55 @@ def transform_output(self, output: bytes) -> str:
Body=expected_body,
ContentType="application/json",
)


def test_sagemaker_endpoint_inference_component_stream() -> None:
class ContentHandler(LLMContentHandler):
accepts = "application/json"
content_type = "application/json"

def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
body = json.dumps({"inputs": prompt, **model_kwargs})
return body.encode()

def transform_output(self, output: bytes) -> str:
body = json.loads(output)
return body.get("outputs")[0]

body = (
{"PayloadPart": {"Bytes": b'{"outputs": ["S"]}\n'}},
{"PayloadPart": {"Bytes": b'{"outputs": ["age"]}\n'}},
{"PayloadPart": {"Bytes": b'{"outputs": ["Maker"]}\n'}},
)

response = {"ContentType": "application/json", "Body": body}

client = Mock()
client.invoke_endpoint_with_response_stream.return_value = response

llm = SagemakerEndpoint(
endpoint_name="my-endpoint",
inference_component_name="my_inference_component",
region_name="us-west-2",
content_handler=ContentHandler(),
client=client,
model_kwargs={"parameters": {"max_new_tokens": 50}},
)

expected_body = json.dumps(
{"inputs": "What is Sagemaker endpoints?", "parameters": {"max_new_tokens": 50}}
).encode()

chunks = ["S", "age", "Maker"]
service_chunks = []

for chunk in llm.stream("What is Sagemaker endpoints?"):
service_chunks.append(chunk)

assert service_chunks == chunks
client.invoke_endpoint_with_response_stream.assert_called_once_with(
EndpointName="my-endpoint",
Body=expected_body,
ContentType="application/json",
InferenceComponentName="my_inference_component",
)

0 comments on commit dabdd99

Please sign in to comment.