@@ -458,26 +458,25 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
458458 // Get the inference profile details
459459 profile , err := a .getInferenceProfile (context .Background (), modelInput )
460460 if err != nil {
461- // Instead of using a fallback model, throw an error
462461 return fmt .Errorf ("failed to get inference profile: %v" , err )
463- } else {
464- // Extract the model ID from the inference profile
465- modelID , err := a .extractModelFromInferenceProfile (profile )
466- if err != nil {
467- return fmt .Errorf ("failed to extract model ID from inference profile: %v" , err )
468- }
469-
470- // Find the model configuration for the extracted model ID
471- foundModel , err := a .getModelFromString (modelID )
472- if err != nil {
473- // Instead of using a fallback model, throw an error
474- return fmt .Errorf ("failed to find model configuration for %s: %v" , modelID , err )
475- }
476- a .model = foundModel
477-
478- // Use the inference profile ARN as the model ID for API calls
479- a .model .Config .ModelName = modelInput
480462 }
463+ // Extract the model ID from the inference profile
464+ modelID , err := a .extractModelFromInferenceProfile (profile )
465+ if err != nil {
466+ return fmt .Errorf ("failed to extract model ID from inference profile: %v" , err )
467+ }
468+ // Find the model configuration for the extracted model ID
469+ foundModel , err := a .getModelFromString (modelID )
470+ if err != nil {
471+ // Instead of failing, use a generic config for completion/response
472+ // But still warn user
473+ return fmt .Errorf ("failed to find model configuration for %s: %v" , modelID , err )
474+ }
475+ // Use the found model config for completion/response, but set ModelName to the profile ARN
476+ a .model = foundModel
477+ a .model .Config .ModelName = modelInput
478+ // Mark that we're using an inference profile
479+ // (could add a field if needed)
481480 } else {
482481 // Regular model ID provided
483482 foundModel , err := a .getModelFromString (modelInput )
@@ -562,7 +561,8 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
562561 supportedModels [i ] = m .Name
563562 }
564563
565- if ! bedrock_support .IsModelSupported (a .model .Config .ModelName , supportedModels ) {
564+ // Allow valid inference profile ARNs as supported models
565+ if ! bedrock_support .IsModelSupported (a .model .Config .ModelName , supportedModels ) && ! validateInferenceProfileArn (a .model .Config .ModelName ) {
566566 return "" , fmt .Errorf ("model '%s' is not supported.\n Supported models:\n %s" , a .model .Config .ModelName , func () string {
567567 s := ""
568568 for _ , m := range supportedModels {
0 commit comments