Skip to content

Commit ebee09b

Browse files
Edge extraction and Node Deduplication updates (camel-ai#564)
* update tests * updated fact extraction * optimize node deduplication * linting * Update graphiti_core/utils/maintenance/edge_operations.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
1 parent e3f1c67 commit ebee09b

File tree

6 files changed

+50
-95
lines changed

6 files changed

+50
-95
lines changed

examples/podcast/podcast_runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ class Person(BaseModel):
6363
occupation: str | None = Field(..., description="The person's work occupation")
6464

6565

66+
class IsPresidentOf(BaseModel):
67+
"""Relationship between a person and the entity they are a president of"""
68+
69+
6670
async def main():
6771
setup_logging()
6872
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
@@ -84,6 +88,8 @@ async def main():
8488
source_description='Podcast Transcript',
8589
group_id=group_id,
8690
entity_types={'Person': Person},
91+
edge_types={'IS_PRESIDENT_OF': IsPresidentOf},
92+
edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']},
8793
previous_episode_uuids=episode_uuids,
8894
)
8995

graphiti_core/prompts/dedupe_nodes.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,12 @@ def nodes(context: dict[str, Any]) -> list[Message]:
137137
<ENTITIES>
138138
{json.dumps(context['extracted_nodes'], indent=2)}
139139
</ENTITIES>
140+
141+
<EXISTING ENTITIES>
142+
{json.dumps(context['existing_nodes'], indent=2)}
143+
</EXISTING ENTITIES>
140144
141-
For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates.
145+
For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
142146
143147
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
144148
@@ -152,9 +156,9 @@ def nodes(context: dict[str, Any]) -> list[Message]:
152156
For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx
153157
as an integer.
154158
155-
- If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a
159+
- If an entity is a duplicate of one of the EXISTING ENTITIES, return the idx of the candidate it is a
156160
duplicate of.
157-
- If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx
161+
- If an entity is not a duplicate of one of the EXISTING ENTITIES, return the -1 as the duplication_idx
158162
""",
159163
),
160164
]

graphiti_core/prompts/extract_edges.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424

2525
class Edge(BaseModel):
2626
relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
27-
source_entity_name: str = Field(..., description='The name of the source entity of the fact.')
28-
target_entity_name: str = Field(..., description='The name of the target entity of the fact.')
27+
source_entity_id: int = Field(..., description='The id of the source entity of the fact.')
28+
target_entity_id: int = Field(..., description='The id of the target entity of the fact.')
2929
fact: str = Field(..., description='')
3030
valid_at: str | None = Field(
3131
None,
@@ -77,7 +77,7 @@ def edge(context: dict[str, Any]) -> list[Message]:
7777
</CURRENT_MESSAGE>
7878
7979
<ENTITIES>
80-
{context['nodes']} # Each has: id, label (e.g., Person, Org), name, aliases
80+
{context['nodes']}
8181
</ENTITIES>
8282
8383
<REFERENCE_TIME>
@@ -94,8 +94,9 @@ def edge(context: dict[str, Any]) -> list[Message]:
9494
- involve two DISTINCT ENTITIES from the ENTITIES list,
9595
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
9696
and can be represented as edges in a knowledge graph.
97-
- The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that
98-
could be classified into one of the provided fact types
97+
- The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
98+
- The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
99+
of the FACT TYPES
99100
100101
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
101102

graphiti_core/utils/maintenance/edge_operations.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,6 @@ async def extract_edges(
9292
extract_edges_max_tokens = 16384
9393
llm_client = clients.llm_client
9494

95-
node_uuids_by_name_map = {node.name: node.uuid for node in nodes}
96-
9795
edge_types_context = (
9896
[
9997
{
@@ -109,7 +107,7 @@ async def extract_edges(
109107
# Prepare context for LLM
110108
context = {
111109
'episode_content': episode.content,
112-
'nodes': [node.name for node in nodes],
110+
'nodes': [{'id': idx, 'name': node.name} for idx, node in enumerate(nodes)],
113111
'previous_episodes': [ep.content for ep in previous_episodes],
114112
'reference_time': episode.valid_at,
115113
'edge_types': edge_types_context,
@@ -160,14 +158,16 @@ async def extract_edges(
160158
invalid_at = edge_data.get('invalid_at', None)
161159
valid_at_datetime = None
162160
invalid_at_datetime = None
163-
source_node_uuid = node_uuids_by_name_map.get(edge_data.get('source_entity_name', ''), '')
164-
target_node_uuid = node_uuids_by_name_map.get(edge_data.get('target_entity_name', ''), '')
165161

166-
if source_node_uuid == '' or target_node_uuid == '':
162+
source_node_idx = edge_data.get('source_entity_id', -1)
163+
target_node_idx = edge_data.get('target_entity_id', -1)
164+
if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)):
167165
logger.warning(
168-
f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_uuid} and target_node_uuid: {target_node_uuid} '
166+
f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} '
169167
)
170168
continue
169+
source_node_uuid = nodes[source_node_idx].uuid
170+
target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid
171171

172172
if valid_at:
173173
try:

graphiti_core/utils/maintenance/node_operations.py

Lines changed: 23 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from graphiti_core.llm_client.config import ModelSize
3030
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
3131
from graphiti_core.prompts import prompt_library
32-
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
32+
from graphiti_core.prompts.dedupe_nodes import NodeResolutions
3333
from graphiti_core.prompts.extract_nodes import (
3434
ExtractedEntities,
3535
ExtractedEntity,
@@ -241,7 +241,25 @@ async def resolve_extracted_nodes(
241241
]
242242
)
243243

244-
existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results]
244+
existing_nodes_dict: dict[str, EntityNode] = {
245+
node.uuid: node for result in search_results for node in result.nodes
246+
}
247+
248+
existing_nodes: list[EntityNode] = list(existing_nodes_dict.values())
249+
250+
existing_nodes_context = (
251+
[
252+
{
253+
**{
254+
'idx': i,
255+
'name': candidate.name,
256+
'entity_types': candidate.labels,
257+
},
258+
**candidate.attributes,
259+
}
260+
for i, candidate in enumerate(existing_nodes)
261+
],
262+
)
245263

246264
entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {}
247265

@@ -255,23 +273,13 @@ async def resolve_extracted_nodes(
255273
next((item for item in node.labels if item != 'Entity'), '')
256274
).__doc__
257275
or 'Default Entity Type',
258-
'duplication_candidates': [
259-
{
260-
**{
261-
'idx': j,
262-
'name': candidate.name,
263-
'entity_types': candidate.labels,
264-
},
265-
**candidate.attributes,
266-
}
267-
for j, candidate in enumerate(existing_nodes_lists[i])
268-
],
269276
}
270277
for i, node in enumerate(extracted_nodes)
271278
]
272279

273280
context = {
274281
'extracted_nodes': extracted_nodes_context,
282+
'existing_nodes': existing_nodes_context,
275283
'episode_content': episode.content if episode is not None else '',
276284
'previous_episodes': [ep.content for ep in previous_episodes]
277285
if previous_episodes is not None
@@ -294,8 +302,8 @@ async def resolve_extracted_nodes(
294302
extracted_node = extracted_nodes[resolution_id]
295303

296304
resolved_node = (
297-
existing_nodes_lists[resolution_id][duplicate_idx]
298-
if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id])
305+
existing_nodes[duplicate_idx]
306+
if 0 <= duplicate_idx < len(existing_nodes)
299307
else extracted_node
300308
)
301309

@@ -309,70 +317,6 @@ async def resolve_extracted_nodes(
309317
return resolved_nodes, uuid_map
310318

311319

312-
async def resolve_extracted_node(
313-
llm_client: LLMClient,
314-
extracted_node: EntityNode,
315-
existing_nodes: list[EntityNode],
316-
episode: EpisodicNode | None = None,
317-
previous_episodes: list[EpisodicNode] | None = None,
318-
entity_type: BaseModel | None = None,
319-
) -> EntityNode:
320-
start = time()
321-
if len(existing_nodes) == 0:
322-
return extracted_node
323-
324-
# Prepare context for LLM
325-
existing_nodes_context = [
326-
{
327-
**{
328-
'id': i,
329-
'name': node.name,
330-
'entity_types': node.labels,
331-
},
332-
**node.attributes,
333-
}
334-
for i, node in enumerate(existing_nodes)
335-
]
336-
337-
extracted_node_context = {
338-
'name': extracted_node.name,
339-
'entity_type': entity_type.__name__ if entity_type is not None else 'Entity', # type: ignore
340-
}
341-
342-
context = {
343-
'existing_nodes': existing_nodes_context,
344-
'extracted_node': extracted_node_context,
345-
'entity_type_description': entity_type.__doc__
346-
if entity_type is not None
347-
else 'Default Entity Type',
348-
'episode_content': episode.content if episode is not None else '',
349-
'previous_episodes': [ep.content for ep in previous_episodes]
350-
if previous_episodes is not None
351-
else [],
352-
}
353-
354-
llm_response = await llm_client.generate_response(
355-
prompt_library.dedupe_nodes.node(context),
356-
response_model=NodeDuplicate,
357-
model_size=ModelSize.small,
358-
)
359-
360-
duplicate_id: int = llm_response.get('duplicate_node_id', -1)
361-
362-
node = (
363-
existing_nodes[duplicate_id] if 0 <= duplicate_id < len(existing_nodes) else extracted_node
364-
)
365-
366-
node.name = llm_response.get('name', '')
367-
368-
end = time()
369-
logger.debug(
370-
f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms'
371-
)
372-
373-
return node
374-
375-
376320
async def extract_attributes_from_nodes(
377321
clients: GraphitiClients,
378322
nodes: list[EntityNode],

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "graphiti-core"
33
description = "A temporal graph building library"
4-
version = "0.12.0pre4"
4+
version = "0.12.0"
55
authors = [
66
{ "name" = "Paul Paliychuk", "email" = "[email protected]" },
77
{ "name" = "Preston Rasmussen", "email" = "[email protected]" },

0 commit comments

Comments
 (0)