Skip to content

Commit e304e5d

Browse files
[Feat] Support resetting agent memory recursively or by keypath (#271)
* update `reset` method * fix type hints
1 parent b2bf23d commit e304e5d

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

lagent/agents/agent.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,26 @@ def register_hook(self, hook: Callable):
165165
self._hooks[handle.id] = hook
166166
return handle
167167

168-
def reset(self, session_id=0):
169-
if self.memory:
170-
self.memory.reset(session_id=session_id)
168+
def reset(self,
169+
session_id=0,
170+
keypath: Optional[str] = None,
171+
recursive: bool = False):
172+
assert not (keypath and
173+
recursive), 'keypath and recursive can\'t be used together'
174+
if keypath:
175+
keys, agent = keypath.split('.'), self
176+
for key in keys:
177+
agents = getattr(agent, '_agents', {})
178+
if key not in agents:
179+
raise KeyError(f'No sub-agent named {key} in {agent}')
180+
agent = agents[key]
181+
agent.reset(session_id, recursive=False)
182+
else:
183+
if self.memory:
184+
self.memory.reset(session_id=session_id)
185+
if recursive:
186+
for agent in getattr(self, '_agents', {}).values():
187+
agent.reset(session_id, recursive=True)
171188

172189
def __repr__(self):
173190

lagent/llms/openai.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from concurrent.futures import ThreadPoolExecutor
88
from logging import getLogger
99
from threading import Lock
10-
from typing import Dict, List, Optional, Union
10+
from typing import AsyncGenerator, Dict, List, Optional, Union
1111

1212
import aiohttp
1313
import requests
@@ -701,7 +701,8 @@ async def _chat(self, messages: List[dict], **gen_params) -> str:
701701
f'{max_num_retries} times. Check the logs for '
702702
f'details. errmsg: {errmsg}')
703703

704-
async def _stream_chat(self, messages: List[dict], **gen_params) -> str:
704+
async def _stream_chat(self, messages: List[dict],
705+
**gen_params) -> AsyncGenerator[str, None]:
705706
"""Generate completion from a list of templates.
706707
707708
Args:

0 commit comments

Comments
 (0)