Skip to content

Commit c6886e5

Browse files
committed
InferenceComponentName
1 parent 9bdfb35 commit c6886e5

9 files changed

Lines changed: 68 additions & 396 deletions

kendra_retriever_samples/README.md

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,21 @@ pip install --force-reinstall "boto3>=1.28.57"
4141
## Running samples
4242
Before you run the sample, you need to deploy a Large Language Model (or get an API key if you using Anthropic or OPENAI). The samples in this repository have been tested on models deployed using SageMaker Jumpstart. The model id for the LLMS are specified in the table below.
4343

44+
With the latest sagemaker release each endpoint can hold multiple models (called InferenceComponent). For jumpstart models, optionally specify the INFERENCE_COMPONENT_NAME as well as an environment varialbe
4445

45-
| Model name | env var name | Jumpstart model id | streamlit provider name |
46+
47+
| Model name | env var name | Endpoint Name | Inference component name (optional) |streamlit provider name |
4648
| -----------| -------- | ------------------ | ----------------- |
47-
| Flan XL | FLAN_XL_ENDPOINT | huggingface-text2text-flan-t5-xl | flanxl |
48-
| Flan XXL | FLAN_XXL_ENDPOINT | huggingface-text2text-flan-t5-xxl | flanxxl |
49-
| Falcon 40B instruct | FALCON_40B_ENDPOINT | huggingface-llm-falcon-40b-instruct-bf16 | falcon40b |
50-
| Llama2 70B instruct | LLAMA_2_ENDPOINT | meta-textgeneration-llama-2-70b-f | llama2 |
51-
| Bedrock Titan | None | | bedrock_titan|
52-
| Bedrock Claude | None | | bedrock_claude|
53-
| Bedrock Claude V2 | None | | bedrock_claudev2|
49+
| Falcon 40B instruct | FALCON_40B_ENDPOINT, INFERENCE_COMPONENT_NAME | <Endpoint_name> | <Inference_component_name>|falcon40b |
50+
| Llama2 70B instruct | LLAMA_2_ENDPOINT, INFERENCE_COMPONENT_NAME |<Endpoint_name> | <Inference_component_name> | llama2 |
51+
| Bedrock Titan | None | | | bedrock_titan|
52+
| Bedrock Claude | None | | | bedrock_claude|
53+
| Bedrock Claude V2 | None | | | bedrock_claudev2|
5454

5555

56-
after deploying the LLM, set up environment variables for kendra id, aws_region and the endpoint name (or the API key for an external provider)
56+
after deploying the LLM, set up environment variables for kendra id, aws_region endpoint name (or the API key for an external provider) and optionally the inference component name
5757

58-
For example, for running the `kendra_chat_flan_xl.py` sample, these environment variables must be set: AWS_REGION, KENDRA_INDEX_ID and FLAN_XL_ENDPOINT.
58+
For example, for running the `kendra_chat_llama_2.py` sample, these environment variables must be set: AWS_REGION, KENDRA_INDEX_ID, LLAMA_2_ENDPOINT and INFERENCE_COMPONENT_NAME. INFERENCE_COMPONENT_NAME is only required when deploying the jumpstart through the console or if you explicitely create an inference component using code. It is also possible to create an endpoint without and inference component in which case, do not set the INFERENCE_COMPONENT_FIELD.
5959

6060
You can use commands as below to set the environment variables. Only set the environment variable for the provider that you are using. For example, if you are using Flan-xl only set the FLAN_XXL_ENDPOINT. There is no need to set the other Endpoints and keys.
6161

@@ -64,10 +64,9 @@ export AWS_REGION=<YOUR-AWS-REGION>
6464
export AWS_PROFILE=<AWS Profile>
6565
export KENDRA_INDEX_ID=<YOUR-KENDRA-INDEX-ID>
6666

67-
export FLAN_XL_ENDPOINT=<YOUR-SAGEMAKER-ENDPOINT-FOR-FLAN-T-XL> # only if you are using FLAN_XL
68-
export FLAN_XXL_ENDPOINT=<YOUR-SAGEMAKER-ENDPOINT-FOR-FLAN-T-XXL> # only if you are using FLAN_XXL
6967
export FALCON_40B_ENDPOINT=<YOUR-SAGEMAKER-ENDPOINT-FOR-FALCON> # only if you are using falcon as the endpoint
7068
export LLAMA_2_ENDPOINT=<YOUR-SAGEMAKER-ENDPOINT-FOR-LLAMA2> #only if you are using llama2 as the endpoint
69+
export INFERENCE_COMPONENT_NAME=<YOUR-SAGEMAKER-INFERENCE-COMPONENT-NAME> # if you are deploying the FM via the JumpStart console.
7170

7271
export OPENAI_API_KEY=<YOUR-OPEN-AI-API-KEY> # only if you are using OPENAI as the endpoint
7372
export ANTHROPIC_API_KEY=<YOUR-ANTHROPIC-API-KEY> # only if you are using Anthropic as the endpoint

kendra_retriever_samples/app.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import sys
44

55
import kendra_chat_anthropic as anthropic
6-
import kendra_chat_flan_xl as flanxl
7-
import kendra_chat_flan_xxl as flanxxl
86
import kendra_chat_open_ai as openai
97
import kendra_chat_falcon_40b as falcon40b
108
import kendra_chat_llama_2 as llama2
@@ -20,8 +18,6 @@
2018
PROVIDER_MAP = {
2119
'openai': 'Open AI',
2220
'anthropic': 'Anthropic',
23-
'flanxl': 'Flan XL',
24-
'flanxxl': 'Flan XXL',
2521
'falcon40b': 'Falcon 40B',
2622
'llama2' : 'Llama 2'
2723
}
@@ -52,12 +48,6 @@ def read_properties_file(filename):
5248
if (sys.argv[1] == 'anthropic'):
5349
st.session_state['llm_app'] = anthropic
5450
st.session_state['llm_chain'] = anthropic.build_chain()
55-
elif (sys.argv[1] == 'flanxl'):
56-
st.session_state['llm_app'] = flanxl
57-
st.session_state['llm_chain'] = flanxl.build_chain()
58-
elif (sys.argv[1] == 'flanxxl'):
59-
st.session_state['llm_app'] = flanxxl
60-
st.session_state['llm_chain'] = flanxxl.build_chain()
6151
elif (sys.argv[1] == 'openai'):
6252
st.session_state['llm_app'] = openai
6353
st.session_state['llm_chain'] = openai.build_chain()

kendra_retriever_samples/kendra_chat_falcon_40b.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def build_chain():
2424
region = os.environ["AWS_REGION"]
2525
kendra_index_id = os.environ["KENDRA_INDEX_ID"]
2626
endpoint_name = os.environ["FALCON_40B_ENDPOINT"]
27+
if "INFERENCE_COMPONENT_NAME" in os.environ:
28+
inference_component_name = os.environ["INFERENCE_COMPONENT_NAME"]
29+
2730

2831
class ContentHandler(LLMContentHandler):
2932
content_type = "application/json"
@@ -40,7 +43,24 @@ def transform_output(self, output: bytes) -> str:
4043

4144
content_handler = ContentHandler()
4245

43-
llm=SagemakerEndpoint(
46+
if inference_component_name:
47+
llm=SagemakerEndpoint(
48+
endpoint_name=endpoint_name,
49+
region_name=region,
50+
model_kwargs={
51+
"temperature": 0.8,
52+
"max_new_tokens": 512,
53+
"do_sample": True,
54+
"top_p": 0.9,
55+
"repetition_penalty": 1.03,
56+
"stop": ["\nUser:","<|endoftext|>","</s>"],
57+
},
58+
endpoint_kwargs={"CustomAttributes":"accept_eula=true",
59+
"InferenceComponentName":inference_component_name},
60+
content_handler=content_handler
61+
)
62+
else :
63+
llm=SagemakerEndpoint(
4464
endpoint_name=endpoint_name,
4565
region_name=region,
4666
model_kwargs={
@@ -49,10 +69,13 @@ def transform_output(self, output: bytes) -> str:
4969
"do_sample": True,
5070
"top_p": 0.9,
5171
"repetition_penalty": 1.03,
52-
"stop": ["\nUser:","<|endoftext|>","</s>"]
72+
"stop": ["\nUser:","<|endoftext|>","</s>"],
5373
},
5474
content_handler=content_handler
5575
)
76+
77+
78+
5679

5780
retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region, top_k=2)
5881

kendra_retriever_samples/kendra_chat_flan_xl.py

Lines changed: 0 additions & 107 deletions
This file was deleted.

kendra_retriever_samples/kendra_chat_flan_xxl.py

Lines changed: 0 additions & 107 deletions
This file was deleted.

kendra_retriever_samples/kendra_chat_llama_2.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from langchain.retrievers import AmazonKendraRetriever
22
from langchain.chains import ConversationalRetrievalChain
33
from langchain.prompts import PromptTemplate
4-
from langchain import SagemakerEndpoint
4+
from langchain.llms import SagemakerEndpoint
55
from langchain.llms.sagemaker_endpoint import LLMContentHandler
66
import sys
77
import json
@@ -26,6 +26,8 @@ def build_chain():
2626
region = os.environ["AWS_REGION"]
2727
kendra_index_id = os.environ["KENDRA_INDEX_ID"]
2828
endpoint_name = os.environ["LLAMA_2_ENDPOINT"]
29+
if "INFERENCE_COMPONENT_NAME" in os.environ:
30+
inference_component_name = os.environ["INFERENCE_COMPONENT_NAME"]
2931

3032
class ContentHandler(LLMContentHandler):
3133
content_type = "application/json"
@@ -47,14 +49,27 @@ def transform_output(self, output: bytes) -> str:
4749

4850
content_handler = ContentHandler()
4951

50-
llm=SagemakerEndpoint(
52+
53+
54+
if 'inference_component_name' in locals():
55+
llm=SagemakerEndpoint(
56+
endpoint_name=endpoint_name,
57+
region_name=region,
58+
model_kwargs={"max_new_tokens": 1500, "top_p": 0.8,"temperature":0.6},
59+
endpoint_kwargs={"CustomAttributes":"accept_eula=true",
60+
"InferenceComponentName":inference_component_name},
61+
content_handler=content_handler,
62+
)
63+
else :
64+
llm=SagemakerEndpoint(
5165
endpoint_name=endpoint_name,
5266
region_name=region,
5367
model_kwargs={"max_new_tokens": 1500, "top_p": 0.8,"temperature":0.6},
5468
endpoint_kwargs={"CustomAttributes":"accept_eula=true"},
5569
content_handler=content_handler,
5670

57-
)
71+
)
72+
5873

5974
retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region)
6075

0 commit comments

Comments
 (0)