@@ -55,39 +55,71 @@ def __init__(self, models_cfg: List[LLMModelConfig]):
5555 )
5656 logger ._ensemble_logged = True
5757
58- async def generate (self , prompt : str , ** kwargs ) -> str :
59- """Generate text using a randomly selected model based on weights"""
60- model = self ._sample_model ()
61- return await model .generate (prompt , ** kwargs )
58+ async def generate (self , prompt : str , ** kwargs ) -> Tuple [str , int ]:
59+ """Generate text using a randomly selected model based on weights
60+
61+ Returns:
62+ Tuple of (generated_text, model_id) where model_id is the index
63+ of the selected model in the ensemble
64+ """
65+ model , model_id = self ._sample_model ()
66+ response = await model .generate (prompt , ** kwargs )
67+ return response , model_id
6268
6369 async def generate_with_context (
6470 self , system_message : str , messages : List [Dict [str , str ]], ** kwargs
65- ) -> str :
66- """Generate text using a system message and conversational context"""
67- model = self ._sample_model ()
68- return await model .generate_with_context (system_message , messages , ** kwargs )
69-
70- def _sample_model (self ) -> LLMInterface :
71- """Sample a model from the ensemble based on weights"""
71+ ) -> Tuple [str , int ]:
72+ """Generate text using a system message and conversational context
73+
74+ Returns:
75+ Tuple of (generated_text, model_id) where model_id is the index
76+ of the selected model in the ensemble
77+ """
78+ model , model_id = self ._sample_model ()
79+ response = await model .generate_with_context (system_message , messages , ** kwargs )
80+ return response , model_id
81+
82+ def _sample_model (self ) -> Tuple [LLMInterface , int ]:
83+ """Sample a model from the ensemble based on weights
84+
85+ Returns:
86+ Tuple of (model, model_id) where model_id is the index of the
87+ selected model in the ensemble
88+ """
7289 index = self .random_state .choices (range (len (self .models )), weights = self .weights , k = 1 )[0 ]
7390 sampled_model = self .models [index ]
7491 logger .info (f"Sampled model: { vars (sampled_model )['model' ]} " )
75- return sampled_model
92+ return sampled_model , index
93+
94+ async def generate_multiple (self , prompt : str , n : int , ** kwargs ) -> List [Tuple [str , int ]]:
95+ """Generate multiple texts in parallel
7696
77- async def generate_multiple (self , prompt : str , n : int , ** kwargs ) -> List [str ]:
78- """Generate multiple texts in parallel"""
97+ Returns:
98+ List of (generated_text, model_id) tuples where model_id is the
99+ index of the selected model in the ensemble
100+ """
79101 tasks = [self .generate (prompt , ** kwargs ) for _ in range (n )]
80102 return await asyncio .gather (* tasks )
81103
82- async def parallel_generate (self , prompts : List [str ], ** kwargs ) -> List [str ]:
83- """Generate responses for multiple prompts in parallel"""
104+ async def parallel_generate (self , prompts : List [str ], ** kwargs ) -> List [Tuple [str , int ]]:
105+ """Generate responses for multiple prompts in parallel
106+
107+ Returns:
108+ List of (generated_text, model_id) tuples where model_id is the
109+ index of the selected model in the ensemble
110+ """
84111 tasks = [self .generate (prompt , ** kwargs ) for prompt in prompts ]
85112 return await asyncio .gather (* tasks )
86113
87114 async def generate_all_with_context (
88115 self , system_message : str , messages : List [Dict [str , str ]], ** kwargs
89- ) -> str :
90- """Generate text using a all available models and average their returned metrics"""
116+ ) -> List [str ]:
117+ """Generate text using all available models and average their returned metrics
118+
119+ Returns:
120+ List of generated texts, one per model in the ensemble (order matches
121+ self.models). The model_id for each response is its index in the list.
122+ """
91123 responses = []
92124 for model in self .models :
93125 responses .append (await model .generate_with_context (system_message , messages , ** kwargs ))
0 commit comments