Skip to content

Commit 729ec1f

Browse files
reduce max_iteration, fix json extract and introduce pooling for workers. (#92)
Co-authored-by: lucifertrj <[email protected]>
1 parent a110ad4 commit 729ec1f

File tree

3 files changed

+79
-95
lines changed

3 files changed

+79
-95
lines changed

src/openagi/prompts/worker_task_execution.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
from textwrap import dedent
2-
31
from openagi.prompts.base import BasePrompt
42

5-
WORKER_TASK_EXECUTION = dedent(
6-
"""
7-
You: {worker_description}
3+
WORKER_TASK_EXECUTION = """
4+
You are expert in: {worker_description}
85
96
# Instructions
107
- You run in a loop of Thought, Action, Observation. Follow the instructions below to understand the workflow and follow them in each iteration of the loop.
@@ -13,9 +10,9 @@
1310
- Observation will be the result of running those actions. Make sure to thoroughly analyze the observation to see if it aligns with your expectations.
1411
- On each observation, try to understand the drawbacks and mistakes and learn from them to improve further and get back on track.
1512
- Take the context into account when you are answering the question. It will be the results or data from the past executions. If no context is provided, then you can assume that the context is empty and you can start from scratch. Use context to ensure consistency and accuracy in your responses.
16-
- Output the answer when you feel the observations are correct and aligned with the goal. They do not have to be very accurate, but ensure they are reasonably reliable.
17-
- The output should always be in the following format in all the iterations. Ensure the JSON format is suitable for utilization with json.loads(), enclosed in triple backticks:
13+
- Output the answer when you feel the observations are reasonably good and aligned with the goal. They do not have to be very accurate, but ensure they are reasonably reliable.
1814
- No Action/Output should be without json. Trying not include your thoughts as part of the action. You can skip the action if not required.
15+
- The output needs to be in JSON ONLY:
1916
- For Running an action:
2017
```json
2118
{
@@ -72,8 +69,6 @@
7269
Begin!
7370
{thought_provokes}
7471
""".strip()
75-
)
76-
7772

7873
class WorkerAgentTaskExecution(BasePrompt):
79-
base_prompt: str = WORKER_TASK_EXECUTION
74+
base_prompt: str = WORKER_TASK_EXECUTION

src/openagi/utils/extraction.py

Lines changed: 34 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,27 @@
22
import json
33
import logging
44
import re
5-
from textwrap import dedent
65
from typing import Dict, List, Optional, Tuple
76

87
from openagi.exception import OpenAGIException
98
from openagi.llms.base import LLMBaseModel
109

1110

12-
def force_json_output(resp_txt: str, llm):
11+
def force_json_output(resp_txt: str, llm) -> str:
1312
"""
14-
Forces the output once the max iterations are reached.
13+
Forces proper JSON output format in first attempt.
1514
"""
16-
#prompt = dedent(
17-
# """
18-
# Below is a JSON block. Please try to provide the output in the format shown below only
19-
# ```json
20-
# {"key": "value"}
21-
# ```
22-
# the contents between ```json and ``` will be extracted and passed to json.loads() in python to convert it to a dictionary. Make sure that it works when passed else you will be fined. If its already in the correct format, then you can return the same output in the expected output format.
23-
# Input:
24-
# {resp_txt}
25-
# Output:
26-
# """
27-
#).strip()
28-
29-
prompt = dedent(
30-
"""
31-
Your task is to process the input JSON and provide a valid JSON output. Follow these instructions carefully:
32-
1. The output must be enclosed in a code block using triple backticks and the 'json' language identifier, like this:
33-
```json
34-
{"key": "value"}
35-
```
36-
2. The JSON inside the code block must be valid and parseable by Python's json.loads() function.
37-
3. Ensure there are no extra spaces, newlines, or characters outside the JSON object within the code block.
38-
4. If the input is already in the correct format, reproduce it exactly in the output format specified above.
39-
5. Do not include any explanations, comments, or additional text in your response. The output needs be in JSON only.
40-
6. Verify your output carefully before submitting. Incorrect responses will result in penalties.
15+
prompt = """
16+
You are a JSON formatting expert. Your task is to process the input and provide a valid JSON output.
4117
42-
Input: {resp_txt}
43-
Output:
44-
"""
45-
).strip()
18+
FOLLOW THESE INSTRUCTIONS to convert:
19+
- Output must be ONLY a JSON object wrapped in ```json code block
20+
- Do not include any explanations, comments, or additional text in your response. The output needs be in JSON only.
21+
22+
Convert this INPUT to proper JSON:
23+
INPUT: {resp_txt}
24+
Output only the JSON:
25+
""".strip()
4626

4727
prompt = prompt.replace("{resp_txt}", resp_txt)
4828
return llm.run(prompt)
@@ -52,47 +32,39 @@ def get_last_json(
5232
text: str, llm: Optional[LLMBaseModel] = None, max_iterations: int = 5
5333
) -> Optional[Dict]:
5434
"""
55-
Extracts the last block of text between ```json and ``` markers from a given string.
56-
57-
Args:
58-
text (str): The string from which to extract the JSON block.
59-
llm (Optional[LLMBaseModel]): The language model instance to use for reformatting.
60-
max_iterations (int): Maximum number of iterations to try reformatting.
61-
62-
Returns:
63-
dict or None: The last JSON block as a dictionary if found and parsed, otherwise None.
35+
Extracts valid JSON from text with improved reliability.
6436
"""
65-
pattern = r"```json(.*?)```"
66-
matches = re.findall(pattern, text, flags=re.DOTALL)
37+
# More precise JSON block pattern
38+
pattern = r"```json\s*(\{[\s\S]*?\})\s*```"
39+
matches = re.findall(pattern, text, re.MULTILINE)
40+
6741
if matches:
68-
last_json = matches[-1].strip().replace("\n", "")
6942
try:
43+
last_json = matches[-1].strip()
44+
last_json = re.sub(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])', '', last_json)
45+
last_json = re.sub(r'\s+', ' ', last_json)
7046
return json.loads(last_json)
71-
except json.JSONDecodeError:
72-
logging.error("JSON not extracted. Trying again.", exc_info=True)
73-
pass
74-
47+
except json.JSONDecodeError as e:
48+
logging.error(f"JSON parsing failed: {str(e)}", exc_info=True)
49+
if llm:
50+
text = force_json_output(last_json, llm)
51+
return get_last_json(text, None, max_iterations)
52+
7553
if llm:
7654
for iteration in range(1, max_iterations + 1):
77-
logging.info(f"Iteration {iteration} to extract JSON from LLM output.")
7855
try:
7956
text = force_json_output(text, llm)
80-
matches = re.findall(pattern, text, flags=re.DOTALL)
81-
if matches:
82-
last_json = matches[-1].strip().replace("\n", "")
83-
json_resp = json.loads(last_json)
84-
logging.info("JSON extracted successfully.")
85-
return json_resp
86-
except json.JSONDecodeError:
87-
logging.error("JSON not extracted. Trying again.", exc_info=True)
88-
continue
89-
if iteration == max_iterations:
90-
raise OpenAGIException(
91-
"The last output is not a valid JSON. Please check the output of the last action."
92-
)
57+
return get_last_json(text, None, max_iterations)
58+
except Exception as e:
59+
logging.error(f"Attempt {iteration} failed: {str(e)}", exc_info=True)
60+
if iteration == max_iterations:
61+
raise OpenAGIException(
62+
f"Failed to extract valid JSON after {max_iterations} attempts. Last error: {str(e)}"
63+
)
9364
return None
9465

9566

67+
9668
def get_act_classes_from_json(json_data) -> List[Tuple[str, Optional[Dict]]]:
9769
"""
9870
Extracts the Action class names and parameters from a JSON block.

src/openagi/worker.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import functools
2+
from concurrent.futures import ThreadPoolExecutor
13
import logging
24
from pathlib import Path
35
import re
4-
from textwrap import dedent
56
from typing import Any, Dict, List, Optional, Union
67

78
from pydantic import BaseModel, Field, field_validator
@@ -35,7 +36,7 @@ class Worker(BaseModel):
3536
default_factory=list,
3637
)
3738
max_iterations: int = Field(
38-
default=20,
39+
default=10,
3940
description="Maximum number of steps to achieve the objective.",
4041
)
4142
output_key: str = Field(
@@ -46,7 +47,7 @@ class Worker(BaseModel):
4647
default=True,
4748
description="If set to True, the output will be overwritten even if it exists.",
4849
)
49-
50+
5051
# Validate output_key. Should contain only alphabets and only underscore are allowed. Not alphanumeric
5152
@field_validator("output_key")
5253
@classmethod
@@ -70,7 +71,7 @@ def worker_doc(self):
7071
}
7172

7273
def provoke_thought_obs(self, observation):
73-
thoughts = dedent(f"""Observation: {observation}""".strip())
74+
thoughts = f"""Observation: {observation}""".strip()
7475
return thoughts
7576

7677
def should_continue(self, llm_resp: str) -> Union[bool, Optional[Dict]]:
@@ -84,7 +85,7 @@ def _force_output(
8485
"""Force the output once the max iterations are reached."""
8586
prompt = (
8687
"\n".join(all_thoughts_and_obs)
87-
+ "Based on the previous action and observation, give me the output."
88+
+ "Based on the previous action and observation, force and give me the output."
8889
)
8990
output = self.llm.run(prompt)
9091
cont, final_output = self.should_continue(output)
@@ -101,43 +102,53 @@ def _force_output(
101102
)
102103
return (cont, final_output)
103104

105+
@functools.lru_cache(maxsize=100)
106+
def _cached_llm_run(self, prompt: str) -> str:
107+
"""Cache LLM responses for identical prompts"""
108+
return self.llm.run(prompt)
109+
104110
def save_to_memory(self, task: Task):
105-
"""Saves the output to the memory."""
106-
return self.memory.update_task(task)
111+
"""Optimized memory update"""
112+
if not hasattr(self, '_memory_buffer'):
113+
self._memory_buffer = []
114+
self._memory_buffer.append(task)
115+
116+
# Batch update memory when buffer reaches certain size
117+
if len(self._memory_buffer) >= 5:
118+
for buffered_task in self._memory_buffer:
119+
self.memory.update_task(buffered_task)
120+
self._memory_buffer.clear()
121+
return True
107122

108123
def execute_task(self, task: Task, context: Any = None) -> Any:
109-
"""Executes the specified task."""
110-
logging.info(
111-
f"{'>'*20} Executing Task - {task.name}[{task.id}] with worker - {self.role}[{self.id}] {'<'*20}"
112-
)
124+
"""Optimized task execution"""
125+
logging.info(f"{'>'*20} Executing Task - {task.name}[{task.id}] with worker - {self.role}[{self.id}] {'<'*20}")
126+
127+
# Pre-compute common values
113128
iteration = 1
114129
task_to_execute = f"{task.description}"
115130
worker_description = f"{self.role} - {self.instructions}"
116131
all_thoughts_and_obs = []
117-
118-
logging.debug("Provoking initial thought observation...")
119-
initial_thought_provokes = self.provoke_thought_obs(None)
132+
133+
# Generate base prompt once
120134
te_vars = dict(
121135
task_to_execute=task_to_execute,
122136
worker_description=worker_description,
123137
supported_actions=[action.cls_doc() for action in self.actions],
124-
thought_provokes=initial_thought_provokes,
138+
thought_provokes=self.provoke_thought_obs(None),
125139
output_key=self.output_key,
126140
context=context,
127141
max_iterations=self.max_iterations,
128142
)
129-
130-
logging.debug("Generating base prompt...")
131143
base_prompt = WorkerAgentTaskExecution().from_template(te_vars)
144+
145+
# Use cached LLM run
132146
prompt = f"{base_prompt}\nThought:\nIteration: {iteration}\nActions:\n"
133-
134-
logging.debug("Running LLM with prompt...")
135-
observations = self.llm.run(prompt)
136-
logging.info(f"LLM execution completed. Observations: {observations}")
147+
observations = self._cached_llm_run(prompt)
137148
all_thoughts_and_obs.append(prompt)
138149

139-
max_iters = self.max_iterations + 1
140-
while iteration < max_iters:
150+
while iteration < self.max_iterations + 1:
151+
141152
logging.info(f"---- Iteration {iteration} ----")
142153
logging.debug("Checking if task should continue...")
143154
continue_flag, output = self.should_continue(observations)
@@ -210,6 +221,7 @@ def execute_task(self, task: Task, context: Any = None) -> Any:
210221
prompt = f"{base_prompt}\n" + "\n".join(all_thoughts_and_obs)
211222
logging.debug(f"\nSTART:{'*' * 20}\n{prompt}\n{'*' * 20}:END")
212223
pth = Path(f"{self.memory.session_id}/logs/{task.name}-{iteration}.log")
224+
213225
pth.parent.mkdir(parents=True, exist_ok=True)
214226
with open(pth, "w", encoding="utf-8") as f:
215227
f.write(f"{prompt}\n")
@@ -240,3 +252,8 @@ def execute_task(self, task: Task, context: Any = None) -> Any:
240252
f"Task Execution Completed - {task.name} with worker - {self.role}[{self.id}] in {iteration} iterations"
241253
)
242254
return output, task
255+
256+
def __del__(self):
257+
"""Cleanup thread pool on deletion"""
258+
if hasattr(self, '_thread_pool'):
259+
self._thread_pool.shutdown(wait=False)

0 commit comments

Comments
 (0)