Skip to content

Commit 9e5beb0

Browse files
committed
fix(typing): strictly type openrouter, simplify imports and routing logic
1 parent 08be5ca commit 9e5beb0

21 files changed

+243
-214
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from .aggregator import async_centralized_llm_aggregator, centralized_llm_aggregator
2+
from .config import AggregatorConfig, ConsensusConfig, ModelConfig
3+
from .consensus import send_round
4+
5+
__all__ = [
6+
"AggregatorConfig",
7+
"ConsensusConfig",
8+
"ModelConfig",
9+
"async_centralized_llm_aggregator",
10+
"centralized_llm_aggregator",
11+
"send_round",
12+
]

src/flare_ai_consensus/consensus/aggregator.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from flare_ai_consensus.consensus.config import AggregatorConfig
2-
from flare_ai_consensus.router.client import AsyncOpenRouterClient, OpenRouterClient
2+
from flare_ai_consensus.router import (
3+
AsyncOpenRouterProvider,
4+
ChatRequest,
5+
Message,
6+
OpenRouterProvider,
7+
)
38

49

510
def concatenate_aggregator(responses: dict[str, str]) -> str:
@@ -13,20 +18,20 @@ def concatenate_aggregator(responses: dict[str, str]) -> str:
1318

1419

1520
def centralized_llm_aggregator(
16-
client: OpenRouterClient,
21+
provider: OpenRouterProvider,
1722
aggregator_config: AggregatorConfig,
1823
aggregated_responses: dict[str, str],
1924
) -> str:
2025
"""Use a centralized LLM to combine responses.
2126
22-
:param client: An OpenRouterClient instance.
27+
:param provider: An OpenRouterProvider instance.
2328
:param aggregator_config: An instance of AggregatorConfig.
2429
:param aggregated_responses: A string containing aggregated
2530
responses from individual models.
2631
:return: The aggregator's combined response.
2732
"""
2833
# Build the message list.
29-
messages = []
34+
messages: list[Message] = []
3035
messages.extend(aggregator_config.context)
3136

3237
# Add a system message with the aggregated responses.
@@ -38,27 +43,27 @@ def centralized_llm_aggregator(
3843
# Add the aggregator prompt
3944
messages.extend(aggregator_config.prompt)
4045

41-
payload = {
46+
payload: ChatRequest = {
4247
"model": aggregator_config.model.model_id,
4348
"messages": messages,
4449
"max_tokens": aggregator_config.model.max_tokens,
4550
"temperature": aggregator_config.model.temperature,
4651
}
4752

4853
# Get aggregated response from the centralized LLM
49-
response = client.send_chat_completion(payload)
54+
response = provider.send_chat_completion(payload)
5055
return response.get("choices", [])[0].get("message", {}).get("content", "")
5156

5257

5358
async def async_centralized_llm_aggregator(
54-
client: AsyncOpenRouterClient,
59+
provider: AsyncOpenRouterProvider,
5560
aggregator_config: AggregatorConfig,
5661
aggregated_responses: dict[str, str],
5762
) -> str:
5863
"""
59-
Use a centralized LLM (via an async client) to combine responses.
64+
Use a centralized LLM (via an async provider) to combine responses.
6065
61-
:param client: An asynchronous OpenRouter client.
66+
:param provider: An asynchronous OpenRouterProvider.
6267
:param aggregator_config: An instance of AggregatorConfig.
6368
:param aggregated_responses: A string containing aggregated
6469
responses from individual models.
@@ -71,12 +76,12 @@ async def async_centralized_llm_aggregator(
7176
)
7277
messages.extend(aggregator_config.prompt)
7378

74-
payload = {
79+
payload: ChatRequest = {
7580
"model": aggregator_config.model.model_id,
7681
"messages": messages,
7782
"max_tokens": aggregator_config.model.max_tokens,
7883
"temperature": aggregator_config.model.temperature,
7984
}
8085

81-
response = await client.send_chat_completion(payload)
86+
response = await provider.send_chat_completion(payload)
8287
return response.get("choices", [])[0].get("message", {}).get("content", "")

src/flare_ai_consensus/consensus/config.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from dataclasses import dataclass
2+
from typing import Literal
3+
4+
from flare_ai_consensus.router import Message
25

36

47
@dataclass(frozen=True)
58
class ModelConfig:
6-
model_id: str | None = None
9+
model_id: str
710
max_tokens: int = 50
811
temperature: float = 0.7
912

@@ -12,18 +15,18 @@ class ModelConfig:
1215
class AggregatorConfig:
1316
model: ModelConfig
1417
approach: str
15-
context: list[dict]
16-
prompt: list[dict]
18+
context: list[Message]
19+
prompt: list[Message]
1720

1821

1922
@dataclass(frozen=True)
2023
class ConsensusConfig:
2124
models: list[ModelConfig]
2225
aggregator_config: AggregatorConfig
23-
initial_prompt: list[dict]
26+
initial_prompt: list[Message]
2427
improvement_prompt: str
2528
iterations: int
26-
aggregated_prompt_type: str
29+
aggregated_prompt_type: Literal["user", "assistant", "system"]
2730

2831
@staticmethod
2932
def load_parameters(json_data: dict) -> "ConsensusConfig":

src/flare_ai_consensus/consensus/consensus.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import structlog
44

55
from flare_ai_consensus.consensus.config import ConsensusConfig, ModelConfig
6-
from flare_ai_consensus.router.client import AsyncOpenRouterClient
7-
from flare_ai_consensus.utils.parser import parse_chat_response
6+
from flare_ai_consensus.router import AsyncOpenRouterProvider, ChatRequest, Message
7+
from flare_ai_consensus.utils import parse_chat_response
88

99
logger = structlog.get_logger(__name__)
1010

1111

1212
def build_improvement_conversation(
1313
consensus_config: ConsensusConfig, aggregated_response: str
14-
) -> list:
14+
) -> list[Message]:
1515
"""Build an updated conversation using the consensus configuration.
1616
1717
:param consensus_config: An instance of ConsensusConfig.
@@ -36,15 +36,15 @@ def build_improvement_conversation(
3636

3737

3838
async def get_response_for_model(
39-
client: AsyncOpenRouterClient,
39+
provider: AsyncOpenRouterProvider,
4040
consensus_config: ConsensusConfig,
4141
model: ModelConfig,
4242
aggregated_response: str | None,
4343
) -> tuple[str | None, str]:
4444
"""
4545
Asynchronously sends a chat completion request for a given model.
4646
47-
:param client: An instance of an asynchronous OpenRouter client.
47+
:param provider: An instance of an asynchronous OpenRouter provider.
4848
:param consensus_config: An instance of ConsensusConfig.
4949
:param aggregated_response: The aggregated consensus response
5050
from the previous round (or None).
@@ -62,34 +62,34 @@ async def get_response_for_model(
6262
)
6363
logger.info("sending improvement prompt", model_id=model.model_id)
6464

65-
payload = {
65+
payload: ChatRequest = {
6666
"model": model.model_id,
6767
"messages": conversation,
6868
"max_tokens": model.max_tokens,
6969
"temperature": model.temperature,
7070
}
71-
response = await client.send_chat_completion(payload)
71+
response = await provider.send_chat_completion(payload)
7272
text = parse_chat_response(response)
7373
logger.info("new response", model_id=model.model_id, response=text)
7474
return model.model_id, text
7575

7676

7777
async def send_round(
78-
client: AsyncOpenRouterClient,
78+
provider: AsyncOpenRouterProvider,
7979
consensus_config: ConsensusConfig,
8080
aggregated_response: str | None = None,
8181
) -> dict:
8282
"""
8383
Asynchronously sends a round of chat completion requests for all models.
8484
85-
:param client: An instance of an asynchronous OpenRouter client.
85+
:param provider: An instance of an asynchronous OpenRouter provider.
8686
:param consensus_config: An instance of ConsensusConfig.
8787
:param aggregated_response: The aggregated consensus response from the
8888
previous round (or None).
8989
:return: A dictionary mapping model IDs to their response texts.
9090
"""
9191
tasks = [
92-
get_response_for_model(client, consensus_config, model, aggregated_response)
92+
get_response_for_model(provider, consensus_config, model, aggregated_response)
9393
for model in consensus_config.models
9494
]
9595
results = await asyncio.gather(*tasks)

src/flare_ai_consensus/main.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,35 @@
33
import structlog
44

55
from flare_ai_consensus.config import config
6-
from flare_ai_consensus.consensus import aggregator, consensus
7-
from flare_ai_consensus.consensus.config import ConsensusConfig
8-
from flare_ai_consensus.router.client import AsyncOpenRouterClient
9-
from flare_ai_consensus.utils import (
10-
loader,
11-
saver,
6+
from flare_ai_consensus.consensus import (
7+
ConsensusConfig,
8+
async_centralized_llm_aggregator,
9+
send_round,
1210
)
11+
from flare_ai_consensus.router import AsyncOpenRouterProvider
12+
from flare_ai_consensus.utils import load_json, save_json
1313

1414
logger = structlog.get_logger(__name__)
1515

1616

1717
async def run_consensus(
18-
client: AsyncOpenRouterClient,
18+
provider: AsyncOpenRouterProvider,
1919
consensus_config: ConsensusConfig,
2020
) -> None:
2121
"""
2222
Asynchronously runs the consensus learning loop.
2323
24-
:param client: An instance of a synchronous OpenRouterClient (used for aggregation).
25-
:param async_client: An instance of an asynchronous OpenRouterClient.
24+
:param provider: An instance of a OpenRouterProvider (used for aggregation).
25+
:param async_provider: An instance of an AsyncOpenRouterProvider.
2626
:param consensus_config: An instance of ConsensusConfig.
2727
"""
2828
response_data = {}
2929
response_data["initial_conversation"] = consensus_config.initial_prompt
3030

3131
# Step 1: Initial round.
32-
responses = await consensus.send_round(client, consensus_config)
33-
aggregated_response = await aggregator.async_centralized_llm_aggregator(
34-
client, consensus_config.aggregator_config, responses
32+
responses = await send_round(provider, consensus_config)
33+
aggregated_response = await async_centralized_llm_aggregator(
34+
provider, consensus_config.aggregator_config, responses
3535
)
3636
logger.info(
3737
"initial response aggregation complete", aggregated_response=aggregated_response
@@ -42,11 +42,9 @@ async def run_consensus(
4242

4343
# Step 2: Improvement rounds.
4444
for i in range(consensus_config.iterations):
45-
responses = await consensus.send_round(
46-
client, consensus_config, aggregated_response
47-
)
48-
aggregated_response = await aggregator.async_centralized_llm_aggregator(
49-
client, consensus_config.aggregator_config, responses
45+
responses = await send_round(provider, consensus_config, aggregated_response)
46+
aggregated_response = await async_centralized_llm_aggregator(
47+
provider, consensus_config.aggregator_config, responses
5048
)
5149
logger.info(
5250
"responses aggregated",
@@ -59,28 +57,28 @@ async def run_consensus(
5957

6058
# Step 3: Save final consensus.
6159
output_file = config.data_path / "final_consensus.json"
62-
saver.save_json(
60+
save_json(
6361
response_data,
6462
output_file,
6563
)
6664
logger.info("saved consensus", output_file=output_file)
6765

68-
# Close the async client to release resources.
69-
await client.close()
66+
# Close the async provider to release resources.
67+
await provider.close()
7068

7169

7270
def main() -> None:
7371
# Load the consensus configuration from input.json
74-
config_json = loader.load_json(config.input_path / "input.json")
72+
config_json = load_json(config.input_path / "input.json")
7573
consensus_config = ConsensusConfig.load_parameters(config_json)
7674

77-
# Initialize the OpenRouter client.
78-
client = AsyncOpenRouterClient(
75+
# Initialize the OpenRouter provider.
76+
provider = AsyncOpenRouterProvider(
7977
api_key=config.open_router_api_key, base_url=config.open_router_base_url
8078
)
8179

8280
# Run the consensus learning process with synchronous requests.
83-
asyncio.run(run_consensus(client, consensus_config))
81+
asyncio.run(run_consensus(provider, consensus_config))
8482

8583

8684
if __name__ == "__main__":
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .base_router import ChatRequest, CompletionRequest, Message
2+
from .openrouter import AsyncOpenRouterProvider, OpenRouterProvider
3+
4+
__all__ = [
5+
"AsyncOpenRouterProvider",
6+
"ChatRequest",
7+
"CompletionRequest",
8+
"Message",
9+
"OpenRouterProvider",
10+
]

src/flare_ai_consensus/router/async_requests.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)