Skip to content

Commit b2241c0

Browse files
feat: adding fixes for Messages API issue 1391 (#1504)
Signed-off-by: rkarthikr <38294804+rkarthikr@users.noreply.github.com> Co-authored-by: Alex Jones <alexsimonjones@gmail.com>
1 parent 0b7ddf5 commit b2241c0

1 file changed

Lines changed: 22 additions & 22 deletions

File tree

pkg/ai/amazonbedrock.go

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ var defaultModels = []bedrock_support.BedrockModel{
8484
},
8585
{
8686
Name: "anthropic.claude-3-5-sonnet-20240620-v1:0",
87-
Completion: &bedrock_support.CohereCompletion{},
88-
Response: &bedrock_support.CohereResponse{},
87+
Completion: &bedrock_support.CohereMessagesCompletion{},
88+
Response: &bedrock_support.CohereMessagesResponse{},
8989
Config: bedrock_support.BedrockModelConfig{
9090
// sensible defaults
9191
MaxTokens: 100,
@@ -96,8 +96,8 @@ var defaultModels = []bedrock_support.BedrockModel{
9696
},
9797
{
9898
Name: "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
99-
Completion: &bedrock_support.CohereCompletion{},
100-
Response: &bedrock_support.CohereResponse{},
99+
Completion: &bedrock_support.CohereMessagesCompletion{},
100+
Response: &bedrock_support.CohereMessagesResponse{},
101101
Config: bedrock_support.BedrockModelConfig{
102102
// sensible defaults
103103
MaxTokens: 100,
@@ -353,10 +353,10 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
353353

354354
// Get the model input
355355
modelInput := config.GetModel()
356-
356+
357357
// Determine the appropriate region to use
358358
var region string
359-
359+
360360
// Check if the model input is actually an inference profile ARN
361361
if validateInferenceProfileArn(modelInput) {
362362
// Extract the region from the inference profile ARN
@@ -370,11 +370,11 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
370370
// Use the provided region or default
371371
region = GetRegionOrDefault(config.GetProviderRegion())
372372
}
373-
373+
374374
// Only create AWS clients if they haven't been injected (for testing)
375375
if a.client == nil || a.mgmtClient == nil {
376376
// Create a new AWS config with the determined region
377-
cfg, err := awsconfig.LoadDefaultConfig(context.Background(),
377+
cfg, err := awsconfig.LoadDefaultConfig(context.Background(),
378378
awsconfig.WithRegion(region),
379379
)
380380
if err != nil {
@@ -385,7 +385,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
385385
a.client = bedrockruntime.NewFromConfig(cfg)
386386
a.mgmtClient = bedrock.NewFromConfig(cfg)
387387
}
388-
388+
389389
// Handle model selection based on input type
390390
if validateInferenceProfileArn(modelInput) {
391391
// Get the inference profile details
@@ -399,15 +399,15 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
399399
if err != nil {
400400
return fmt.Errorf("failed to extract model ID from inference profile: %v", err)
401401
}
402-
402+
403403
// Find the model configuration for the extracted model ID
404404
foundModel, err := a.getModelFromString(modelID)
405405
if err != nil {
406406
// Instead of using a fallback model, throw an error
407407
return fmt.Errorf("failed to find model configuration for %s: %v", modelID, err)
408408
}
409409
a.model = foundModel
410-
410+
411411
// Use the inference profile ARN as the model ID for API calls
412412
a.model.Config.ModelName = modelInput
413413
}
@@ -420,7 +420,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
420420
a.model = foundModel
421421
a.model.Config.ModelName = foundModel.Config.ModelName
422422
}
423-
423+
424424
// Set common configuration parameters
425425
a.temperature = config.GetTemperature()
426426
a.topP = config.GetTopP()
@@ -438,20 +438,20 @@ func (a *AmazonBedRockClient) getInferenceProfile(ctx context.Context, inference
438438
if len(parts) != 2 {
439439
return nil, fmt.Errorf("invalid inference profile ARN format: %s", inferenceProfileARN)
440440
}
441-
441+
442442
profileID := parts[1]
443-
443+
444444
// Create the input for the GetInferenceProfile API call
445445
input := &bedrock.GetInferenceProfileInput{
446446
InferenceProfileIdentifier: aws.String(profileID),
447447
}
448-
448+
449449
// Call the GetInferenceProfile API
450450
output, err := a.mgmtClient.GetInferenceProfile(ctx, input)
451451
if err != nil {
452452
return nil, fmt.Errorf("failed to get inference profile: %w", err)
453453
}
454-
454+
455455
return output, nil
456456
}
457457

@@ -460,25 +460,25 @@ func (a *AmazonBedRockClient) extractModelFromInferenceProfile(profile *bedrock.
460460
if profile == nil || len(profile.Models) == 0 {
461461
return "", fmt.Errorf("inference profile does not contain any models")
462462
}
463-
463+
464464
// Check if the first model has a non-nil ModelArn
465465
if profile.Models[0].ModelArn == nil {
466466
return "", fmt.Errorf("model information is missing in inference profile")
467467
}
468-
468+
469469
// Get the first model ARN from the profile
470470
modelARN := aws.ToString(profile.Models[0].ModelArn)
471471
if modelARN == "" {
472472
return "", fmt.Errorf("model ARN is empty in inference profile")
473473
}
474-
474+
475475
// Extract the model ID from the ARN
476476
// ARN format: arn:aws:bedrock:region::foundation-model/model-id
477477
parts := strings.Split(modelARN, "/")
478478
if len(parts) != 2 {
479479
return "", fmt.Errorf("invalid model ARN format: %s", modelARN)
480480
}
481-
481+
482482
modelID := parts[1]
483483
return modelID, nil
484484
}
@@ -494,15 +494,15 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
494494
if err != nil {
495495
return "", err
496496
}
497-
497+
498498
// Build the parameters for the model invocation
499499
params := &bedrockruntime.InvokeModelInput{
500500
Body: body,
501501
ModelId: aws.String(a.model.Config.ModelName),
502502
ContentType: aws.String("application/json"),
503503
Accept: aws.String("application/json"),
504504
}
505-
505+
506506
// Invoke the model
507507
resp, err := a.client.InvokeModel(ctx, params)
508508
if err != nil {

0 commit comments

Comments
 (0)