Skip to content

Commit 39fd7b6

Browse files
add model_id tracking
1 parent 65cbbe8 commit 39fd7b6

File tree

6 files changed

+67
-29
lines changed

6 files changed

+67
-29
lines changed

openevolve/database.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class Program:
5555
generation: int = 0
5656
timestamp: float = field(default_factory=time.time)
5757
iteration_found: int = 0 # Track which iteration this program was found
58+
model_id: Optional[int] = None # Track the id of the model that generated this program
5859

5960
# Performance metrics
6061
metrics: Dict[str, float] = field(default_factory=dict)
@@ -1016,10 +1017,10 @@ def _llm_judge_novelty(self, program: Program, similar_program: Program) -> bool
10161017
messages=[{"role": "user", "content": user_msg}],
10171018
),
10181019
)
1019-
content: str = future.result()
1020+
content, _model_id = future.result()
10201021
except RuntimeError:
10211022
# No event loop running, safe to use asyncio.run()
1022-
content: str = asyncio.run(
1023+
content, _model_id = asyncio.run(
10231024
self.novelty_llm.generate_with_context(
10241025
system_message=NOVELTY_SYSTEM_MSG,
10251026
messages=[{"role": "user", "content": user_msg}],

openevolve/iteration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ async def run_iteration_with_shared_db(
8989
iteration_start = time.time()
9090

9191
# Generate code modification
92-
llm_response = await llm_ensemble.generate_with_context(
92+
llm_response, model_id = await llm_ensemble.generate_with_context(
9393
system_message=prompt["system"],
9494
messages=[{"role": "user", "content": prompt["user"]}],
9595
)
@@ -181,6 +181,7 @@ async def run_iteration_with_shared_db(
181181
generation=parent.generation + 1,
182182
metrics=result.child_metrics,
183183
iteration_found=iteration,
184+
model_id=model_id,
184185
metadata={
185186
"changes": changes_summary,
186187
"parent_metrics": parent.metrics,

openevolve/llm/ensemble.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,39 +55,71 @@ def __init__(self, models_cfg: List[LLMModelConfig]):
5555
)
5656
logger._ensemble_logged = True
5757

58-
async def generate(self, prompt: str, **kwargs) -> str:
59-
"""Generate text using a randomly selected model based on weights"""
60-
model = self._sample_model()
61-
return await model.generate(prompt, **kwargs)
58+
async def generate(self, prompt: str, **kwargs) -> Tuple[str, int]:
59+
"""Generate text using a randomly selected model based on weights
60+
61+
Returns:
62+
Tuple of (generated_text, model_id) where model_id is the index
63+
of the selected model in the ensemble
64+
"""
65+
model, model_id = self._sample_model()
66+
response = await model.generate(prompt, **kwargs)
67+
return response, model_id
6268

6369
async def generate_with_context(
6470
self, system_message: str, messages: List[Dict[str, str]], **kwargs
65-
) -> str:
66-
"""Generate text using a system message and conversational context"""
67-
model = self._sample_model()
68-
return await model.generate_with_context(system_message, messages, **kwargs)
69-
70-
def _sample_model(self) -> LLMInterface:
71-
"""Sample a model from the ensemble based on weights"""
71+
) -> Tuple[str, int]:
72+
"""Generate text using a system message and conversational context
73+
74+
Returns:
75+
Tuple of (generated_text, model_id) where model_id is the index
76+
of the selected model in the ensemble
77+
"""
78+
model, model_id = self._sample_model()
79+
response = await model.generate_with_context(system_message, messages, **kwargs)
80+
return response, model_id
81+
82+
def _sample_model(self) -> Tuple[LLMInterface, int]:
83+
"""Sample a model from the ensemble based on weights
84+
85+
Returns:
86+
Tuple of (model, model_id) where model_id is the index of the
87+
selected model in the ensemble
88+
"""
7289
index = self.random_state.choices(range(len(self.models)), weights=self.weights, k=1)[0]
7390
sampled_model = self.models[index]
7491
logger.info(f"Sampled model: {vars(sampled_model)['model']}")
75-
return sampled_model
92+
return sampled_model, index
93+
94+
async def generate_multiple(self, prompt: str, n: int, **kwargs) -> List[Tuple[str, int]]:
95+
"""Generate multiple texts in parallel
7696
77-
async def generate_multiple(self, prompt: str, n: int, **kwargs) -> List[str]:
78-
"""Generate multiple texts in parallel"""
97+
Returns:
98+
List of (generated_text, model_id) tuples where model_id is the
99+
index of the selected model in the ensemble
100+
"""
79101
tasks = [self.generate(prompt, **kwargs) for _ in range(n)]
80102
return await asyncio.gather(*tasks)
81103

82-
async def parallel_generate(self, prompts: List[str], **kwargs) -> List[str]:
83-
"""Generate responses for multiple prompts in parallel"""
104+
async def parallel_generate(self, prompts: List[str], **kwargs) -> List[Tuple[str, int]]:
105+
"""Generate responses for multiple prompts in parallel
106+
107+
Returns:
108+
List of (generated_text, model_id) tuples where model_id is the
109+
index of the selected model in the ensemble
110+
"""
84111
tasks = [self.generate(prompt, **kwargs) for prompt in prompts]
85112
return await asyncio.gather(*tasks)
86113

87114
async def generate_all_with_context(
88115
self, system_message: str, messages: List[Dict[str, str]], **kwargs
89-
) -> str:
90-
"""Generate text using a all available models and average their returned metrics"""
116+
) -> List[str]:
117+
"""Generate text using all available models and average their returned metrics
118+
119+
Returns:
120+
List of generated texts, one per model in the ensemble (order matches
121+
self.models). The model_id for each response is its index in the list.
122+
"""
91123
responses = []
92124
for model in self.models:
93125
responses.append(await model.generate_with_context(system_message, messages, **kwargs))

openevolve/process_parallel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _run_iteration_worker(
197197

198198
# Generate code modification (sync wrapper for async)
199199
try:
200-
llm_response = asyncio.run(
200+
llm_response, model_id = asyncio.run(
201201
_worker_llm_ensemble.generate_with_context(
202202
system_message=prompt["system"],
203203
messages=[{"role": "user", "content": prompt["user"]}],
@@ -304,6 +304,7 @@ def _run_iteration_worker(
304304
generation=parent.generation + 1,
305305
metrics=child_metrics,
306306
iteration_found=iteration,
307+
model_id=model_id,
307308
metadata={
308309
"changes": changes_summary,
309310
"parent_metrics": parent.metrics,

tests/test_llm_ensemble.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,22 @@ def test_weighted_sampling(self):
1717
ensemble = LLMEnsemble(models)
1818
# Should always sample model 'b'
1919
for _ in range(10):
20-
self.assertEqual(ensemble._sample_model().model, "b")
20+
model, model_id = ensemble._sample_model()
21+
self.assertEqual(model.model, "b")
22+
self.assertEqual(model_id, 1)
2123

2224
models = [
2325
LLMModelConfig(name="a", weight=0.3, api_key="test", api_base="http://test"),
2426
LLMModelConfig(name="b", weight=0.3, api_key="test", api_base="http://test"),
2527
LLMModelConfig(name="c", weight=0.3, api_key="test", api_base="http://test"),
2628
]
2729
ensemble = LLMEnsemble(models)
28-
# Should sample both models. Track sampled models in a set
30+
# Should sample all models. Track sampled models in a set
2931
sampled_models = set()
3032
for _ in range(1000):
31-
sampled_models.add(ensemble._sample_model().model)
32-
# Cancel once we have both models
33+
model, model_id = ensemble._sample_model()
34+
sampled_models.add(model.model)
35+
# Cancel once we have all models
3336
if len(sampled_models) == len(models):
3437
break
3538
self.assertEqual(len(sampled_models), len(models))

tests/test_novelty_asyncio_issue.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616

1717
class MockLLM:
18-
"""Mock LLM that implements the async interface"""
18+
"""Mock LLM that implements the LLMEnsemble async interface"""
1919

2020
async def generate_with_context(self, system_message: str, messages: list):
21-
"""Mock async generate method that returns NOVEL"""
22-
return "NOVEL"
21+
"""Mock async generate method that returns NOVEL with model_id"""
22+
return "NOVEL", 0
2323

2424

2525
class TestNoveltyAsyncioIssue(unittest.TestCase):

0 commit comments

Comments
 (0)