Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions metagpt/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
@Author : alexanderwu
@File : memory.py
@Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key.
@Modified By: 2026-04-21. Added TTL (Time-To-Live) support for messages with automatic cleanup.
"""
import asyncio
from collections import defaultdict
from typing import DefaultDict, Iterable, Optional, Set

from pydantic import BaseModel, Field, SerializeAsAny
from pydantic import BaseModel, Field, PrivateAttr, SerializeAsAny

from metagpt.const import IGNORED_MESSAGE_ID
from metagpt.schema import Message
Expand All @@ -24,6 +26,24 @@ class Memory(BaseModel):
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
ignore_id: bool = False

_cleanup_task: Optional[asyncio.Task] = PrivateAttr(default=None)
_is_running: bool = PrivateAttr(default=False)

def _cleanup_expired_messages(self) -> int:
"""Remove all expired messages from storage and index.

Returns:
int: The number of expired messages that were removed.
"""
expired_messages = [message for message in self.storage if message.is_expired()]

for message in expired_messages:
self.storage.remove(message)
if message.cause_by and message in self.index[message.cause_by]:
self.index[message.cause_by].remove(message)

return len(expired_messages)

def add(self, message: Message):
"""Add a new message to storage, while updating the index"""
if self.ignore_id:
Expand All @@ -40,10 +60,12 @@ def add_batch(self, messages: Iterable[Message]):

def get_by_role(self, role: str) -> list[Message]:
"""Return all messages of a specified role"""
self._cleanup_expired_messages()
return [message for message in self.storage if message.role == role]

def get_by_content(self, content: str) -> list[Message]:
"""Return all messages containing a specified content"""
self._cleanup_expired_messages()
return [message for message in self.storage if content in message.content]

def delete_newest(self) -> "Message":
Expand Down Expand Up @@ -71,14 +93,17 @@ def clear(self):

def count(self) -> int:
"""Return the number of messages in storage"""
self._cleanup_expired_messages()
return len(self.storage)

def try_remember(self, keyword: str) -> list[Message]:
"""Try to recall all messages containing a specified keyword"""
self._cleanup_expired_messages()
return [message for message in self.storage if keyword in message.content]

def get(self, k=0) -> list[Message]:
"""Return the most recent k memories, return all when k=0"""
self._cleanup_expired_messages()
return self.storage[-k:]

def find_news(self, observed: list[Message], k=0) -> list[Message]:
Expand All @@ -93,11 +118,13 @@ def find_news(self, observed: list[Message], k=0) -> list[Message]:

def get_by_action(self, action) -> list[Message]:
"""Return all messages triggered by a specified Action"""
self._cleanup_expired_messages()
index = any_to_str(action)
return self.index[index]

def get_by_actions(self, actions: Set) -> list[Message]:
"""Return all messages triggered by specified Actions"""
self._cleanup_expired_messages()
rsp = []
indices = any_to_str_set(actions)
for action in indices:
Expand All @@ -108,5 +135,58 @@ def get_by_actions(self, actions: Set) -> list[Message]:

@handle_exception
def get_by_position(self, position: int) -> Optional[Message]:
"""Returns the message at the given position if valid, otherwise returns None"""
"""Returns the message at the given position if valid and not expired, otherwise returns None"""
self._cleanup_expired_messages()
if position < 0 or position >= len(self.storage):
return None
return self.storage[position]

async def _periodic_cleanup(self, interval: int = 60):
"""Asynchronous background task that periodically cleans up expired messages.

Args:
interval: The time in seconds between each cleanup check.
"""
while self._is_running:
try:
self._cleanup_expired_messages()
await asyncio.sleep(interval)
except asyncio.CancelledError:
break
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"Error in periodic cleanup: {e}")
await asyncio.sleep(interval)

def start(self, cleanup_interval: int = 60):
"""Start the background periodic cleanup task.

Args:
cleanup_interval: The time in seconds between each cleanup check. Defaults to 60.
"""
if self._is_running:
return

self._is_running = True

try:
loop = asyncio.get_running_loop()
self._cleanup_task = loop.create_task(self._periodic_cleanup(cleanup_interval))
except RuntimeError:
import threading
from metagpt.utils.async_helper import run_coroutine_in_new_loop

def run_cleanup():
asyncio.run(self._periodic_cleanup(cleanup_interval))

thread = threading.Thread(target=run_cleanup, daemon=True)
thread.start()

def stop(self):
"""Stop the background periodic cleanup task gracefully."""
self._is_running = False

if self._cleanup_task:
self._cleanup_task.cancel()
self._cleanup_task = None
14 changes: 14 additions & 0 deletions metagpt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ class Message(BaseModel):
sent_from: str = Field(default="", validate_default=True)
send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
metadata: Dict[str, Any] = Field(default_factory=dict) # metadata for `content` and `instruct_content`
ttl: int = Field(default=-1, validate_default=True) # Time-To-Live in seconds, -1 means never expire
created_at: float = Field(default_factory=time.time) # Creation time in seconds since epoch

@field_validator("id", mode="before")
@classmethod
Expand Down Expand Up @@ -415,6 +417,18 @@ def is_user_message(self) -> bool:
def is_ai_message(self) -> bool:
return self.role == "assistant"

def is_expired(self) -> bool:
"""Check if the message has expired based on its TTL.

Returns:
bool: True if the message has expired, False otherwise.
Messages with ttl=-1 never expire.
"""
if self.ttl == -1:
return False
current_time = time.time()
return current_time - self.created_at > self.ttl


class UserMessage(Message):
"""便于支持OpenAI的消息
Expand Down
Loading
Loading