Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mem0/graphs/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class GraphStoreConfig(BaseModel):
le=1.0,
)

custom_search_prompt: Optional[str] = Field(None, description="Custom prompt for entity extraction during graph search")

@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
Expand Down
10 changes: 9 additions & 1 deletion mem0/graphs/neptune/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,19 @@ def _retrieve_nodes_from_data(self, data, filters):
_tools = [EXTRACT_ENTITIES_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]

default_search_prompt = f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question."

search_prompt = getattr(self.config.graph_store, 'custom_search_prompt', None) or default_search_prompt

if "{user_id}" in search_prompt:
search_prompt = search_prompt.replace("{user_id}", str(filters.get("user_id", "USER")))

search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
"content": search_prompt,
},
{"role": "user", "content": data},
],
Expand Down
66 changes: 37 additions & 29 deletions mem0/memory/graph_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,37 +194,45 @@ def get_all(self, filters, limit=100):
return final_results

def _retrieve_nodes_from_data(self, data, filters):
"""Extracts all the entities mentioned in the query."""
_tools = [EXTRACT_ENTITIES_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
},
{"role": "user", "content": data},
],
tools=_tools,
)

entity_type_map = {}

try:
for tool_call in search_results["tool_calls"]:
if tool_call["name"] != "extract_entities":
continue
for item in tool_call["arguments"]["entities"]:
entity_type_map[item["entity"]] = item["entity_type"]
except Exception as e:
logger.exception(
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
"""Extracts all the entities mentioned in the query."""
_tools = [EXTRACT_ENTITIES_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]

default_search_prompt = f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question."

search_prompt = getattr(self.config.graph_store, 'custom_search_prompt', None) or default_search_prompt

if "{user_id}" in search_prompt:
search_prompt = search_prompt.replace("{user_id}", str(filters.get("user_id", "USER")))

search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": search_prompt,
},
{"role": "user", "content": data},
],
tools=_tools,
)

entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
return entity_type_map
entity_type_map = {}

try:
for tool_call in search_results["tool_calls"]:
if tool_call["name"] != "extract_entities":
continue
for item in tool_call["arguments"]["entities"]:
entity_type_map[item["entity"]] = item["entity_type"]
except Exception as e:
logger.exception(
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
)

entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
logger.debug(f"Entity type map: {entity_type_map}\n search_results={search_results}")
return entity_type_map

def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
"""Establish relations among the extracted nodes."""
Expand Down
10 changes: 9 additions & 1 deletion mem0/memory/memgraph_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,19 @@ def _retrieve_nodes_from_data(self, data, filters):
_tools = [EXTRACT_ENTITIES_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]

default_search_prompt = f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question."

search_prompt = getattr(self.config.graph_store, 'custom_search_prompt', None) or default_search_prompt

if "{user_id}" in search_prompt:
search_prompt = search_prompt.replace("{user_id}", str(filters.get("user_id", "USER")))

search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
"content": search_prompt,
},
{"role": "user", "content": data},
],
Expand Down