Skip to content

Commit da9edcd

Browse files
authored
Merge pull request #4 from flare-research/openrouter-typing
fix(typing): strictly type openrouter, simplify imports and routing logic
2 parents 08be5ca + 36475da commit da9edcd

23 files changed

+380
-351
lines changed

src/flare_ai_consensus/config.py

Lines changed: 0 additions & 42 deletions
This file was deleted.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .aggregator import async_centralized_llm_aggregator, centralized_llm_aggregator
2+
from .consensus import send_round
3+
4+
__all__ = [
5+
"async_centralized_llm_aggregator",
6+
"centralized_llm_aggregator",
7+
"send_round",
8+
]
Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
from flare_ai_consensus.consensus.config import AggregatorConfig
2-
from flare_ai_consensus.router.client import AsyncOpenRouterClient, OpenRouterClient
1+
from flare_ai_consensus.router import (
2+
AsyncOpenRouterProvider,
3+
ChatRequest,
4+
OpenRouterProvider,
5+
)
6+
from flare_ai_consensus.settings import AggregatorConfig, Message
37

48

5-
def concatenate_aggregator(responses: dict[str, str]) -> str:
9+
def _concatenate_aggregator(responses: dict[str, str]) -> str:
610
"""
711
Aggregate responses by concatenating each model's answer with a label.
812
@@ -13,52 +17,52 @@ def concatenate_aggregator(responses: dict[str, str]) -> str:
1317

1418

1519
def centralized_llm_aggregator(
16-
client: OpenRouterClient,
20+
provider: OpenRouterProvider,
1721
aggregator_config: AggregatorConfig,
1822
aggregated_responses: dict[str, str],
1923
) -> str:
2024
"""Use a centralized LLM to combine responses.
2125
22-
:param client: An OpenRouterClient instance.
26+
:param provider: An OpenRouterProvider instance.
2327
:param aggregator_config: An instance of AggregatorConfig.
2428
:param aggregated_responses: A string containing aggregated
2529
responses from individual models.
2630
:return: The aggregator's combined response.
2731
"""
2832
# Build the message list.
29-
messages = []
33+
messages: list[Message] = []
3034
messages.extend(aggregator_config.context)
3135

3236
# Add a system message with the aggregated responses.
33-
aggregated_str = concatenate_aggregator(aggregated_responses)
37+
aggregated_str = _concatenate_aggregator(aggregated_responses)
3438
messages.append(
3539
{"role": "system", "content": f"Aggregated responses:\n{aggregated_str}"}
3640
)
3741

3842
# Add the aggregator prompt
3943
messages.extend(aggregator_config.prompt)
4044

41-
payload = {
45+
payload: ChatRequest = {
4246
"model": aggregator_config.model.model_id,
4347
"messages": messages,
4448
"max_tokens": aggregator_config.model.max_tokens,
4549
"temperature": aggregator_config.model.temperature,
4650
}
4751

4852
# Get aggregated response from the centralized LLM
49-
response = client.send_chat_completion(payload)
53+
response = provider.send_chat_completion(payload)
5054
return response.get("choices", [])[0].get("message", {}).get("content", "")
5155

5256

5357
async def async_centralized_llm_aggregator(
54-
client: AsyncOpenRouterClient,
58+
provider: AsyncOpenRouterProvider,
5559
aggregator_config: AggregatorConfig,
5660
aggregated_responses: dict[str, str],
5761
) -> str:
5862
"""
59-
Use a centralized LLM (via an async client) to combine responses.
63+
Use a centralized LLM (via an async provider) to combine responses.
6064
61-
:param client: An asynchronous OpenRouter client.
65+
:param provider: An asynchronous OpenRouterProvider.
6266
:param aggregator_config: An instance of AggregatorConfig.
6367
:param aggregated_responses: A string containing aggregated
6468
responses from individual models.
@@ -71,12 +75,12 @@ async def async_centralized_llm_aggregator(
7175
)
7276
messages.extend(aggregator_config.prompt)
7377

74-
payload = {
78+
payload: ChatRequest = {
7579
"model": aggregator_config.model.model_id,
7680
"messages": messages,
7781
"max_tokens": aggregator_config.model.max_tokens,
7882
"temperature": aggregator_config.model.temperature,
7983
}
8084

81-
response = await client.send_chat_completion(payload)
85+
response = await provider.send_chat_completion(payload)
8286
return response.get("choices", [])[0].get("message", {}).get("content", "")

src/flare_ai_consensus/consensus/config.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

src/flare_ai_consensus/consensus/consensus.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22

33
import structlog
44

5-
from flare_ai_consensus.consensus.config import ConsensusConfig, ModelConfig
6-
from flare_ai_consensus.router.client import AsyncOpenRouterClient
7-
from flare_ai_consensus.utils.parser import parse_chat_response
5+
from flare_ai_consensus.router import AsyncOpenRouterProvider, ChatRequest
6+
from flare_ai_consensus.settings import ConsensusConfig, Message, ModelConfig
7+
from flare_ai_consensus.utils import parse_chat_response
88

99
logger = structlog.get_logger(__name__)
1010

1111

12-
def build_improvement_conversation(
12+
def _build_improvement_conversation(
1313
consensus_config: ConsensusConfig, aggregated_response: str
14-
) -> list:
14+
) -> list[Message]:
1515
"""Build an updated conversation using the consensus configuration.
1616
1717
:param consensus_config: An instance of ConsensusConfig.
@@ -35,16 +35,16 @@ def build_improvement_conversation(
3535
return conversation
3636

3737

38-
async def get_response_for_model(
39-
client: AsyncOpenRouterClient,
38+
async def _get_response_for_model(
39+
provider: AsyncOpenRouterProvider,
4040
consensus_config: ConsensusConfig,
4141
model: ModelConfig,
4242
aggregated_response: str | None,
4343
) -> tuple[str | None, str]:
4444
"""
4545
Asynchronously sends a chat completion request for a given model.
4646
47-
:param client: An instance of an asynchronous OpenRouter client.
47+
:param provider: An instance of an asynchronous OpenRouter provider.
4848
:param consensus_config: An instance of ConsensusConfig.
4949
:param aggregated_response: The aggregated consensus response
5050
from the previous round (or None).
@@ -57,39 +57,39 @@ async def get_response_for_model(
5757
logger.info("sending initial prompt", model_id=model.model_id)
5858
else:
5959
# Build the improvement conversation.
60-
conversation = build_improvement_conversation(
60+
conversation = _build_improvement_conversation(
6161
consensus_config, aggregated_response
6262
)
6363
logger.info("sending improvement prompt", model_id=model.model_id)
6464

65-
payload = {
65+
payload: ChatRequest = {
6666
"model": model.model_id,
6767
"messages": conversation,
6868
"max_tokens": model.max_tokens,
6969
"temperature": model.temperature,
7070
}
71-
response = await client.send_chat_completion(payload)
71+
response = await provider.send_chat_completion(payload)
7272
text = parse_chat_response(response)
7373
logger.info("new response", model_id=model.model_id, response=text)
7474
return model.model_id, text
7575

7676

7777
async def send_round(
78-
client: AsyncOpenRouterClient,
78+
provider: AsyncOpenRouterProvider,
7979
consensus_config: ConsensusConfig,
8080
aggregated_response: str | None = None,
8181
) -> dict:
8282
"""
8383
Asynchronously sends a round of chat completion requests for all models.
8484
85-
:param client: An instance of an asynchronous OpenRouter client.
85+
:param provider: An instance of an asynchronous OpenRouter provider.
8686
:param consensus_config: An instance of ConsensusConfig.
8787
:param aggregated_response: The aggregated consensus response from the
8888
previous round (or None).
8989
:return: A dictionary mapping model IDs to their response texts.
9090
"""
9191
tasks = [
92-
get_response_for_model(client, consensus_config, model, aggregated_response)
92+
_get_response_for_model(provider, consensus_config, model, aggregated_response)
9393
for model in consensus_config.models
9494
]
9595
results = await asyncio.gather(*tasks)

0 commit comments

Comments
 (0)