Skip to content

Commit 00b1e09

Browse files
committed
Merge branch 'feat/reasoning-timeout' into 'main'
Feat/reasoning timeout See merge request proserve/genaiid/innovation-assets/hive!4
2 parents c97d831 + 473be42 commit 00b1e09

6 files changed

Lines changed: 444 additions & 158 deletions

File tree

README.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,27 @@ print(response)
236236

237237
You can also apply the verifier from the previous stage in this inference method, applying it independently to each revision from each model.
238238

239-
### 4) Optimisation
239+
### 4) Reasoning Timeout
240+
241+
When using reflection or multi-model debate, you can set `max_reasoning_seconds` to cap the total inference time. If the time limit is reached before all reflection rounds complete, the library returns the best answer available at that point. The initial model call always executes — the timeout only applies to subsequent reflection rounds.
242+
243+
```python
244+
from bhive import Hive, HiveConfig
245+
246+
bhive_client = Hive()
247+
bhive_config = HiveConfig(
248+
bedrock_model_ids=["anthropic.claude-haiku-4-5-20251001-v1:0"],
249+
num_reflections=5,
250+
max_reasoning_seconds=10.0, # return best answer after 10s even if reflections remain
251+
)
252+
messages = [{"role": "user", "content": [{"text": "What is 2 + 2?"}]}]
253+
response = bhive_client.converse(messages, bhive_config)
254+
print(response)
255+
```
256+
257+
This is useful for latency-sensitive applications where you want to allow extra reasoning time when available but need a hard upper bound.
258+
259+
### 5) Optimisation
240260

241261
If you are not sure which exact hyperparameter configuration will suit your needs, you can use the hyperparameter optimisation functionality. Here, you can define a set of ranges for the inference parameters such as the Amazon Bedrock models or rounds of reflection and these will be evaluated against a test dataset. You can also specify a budget constraining the maximum cost ($) and maximum latency (seconds) per example.
242262

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "bee_hive"
3-
version = "0.7.12"
3+
version = "0.8.0"
44
description = "Library for enabling inference-time-compute augmentations in Bedrock"
55
authors = [
66
{name = "Jack Butler", email = "jackbtlr@amazon.co.uk"},

src/bhive/client.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -97,28 +97,7 @@ def converse(
9797
logger.info(f"Starting inference with {config=} and {converse_kwargs=}")
9898

9999
_converse_func = functools.partial(self._converse, **converse_kwargs)
100-
response: str | list[str] | None = None
101-
if config.single_model_single_call:
102-
# single model call
103-
response, chatlog = inference.single_model_single_call(config, chatlog, _converse_func)
104-
105-
elif config.multi_model_single_call:
106-
# multi model / single round debate
107-
response, chatlog = inference.multi_model_single_call(
108-
config, chatlog, _converse_func, message
109-
)
110-
111-
elif config.single_model_multi_call:
112-
# single model but reflection
113-
response, chatlog = inference.single_model_multi_call(
114-
config, chatlog, _converse_func, message
115-
)
116-
117-
else:
118-
# multi model + multi round debate
119-
response, chatlog = inference.multi_model_multi_call(
120-
config, chatlog, _converse_func, message
121-
)
100+
response, chatlog = inference.run_inference(config, chatlog, _converse_func, message)
122101

123102
# parsing structured outputs
124103
parsed_response = None

src/bhive/config.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class HiveConfig(pydantic.BaseModel):
2828
verifier: Callable[[str], str] | None = None
2929
use_prompt_caching: bool = False
3030
output_model: type[pydantic.BaseModel] | None = None
31+
max_reasoning_seconds: float | None = pydantic.Field(default=None, gt=0)
3132

3233
@pydantic.field_validator("bedrock_model_ids")
3334
@classmethod
@@ -42,7 +43,7 @@ def validate_configuration(self: "HiveConfig") -> "HiveConfig":
4243
logger.warning("We recommend a final aggregator_model when using multiple models.")
4344
if self.aggregator_model_id and self.n_models == 1:
4445
logger.warning("No need for an aggregator_model when using a single model.")
45-
if self.single_model_single_call and self.verifier:
46+
if self.n_models == 1 and self.no_reflections and self.verifier:
4647
raise ValueError("verifier cannot be provided when using a single model call.")
4748
if self.use_prompt_caching:
4849
logger.warning("Cache read / write pricing is approximate but may not be exact.")
@@ -56,18 +57,6 @@ def n_models(self) -> int:
5657
def no_reflections(self) -> bool:
5758
return self.num_reflections == 0
5859

59-
@property
60-
def single_model_single_call(self) -> bool:
61-
return self.n_models == 1 and self.no_reflections
62-
63-
@property
64-
def multi_model_single_call(self) -> bool:
65-
return self.n_models > 1 and self.no_reflections
66-
67-
@property
68-
def single_model_multi_call(self) -> bool:
69-
return self.n_models == 1 and not self.no_reflections
70-
7160

7261
class TrialConfig(pydantic.BaseModel):
7362
"""Configuration class for Hive trials, managing trial settings and validation."""

src/bhive/inference.py

Lines changed: 68 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
SPDX-License-Identifier: Apache-2.0
44
"""
55

6+
import time
67
from typing import Callable
78

89
from loguru import logger
@@ -12,6 +13,71 @@
1213
from bhive.utils import parallel_bedrock_exec
1314

1415

16+
def run_inference(
17+
config: HiveConfig, chatlog: chat.ChatLog, _converse_func: Callable, message: str | None = None
18+
) -> tuple[str | list[str], chat.ChatLog]:
19+
is_single = config.n_models == 1
20+
start_time = time.monotonic()
21+
22+
for n_reflect in range(config.num_reflections + 1):
23+
if n_reflect > 0:
24+
if config.max_reasoning_seconds is not None:
25+
elapsed = time.monotonic() - start_time
26+
if elapsed >= config.max_reasoning_seconds:
27+
logger.info(
28+
f"Exiting early at round {n_reflect}/{config.num_reflections} "
29+
f"after {elapsed:.1f}s (limit: {config.max_reasoning_seconds}s)"
30+
)
31+
break
32+
if n_reflect > 0:
33+
if is_single:
34+
reflect_msg = prompt.reflect + "\n"
35+
if config.verifier:
36+
past_answer = chatlog.get_last_answer()
37+
reflect_msg += apply_verification(past_answer, config.verifier) # type: ignore[arg-type]
38+
if message:
39+
reflect_msg += f"\nAs a reminder, the original question is {message}"
40+
chatlog.add_user_msg(reflect_msg, invoke_index=0)
41+
else:
42+
for index in range(config.n_models):
43+
recent_other_answers = chatlog.get_recent_other_answers(index)
44+
debate_msg = prompt.debate
45+
for recent_ans in recent_other_answers:
46+
answer_text = recent_ans["content"][0]["text"]
47+
debate_msg += f"\n\nOne agent response: ```{answer_text}```"
48+
if config.verifier:
49+
debate_msg += apply_verification(answer_text, config.verifier)
50+
debate_msg += f"\n\n {prompt.careful}\n"
51+
if message:
52+
debate_msg += f"\nAs a reminder, the original question is {message}"
53+
chatlog.add_user_msg(debate_msg, index)
54+
55+
if is_single:
56+
modelid = config.bedrock_model_ids[0]
57+
response = _converse_func(model_id=modelid, messages=chatlog.history[0].chat_history)
58+
_record_response(chatlog, 0, modelid, response)
59+
else:
60+
responses = parallel_bedrock_exec(_converse_func, chathistory=chatlog.history)
61+
for (index, modelid), response in responses.items():
62+
_record_response(chatlog, index, modelid, response)
63+
64+
if config.aggregator_model_id:
65+
chatlog = aggregate_last_responses(config, chatlog, _converse_func, message)
66+
67+
return chatlog.get_last_answer(), chatlog
68+
69+
70+
def _record_response(
71+
chatlog: chat.ChatLog, index: int, modelid: str, response: chat.ConverseResponse
72+
):
73+
chatlog.add_assistant_msg(response.answer, index)
74+
if response.thinking:
75+
chatlog.add_thinking_trace(response.thinking, index)
76+
chatlog.update_stats(modelid, response)
77+
chatlog.add_stop_reason(response.stopReason)
78+
chatlog.add_trace(response.trace)
79+
80+
1581
def aggregate_last_responses(
1682
config: HiveConfig, chatlog: chat.ChatLog, _converse_func: Callable, message: str | None = None
1783
) -> chat.ChatLog:
@@ -30,131 +96,12 @@ def aggregate_last_responses(
3096
logger.info(f"Aggregating a final response using {config.aggregator_model_id=}")
3197
response: chat.ConverseResponse = _converse_func(config.aggregator_model_id, [fmt_msg])
3298

33-
chatlog.add_assistant_msg(response.answer, 0)
34-
if response.thinking:
35-
chatlog.add_thinking_trace(response.thinking, 0)
36-
chatlog.update_stats(config.aggregator_model_id, response)
37-
chatlog.add_stop_reason(response.stopReason)
38-
chatlog.add_trace(response.trace)
99+
_record_response(chatlog, 0, config.aggregator_model_id, response)
39100

40101
return chatlog
41102

42103

43-
def single_model_single_call(
44-
config: HiveConfig, chatlog: chat.ChatLog, _converse_func: Callable
45-
) -> tuple[str, chat.ChatLog]:
46-
modelid = config.bedrock_model_ids[0]
47-
logger.info(f"Calling {modelid} with no self-reflection")
48-
response: chat.ConverseResponse = _converse_func(
49-
model_id=modelid, messages=chatlog.history[0].chat_history
50-
)
51-
chatlog.add_assistant_msg(response.answer, 0)
52-
if response.thinking:
53-
chatlog.add_thinking_trace(response.thinking, 0)
54-
chatlog.update_stats(modelid, response)
55-
chatlog.add_stop_reason(response.stopReason)
56-
chatlog.add_trace(response.trace)
57-
58-
return response.answer, chatlog
59-
60-
61-
def multi_model_single_call(
62-
config: HiveConfig, chatlog: chat.ChatLog, _converse_func: Callable, message: str | None = None
63-
) -> tuple[str | list[str], chat.ChatLog]:
64-
logger.info(f"Calling {config.bedrock_model_ids} with no self-reflection")
65-
responses: dict[tuple[int, str], chat.ConverseResponse] = parallel_bedrock_exec(
66-
_converse_func, chathistory=chatlog.history
67-
)
68-
for (index, modelid), response in responses.items():
69-
chatlog.add_assistant_msg(response.answer, index)
70-
if response.thinking:
71-
chatlog.add_thinking_trace(response.thinking, index)
72-
chatlog.update_stats(modelid, response)
73-
chatlog.add_stop_reason(response.stopReason)
74-
chatlog.add_trace(response.trace)
75-
76-
if config.aggregator_model_id:
77-
# aggregate an answer
78-
chatlog = aggregate_last_responses(config, chatlog, _converse_func, message)
79-
return chatlog.get_last_answer(), chatlog
80-
81-
82-
def single_model_multi_call(
83-
config: HiveConfig, chatlog: chat.ChatLog, _converse_func: Callable, message: str | None = None
84-
) -> tuple[str, chat.ChatLog]:
85-
modelid = config.bedrock_model_ids[0]
86-
logger.info(f"Calling {modelid} with {config.num_reflections} rounds of self-reflection")
87-
for n_reflect in range(config.num_reflections + 1):
88-
if 0 < n_reflect:
89-
reflect_msg = prompt.reflect + "\n"
90-
if config.verifier:
91-
past_answer = chatlog.get_last_answer()
92-
assert isinstance(past_answer, str), (
93-
"Received multiple responds when doing a single model call"
94-
)
95-
reflect_msg += apply_verification(past_answer, config.verifier)
96-
if message:
97-
reflect_msg += f"\nAs a reminder, the original question is {message}"
98-
chatlog.add_user_msg(reflect_msg, invoke_index=0)
99-
response: chat.ConverseResponse = _converse_func(
100-
model_id=modelid, messages=chatlog.history[0].chat_history
101-
)
102-
chatlog.add_assistant_msg(response.answer, invoke_index=0)
103-
if response.thinking:
104-
chatlog.add_thinking_trace(response.thinking, invoke_index=0)
105-
chatlog.update_stats(modelid, response)
106-
chatlog.add_stop_reason(response.stopReason)
107-
chatlog.add_trace(response.trace)
108-
109-
return response.answer, chatlog
110-
111-
112-
def multi_model_multi_call(
113-
config: HiveConfig, chatlog: chat.ChatLog, _converse_func: Callable, message: str | None = None
114-
) -> tuple[str | list[str], chat.ChatLog]:
115-
logger.info(f"Calling {config.bedrock_model_ids} with {config.num_reflections} rounds")
116-
for n_reflect in range(config.num_reflections + 1):
117-
if 0 < n_reflect:
118-
# consider others & debate
119-
for index, model_log in enumerate(chatlog.history):
120-
recent_other_answers = chatlog.get_recent_other_answers(index)
121-
debate_msg = prompt.debate
122-
for recent_ans in recent_other_answers:
123-
# NOTE could alternatively summarise messages
124-
answer_text = recent_ans["content"][0]["text"]
125-
debate_msg += f"\n\nOne agent response: ```{answer_text}```"
126-
if config.verifier:
127-
debate_msg += apply_verification(answer_text, config.verifier)
128-
debate_msg += f"\n\n {prompt.careful}\n"
129-
if message:
130-
debate_msg += f"\nAs a reminder, the original question is {message}"
131-
logger.debug(f"Sending request to {model_log.modelid}:\n{debate_msg}")
132-
chatlog.add_user_msg(debate_msg, index)
133-
134-
logger.info(
135-
f"Fetching debate #{n_reflect + 1} answers from all {config.bedrock_model_ids=}"
136-
)
137-
responses: dict[tuple[int, str], chat.ConverseResponse] = parallel_bedrock_exec(
138-
_converse_func, chathistory=chatlog.history
139-
)
140-
for (index, modelid), response in responses.items():
141-
chatlog.add_assistant_msg(response.answer, index)
142-
if response.thinking:
143-
chatlog.add_thinking_trace(response.thinking, index)
144-
chatlog.update_stats(modelid, response)
145-
chatlog.add_stop_reason(response.stopReason)
146-
chatlog.add_trace(response.trace)
147-
148-
if config.aggregator_model_id:
149-
# aggregate an answer
150-
chatlog = aggregate_last_responses(config, chatlog, _converse_func, message)
151-
return chatlog.get_last_answer(), chatlog
152-
153-
154104
def apply_verification(past_answer: str, verifier: Callable[[str], str]) -> str:
155-
# Applies a verification function to add more context
156105
verifier_context = verifier(past_answer)
157106
logger.debug(f"External verification function returned: {verifier_context}")
158-
return (
159-
f"\nAn external verification function has added context to this answer: {verifier_context}"
160-
)
107+
return f"An external verifier has added the following to this answer: {verifier_context}"

0 commit comments

Comments
 (0)