Skip to content

Commit 7e33276

Browse files
authored
feat: reintroduced inference code (#1548)
Signed-off-by: AlexsJones <alexsimonjones@gmail.com>
1 parent 0239b2f commit 7e33276

2 files changed

Lines changed: 28 additions & 1 deletion

File tree

SUPPORTED_MODELS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ K8sGPT supports a variety of AI/LLM providers (backends). Some providers have a
4949
- us.amazon.nova-lite-v1:0
5050
- anthropic.claude-3-haiku-20240307-v1:0
5151

52+
> **Note:**
53+
> If you use an AWS Bedrock inference profile ARN (e.g., `arn:aws:bedrock:us-east-1:<account>:application-inference-profile/<id>`) as the model, you must still provide a valid modelId (e.g., `anthropic.claude-3-sonnet-20240229-v1:0`). K8sGPT will automatically set the required `X-Amzn-Bedrock-Inference-Profile-ARN` header for you when making requests to Bedrock.
54+
5255
### Amazon SageMaker
5356
- **Model:** User-configurable (any model deployed in your SageMaker endpoint)
5457

pkg/ai/amazonbedrock.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
awsconfig "github.com/aws/aws-sdk-go-v2/config"
1515
"github.com/aws/aws-sdk-go-v2/service/bedrock"
1616
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
17+
"github.com/aws/smithy-go/middleware"
18+
smithyhttp "github.com/aws/smithy-go/transport/http"
1719
)
1820

1921
const amazonbedrockAIClientName = "amazonbedrock"
@@ -583,8 +585,30 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
583585
Accept: aws.String("application/json"),
584586
}
585587

588+
// Detect if the model name is an inference profile ARN and set the header if so
589+
var optFns []func(*bedrockruntime.Options)
590+
if validateInferenceProfileArn(a.model.Config.ModelName) {
591+
inferenceProfileArn := a.model.Config.ModelName
592+
optFns = append(optFns, func(options *bedrockruntime.Options) {
593+
options.APIOptions = append(options.APIOptions, func(stack *middleware.Stack) error {
594+
return stack.Initialize.Add(middleware.InitializeMiddlewareFunc("InferenceProfileHeader", func(ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler) (out middleware.InitializeOutput, metadata middleware.Metadata, err error) {
595+
req, ok := in.Parameters.(*smithyhttp.Request)
596+
if ok {
597+
req.Header.Set("X-Amzn-Bedrock-Inference-Profile-ARN", inferenceProfileArn)
598+
}
599+
return next.HandleInitialize(ctx, in)
600+
}), middleware.Before)
601+
})
602+
})
603+
}
604+
586605
// Invoke the model
587-
resp, err := a.client.InvokeModel(ctx, params)
606+
var resp *bedrockruntime.InvokeModelOutput
607+
if len(optFns) > 0 {
608+
resp, err = a.client.InvokeModel(ctx, params, optFns...)
609+
} else {
610+
resp, err = a.client.InvokeModel(ctx, params)
611+
}
588612
if err != nil {
589613
if strings.Contains(err.Error(), "InvalidAccessKeyId") || strings.Contains(err.Error(), "SignatureDoesNotMatch") || strings.Contains(err.Error(), "NoCredentialProviders") {
590614
return "", fmt.Errorf("AWS credentials are invalid or missing. Please check your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables or AWS config. Details: %v", err)

0 commit comments

Comments
 (0)