33SPDX-License-Identifier: Apache-2.0
44"""
55
6+ import time
67from typing import Callable
78
89from loguru import logger
1213from bhive .utils import parallel_bedrock_exec
1314
1415
16+ def run_inference (
17+ config : HiveConfig , chatlog : chat .ChatLog , _converse_func : Callable , message : str | None = None
18+ ) -> tuple [str | list [str ], chat .ChatLog ]:
19+ is_single = config .n_models == 1
20+ start_time = time .monotonic ()
21+
22+ for n_reflect in range (config .num_reflections + 1 ):
23+ if n_reflect > 0 :
24+ if config .max_reasoning_seconds is not None :
25+ elapsed = time .monotonic () - start_time
26+ if elapsed >= config .max_reasoning_seconds :
27+ logger .info (
28+ f"Exiting early at round { n_reflect } /{ config .num_reflections } "
29+ f"after { elapsed :.1f} s (limit: { config .max_reasoning_seconds } s)"
30+ )
31+ break
32+ if n_reflect > 0 :
33+ if is_single :
34+ reflect_msg = prompt .reflect + "\n "
35+ if config .verifier :
36+ past_answer = chatlog .get_last_answer ()
37+ reflect_msg += apply_verification (past_answer , config .verifier ) # type: ignore[arg-type]
38+ if message :
39+ reflect_msg += f"\n As a reminder, the original question is { message } "
40+ chatlog .add_user_msg (reflect_msg , invoke_index = 0 )
41+ else :
42+ for index in range (config .n_models ):
43+ recent_other_answers = chatlog .get_recent_other_answers (index )
44+ debate_msg = prompt .debate
45+ for recent_ans in recent_other_answers :
46+ answer_text = recent_ans ["content" ][0 ]["text" ]
47+ debate_msg += f"\n \n One agent response: ```{ answer_text } ```"
48+ if config .verifier :
49+ debate_msg += apply_verification (answer_text , config .verifier )
50+ debate_msg += f"\n \n { prompt .careful } \n "
51+ if message :
52+ debate_msg += f"\n As a reminder, the original question is { message } "
53+ chatlog .add_user_msg (debate_msg , index )
54+
55+ if is_single :
56+ modelid = config .bedrock_model_ids [0 ]
57+ response = _converse_func (model_id = modelid , messages = chatlog .history [0 ].chat_history )
58+ _record_response (chatlog , 0 , modelid , response )
59+ else :
60+ responses = parallel_bedrock_exec (_converse_func , chathistory = chatlog .history )
61+ for (index , modelid ), response in responses .items ():
62+ _record_response (chatlog , index , modelid , response )
63+
64+ if config .aggregator_model_id :
65+ chatlog = aggregate_last_responses (config , chatlog , _converse_func , message )
66+
67+ return chatlog .get_last_answer (), chatlog
68+
69+
70+ def _record_response (
71+ chatlog : chat .ChatLog , index : int , modelid : str , response : chat .ConverseResponse
72+ ):
73+ chatlog .add_assistant_msg (response .answer , index )
74+ if response .thinking :
75+ chatlog .add_thinking_trace (response .thinking , index )
76+ chatlog .update_stats (modelid , response )
77+ chatlog .add_stop_reason (response .stopReason )
78+ chatlog .add_trace (response .trace )
79+
80+
1581def aggregate_last_responses (
1682 config : HiveConfig , chatlog : chat .ChatLog , _converse_func : Callable , message : str | None = None
1783) -> chat .ChatLog :
@@ -30,131 +96,12 @@ def aggregate_last_responses(
3096 logger .info (f"Aggregating a final response using { config .aggregator_model_id = } " )
3197 response : chat .ConverseResponse = _converse_func (config .aggregator_model_id , [fmt_msg ])
3298
33- chatlog .add_assistant_msg (response .answer , 0 )
34- if response .thinking :
35- chatlog .add_thinking_trace (response .thinking , 0 )
36- chatlog .update_stats (config .aggregator_model_id , response )
37- chatlog .add_stop_reason (response .stopReason )
38- chatlog .add_trace (response .trace )
99+ _record_response (chatlog , 0 , config .aggregator_model_id , response )
39100
40101 return chatlog
41102
42103
43- def single_model_single_call (
44- config : HiveConfig , chatlog : chat .ChatLog , _converse_func : Callable
45- ) -> tuple [str , chat .ChatLog ]:
46- modelid = config .bedrock_model_ids [0 ]
47- logger .info (f"Calling { modelid } with no self-reflection" )
48- response : chat .ConverseResponse = _converse_func (
49- model_id = modelid , messages = chatlog .history [0 ].chat_history
50- )
51- chatlog .add_assistant_msg (response .answer , 0 )
52- if response .thinking :
53- chatlog .add_thinking_trace (response .thinking , 0 )
54- chatlog .update_stats (modelid , response )
55- chatlog .add_stop_reason (response .stopReason )
56- chatlog .add_trace (response .trace )
57-
58- return response .answer , chatlog
59-
60-
61- def multi_model_single_call (
62- config : HiveConfig , chatlog : chat .ChatLog , _converse_func : Callable , message : str | None = None
63- ) -> tuple [str | list [str ], chat .ChatLog ]:
64- logger .info (f"Calling { config .bedrock_model_ids } with no self-reflection" )
65- responses : dict [tuple [int , str ], chat .ConverseResponse ] = parallel_bedrock_exec (
66- _converse_func , chathistory = chatlog .history
67- )
68- for (index , modelid ), response in responses .items ():
69- chatlog .add_assistant_msg (response .answer , index )
70- if response .thinking :
71- chatlog .add_thinking_trace (response .thinking , index )
72- chatlog .update_stats (modelid , response )
73- chatlog .add_stop_reason (response .stopReason )
74- chatlog .add_trace (response .trace )
75-
76- if config .aggregator_model_id :
77- # aggregate an answer
78- chatlog = aggregate_last_responses (config , chatlog , _converse_func , message )
79- return chatlog .get_last_answer (), chatlog
80-
81-
82- def single_model_multi_call (
83- config : HiveConfig , chatlog : chat .ChatLog , _converse_func : Callable , message : str | None = None
84- ) -> tuple [str , chat .ChatLog ]:
85- modelid = config .bedrock_model_ids [0 ]
86- logger .info (f"Calling { modelid } with { config .num_reflections } rounds of self-reflection" )
87- for n_reflect in range (config .num_reflections + 1 ):
88- if 0 < n_reflect :
89- reflect_msg = prompt .reflect + "\n "
90- if config .verifier :
91- past_answer = chatlog .get_last_answer ()
92- assert isinstance (past_answer , str ), (
93- "Received multiple responds when doing a single model call"
94- )
95- reflect_msg += apply_verification (past_answer , config .verifier )
96- if message :
97- reflect_msg += f"\n As a reminder, the original question is { message } "
98- chatlog .add_user_msg (reflect_msg , invoke_index = 0 )
99- response : chat .ConverseResponse = _converse_func (
100- model_id = modelid , messages = chatlog .history [0 ].chat_history
101- )
102- chatlog .add_assistant_msg (response .answer , invoke_index = 0 )
103- if response .thinking :
104- chatlog .add_thinking_trace (response .thinking , invoke_index = 0 )
105- chatlog .update_stats (modelid , response )
106- chatlog .add_stop_reason (response .stopReason )
107- chatlog .add_trace (response .trace )
108-
109- return response .answer , chatlog
110-
111-
112- def multi_model_multi_call (
113- config : HiveConfig , chatlog : chat .ChatLog , _converse_func : Callable , message : str | None = None
114- ) -> tuple [str | list [str ], chat .ChatLog ]:
115- logger .info (f"Calling { config .bedrock_model_ids } with { config .num_reflections } rounds" )
116- for n_reflect in range (config .num_reflections + 1 ):
117- if 0 < n_reflect :
118- # consider others & debate
119- for index , model_log in enumerate (chatlog .history ):
120- recent_other_answers = chatlog .get_recent_other_answers (index )
121- debate_msg = prompt .debate
122- for recent_ans in recent_other_answers :
123- # NOTE could alternatively summarise messages
124- answer_text = recent_ans ["content" ][0 ]["text" ]
125- debate_msg += f"\n \n One agent response: ```{ answer_text } ```"
126- if config .verifier :
127- debate_msg += apply_verification (answer_text , config .verifier )
128- debate_msg += f"\n \n { prompt .careful } \n "
129- if message :
130- debate_msg += f"\n As a reminder, the original question is { message } "
131- logger .debug (f"Sending request to { model_log .modelid } :\n { debate_msg } " )
132- chatlog .add_user_msg (debate_msg , index )
133-
134- logger .info (
135- f"Fetching debate #{ n_reflect + 1 } answers from all { config .bedrock_model_ids = } "
136- )
137- responses : dict [tuple [int , str ], chat .ConverseResponse ] = parallel_bedrock_exec (
138- _converse_func , chathistory = chatlog .history
139- )
140- for (index , modelid ), response in responses .items ():
141- chatlog .add_assistant_msg (response .answer , index )
142- if response .thinking :
143- chatlog .add_thinking_trace (response .thinking , index )
144- chatlog .update_stats (modelid , response )
145- chatlog .add_stop_reason (response .stopReason )
146- chatlog .add_trace (response .trace )
147-
148- if config .aggregator_model_id :
149- # aggregate an answer
150- chatlog = aggregate_last_responses (config , chatlog , _converse_func , message )
151- return chatlog .get_last_answer (), chatlog
152-
153-
154104def apply_verification (past_answer : str , verifier : Callable [[str ], str ]) -> str :
155- # Applies a verification function to add more context
156105 verifier_context = verifier (past_answer )
157106 logger .debug (f"External verification function returned: { verifier_context } " )
158- return (
159- f"\n An external verification function has added context to this answer: { verifier_context } "
160- )
107+ return f"An external verifier has added the following to this answer: { verifier_context } "
0 commit comments