Skip to content

Commit a6ba8ca

Browse files
authored
Merge pull request #533 from transformerlab/fix/commercial-model-wrapper
Fix model name format for commercial model wrappers
2 parents 6a9f7a3 + 778c644 commit a6ba8ca

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

transformerlab/plugin_sdk/transformerlab/sdk/v1/tlab_plugin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,9 @@ class CustomCommercialModel(DeepEvalBaseLLM):
594594
def __init__(self, model_type="claude", model_name="claude-3-7-sonnet-latest"):
595595
self.model_type = model_type
596596
self.generation_model_name = model_name
597+
# Dealing with the new {"provider": "<model_name>"} output format
598+
if isinstance(model_name, dict):
599+
self.generation_model_name = model_name.get("provider", model_name)
597600

598601
if model_type == "claude":
599602
self.chat_completions_url = "https://api.anthropic.com/v1/chat/completions"
@@ -655,6 +658,8 @@ def load_model(self):
655658

656659
def generate(self, prompt: str, schema=None):
657660
client = self.load_model()
661+
if isinstance(self.generation_model_name, dict):
662+
self.generation_model_name = self.generation_model_name.get("provider", self.generation_model_name)
658663
if schema:
659664
import instructor
660665

@@ -675,6 +680,7 @@ def generate(self, prompt: str, schema=None):
675680
model=self.generation_model_name,
676681
messages=[{"role": "user", "content": prompt}],
677682
)
683+
678684
return response.choices[0].message.content
679685

680686
async def a_generate(self, prompt: str, schema=None):

0 commit comments

Comments
 (0)