@@ -84,8 +84,8 @@ var defaultModels = []bedrock_support.BedrockModel{
8484 },
8585 {
8686 Name : "anthropic.claude-3-5-sonnet-20240620-v1:0" ,
87- Completion : & bedrock_support.CohereCompletion {},
88- Response : & bedrock_support.CohereResponse {},
87+ Completion : & bedrock_support.CohereMessagesCompletion {},
88+ Response : & bedrock_support.CohereMessagesResponse {},
8989 Config : bedrock_support.BedrockModelConfig {
9090 // sensible defaults
9191 MaxTokens : 100 ,
@@ -96,8 +96,8 @@ var defaultModels = []bedrock_support.BedrockModel{
9696 },
9797 {
9898 Name : "us.anthropic.claude-3-5-sonnet-20241022-v2:0" ,
99- Completion : & bedrock_support.CohereCompletion {},
100- Response : & bedrock_support.CohereResponse {},
99+ Completion : & bedrock_support.CohereMessagesCompletion {},
100+ Response : & bedrock_support.CohereMessagesResponse {},
101101 Config : bedrock_support.BedrockModelConfig {
102102 // sensible defaults
103103 MaxTokens : 100 ,
@@ -353,10 +353,10 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
353353
354354 // Get the model input
355355 modelInput := config .GetModel ()
356-
356+
357357 // Determine the appropriate region to use
358358 var region string
359-
359+
360360 // Check if the model input is actually an inference profile ARN
361361 if validateInferenceProfileArn (modelInput ) {
362362 // Extract the region from the inference profile ARN
@@ -370,11 +370,11 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
370370 // Use the provided region or default
371371 region = GetRegionOrDefault (config .GetProviderRegion ())
372372 }
373-
373+
374374 // Only create AWS clients if they haven't been injected (for testing)
375375 if a .client == nil || a .mgmtClient == nil {
376376 // Create a new AWS config with the determined region
377- cfg , err := awsconfig .LoadDefaultConfig (context .Background (),
377+ cfg , err := awsconfig .LoadDefaultConfig (context .Background (),
378378 awsconfig .WithRegion (region ),
379379 )
380380 if err != nil {
@@ -385,7 +385,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
385385 a .client = bedrockruntime .NewFromConfig (cfg )
386386 a .mgmtClient = bedrock .NewFromConfig (cfg )
387387 }
388-
388+
389389 // Handle model selection based on input type
390390 if validateInferenceProfileArn (modelInput ) {
391391 // Get the inference profile details
@@ -399,15 +399,15 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
399399 if err != nil {
400400 return fmt .Errorf ("failed to extract model ID from inference profile: %v" , err )
401401 }
402-
402+
403403 // Find the model configuration for the extracted model ID
404404 foundModel , err := a .getModelFromString (modelID )
405405 if err != nil {
406406 // Instead of using a fallback model, throw an error
407407 return fmt .Errorf ("failed to find model configuration for %s: %v" , modelID , err )
408408 }
409409 a .model = foundModel
410-
410+
411411 // Use the inference profile ARN as the model ID for API calls
412412 a .model .Config .ModelName = modelInput
413413 }
@@ -420,7 +420,7 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
420420 a .model = foundModel
421421 a .model .Config .ModelName = foundModel .Config .ModelName
422422 }
423-
423+
424424 // Set common configuration parameters
425425 a .temperature = config .GetTemperature ()
426426 a .topP = config .GetTopP ()
@@ -438,20 +438,20 @@ func (a *AmazonBedRockClient) getInferenceProfile(ctx context.Context, inference
438438 if len (parts ) != 2 {
439439 return nil , fmt .Errorf ("invalid inference profile ARN format: %s" , inferenceProfileARN )
440440 }
441-
441+
442442 profileID := parts [1 ]
443-
443+
444444 // Create the input for the GetInferenceProfile API call
445445 input := & bedrock.GetInferenceProfileInput {
446446 InferenceProfileIdentifier : aws .String (profileID ),
447447 }
448-
448+
449449 // Call the GetInferenceProfile API
450450 output , err := a .mgmtClient .GetInferenceProfile (ctx , input )
451451 if err != nil {
452452 return nil , fmt .Errorf ("failed to get inference profile: %w" , err )
453453 }
454-
454+
455455 return output , nil
456456}
457457
@@ -460,25 +460,25 @@ func (a *AmazonBedRockClient) extractModelFromInferenceProfile(profile *bedrock.
460460 if profile == nil || len (profile .Models ) == 0 {
461461 return "" , fmt .Errorf ("inference profile does not contain any models" )
462462 }
463-
463+
464464 // Check if the first model has a non-nil ModelArn
465465 if profile .Models [0 ].ModelArn == nil {
466466 return "" , fmt .Errorf ("model information is missing in inference profile" )
467467 }
468-
468+
469469 // Get the first model ARN from the profile
470470 modelARN := aws .ToString (profile .Models [0 ].ModelArn )
471471 if modelARN == "" {
472472 return "" , fmt .Errorf ("model ARN is empty in inference profile" )
473473 }
474-
474+
475475 // Extract the model ID from the ARN
476476 // ARN format: arn:aws:bedrock:region::foundation-model/model-id
477477 parts := strings .Split (modelARN , "/" )
478478 if len (parts ) != 2 {
479479 return "" , fmt .Errorf ("invalid model ARN format: %s" , modelARN )
480480 }
481-
481+
482482 modelID := parts [1 ]
483483 return modelID , nil
484484}
@@ -494,15 +494,15 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
494494 if err != nil {
495495 return "" , err
496496 }
497-
497+
498498 // Build the parameters for the model invocation
499499 params := & bedrockruntime.InvokeModelInput {
500500 Body : body ,
501501 ModelId : aws .String (a .model .Config .ModelName ),
502502 ContentType : aws .String ("application/json" ),
503503 Accept : aws .String ("application/json" ),
504504 }
505-
505+
506506 // Invoke the model
507507 resp , err := a .client .InvokeModel (ctx , params )
508508 if err != nil {
0 commit comments