Skip to content

Commit 291e42d

Browse files
authored
feat: fix to broken inference (#1575)
Signed-off-by: Alex <alexsimonjones@gmail.com>
1 parent 8bbffed commit 291e42d

1 file changed

Lines changed: 19 additions & 19 deletions

File tree

pkg/ai/amazonbedrock.go

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -458,26 +458,25 @@ func (a *AmazonBedRockClient) Configure(config IAIConfig) error {
458458
// Get the inference profile details
459459
profile, err := a.getInferenceProfile(context.Background(), modelInput)
460460
if err != nil {
461-
// Instead of using a fallback model, throw an error
462461
return fmt.Errorf("failed to get inference profile: %v", err)
463-
} else {
464-
// Extract the model ID from the inference profile
465-
modelID, err := a.extractModelFromInferenceProfile(profile)
466-
if err != nil {
467-
return fmt.Errorf("failed to extract model ID from inference profile: %v", err)
468-
}
469-
470-
// Find the model configuration for the extracted model ID
471-
foundModel, err := a.getModelFromString(modelID)
472-
if err != nil {
473-
// Instead of using a fallback model, throw an error
474-
return fmt.Errorf("failed to find model configuration for %s: %v", modelID, err)
475-
}
476-
a.model = foundModel
477-
478-
// Use the inference profile ARN as the model ID for API calls
479-
a.model.Config.ModelName = modelInput
480462
}
463+
// Extract the model ID from the inference profile
464+
modelID, err := a.extractModelFromInferenceProfile(profile)
465+
if err != nil {
466+
return fmt.Errorf("failed to extract model ID from inference profile: %v", err)
467+
}
468+
// Find the model configuration for the extracted model ID
469+
foundModel, err := a.getModelFromString(modelID)
470+
if err != nil {
471+
// Instead of failing, use a generic config for completion/response
472+
// But still warn user
473+
return fmt.Errorf("failed to find model configuration for %s: %v", modelID, err)
474+
}
475+
// Use the found model config for completion/response, but set ModelName to the profile ARN
476+
a.model = foundModel
477+
a.model.Config.ModelName = modelInput
478+
// Mark that we're using an inference profile
479+
// (could add a field if needed)
481480
} else {
482481
// Regular model ID provided
483482
foundModel, err := a.getModelFromString(modelInput)
@@ -562,7 +561,8 @@ func (a *AmazonBedRockClient) GetCompletion(ctx context.Context, prompt string)
562561
supportedModels[i] = m.Name
563562
}
564563

565-
if !bedrock_support.IsModelSupported(a.model.Config.ModelName, supportedModels) {
564+
// Allow valid inference profile ARNs as supported models
565+
if !bedrock_support.IsModelSupported(a.model.Config.ModelName, supportedModels) && !validateInferenceProfileArn(a.model.Config.ModelName) {
566566
return "", fmt.Errorf("model '%s' is not supported.\nSupported models:\n%s", a.model.Config.ModelName, func() string {
567567
s := ""
568568
for _, m := range supportedModels {

0 commit comments

Comments
 (0)