Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions metagpt/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ class Memory(BaseModel):
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
ignore_id: bool = False

def _filter_expired_messages(self, messages: list[Message]) -> list[Message]:
"""Filter out expired messages from the given list.

Args:
messages: List of messages to filter.

Returns:
List of non-expired messages.
"""
return [message for message in messages if not message.is_expired()]

def add(self, message: Message):
"""Add a new message to storage, while updating the index"""
if self.ignore_id:
Expand All @@ -40,11 +51,13 @@ def add_batch(self, messages: Iterable[Message]):

def get_by_role(self, role: str) -> list[Message]:
"""Return all messages of a specified role"""
return [message for message in self.storage if message.role == role]
messages = [message for message in self.storage if message.role == role]
return self._filter_expired_messages(messages)

def get_by_content(self, content: str) -> list[Message]:
"""Return all messages containing a specified content"""
return [message for message in self.storage if content in message.content]
messages = [message for message in self.storage if content in message.content]
return self._filter_expired_messages(messages)

def delete_newest(self) -> "Message":
"""delete the newest message from the storage"""
Expand Down Expand Up @@ -75,11 +88,13 @@ def count(self) -> int:

def try_remember(self, keyword: str) -> list[Message]:
"""Try to recall all messages containing a specified keyword"""
return [message for message in self.storage if keyword in message.content]
messages = [message for message in self.storage if keyword in message.content]
return self._filter_expired_messages(messages)

def get(self, k=0) -> list[Message]:
"""Return the most recent k memories, return all when k=0"""
return self.storage[-k:]
messages = self.storage[-k:]
return self._filter_expired_messages(messages)

def find_news(self, observed: list[Message], k=0) -> list[Message]:
"""find news (previously unseen messages) from the most recent k memories, from all memories when k=0"""
Expand All @@ -94,7 +109,8 @@ def find_news(self, observed: list[Message], k=0) -> list[Message]:
def get_by_action(self, action) -> list[Message]:
"""Return all messages triggered by a specified Action"""
index = any_to_str(action)
return self.index[index]
messages = self.index[index]
return self._filter_expired_messages(messages)

def get_by_actions(self, actions: Set) -> list[Message]:
"""Return all messages triggered by specified Actions"""
Expand All @@ -104,9 +120,12 @@ def get_by_actions(self, actions: Set) -> list[Message]:
if action not in self.index:
continue
rsp += self.index[action]
return rsp
return self._filter_expired_messages(rsp)

@handle_exception
def get_by_position(self, position: int) -> Optional[Message]:
"""Returns the message at the given position if valid, otherwise returns None"""
return self.storage[position]
"""Returns the message at the given position if valid and not expired, otherwise returns None"""
message = self.storage[position]
if message and message.is_expired():
return None
return message
14 changes: 14 additions & 0 deletions metagpt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ class Message(BaseModel):
sent_from: str = Field(default="", validate_default=True)
send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
metadata: Dict[str, Any] = Field(default_factory=dict) # metadata for `content` and `instruct_content`
ttl: int = Field(default=-1, validate_default=True) # Time-To-Live in seconds, -1 means never expire
created_at: float = Field(default_factory=time.time) # Creation time in seconds since epoch

@field_validator("id", mode="before")
@classmethod
Expand Down Expand Up @@ -415,6 +417,18 @@ def is_user_message(self) -> bool:
def is_ai_message(self) -> bool:
return self.role == "assistant"

def is_expired(self) -> bool:
"""Check if the message has expired based on its TTL.

Returns:
bool: True if the message has expired, False otherwise.
Messages with ttl=-1 never expire.
"""
if self.ttl == -1:
return False
current_time = time.time()
return current_time - self.created_at > self.ttl


class UserMessage(Message):
"""便于支持OpenAI的消息
Expand Down
134 changes: 134 additions & 0 deletions tests/metagpt/memory/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of Memory

import time

from metagpt.actions import UserRequirement
from metagpt.memory.memory import Memory
from metagpt.schema import Message
Expand Down Expand Up @@ -55,3 +57,135 @@ def test_memory():
memory.clear()
assert memory.count() == 0
assert len(memory.index) == 0


def test_message_ttl_and_created_at():
"""Test Message class ttl and created_at fields"""
# Test default values
message = Message(content="test message", role="user")
assert message.ttl == -1
assert message.created_at > 0

# Test custom ttl
message_with_ttl = Message(content="test message with ttl", role="user", ttl=60)
assert message_with_ttl.ttl == 60

# Test is_expired method for ttl=-1 (never expire)
assert message.is_expired() == False

# Test serialization
dumped = message.dump()
loaded = Message.load(dumped)
assert loaded.ttl == message.ttl
assert abs(loaded.created_at - message.created_at) < 0.001


def test_message_expiration():
"""Test message expiration functionality"""
# Create a message that expires in 1 second
message = Message(content="expiring message", role="user", ttl=1)
assert message.is_expired() == False

# Wait for message to expire
time.sleep(1.1)
assert message.is_expired() == True

# Create a message that never expires
message_never_expire = Message(content="never expiring message", role="user", ttl=-1)
assert message_never_expire.is_expired() == False

# Wait and verify it still doesn't expire
time.sleep(0.5)
assert message_never_expire.is_expired() == False


def test_memory_filter_expired_messages():
"""Test Memory class filtering expired messages"""
memory = Memory()

# Create messages with different TTLs
message1 = Message(content="never expire", role="user1", ttl=-1)
message2 = Message(content="expire in 1 sec", role="user2", ttl=1)
message3 = Message(content="never expire too", role="user1", ttl=-1)

# Add all messages to memory
memory.add_batch([message1, message2, message3])
assert memory.count() == 3

# Wait for message2 to expire
time.sleep(1.1)

# Test get() method filters expired messages
messages = memory.get()
assert len(messages) == 2
assert message2 not in messages

# Test get_by_role() method filters expired messages
messages = memory.get_by_role("user2")
assert len(messages) == 0

# Test get_by_role() for non-expired messages
messages = memory.get_by_role("user1")
assert len(messages) == 2

# Test get_by_content() method filters expired messages
messages = memory.get_by_content("expire")
assert len(messages) == 2
assert message2 not in messages

# Test try_remember() method filters expired messages
messages = memory.try_remember("expire")
assert len(messages) == 2
assert message2 not in messages

# Test get_by_action() method filters expired messages
messages = memory.get_by_action(UserRequirement)
assert len(messages) == 2
assert message2 not in messages

# Test get_by_actions() method filters expired messages
messages = memory.get_by_actions({UserRequirement})
assert len(messages) == 2
assert message2 not in messages

# Test get_by_position() method returns None for expired messages
# Note: message2 is at position 1 in storage
message = memory.get_by_position(1)
assert message is None


def test_memory_backward_compatibility():
"""Test backward compatibility with existing code"""
memory = Memory()

# Create messages without specifying ttl (should use default -1)
message1 = Message(content="message1", role="user1")
message2 = Message(content="message2", role="user2")

# Verify default ttl is -1
assert message1.ttl == -1
assert message2.ttl == -1

# Add to memory
memory.add_batch([message1, message2])

# Verify all retrieval methods work as before
assert memory.count() == 2

messages = memory.get()
assert len(messages) == 2

messages = memory.get_by_role("user1")
assert len(messages) == 1
assert messages[0].content == "message1"

messages = memory.get_by_content("message")
assert len(messages) == 2

# Verify messages don't expire
time.sleep(0.5)
assert message1.is_expired() == False
assert message2.is_expired() == False

messages = memory.get()
assert len(messages) == 2
Loading