@@ -594,6 +594,9 @@ class CustomCommercialModel(DeepEvalBaseLLM):
594594 def __init__ (self , model_type = "claude" , model_name = "claude-3-7-sonnet-latest" ):
595595 self .model_type = model_type
596596 self .generation_model_name = model_name
597+ # Dealing with the new {"provider": "<model_name>"} output format
598+ if isinstance (model_name , dict ):
599+ self .generation_model_name = model_name .get ("provider" , model_name )
597600
598601 if model_type == "claude" :
599602 self .chat_completions_url = "https://api.anthropic.com/v1/chat/completions"
@@ -655,6 +658,8 @@ def load_model(self):
655658
656659 def generate (self , prompt : str , schema = None ):
657660 client = self .load_model ()
661+ if isinstance (self .generation_model_name , dict ):
662+ self .generation_model_name = self .generation_model_name .get ("provider" , self .generation_model_name )
658663 if schema :
659664 import instructor
660665
@@ -675,6 +680,7 @@ def generate(self, prompt: str, schema=None):
675680 model = self .generation_model_name ,
676681 messages = [{"role" : "user" , "content" : prompt }],
677682 )
683+
678684 return response .choices [0 ].message .content
679685
680686 async def a_generate (self , prompt : str , schema = None ):
0 commit comments