Skip to content

Commit 36475da

Browse files
committed
fix(config): use single (renamed) settings
1 parent 9e5beb0 commit 36475da

File tree

15 files changed

+163
-163
lines changed

15 files changed

+163
-163
lines changed

src/flare_ai_consensus/config.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

src/flare_ai_consensus/consensus/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
from .aggregator import async_centralized_llm_aggregator, centralized_llm_aggregator
2-
from .config import AggregatorConfig, ConsensusConfig, ModelConfig
32
from .consensus import send_round
43

54
__all__ = [
6-
"AggregatorConfig",
7-
"ConsensusConfig",
8-
"ModelConfig",
95
"async_centralized_llm_aggregator",
106
"centralized_llm_aggregator",
117
"send_round",

src/flare_ai_consensus/consensus/aggregator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
from flare_ai_consensus.consensus.config import AggregatorConfig
21
from flare_ai_consensus.router import (
32
AsyncOpenRouterProvider,
43
ChatRequest,
5-
Message,
64
OpenRouterProvider,
75
)
6+
from flare_ai_consensus.settings import AggregatorConfig, Message
87

98

10-
def concatenate_aggregator(responses: dict[str, str]) -> str:
9+
def _concatenate_aggregator(responses: dict[str, str]) -> str:
1110
"""
1211
Aggregate responses by concatenating each model's answer with a label.
1312
@@ -35,7 +34,7 @@ def centralized_llm_aggregator(
3534
messages.extend(aggregator_config.context)
3635

3736
# Add a system message with the aggregated responses.
38-
aggregated_str = concatenate_aggregator(aggregated_responses)
37+
aggregated_str = _concatenate_aggregator(aggregated_responses)
3938
messages.append(
4039
{"role": "system", "content": f"Aggregated responses:\n{aggregated_str}"}
4140
)

src/flare_ai_consensus/consensus/config.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

src/flare_ai_consensus/consensus/consensus.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
import structlog
44

5-
from flare_ai_consensus.consensus.config import ConsensusConfig, ModelConfig
6-
from flare_ai_consensus.router import AsyncOpenRouterProvider, ChatRequest, Message
5+
from flare_ai_consensus.router import AsyncOpenRouterProvider, ChatRequest
6+
from flare_ai_consensus.settings import ConsensusConfig, Message, ModelConfig
77
from flare_ai_consensus.utils import parse_chat_response
88

99
logger = structlog.get_logger(__name__)
1010

1111

12-
def build_improvement_conversation(
12+
def _build_improvement_conversation(
1313
consensus_config: ConsensusConfig, aggregated_response: str
1414
) -> list[Message]:
1515
"""Build an updated conversation using the consensus configuration.
@@ -35,7 +35,7 @@ def build_improvement_conversation(
3535
return conversation
3636

3737

38-
async def get_response_for_model(
38+
async def _get_response_for_model(
3939
provider: AsyncOpenRouterProvider,
4040
consensus_config: ConsensusConfig,
4141
model: ModelConfig,
@@ -57,7 +57,7 @@ async def get_response_for_model(
5757
logger.info("sending initial prompt", model_id=model.model_id)
5858
else:
5959
# Build the improvement conversation.
60-
conversation = build_improvement_conversation(
60+
conversation = _build_improvement_conversation(
6161
consensus_config, aggregated_response
6262
)
6363
logger.info("sending improvement prompt", model_id=model.model_id)
@@ -89,7 +89,7 @@ async def send_round(
8989
:return: A dictionary mapping model IDs to their response texts.
9090
"""
9191
tasks = [
92-
get_response_for_model(provider, consensus_config, model, aggregated_response)
92+
_get_response_for_model(provider, consensus_config, model, aggregated_response)
9393
for model in consensus_config.models
9494
]
9595
results = await asyncio.gather(*tasks)

src/flare_ai_consensus/main.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
import structlog
44

5-
from flare_ai_consensus.config import config
65
from flare_ai_consensus.consensus import (
7-
ConsensusConfig,
86
async_centralized_llm_aggregator,
97
send_round,
108
)
119
from flare_ai_consensus.router import AsyncOpenRouterProvider
10+
from flare_ai_consensus.settings import ConsensusConfig, settings
1211
from flare_ai_consensus.utils import load_json, save_json
1312

1413
logger = structlog.get_logger(__name__)
@@ -56,7 +55,7 @@ async def run_consensus(
5655
response_data[f"aggregate_{i + 1}"] = aggregated_response
5756

5857
# Step 3: Save final consensus.
59-
output_file = config.data_path / "final_consensus.json"
58+
output_file = settings.data_path / "final_consensus.json"
6059
save_json(
6160
response_data,
6261
output_file,
@@ -69,16 +68,17 @@ async def run_consensus(
6968

7069
def main() -> None:
7170
# Load the consensus configuration from input.json
72-
config_json = load_json(config.input_path / "input.json")
73-
consensus_config = ConsensusConfig.load_parameters(config_json)
71+
config_json = load_json(settings.input_path / "input.json")
72+
settings.load_consensus_config(config_json)
7473

7574
# Initialize the OpenRouter provider.
7675
provider = AsyncOpenRouterProvider(
77-
api_key=config.open_router_api_key, base_url=config.open_router_base_url
76+
api_key=settings.open_router_api_key, base_url=settings.open_router_base_url
7877
)
7978

8079
# Run the consensus learning process with synchronous requests.
81-
asyncio.run(run_consensus(provider, consensus_config))
80+
if settings.consensus_config:
81+
asyncio.run(run_consensus(provider, settings.consensus_config))
8282

8383

8484
if __name__ == "__main__":
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from .base_router import ChatRequest, CompletionRequest, Message
1+
from .base_router import ChatRequest, CompletionRequest
22
from .openrouter import AsyncOpenRouterProvider, OpenRouterProvider
33

44
__all__ = [
55
"AsyncOpenRouterProvider",
66
"ChatRequest",
77
"CompletionRequest",
8-
"Message",
98
"OpenRouterProvider",
109
]

src/flare_ai_consensus/router/base_router.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from typing import Any, Literal, TypedDict
1+
from typing import Any, TypedDict
22

33
import httpx
44
import requests
55

6+
from flare_ai_consensus.settings import Message
7+
68

79
class CompletionRequest(TypedDict):
810
model: str
@@ -11,11 +13,6 @@ class CompletionRequest(TypedDict):
1113
temperature: float
1214

1315

14-
class Message(TypedDict):
15-
role: Literal["user", "assistant", "system"]
16-
content: str
17-
18-
1916
class ChatRequest(TypedDict):
2017
model: str
2118
messages: list[Message]

src/flare_ai_consensus/settings.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from pathlib import Path
2+
from typing import Literal, TypedDict
3+
4+
import structlog
5+
from pydantic import BaseModel
6+
from pydantic_settings import BaseSettings, SettingsConfigDict
7+
8+
logger = structlog.get_logger(__name__)
9+
10+
11+
def create_path(folder_name: str) -> Path:
12+
"""Creates and returns a path for storing data or logs."""
13+
path = Path(__file__).parent.resolve().parent / f"{folder_name}"
14+
path.mkdir(exist_ok=True)
15+
return path
16+
17+
18+
class Message(TypedDict):
19+
role: str
20+
content: str
21+
22+
23+
class ModelConfig(BaseModel):
24+
"""Configuration for individual models"""
25+
26+
model_id: str
27+
max_tokens: int = 50
28+
temperature: float = 0.7
29+
30+
31+
class AggregatorConfig(BaseModel):
32+
"""Configuration for the aggregator"""
33+
34+
model: ModelConfig
35+
approach: str
36+
context: list[Message]
37+
prompt: list[Message]
38+
39+
40+
class ConsensusConfig(BaseModel):
41+
"""Configuration for the consensus mechanism"""
42+
43+
models: list[ModelConfig]
44+
aggregator_config: AggregatorConfig
45+
initial_prompt: list[Message]
46+
improvement_prompt: str
47+
iterations: int
48+
aggregated_prompt_type: Literal["user", "assistant", "system"]
49+
50+
@classmethod
51+
def from_json(cls, json_data: dict) -> "ConsensusConfig":
52+
"""Create ConsensusConfig from JSON data"""
53+
# Parse the list of models
54+
models = [
55+
ModelConfig(
56+
model_id=m["id"],
57+
max_tokens=m["max_tokens"],
58+
temperature=m["temperature"],
59+
)
60+
for m in json_data.get("models", [])
61+
]
62+
63+
# Parse the aggregator configuration
64+
aggr_data = json_data.get("aggregator", [])[0]
65+
aggr_model_data = aggr_data.get("model", {})
66+
aggregator_model = ModelConfig(
67+
model_id=aggr_model_data["id"],
68+
max_tokens=aggr_model_data["max_tokens"],
69+
temperature=aggr_model_data["temperature"],
70+
)
71+
72+
aggregator_config = AggregatorConfig(
73+
model=aggregator_model,
74+
approach=aggr_data.get("approach", ""),
75+
context=aggr_data.get("aggregator_context", []),
76+
prompt=aggr_data.get("aggregator_prompt", []),
77+
)
78+
79+
return cls(
80+
models=models,
81+
aggregator_config=aggregator_config,
82+
initial_prompt=json_data.get("initial_conversation", []),
83+
improvement_prompt=json_data.get("improvement_prompt", ""),
84+
iterations=json_data.get("iterations", 1),
85+
aggregated_prompt_type=json_data.get("aggregated_prompt_type", "system"),
86+
)
87+
88+
89+
class Settings(BaseSettings):
90+
"""
91+
Application settings model that provides configuration for all components.
92+
Combines both infrastructure and consensus settings.
93+
"""
94+
95+
# OpenRouter Settings
96+
open_router_base_url: str = "https://openrouter.ai/api/v1"
97+
open_router_api_key: str = ""
98+
99+
# Path Settings
100+
data_path: Path = create_path("data")
101+
input_path: Path = create_path("flare_ai_consensus")
102+
103+
# Consensus Settings
104+
consensus_config: ConsensusConfig | None = None
105+
106+
model_config = SettingsConfigDict(
107+
env_file=".env",
108+
env_file_encoding="utf-8",
109+
extra="ignore",
110+
)
111+
112+
def load_consensus_config(self, json_data: dict) -> None:
113+
"""Load consensus configuration from JSON data"""
114+
self.consensus_config = ConsensusConfig.from_json(json_data)
115+
logger.info("loaded consensus configuration")
116+
117+
118+
# Create a global settings instance
119+
settings = Settings()
120+
logger.debug("settings initialized", settings=settings.model_dump())

0 commit comments

Comments
 (0)