diff --git a/Cargo.lock b/Cargo.lock index 9ef2078e64..c4188c14cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -445,6 +445,28 @@ dependencies = [ "regex-syntax 0.8.5", ] +[[package]] +name = "arroy" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08e6111f351d004bd13e95ab540721272136fd3218b39d3ec95a2ea1c4e6a0a6" +dependencies = [ + "bytemuck", + "byteorder", + "enum-iterator", + "heed", + "memmap2 0.9.5", + "nohash", + "ordered-float 4.6.0", + "page_size", + "rand 0.8.5", + "rayon", + "roaring", + "tempfile", + "thiserror 2.0.12", + "tracing", +] + [[package]] name = "ascii_utils" version = "0.9.3" @@ -1336,6 +1358,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -2031,6 +2062,15 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" +[[package]] +name = "doxygen-rs" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "415b6ec780d34dcf624666747194393603d0373b7141eef01d12ee58881507d9" +dependencies = [ + "phf", +] + [[package]] name = "dyn-clone" version = "1.0.19" @@ -2078,6 +2118,26 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "enum-iterator" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c280b9e6b3ae19e152d8e31cf47f18389781e119d4013a2a2bb0180e5facc635" +dependencies = [ + "enum-iterator-derive", +] + +[[package]] +name = "enum-iterator-derive" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1ab991c1362ac86c61ab6f556cff143daa22e5a15e4e189df818b2fd19fe65b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "enum_dispatch" version = "0.3.13" @@ -2345,7 +2405,7 @@ dependencies = [ "libc", "log", "rustversion", - "windows", + "windows 0.58.0", ] [[package]] @@ -2531,6 +2591,44 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "heed" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a56c94661ddfb51aa9cdfbf102cfcc340aa69267f95ebccc4af08d7c530d393" +dependencies = [ + "bitflags 2.9.0", + "byteorder", + "heed-traits", + "heed-types", + "libc", + "lmdb-master-sys", + "once_cell", + "page_size", + "serde", + "synchronoise", + "url", +] + +[[package]] +name = "heed-traits" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3130048d404c57ce5a1ac61a903696e8fcde7e8c2991e9fcfc1f27c3ef74ff" + +[[package]] +name = "heed-types" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c255bdf46e07fb840d120a36dcc81f385140d7191c76a7391672675c01a55d" +dependencies = [ + "bincode", + "byteorder", + "heed-traits", + "serde", + "serde_json", +] + [[package]] name = "hermit-abi" version = "0.3.9" @@ -3145,6 +3243,17 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" +[[package]] +name = "lmdb-master-sys" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "864808e0b19fb6dd3b70ba94ee671b82fce17554cf80aeb0a155c65bb08027df" +dependencies = [ + "cc", + "doxygen-rs", + "libc", +] + [[package]] name = "lock_api" version = "0.4.12" @@ -3514,6 +3623,12 @@ dependencies = [ "libc", ] +[[package]] +name = "nohash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0f889fb66f7acdf83442c35775764b51fed3c606ab9cee51500dbde2cf528ca" + [[package]] name = "nom" version = "7.1.3" @@ -3524,6 +3639,15 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "ntapi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +dependencies = [ + "winapi", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -3640,6 +3764,25 @@ dependencies = [ "rustc-hash 2.1.1", ] +[[package]] +name = "objc2-core-foundation" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" +dependencies = [ + "bitflags 2.9.0", +] + +[[package]] +name = "objc2-io-kit" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71c1c64d6120e51cd86033f67176b1cb66780c2efe34dec55176f77befd93c0a" +dependencies = [ + "libc", + "objc2-core-foundation", +] + [[package]] name = "object" version = "0.36.7" @@ -3833,6 +3976,16 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "parking_lot" version = "0.12.3" @@ -4016,6 +4169,7 @@ version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" dependencies = [ + "phf_macros", "phf_shared", ] @@ -4039,6 +4193,19 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "phf_macros" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" +dependencies = [ + "phf_generator", + "phf_shared", + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "phf_shared" version = "0.11.3" @@ -4820,6 +4987,7 @@ dependencies = [ "arrow-ipc", "arrow-json", "arrow-schema", + "arroy", "async-openai", "async-trait", "bigdecimal", @@ -4838,6 +5006,7 @@ dependencies = [ "futures-util", "glam", "hashbrown 0.15.3", + "heed", "indexmap 2.9.0", "indoc", "itertools 0.13.0", @@ -4846,6 +5015,7 @@ dependencies = [ "memmap2 0.9.5", "minijinja", "minijinja-contrib", + "moka", "neo4rs", "num", "num-bigint", @@ -4884,6 +5054,7 @@ dependencies = [ "serde_json", "streaming-stats", "strsim", + "sysinfo", "tantivy", "tempfile", "thiserror 2.0.12", @@ -4937,6 +5108,7 @@ dependencies = [ "serde", "sorted_vector_map", "tempfile", + "tokio", "tracing", ] @@ -5903,6 +6075,15 @@ dependencies = [ "futures-core", ] +[[package]] +name = "synchronoise" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dbc01390fc626ce8d1cffe3376ded2b72a11bb70e1c75f404a210e4daa4def2" +dependencies = [ + "crossbeam-queue", +] + [[package]] name = "synstructure" version = "0.13.2" @@ -5914,6 +6095,20 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "sysinfo" +version = "0.35.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79251336d17c72d9762b8b54be4befe38d2db56fbbc0241396d70f173c39d47a" +dependencies = [ + "libc", + "memchr", + "ntapi", + "objc2-core-foundation", + "objc2-io-kit", + "windows 0.61.1", +] + [[package]] name = "tagptr" version = "0.2.0" @@ -6890,6 +7085,28 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows" +version = "0.61.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5ee8f3d025738cb02bad7868bbb5f8a6327501e870bf51f1b455b0a2454a419" +dependencies = [ + "windows-collections", + "windows-core 0.61.0", + "windows-future", + "windows-link", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core 0.61.0", +] + [[package]] name = "windows-core" version = "0.58.0" @@ -6916,6 +7133,16 @@ dependencies = [ "windows-strings 0.4.0", ] +[[package]] +name = "windows-future" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a1d6bbefcb7b60acd19828e1bc965da6fcf18a7e39490c5f8be71e54a19ba32" +dependencies = [ + "windows-core 0.61.0", + "windows-link", +] + [[package]] name = "windows-implement" version = "0.58.0" @@ -6966,6 +7193,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core 0.61.0", + "windows-link", +] + [[package]] name = "windows-registry" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index 0138c07e9f..c249a716bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -141,6 +141,9 @@ pest_derive = "2.7.8" minijinja = "2.2.0" minijinja-contrib = { version = "2.2.0", features = ["datetime"] } datafusion = { version = "43.0.0" } +arroy = "0.6.1" +heed = "0.22.0" +sysinfo = "0.35.1" sqlparser = "0.51.0" futures = "0.3" arrow = { version = "=53.2.0" } diff --git a/python/tests/test_base_install/test_graphql/misc/test_graphql_vectors.py b/python/tests/test_base_install/test_graphql/misc/test_graphql_vectors.py index f346d1d1bc..57b2d4e1a5 100644 --- a/python/tests/test_base_install/test_graphql/misc/test_graphql_vectors.py +++ b/python/tests/test_base_install/test_graphql/misc/test_graphql_vectors.py @@ -14,28 +14,17 @@ def test_embedding(): def setup_graph(g): - g.update_constant_properties({"name": "abb"}) g.add_node(1, "aab") g.add_edge(1, "aab", "bbb") def assert_correct_documents(client): query = """{ - plugins { - globalSearch(query: "aab", limit: 1) { - entity { - __typename - ... on DocumentGraph { - name - } - } - content - embedding - } - } vectorisedGraph(path: "abb") { - algorithms { - similaritySearch(query:"ab", limit: 1) { + entitiesBySimilarity(query: "aab", limit: 1) { + getDocuments { + content + embedding entity { __typename ... on Node { @@ -50,26 +39,15 @@ def assert_correct_documents(client): } } } - content - embedding } } } }""" result = client.query(query) assert result == { - "plugins": { - "globalSearch": [ - { - "entity": {"__typename": "DocumentGraph", "name": "abb"}, - "content": "abb", - "embedding": [1.0, 2.0], - }, - ], - }, "vectorisedGraph": { - "algorithms": { - "similaritySearch": [ + "entitiesBySimilarity": { + "getDocuments": [ { "entity": {"__typename": "Node", "name": "aab"}, "content": "aab", @@ -87,7 +65,6 @@ def setup_server(work_dir): cache="/tmp/graph-cache", embedding=embedding, nodes="{{ name }}", - graphs="{{ properties.name }}", edges=False, ) return server @@ -130,7 +107,3 @@ def test_include_graph(): with server.start(): client = RaphtoryClient("http://localhost:1736") assert_correct_documents(client) - - -test_upload_graph() -test_include_graph() diff --git a/python/tests/test_base_install/test_vectors.py b/python/tests/test_base_install/test_vectors.py index 2cafa316fd..9eb455eae4 100644 --- a/python/tests/test_base_install/test_vectors.py +++ b/python/tests/test_base_install/test_vectors.py @@ -27,11 +27,11 @@ def floats_are_equals(float1: float, float2: float) -> bool: return float1 + 0.001 > float2 and float1 - 0.001 < float2 -# the graph generated by this function looks like this (entities labeled with (2d) have 2 documents): +# the graph generated by this function looks like this: # -# edge1 (2d) +# edge1 # ╭─────── node2 -# node1 (2d) +# node1 # ╰─────── node3 ───── node4 # edge2 edge3 # @@ -48,9 +48,7 @@ def create_graph() -> VectorisedGraph: g.add_edge(3, "node1", "node3", {"name": "edge2"}) g.add_edge(4, "node3", "node4", {"name": "edge3"}) - vg = g.vectorise( - embedding, nodes="{{ name }}", edges="{{ properties.name }}", graph=False - ) + vg = g.vectorise(embedding, nodes="{{ name }}", edges="{{ properties.name }}") return vg @@ -58,6 +56,16 @@ def create_graph() -> VectorisedGraph: def test_selection(): vg = create_graph() + ################################ + selection = vg.empty_selection() + nodes_to_select = ["node1", "node2"] + edges_to_select = [("node1", "node2"), ("node1", "node3")] + selection = vg.empty_selection() + selection.add_nodes(nodes_to_select) + selection.add_edges(edges_to_select) + nodes = selection.nodes() + ########################### + assert len(vg.empty_selection().get_documents()) == 0 assert len(vg.empty_selection().get_documents_with_scores()) == 0 @@ -105,21 +113,9 @@ def test_search(): assert edge_names_returned == [("node1", "node2")] # TODO: same for edges ? - (doc1, score1), (doc2, score2) = vg.documents_by_similarity( - [1.0, 0.0, 0.0], 2 - ).get_documents_with_scores() + [(doc1, score1)] = vg.entities_by_similarity("node1", 1).get_documents_with_scores() assert floats_are_equals(score1, 1.0) assert (doc1.entity.name, doc1.content) == ("node1", "node1") - assert (doc2.entity.src.name, doc2.entity.dst.name) == ("node1", "node2") - - [(doc1, score1)] = vg.entities_by_similarity( - [1.0, 0.0, 0.0], 1 - ).get_documents_with_scores() - assert floats_are_equals(score1, 1.0) - assert (doc1.entity.name, doc1.content) == ("node1", "node1") - - docs = vg.documents_by_similarity([0.0, 0.0, 1.1], 3).get_documents() - assert [doc.content for doc in docs] == ["node3", "edge3", "edge2"] # chained search node_selection = vg.nodes_by_similarity("node2", 1) @@ -183,21 +179,12 @@ def test_windows(): contents = [doc.content for doc in selection.get_documents()] assert contents == ["node1", "edge1", "node2"] - selection.expand_documents_by_similarity("edge2", 100, (0, 4)) - contents = [doc.content for doc in selection.get_documents()] - assert contents == ["node1", "edge1", "node2", "edge2", "node3"] - - # this should leave the selection unchanged - selection.expand_documents_by_similarity("node1", 100, (20, 100)) - contents = [doc.content for doc in selection.get_documents()] - assert contents == ["node1", "edge1", "node2", "edge2", "node3"] - - # this should also leave the selection unchanged - selection.expand_entities_by_similarity("node1", 100, (20, 100)) + # this leave the selection unchanged, as only edge3 and node4 exist + selection.expand_entities_by_similarity("node1", 100, (4, 100)) contents = [doc.content for doc in selection.get_documents()] - assert contents == ["node1", "edge1", "node2", "edge2", "node3"] + assert contents == ["node1", "edge1", "node2"] - selection.expand(10, (4, 100)) + selection.expand(10, (3, 100)) contents = [doc.content for doc in selection.get_documents()] assert contents == ["node1", "edge1", "node2", "edge2", "node3", "edge3", "node4"] @@ -219,7 +206,7 @@ def test_filtering_by_entity_type(): def constant_embedding(texts): - return [[1, 0, 0] for text in texts] + return [[1.0, 0.0, 0.0] for text in texts] def test_default_template(): @@ -239,223 +226,3 @@ def test_default_template(): edge_docs[0].content == "There is an edge from node1 to node1 with events at:\n- Jan 1 1970 00:00\n" ) - - -### MULTI-DOCUMENT VERSION TO BE RE-ENABLED - -# from raphtory import Graph -# from raphtory.vectors import VectorisedGraph - -# embedding_map = { -# "node1": [1.0, 0.0, 0.0], -# "node2": [0.0, 1.0, 0.0], -# "node3": [0.0, 0.0, 1.0], -# "node4": [1.0, 1.0, 0.0], -# "edge1": [1.0, 0.1, 0.0], -# "edge2": [0.0, 1.0, 0.1], -# "edge3": [0.0, 1.0, 1.0], -# "node1-extra": [0.0, 1.0, 1.0], -# "edge1-extra": [0.1, 1.0, 0.0], -# } - - -# def single_embedding(text: str): -# try: -# return embedding_map[text] -# except: -# raise Exception(f"unexpected document content: {text}") - - -# def embedding(texts): -# return [single_embedding(text) for text in texts] - - -# def floats_are_equals(float1: float, float2: float) -> bool: -# return float1 + 0.001 > float2 and float1 - 0.001 < float2 - - -# # the graph generated by this function looks like this (entities labeled with (2d) have 2 documents): -# # -# # edge1 (2d) -# # ╭─────── node2 -# # node1 (2d) -# # ╰─────── node3 ───── node4 -# # edge2 edge3 -# # -# # -# def create_graph() -> VectorisedGraph: -# g = Graph() - -# g.add_node(1, "node1", {"doc": ["node1", "node1-extra"]}) # multi-document node -# g.add_node(2, "node2", {"doc": ["node2"]}) -# g.add_node(3, "node3", {"doc": ["node3"]}) -# g.add_node(4, "node4", {"doc": ["node4"]}) - -# g.add_edge(2, "node1", "node2", {"doc": ["edge1", "edge1-extra"]}) # multi-document edge -# g.add_edge(3, "node1", "node3", {"doc": ["edge2"]}) -# g.add_edge(4, "node3", "node4", {"doc": ["edge3"]}) - -# vg = g.vectorise(embedding, node_document="doc", edge_document="doc") - -# return vg - - -# def test_selection(): -# vg = create_graph() - -# assert len(vg.empty_selection().get_documents()) == 0 -# assert len(vg.empty_selection().get_documents_with_scores()) == 0 - -# nodes_to_select = ["node1", "node2"] -# edges_to_select = [("node1", "node2"), ("node1", "node3")] - -# selection = vg.empty_selection() -# selection.add_nodes(nodes_to_select) -# nodes = selection.nodes() -# node_names_returned = [node.name for node in nodes] -# assert node_names_returned == nodes_to_select -# docs = [doc.content for doc in selection.get_documents()] -# assert docs == ["node1", "node1-extra", "node2"] - -# selection = vg.empty_selection() -# selection.add_edges(edges_to_select) -# edges = selection.edges() -# edge_names_returned = [(edge.src.name, edge.dst.name) for edge in edges] -# assert edge_names_returned == edges_to_select -# docs = [doc.content for doc in selection.get_documents()] -# assert docs == ["edge1", "edge1-extra", "edge2"] - -# edge_tuples = [(edge.src, edge.dst) for edge in edges] -# selection = vg.empty_selection() -# selection.add_nodes(nodes) -# selection.add_edges(edge_tuples) -# nodes_returned = selection.nodes() -# assert nodes == nodes_returned -# edges_returned = selection.edges() -# assert edges == edges_returned - - -# def test_search(): -# vg = create_graph() - -# assert len(vg.edges_by_similarity("edge1", 10).nodes()) == 0 -# assert len(vg.nodes_by_similarity("node1", 10).edges()) == 0 - -# selection = vg.nodes_by_similarity([1.0, 0.0, 0.0], 1) -# assert [node.name for node in selection.nodes()] == ["node1"] -# assert [doc.content for doc in selection.get_documents()] == ["node1", "node1-extra"] - -# edges = vg.edges_by_similarity([1.0, 0.0, 0.0], 1).edges() -# edge_names_returned = [(edge.src.name, edge.dst.name) for edge in edges] -# assert edge_names_returned == [("node1", "node2")] -# # TODO: same for edges ? - -# (doc1, score1), (doc2, score2) = vg.documents_by_similarity( -# [1.0, 0.0, 0.0], 2 -# ).get_documents_with_scores() -# assert floats_are_equals(score1, 1.0) -# assert (doc1.entity.name, doc1.content) == ("node1", "node1") -# assert (doc2.entity.src.name, doc2.entity.dst.name) == ("node1", "node2") - -# (doc1, score1), (doc2, score2) = vg.entities_by_similarity( -# [1.0, 0.0, 0.0], 1 -# ).get_documents_with_scores() -# assert floats_are_equals(score1, 1.0) -# assert (doc1.entity.name, doc1.content) == ("node1", "node1") -# assert (doc2.entity.name, doc2.content) == ("node1", "node1-extra") - -# docs = vg.documents_by_similarity([0.0, 0.0, 1.1], 3).get_documents() -# assert [doc.content for doc in docs] == ["node3", "node1-extra", "edge3"] - -# # chained search -# node_selection = vg.nodes_by_similarity("node2", 1); -# edge_selection = vg.edges_by_similarity("node3", 1); -# entity_selection = vg.entities_by_similarity("node1", 4); -# docs = node_selection.join(edge_selection).join(entity_selection).get_documents()[:4] -# # assert [doc.content for doc in docs] == ['node2', 'edge3', 'node1', 'edge1'] -# assert [doc.content for doc in docs] == ["node2", "edge3", "node1", "node1-extra"] -# # the intention of this test was getting all the documents of for different entities, -# # including at least node and one edge at the top. -# # However, we don't have a way currently of taking the documents of the first N entities -# # we could have a method selection.limit_entities() -# # or we could also have a method entity.get_documents for the entities we return (not trivial) - - -# def test_expansion(): -# vg = create_graph() - -# selection = vg.entities_by_similarity("node1", 1) -# selection.expand(2) -# assert len(selection.get_documents()) == 7 -# assert len(selection.nodes()) == 3 -# assert len(selection.edges()) == 2 - -# selection = vg.entities_by_similarity("node1", 1) -# selection.expand_entities_by_similarity("edge1", 1) -# selection.expand_entities_by_similarity("node2", 1) -# assert len(selection.get_documents()) == 5 -# nodes = selection.nodes() -# node_names_returned = [node.name for node in nodes] -# assert node_names_returned == ["node1", "node2"] -# edges = selection.edges() -# edge_names_returned = [(edge.src.name, edge.dst.name) for edge in edges] -# assert edge_names_returned == [("node1", "node2")] - -# selection = vg.empty_selection() -# selection.expand_entities_by_similarity("node3", 10) -# assert len(selection.get_documents()) == 0 - -# selection = vg.entities_by_similarity("node1", 1) -# selection.expand_entities_by_similarity("node3", 10) -# assert len(selection.get_documents()) == 9 -# assert len(selection.nodes()) == 4 -# assert len(selection.edges()) == 3 -# # TODO: add some expand_documents here - - -# def test_windows(): -# vg = create_graph() - -# selection = vg.nodes_by_similarity("node1", 1, (4, 5)) -# assert [doc.content for doc in selection.get_documents()] == ["node4"] - -# selection = vg.nodes_by_similarity("node4", 1, (1, 2)) -# assert [doc.content for doc in selection.get_documents()] == ["node1", "node1-extra"] - -# selection.expand(10, (0, 3)) -# contents = [doc.content for doc in selection.get_documents()] -# assert contents == ["node1", "node1-extra", "edge1", "edge1-extra", "node2"] - -# selection.expand_documents_by_similarity("edge2", 100, (0, 4)) -# contents = [doc.content for doc in selection.get_documents()] -# assert contents == ["node1", "node1-extra", "edge1", "edge1-extra", "node2", "edge2", "node3"] - -# # this should leave the selection unchanged -# selection.expand_documents_by_similarity("node1", 100, (20, 100)) -# contents = [doc.content for doc in selection.get_documents()] -# assert contents == ["node1", "node1-extra", "edge1", "edge1-extra", "node2", "edge2", "node3"] - -# # this should also leave the selection unchanged -# selection.expand_entities_by_similarity("node1", 100, (20, 100)) -# contents = [doc.content for doc in selection.get_documents()] -# assert contents == ["node1", "node1-extra", "edge1", "edge1-extra", "node2", "edge2", "node3"] - -# selection.expand(10, (4, 100)) -# contents = [doc.content for doc in selection.get_documents()] -# assert contents == ["node1", "node1-extra", "edge1", "edge1-extra", "node2", "edge2", "node3", "edge3", "node4"] - - -# def test_filtering_by_entity_type(): -# vg = create_graph() - -# selection = vg.empty_selection() -# selection.add_nodes(["node1"]) -# selection.expand_nodes_by_similarity("node2", 10) -# contents = [doc.content for doc in selection.get_documents()] -# assert contents == ["node1", "node1-extra", "node2", "node3", "node4"] - -# selection = vg.empty_selection() -# selection.add_edges([("node1", "node2")]) -# selection.expand_edges_by_similarity("edge3", 10) -# contents = [doc.content for doc in selection.get_documents()] -# assert contents == ["edge1", "edge1-extra", "edge2", "edge3"] diff --git a/raphtory-benchmark/Cargo.toml b/raphtory-benchmark/Cargo.toml index 9e3589e2f8..9b21e3a42a 100644 --- a/raphtory-benchmark/Cargo.toml +++ b/raphtory-benchmark/Cargo.toml @@ -7,19 +7,28 @@ edition = "2021" [dependencies] criterion = { workspace = true } -raphtory = { path = "../raphtory", features = ["io", "proto"], version = "0.15.1" } +raphtory = { path = "../raphtory", features = [ + "io", + "proto", + "vectors", +], version = "0.15.1" } raphtory-api = { path = "../raphtory-api", version = "0.15.1" } sorted_vector_map = { workspace = true } rand = { workspace = true } rayon = { workspace = true } tempfile = { workspace = true } -tracing = {workspace = true} +tracing = { workspace = true } once_cell = { workspace = true } -serde = { workspace = true } -itertools = { workspace = true } +serde = { workspace = true } +itertools = { workspace = true } fake = { workspace = true } csv = { workspace = true } chrono = { workspace = true } +tokio = { workspace = true } + +[[bin]] +name = "vectorise" +path = "bin/vectorise.rs" [[bench]] name = "tgraph_benchmarks" @@ -66,5 +75,9 @@ name = "search_bench" harness = false required-features = ["search"] +[[bench]] +name = "vectors" +harness = false + [features] search = ["raphtory/search"] diff --git a/raphtory-benchmark/benches/vectors.rs b/raphtory-benchmark/benches/vectors.rs new file mode 100644 index 0000000000..19d421bb3e --- /dev/null +++ b/raphtory-benchmark/benches/vectors.rs @@ -0,0 +1,18 @@ +use criterion::{criterion_group, criterion_main, Criterion}; + +use raphtory_benchmark::common::vectors::{ + create_graph_for_vector_bench, gen_embedding_for_bench, vectorise_graph_for_bench, +}; + +fn bench_search_entities(c: &mut Criterion) { + let g = create_graph_for_vector_bench(100_000); + let v = vectorise_graph_for_bench(g); + + let query = gen_embedding_for_bench("0"); + c.bench_function("semantic_search_entities", |b| { + b.iter(|| v.entities_by_similarity(&query, 10, None)) + }); +} + +criterion_group!(vector_benches, bench_search_entities,); +criterion_main!(vector_benches); diff --git a/raphtory-benchmark/bin/vectorise.rs b/raphtory-benchmark/bin/vectorise.rs new file mode 100644 index 0000000000..e69b37e2ed --- /dev/null +++ b/raphtory-benchmark/bin/vectorise.rs @@ -0,0 +1,19 @@ +use std::time::SystemTime; + +use raphtory_benchmark::common::vectors::{ + create_graph_for_vector_bench, vectorise_graph_for_bench, +}; + +fn print_time(start: SystemTime, message: &str) { + let duration = SystemTime::now().duration_since(start).unwrap().as_secs(); + println!("{message} - took {duration}s"); +} + +fn main() { + for size in [1_000_000] { + let graph = create_graph_for_vector_bench(size); + let start = SystemTime::now(); + vectorise_graph_for_bench(graph); + print_time(start, &format!(">>> vectorise {}k", size / 1000)); + } +} diff --git a/raphtory-benchmark/src/common/mod.rs b/raphtory-benchmark/src/common/mod.rs index caf4247db4..3c2a0f2e43 100644 --- a/raphtory-benchmark/src/common/mod.rs +++ b/raphtory-benchmark/src/common/mod.rs @@ -1,5 +1,7 @@ #![allow(dead_code)] +pub mod vectors; + use criterion::{ black_box, measurement::WallTime, BatchSize, Bencher, BenchmarkGroup, BenchmarkId, Criterion, }; diff --git a/raphtory-benchmark/src/common/vectors.rs b/raphtory-benchmark/src/common/vectors.rs new file mode 100644 index 0000000000..94456443a8 --- /dev/null +++ b/raphtory-benchmark/src/common/vectors.rs @@ -0,0 +1,46 @@ +use std::hash::{DefaultHasher, Hash, Hasher}; + +use rand::{rngs::StdRng, Rng, SeedableRng}; +use raphtory::{ + prelude::{AdditionOps, Graph, NO_PROPS}, + vectors::{ + cache::VectorCache, embeddings::EmbeddingResult, template::DocumentTemplate, + vectorisable::Vectorisable, vectorised_graph::VectorisedGraph, Embedding, + }, +}; +use tokio::runtime::Runtime; + +pub fn gen_embedding_for_bench(text: &str) -> Embedding { + let mut hasher = DefaultHasher::new(); + text.hash(&mut hasher); + let hash = hasher.finish(); + + let mut rng: StdRng = SeedableRng::seed_from_u64(hash); + (0..1024).map(|_| rng.gen()).collect() +} + +async fn embedding_model(texts: Vec) -> EmbeddingResult> { + Ok(texts + .iter() + .map(|text| gen_embedding_for_bench(text)) + .collect()) +} + +pub fn create_graph_for_vector_bench(size: usize) -> Graph { + let graph = Graph::new(); + for id in 0..size { + graph.add_node(0, id as u64, NO_PROPS, None).unwrap(); + } + graph +} + +pub fn vectorise_graph_for_bench(graph: Graph) -> VectorisedGraph { + let cache = VectorCache::in_memory(embedding_model); + let template = DocumentTemplate { + node_template: Some("{{name}}".to_owned()), + edge_template: None, + }; + let rt = Runtime::new().unwrap(); + rt.block_on(graph.vectorise(cache, template, None, true)) + .unwrap() +} diff --git a/raphtory-graphql/src/data.rs b/raphtory-graphql/src/data.rs index 9bc5eccfbd..9bc0c3c136 100644 --- a/raphtory-graphql/src/data.rs +++ b/raphtory-graphql/src/data.rs @@ -1,18 +1,17 @@ use crate::{ config::app_config::AppConfig, graph::GraphWithVectors, - model::plugins::query_plugin::QueryPlugin, paths::{valid_path, ExistingGraphFolder, ValidGraphFolder}, }; use itertools::Itertools; use moka::sync::Cache; use raphtory::{ - core::utils::errors::{GraphError, GraphResult, InvalidPathReason}, + core::utils::errors::{GraphError, InvalidPathReason}, db::api::view::MaterializedGraph, + prelude::CacheOps, vectors::{ - embedding_cache::EmbeddingCache, embeddings::openai_embedding, template::DocumentTemplate, - vectorisable::Vectorisable, vectorised_graph::VectorisedGraph, Embedding, - EmbeddingFunction, + cache::VectorCache, embeddings::openai_embedding, template::DocumentTemplate, + vectorisable::Vectorisable, vectorised_graph::VectorisedGraph, }, }; use std::{ @@ -27,8 +26,7 @@ use walkdir::WalkDir; #[derive(Clone)] pub struct EmbeddingConf { - pub(crate) function: Arc, - pub(crate) cache: Arc>, // FIXME: no need for this to be Option + pub(crate) cache: VectorCache, pub(crate) global_template: Option, pub(crate) individual_templates: HashMap, } @@ -116,17 +114,6 @@ impl Data { &self, path: &str, graph: MaterializedGraph, - ) -> Result<(), GraphError> { - let folder = ValidGraphFolder::try_from(self.work_dir.clone(), path)?; - let vectors = self.vectorise(graph.clone(), &folder).await; - let graph = GraphWithVectors::new(graph, vectors); - self.insert_graph_with_vectors(path, graph) - } - - pub fn insert_graph_with_vectors( - &self, - path: &str, - graph: GraphWithVectors, ) -> Result<(), GraphError> { // TODO: replace ValidGraphFolder with ValidNonExistingGraphFolder !!!!!!!!! // or even a NewGraphFolder, so that we try to create the graph file and if that is sucessful @@ -136,7 +123,12 @@ impl Data { Ok(_) => Err(GraphError::GraphNameAlreadyExists(folder.to_error_path())), Err(_) => { fs::create_dir_all(folder.get_base_path())?; - graph.cache(folder)?; + graph.cache(folder.clone())?; + let vectors = self.vectorise(graph.clone(), &folder).await; + let graph = GraphWithVectors::new(graph, vectors); + graph + .folder + .get_or_try_init(|| Ok::<_, GraphError>(folder.into()))?; self.cache.insert(path.into(), graph); Ok(()) } @@ -150,19 +142,6 @@ impl Data { Ok(()) } - pub async fn embed_query(&self, query: String) -> GraphResult { - let embedding_function = self - .embedding_conf - .as_ref() - .map(|conf| conf.function.clone()); - let embedding = if let Some(embedding_function) = embedding_function { - embedding_function.call(vec![query]).await?.remove(0) - } else { - openai_embedding(vec![query]).await?.remove(0) - }; - Ok(embedding) - } - fn resolve_template(&self, graph: &Path) -> Option<&DocumentTemplate> { let conf = self.embedding_conf.as_ref()?; conf.individual_templates @@ -179,11 +158,9 @@ impl Data { let conf = self.embedding_conf.as_ref()?; let vectors = graph .vectorise( - Box::new(conf.function.clone()), conf.cache.clone(), - true, // overwrite template.clone(), - Some(folder.get_original_path_str().to_owned()), + Some(&folder.get_vectors_path()), true, // verbose ) .await; @@ -206,24 +183,19 @@ impl Data { self.vectorise_with_template(graph, folder, template).await } - async fn vectorise_folder( - &self, - folder: &ExistingGraphFolder, - ) -> Option> { + async fn vectorise_folder(&self, folder: &ExistingGraphFolder) -> Option<()> { // it's important that we check if there is a valid template set for this graph path // before actually loading the graph, otherwise we are loading the graph for no reason let template = self.resolve_template(folder.get_original_path())?; let graph = self.read_graph_from_folder(folder).ok()?.graph; - self.vectorise_with_template(graph, folder, template).await + self.vectorise_with_template(graph, folder, template).await; + Some(()) } pub(crate) async fn vectorise_all_graphs_that_are_not(&self) -> Result<(), GraphError> { for folder in self.get_all_graph_folders() { if !folder.get_vectors_path().exists() { - let vectors = self.vectorise_folder(&folder).await; - if let Some(vectors) = vectors { - vectors.write_to_path(&folder.get_vectors_path())?; - } + self.vectorise_folder(&folder).await; } } Ok(()) @@ -244,38 +216,12 @@ impl Data { .collect() } - pub(crate) fn get_global_plugins(&self) -> QueryPlugin { - let graphs = self - .get_all_graph_folders() - .into_iter() - .filter_map(|folder| { - Some(( - folder.get_original_path_str().to_owned(), - self.read_graph_from_folder(&folder).ok()?.vectors?, - )) - }) - .collect::>(); - QueryPlugin { - graphs: graphs.into(), - } - } - fn read_graph_from_folder( &self, folder: &ExistingGraphFolder, ) -> Result { - let embedding = self - .embedding_conf - .as_ref() - .map(|conf| conf.function.clone()) - .unwrap_or(Arc::new(openai_embedding)); - let cache = self - .embedding_conf - .as_ref() - .map(|conf| conf.cache.clone()) - .unwrap_or(Arc::new(None)); - - GraphWithVectors::read_from_folder(folder, embedding, cache, self.create_index) + let cache = self.embedding_conf.as_ref().map(|conf| conf.cache.clone()); + GraphWithVectors::read_from_folder(folder, cache, self.create_index) } } diff --git a/raphtory-graphql/src/embeddings.rs b/raphtory-graphql/src/embeddings.rs new file mode 100644 index 0000000000..78a6f98c7b --- /dev/null +++ b/raphtory-graphql/src/embeddings.rs @@ -0,0 +1,17 @@ +use async_graphql::Context; +use raphtory::{core::utils::errors::GraphResult, vectors::Embedding}; + +use crate::data::Data; + +pub(crate) trait EmbedQuery { + async fn embed_query(&self, text: String) -> GraphResult; +} + +impl EmbedQuery for Context<'_> { + /// this is meant to be called from a vector context, so the embedding conf is assumed to exist + async fn embed_query(&self, text: String) -> GraphResult { + let data = self.data_unchecked::(); + let cache = &data.embedding_conf.as_ref().unwrap().cache; + cache.get_single(text).await + } +} diff --git a/raphtory-graphql/src/graph.rs b/raphtory-graphql/src/graph.rs index 61570a5536..f526f25047 100644 --- a/raphtory-graphql/src/graph.rs +++ b/raphtory-graphql/src/graph.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use crate::paths::ExistingGraphFolder; use once_cell::sync::OnceCell; #[cfg(feature = "storage")] @@ -25,16 +23,14 @@ use raphtory::{ }, prelude::{CacheOps, DeletionOps, EdgeViewOps, NodeViewOps, SearchableGraphOps}, serialise::GraphFolder, - vectors::{ - embedding_cache::EmbeddingCache, vectorised_graph::VectorisedGraph, EmbeddingFunction, - }, + vectors::{cache::VectorCache, vectorised_graph::VectorisedGraph}, }; #[derive(Clone)] pub struct GraphWithVectors { pub graph: MaterializedGraph, pub vectors: Option>, - folder: OnceCell, + pub(crate) folder: OnceCell, } impl GraphWithVectors { @@ -49,16 +45,6 @@ impl GraphWithVectors { } } - pub(crate) async fn update_graph_embeddings( - &self, - graph_name: Option, - ) -> GraphResult<()> { - if let Some(vectors) = &self.vectors { - vectors.update_graph(graph_name).await?; - } - Ok(()) - } - pub(crate) async fn update_node_embeddings(&self, node: T) -> GraphResult<()> { if let Some(vectors) = &self.vectors { vectors.update_node(node).await?; @@ -77,42 +63,17 @@ impl GraphWithVectors { Ok(()) } - pub(crate) fn cache(&self, path: impl Into) -> Result<(), GraphError> { - let folder = path.into(); - self.folder - .get_or_try_init(|| Ok::<_, GraphError>(folder.clone()))?; - self.graph.cache(folder)?; - self.dump_vectors_to_disk() - } - pub(crate) fn write_updates(&self) -> Result<(), GraphError> { match self.graph.core_graph() { - GraphStorage::Mem(_) | GraphStorage::Unlocked(_) => { - self.graph.write_updates()?; - } + GraphStorage::Mem(_) | GraphStorage::Unlocked(_) => self.graph.write_updates(), #[cfg(feature = "storage")] - GraphStorage::Disk(_) => {} - } - self.dump_vectors_to_disk() - } - - fn dump_vectors_to_disk(&self) -> Result<(), GraphError> { - if let Some(vectors) = &self.vectors { - vectors.write_to_path( - &self - .folder - .get() - .ok_or(GraphError::CacheNotInnitialised)? - .get_vectors_path(), - )?; + GraphStorage::Disk(_) => Ok(()), } - Ok(()) } pub(crate) fn read_from_folder( folder: &ExistingGraphFolder, - embedding: Arc, - cache: Arc>, + cache: Option, create_index: bool, ) -> Result { let graph_path = &folder.get_graph_path(); @@ -121,20 +82,13 @@ impl GraphWithVectors { } else { MaterializedGraph::load_cached(folder.clone())? }; - - let vectors = VectorisedGraph::read_from_path( - &folder.get_vectors_path(), - graph.clone(), - embedding, - cache, - ); - + let vectors = cache.and_then(|cache| { + VectorisedGraph::read_from_path(&folder.get_vectors_path(), graph.clone(), cache).ok() + }); println!("Graph loaded = {}", folder.get_original_path_str()); - if create_index { graph.create_index()?; } - Ok(Self { graph: graph.clone(), vectors, diff --git a/raphtory-graphql/src/lib.rs b/raphtory-graphql/src/lib.rs index 5e32eeb77d..24c3e38dda 100644 --- a/raphtory-graphql/src/lib.rs +++ b/raphtory-graphql/src/lib.rs @@ -1,6 +1,7 @@ pub use crate::server::GraphServer; mod auth; pub mod data; +mod embeddings; mod graph; pub mod model; pub mod observability; diff --git a/raphtory-graphql/src/model/algorithms/document.rs b/raphtory-graphql/src/model/algorithms/document.rs deleted file mode 100644 index faf1bbfcf4..0000000000 --- a/raphtory-graphql/src/model/algorithms/document.rs +++ /dev/null @@ -1,71 +0,0 @@ -use crate::model::graph::{edge::GqlEdge, node::GqlNode}; -use dynamic_graphql::{SimpleObject, Union}; -use raphtory::{ - core::Lifespan, - db::api::view::{IntoDynamic, StaticGraphViewOps}, - vectors::{Document as RustDocument, DocumentEntity}, -}; - -#[derive(SimpleObject)] -struct DocumentGraph { - name: String, // TODO: maybe return the graph as well here -} - -impl From for DocumentGraph { - fn from(value: String) -> Self { - Self { name: value } - } -} - -#[derive(Union)] -#[graphql(name = "DocumentEntity")] -enum GqlDocumentEntity { - DocNode(GqlNode), - DocEdge(GqlEdge), - DocGraph(DocumentGraph), -} - -impl From> for GqlDocumentEntity { - fn from(value: DocumentEntity) -> Self { - match value { - DocumentEntity::Graph { name, .. } => Self::DocGraph(DocumentGraph { - name: name.unwrap(), - }), - DocumentEntity::Node(node) => Self::DocNode(GqlNode::from(node)), - DocumentEntity::Edge(edge) => Self::DocEdge(GqlEdge::from(edge)), - } - } -} - -#[derive(SimpleObject)] -pub struct Document { - entity: GqlDocumentEntity, - content: String, - embedding: Vec, - life: Vec, // TODO: give this a proper type -} - -impl From> for Document { - fn from(value: RustDocument) -> Self { - let RustDocument { - entity, - content, - embedding, - life, - } = value; - Self { - entity: entity.into(), - content, - embedding: embedding.to_vec(), - life: lifespan_into_vec(life), - } - } -} - -fn lifespan_into_vec(life: Lifespan) -> Vec { - match life { - Lifespan::Inherited => vec![], - Lifespan::Event { time } => vec![time], - Lifespan::Interval { start, end } => vec![start, end], - } -} diff --git a/raphtory-graphql/src/model/algorithms/global_search.rs b/raphtory-graphql/src/model/algorithms/global_search.rs deleted file mode 100644 index 8e53a0423d..0000000000 --- a/raphtory-graphql/src/model/algorithms/global_search.rs +++ /dev/null @@ -1,62 +0,0 @@ -use crate::{ - data::Data, - model::{ - algorithms::document::Document, - plugins::{operation::Operation, query_plugin::QueryPlugin}, - }, -}; -use async_graphql::{ - dynamic::{FieldValue, ResolverContext, TypeRef}, - FieldResult, -}; -use dynamic_graphql::internal::TypeName; -use futures_util::future::BoxFuture; -use raphtory::vectors::vectorised_cluster::VectorisedCluster; -use std::ops::Deref; -use tracing::info; - -pub(crate) struct GlobalSearch; - -impl<'a> Operation<'a, QueryPlugin> for GlobalSearch { - type OutputType = Document; - - fn output_type() -> TypeRef { - TypeRef::named_nn_list_nn(Document::get_type_name()) - } - - fn args<'b>() -> Vec<(&'b str, TypeRef)> { - vec![ - ("query", TypeRef::named_nn(TypeRef::STRING)), - ("limit", TypeRef::named_nn(TypeRef::INT)), - ] - } - - fn apply<'b>( - entry_point: &QueryPlugin, - ctx: ResolverContext, - ) -> BoxFuture<'b, FieldResult>>> { - let data = ctx.data_unchecked::().clone(); - let query = ctx - .args - .try_get("query") - .unwrap() - .string() - .unwrap() - .to_owned(); - let limit = ctx.args.try_get("limit").unwrap().u64().unwrap() as usize; - let graphs = entry_point.graphs.clone(); - - Box::pin(async move { - info!("running global search for {query}"); - let embedding = data.embed_query(query).await?; - - let cluster = VectorisedCluster::new(graphs.deref()); - let documents = cluster.search_graph_documents(&embedding, limit, None); // TODO: add window - - let gql_documents = documents - .into_iter() - .map(|doc| FieldValue::owned_any(Document::from(doc))); - Ok(Some(FieldValue::list(gql_documents))) - }) - } -} diff --git a/raphtory-graphql/src/model/algorithms/mod.rs b/raphtory-graphql/src/model/algorithms/mod.rs deleted file mode 100644 index f0a29dba95..0000000000 --- a/raphtory-graphql/src/model/algorithms/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -use async_graphql::dynamic::Object; -use dynamic_graphql::internal::Registry; - -pub mod algorithms; -pub mod document; -pub mod global_search; -pub mod similarity_search; - -pub type RegisterFunction = Box (Registry, Object) + Send>; diff --git a/raphtory-graphql/src/model/algorithms/similarity_search.rs b/raphtory-graphql/src/model/algorithms/similarity_search.rs deleted file mode 100644 index a07287fa6c..0000000000 --- a/raphtory-graphql/src/model/algorithms/similarity_search.rs +++ /dev/null @@ -1,80 +0,0 @@ -use crate::{ - data::Data, - model::{ - algorithms::document::Document, - plugins::{operation::Operation, vector_algorithm_plugin::VectorAlgorithmPlugin}, - }, -}; -use async_graphql::{ - dynamic::{FieldValue, ResolverContext, TypeRef}, - FieldResult, -}; -use dynamic_graphql::internal::TypeName; -use futures_util::future::BoxFuture; -use tracing::info; - -pub(crate) struct SimilaritySearch; - -impl<'a> Operation<'a, VectorAlgorithmPlugin> for SimilaritySearch { - type OutputType = Document; - - fn output_type() -> TypeRef { - TypeRef::named_nn_list_nn(Document::get_type_name()) - } - - fn args<'b>() -> Vec<(&'b str, TypeRef)> { - vec![ - ("query", TypeRef::named_nn(TypeRef::STRING)), - ("limit", TypeRef::named_nn(TypeRef::INT)), - ("start", TypeRef::named(TypeRef::INT)), - ("end", TypeRef::named(TypeRef::INT)), - ] - } - - fn apply<'b>( - entry_point: &VectorAlgorithmPlugin, - ctx: ResolverContext, - ) -> BoxFuture<'b, FieldResult>>> { - let data = ctx.data_unchecked::().clone(); - let query = ctx - .args - .try_get("query") - .unwrap() - .string() - .unwrap() - .to_owned(); - let limit = ctx.args.try_get("limit").unwrap().u64().unwrap() as usize; - let graph = entry_point.graph.clone(); - let start = ctx - .args - .try_get("start") - .map(|start| start.u64().ok().map(|value| value as i64)) - .ok() - .flatten(); - - let end = ctx - .args - .try_get("end") - .map(|end| end.u64().ok().map(|value| value as i64)) - .ok() - .flatten(); - let window = match (start, end) { - (Some(start), Some(end)) => Some((start, end)), - _ => None, - }; - - Box::pin(async move { - info!("running similarity search for {query}"); - let embedding = data.embed_query(query).await?; - - let documents = graph - .documents_by_similarity(&embedding, limit, window) - .get_documents(); - - let gql_documents = documents - .into_iter() - .map(|doc| FieldValue::owned_any(Document::from(doc))); - Ok(Some(FieldValue::list(gql_documents))) - }) - } -} diff --git a/raphtory-graphql/src/model/graph/document.rs b/raphtory-graphql/src/model/graph/document.rs new file mode 100644 index 0000000000..37bc8af98c --- /dev/null +++ b/raphtory-graphql/src/model/graph/document.rs @@ -0,0 +1,31 @@ +use dynamic_graphql::{SimpleObject, Union}; +use raphtory::{ + db::api::view::{IntoDynamic, StaticGraphViewOps}, + vectors::DocumentEntity, +}; + +use super::{edge::GqlEdge, node::GqlNode}; + +#[derive(Union)] +#[graphql(name = "DocumentEntity")] +pub(crate) enum GqlDocumentEntity { + Node(GqlNode), + Edge(GqlEdge), +} + +impl From> for GqlDocumentEntity { + fn from(value: DocumentEntity) -> Self { + match value { + DocumentEntity::Node(node) => Self::Node(GqlNode::from(node)), + DocumentEntity::Edge(edge) => Self::Edge(GqlEdge::from(edge)), + } + } +} + +#[derive(SimpleObject)] +pub struct GqlDocument { + pub(crate) entity: GqlDocumentEntity, + pub(crate) content: String, + pub(crate) embedding: Vec, + pub(crate) score: f32, +} diff --git a/raphtory-graphql/src/model/graph/mod.rs b/raphtory-graphql/src/model/graph/mod.rs index c7c73de7eb..203f156a60 100644 --- a/raphtory-graphql/src/model/graph/mod.rs +++ b/raphtory-graphql/src/model/graph/mod.rs @@ -1,5 +1,6 @@ use dynamic_graphql::OneOfInput; +mod document; pub(crate) mod edge; mod edges; pub(crate) mod filtering; @@ -13,6 +14,7 @@ pub(crate) mod node; mod nodes; mod path_from_node; pub(crate) mod property; +pub(crate) mod vector_selection; pub(crate) mod vectorised_graph; mod windowset; diff --git a/raphtory-graphql/src/model/graph/mutable_graph.rs b/raphtory-graphql/src/model/graph/mutable_graph.rs index dc524834a7..1ee0606984 100644 --- a/raphtory-graphql/src/model/graph/mutable_graph.rs +++ b/raphtory-graphql/src/model/graph/mutable_graph.rs @@ -299,17 +299,11 @@ impl GqlMutableGraph { properties: Vec, ) -> Result { let self_clone = self.clone(); - let self_clone_2 = self.clone(); spawn_blocking(move || { self_clone .graph - .add_properties(t, as_properties(properties)?) - }) - .await - .unwrap()?; - self.update_graph_embeddings().await; - spawn_blocking(move || { - self_clone_2.graph.write_updates()?; + .add_properties(t, as_properties(properties)?)?; + self_clone.graph.write_updates()?; Ok(true) }) .await @@ -322,17 +316,11 @@ impl GqlMutableGraph { properties: Vec, ) -> Result { let self_clone = self.clone(); - let self_clone_2 = self.clone(); spawn_blocking(move || { self_clone .graph - .add_constant_properties(as_properties(properties)?) - }) - .await - .unwrap()?; - self.update_graph_embeddings().await; - spawn_blocking(move || { - self_clone_2.graph.write_updates()?; + .add_constant_properties(as_properties(properties)?)?; + self_clone.graph.write_updates()?; Ok(true) }) .await @@ -345,17 +333,11 @@ impl GqlMutableGraph { properties: Vec, ) -> Result { let self_clone = self.clone(); - let self_clone_2 = self.clone(); spawn_blocking(move || { self_clone .graph - .update_constant_properties(as_properties(properties)?) - }) - .await - .unwrap()?; - self.update_graph_embeddings().await; - spawn_blocking(move || { - self_clone_2.graph.write_updates()?; + .update_constant_properties(as_properties(properties)?)?; + self_clone.graph.write_updates()?; Ok(true) }) .await @@ -364,13 +346,6 @@ impl GqlMutableGraph { } impl GqlMutableGraph { - async fn update_graph_embeddings(&self) { - let _ = self - .graph - .update_graph_embeddings(Some(self.path.get_original_path_str().to_owned())) - .await; - } - fn get_node_view(&self, name: String) -> Result, GraphError> { self.graph .node(name.clone()) diff --git a/raphtory-graphql/src/model/graph/vector_selection.rs b/raphtory-graphql/src/model/graph/vector_selection.rs new file mode 100644 index 0000000000..051da78427 --- /dev/null +++ b/raphtory-graphql/src/model/graph/vector_selection.rs @@ -0,0 +1,142 @@ +use crate::{embeddings::EmbedQuery, model::blocking}; +use async_graphql::Context; +use dynamic_graphql::{InputObject, ResolvedObject, ResolvedObjectFields}; +use raphtory::{ + core::utils::errors::GraphResult, db::api::view::MaterializedGraph, + vectors::vector_selection::VectorSelection, +}; + +use super::{ + document::GqlDocument, + edge::GqlEdge, + node::GqlNode, + vectorised_graph::{IntoWindowTuple, Window}, +}; + +#[derive(InputObject)] +pub(super) struct InputEdge { + src: String, + dst: String, +} + +#[derive(ResolvedObject)] +pub(crate) struct GqlVectorSelection(VectorSelection); + +impl From> for GqlVectorSelection { + fn from(value: VectorSelection) -> Self { + Self(value) + } +} + +#[ResolvedObjectFields] +impl GqlVectorSelection { + async fn nodes(&self) -> Vec { + self.0.nodes().into_iter().map(|e| e.into()).collect() + } + + async fn edges(&self) -> Vec { + self.0.edges().into_iter().map(|e| e.into()).collect() + } + + async fn get_documents(&self) -> GraphResult> { + let cloned = self.0.clone(); + blocking(move || { + let docs = cloned.get_documents_with_scores()?.into_iter(); + Ok(docs + .map(|(doc, score)| GqlDocument { + content: doc.content, + entity: doc.entity.into(), + embedding: doc.embedding.to_vec(), + score, + }) + .collect()) + }) + .await + } + + async fn add_nodes(&self, nodes: Vec) -> Self { + let mut selection = self.cloned(); + blocking(move || { + selection.add_nodes(nodes); + selection.into() + }) + .await + } + + async fn add_edges(&self, edges: Vec) -> Self { + let mut selection = self.cloned(); + blocking(move || { + let edges = edges.into_iter().map(|edge| (edge.src, edge.dst)).collect(); + selection.add_edges(edges); + selection.into() + }) + .await + } + + async fn expand(&self, hops: usize, window: Option) -> Self { + let window = window.into_window_tuple(); + let mut selection = self.cloned(); + blocking(move || { + selection.expand(hops, window); + selection.into() + }) + .await + } + + async fn expand_entities_by_similarity( + &self, + ctx: &Context<'_>, + query: String, + limit: usize, + window: Option, + ) -> GraphResult { + let vector = ctx.embed_query(query).await?; + let window = window.into_window_tuple(); + let mut selection = self.cloned(); + blocking(move || { + selection.expand_entities_by_similarity(&vector, limit, window)?; + Ok(selection.into()) + }) + .await + } + + async fn expand_nodes_by_similarity( + &self, + ctx: &Context<'_>, + query: String, + limit: usize, + window: Option, + ) -> GraphResult { + let vector = ctx.embed_query(query).await?; + let window = window.into_window_tuple(); + let mut selection = self.cloned(); + blocking(move || { + selection.expand_nodes_by_similarity(&vector, limit, window)?; + Ok(selection.into()) + }) + .await + } + + async fn expand_edges_by_similarity( + &self, + ctx: &Context<'_>, + query: String, + limit: usize, + window: Option, + ) -> GraphResult { + let vector = ctx.embed_query(query).await?; + let window = window.into_window_tuple(); + let mut selection = self.cloned(); + blocking(move || { + selection.expand_edges_by_similarity(&vector, limit, window)?; + Ok(selection.into()) + }) + .await + } +} + +impl GqlVectorSelection { + fn cloned(&self) -> VectorSelection { + self.0.clone() + } +} diff --git a/raphtory-graphql/src/model/graph/vectorised_graph.rs b/raphtory-graphql/src/model/graph/vectorised_graph.rs index 8fdefd22c0..dcef926fac 100644 --- a/raphtory-graphql/src/model/graph/vectorised_graph.rs +++ b/raphtory-graphql/src/model/graph/vectorised_graph.rs @@ -1,24 +1,82 @@ -use crate::model::plugins::vector_algorithm_plugin::VectorAlgorithmPlugin; -use dynamic_graphql::{ResolvedObject, ResolvedObjectFields}; -use raphtory::{db::api::view::MaterializedGraph, vectors::vectorised_graph::VectorisedGraph}; +use async_graphql::Context; +use dynamic_graphql::{InputObject, ResolvedObject, ResolvedObjectFields}; +use raphtory::{ + core::utils::errors::GraphResult, db::api::view::MaterializedGraph, + vectors::vectorised_graph::VectorisedGraph, +}; + +use crate::{embeddings::EmbedQuery, model::blocking}; + +use super::vector_selection::GqlVectorSelection; + +#[derive(InputObject)] +pub(super) struct Window { + start: i64, + end: i64, +} + +pub(super) trait IntoWindowTuple { + fn into_window_tuple(self) -> Option<(i64, i64)>; +} + +impl IntoWindowTuple for Option { + fn into_window_tuple(self) -> Option<(i64, i64)> { + self.map(|window| (window.start, window.end)) + } +} #[derive(ResolvedObject)] #[graphql(name = "VectorisedGraph")] -pub(crate) struct GqlVectorisedGraph { - graph: VectorisedGraph, -} +pub(crate) struct GqlVectorisedGraph(VectorisedGraph); impl From> for GqlVectorisedGraph { fn from(value: VectorisedGraph) -> Self { - Self { - graph: value.clone(), - } + Self(value.clone()) } } #[ResolvedObjectFields] impl GqlVectorisedGraph { - async fn algorithms(&self) -> VectorAlgorithmPlugin { - self.graph.clone().into() + async fn empty_selection(&self) -> GqlVectorSelection { + self.0.empty_selection().into() + } + + async fn entities_by_similarity( + &self, + ctx: &Context<'_>, + query: String, + limit: usize, + window: Option, + ) -> GraphResult { + let vector = ctx.embed_query(query).await?; + let w = window.into_window_tuple(); + let cloned = self.0.clone(); + blocking(move || Ok(cloned.entities_by_similarity(&vector, limit, w)?.into())).await + } + + async fn nodes_by_similarity( + &self, + ctx: &Context<'_>, + query: String, + limit: usize, + window: Option, + ) -> GraphResult { + let vector = ctx.embed_query(query).await?; + let w = window.into_window_tuple(); + let cloned = self.0.clone(); + blocking(move || Ok(cloned.nodes_by_similarity(&vector, limit, w)?.into())).await + } + + async fn edges_by_similarity( + &self, + ctx: &Context<'_>, + query: String, + limit: usize, + window: Option, + ) -> GraphResult { + let vector = ctx.embed_query(query).await?; + let w = window.into_window_tuple(); + let cloned = self.0.clone(); + blocking(move || Ok(cloned.edges_by_similarity(&vector, limit, w)?.into())).await } } diff --git a/raphtory-graphql/src/model/mod.rs b/raphtory-graphql/src/model/mod.rs index 2ab9094d66..8e38349af1 100644 --- a/raphtory-graphql/src/model/mod.rs +++ b/raphtory-graphql/src/model/mod.rs @@ -34,12 +34,20 @@ use std::{ }; use zip::ZipArchive; -pub mod algorithms; pub(crate) mod graph; pub mod plugins; pub(crate) mod schema; pub(crate) mod sorting; +/// a thin wrapper around spawn_blocking that unwraps the join handle +pub(crate) async fn blocking(f: F) -> R +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + tokio::task::spawn_blocking(f).await.unwrap() +} + #[derive(Debug)] pub struct MissingGraph; @@ -129,9 +137,8 @@ impl QueryRoot { Namespace::new(data.work_dir.clone(), data.work_dir.clone()) } - async fn plugins<'a>(ctx: &Context<'a>) -> QueryPlugin { - let data = ctx.data_unchecked::(); - data.get_global_plugins() + async fn plugins<'a>() -> QueryPlugin { + QueryPlugin::default() } async fn receive_graph<'a>(ctx: &Context<'a>, path: String) -> Result> { diff --git a/raphtory-graphql/src/model/algorithms/algorithms.rs b/raphtory-graphql/src/model/plugins/algorithms.rs similarity index 100% rename from raphtory-graphql/src/model/algorithms/algorithms.rs rename to raphtory-graphql/src/model/plugins/algorithms.rs diff --git a/raphtory-graphql/src/model/plugins/entry_point.rs b/raphtory-graphql/src/model/plugins/entry_point.rs index be1f589bcf..691a4c8cdb 100644 --- a/raphtory-graphql/src/model/plugins/entry_point.rs +++ b/raphtory-graphql/src/model/plugins/entry_point.rs @@ -1,9 +1,10 @@ -use crate::model::algorithms::RegisterFunction; use async_graphql::dynamic::Object; use dynamic_graphql::internal::{OutputTypeName, Register, Registry, ResolveOwned, TypeName}; use itertools::Itertools; use std::{collections::HashMap, sync::MutexGuard}; +use super::RegisterFunction; + pub trait EntryPoint<'a>: Register + TypeName + OutputTypeName + ResolveOwned<'a> + Sync { fn predefined_operations() -> HashMap<&'static str, RegisterFunction>; diff --git a/raphtory-graphql/src/model/plugins/graph_algorithm_plugin.rs b/raphtory-graphql/src/model/plugins/graph_algorithm_plugin.rs index 9e5e54fbb1..04e45280be 100644 --- a/raphtory-graphql/src/model/plugins/graph_algorithm_plugin.rs +++ b/raphtory-graphql/src/model/plugins/graph_algorithm_plugin.rs @@ -1,10 +1,4 @@ -use crate::model::{ - algorithms::{ - algorithms::{Pagerank, ShortestPath}, - RegisterFunction, - }, - plugins::{entry_point::EntryPoint, operation::Operation}, -}; +use crate::model::plugins::entry_point::EntryPoint; use async_graphql::{dynamic::FieldValue, Context}; use dynamic_graphql::internal::{OutputTypeName, Register, Registry, ResolveOwned, TypeName}; use once_cell::sync::Lazy; @@ -15,6 +9,12 @@ use std::{ sync::{Mutex, MutexGuard}, }; +use super::{ + algorithms::{Pagerank, ShortestPath}, + operation::Operation, + RegisterFunction, +}; + pub static GRAPH_ALGO_PLUGINS: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); diff --git a/raphtory-graphql/src/model/plugins/mod.rs b/raphtory-graphql/src/model/plugins/mod.rs index e8ba301b52..d499e1d39c 100644 --- a/raphtory-graphql/src/model/plugins/mod.rs +++ b/raphtory-graphql/src/model/plugins/mod.rs @@ -1,3 +1,7 @@ +use async_graphql::dynamic::Object; +use dynamic_graphql::internal::Registry; + +pub mod algorithms; pub mod entry_point; pub mod graph_algorithm_plugin; pub mod mutation_entry_point; @@ -5,4 +9,5 @@ pub mod mutation_plugin; pub mod operation; pub mod query_entry_point; pub mod query_plugin; -pub mod vector_algorithm_plugin; + +pub type RegisterFunction = Box (Registry, Object) + Send>; diff --git a/raphtory-graphql/src/model/plugins/mutation_entry_point.rs b/raphtory-graphql/src/model/plugins/mutation_entry_point.rs index 1df8c74f8d..10c0e88166 100644 --- a/raphtory-graphql/src/model/plugins/mutation_entry_point.rs +++ b/raphtory-graphql/src/model/plugins/mutation_entry_point.rs @@ -1,9 +1,10 @@ -use crate::model::algorithms::RegisterFunction; use async_graphql::dynamic::Object; use dynamic_graphql::internal::{OutputTypeName, Register, Registry, ResolveOwned, TypeName}; use itertools::Itertools; use std::{collections::HashMap, sync::MutexGuard}; +use super::RegisterFunction; + pub trait MutationEntryPoint<'a>: Register + TypeName + OutputTypeName + ResolveOwned<'a> + Sync { diff --git a/raphtory-graphql/src/model/plugins/mutation_plugin.rs b/raphtory-graphql/src/model/plugins/mutation_plugin.rs index 0ca8d9321e..518a6e5ab0 100644 --- a/raphtory-graphql/src/model/plugins/mutation_plugin.rs +++ b/raphtory-graphql/src/model/plugins/mutation_plugin.rs @@ -1,9 +1,6 @@ -use crate::model::{ - algorithms::RegisterFunction, - plugins::{ - entry_point::EntryPoint, - operation::{NoOpMutation, Operation}, - }, +use crate::model::plugins::{ + entry_point::EntryPoint, + operation::{NoOpMutation, Operation}, }; use async_graphql::{dynamic::FieldValue, Context}; use dynamic_graphql::internal::{OutputTypeName, Register, Registry, ResolveOwned, TypeName}; @@ -14,11 +11,13 @@ use std::{ sync::{Mutex, MutexGuard}, }; +use super::RegisterFunction; + pub static MUTATION_PLUGINS: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); #[derive(Clone, Default)] -pub struct MutationPlugin {} +pub struct MutationPlugin; impl<'a> EntryPoint<'a> for MutationPlugin { fn predefined_operations() -> HashMap<&'static str, RegisterFunction> { diff --git a/raphtory-graphql/src/model/plugins/operation.rs b/raphtory-graphql/src/model/plugins/operation.rs index 1202ec48b7..144a829b3c 100644 --- a/raphtory-graphql/src/model/plugins/operation.rs +++ b/raphtory-graphql/src/model/plugins/operation.rs @@ -6,6 +6,8 @@ use async_graphql::{ use dynamic_graphql::internal::{Register, Registry}; use futures_util::future::BoxFuture; +use super::query_plugin::QueryPlugin; + pub trait Operation<'a, A: Send + Sync + 'static> { type OutputType: Register + 'static; @@ -55,3 +57,24 @@ impl<'a> Operation<'a, MutationPlugin> for NoOpMutation { Box::pin(async move { Ok(Some(FieldValue::value("no-op".to_owned()))) }) } } + +pub(crate) struct NoOpQuery; + +impl<'a> Operation<'a, QueryPlugin> for NoOpQuery { + type OutputType = String; + + fn output_type() -> TypeRef { + TypeRef::named_nn(TypeRef::STRING) + } + + fn args<'b>() -> Vec<(&'b str, TypeRef)> { + vec![] + } + + fn apply<'b>( + _entry_point: &QueryPlugin, + _ctx: ResolverContext, + ) -> BoxFuture<'b, FieldResult>>> { + Box::pin(async move { Ok(Some(FieldValue::value("no-op".to_owned()))) }) + } +} diff --git a/raphtory-graphql/src/model/plugins/query_entry_point.rs b/raphtory-graphql/src/model/plugins/query_entry_point.rs index 9017677006..964b7ba18f 100644 --- a/raphtory-graphql/src/model/plugins/query_entry_point.rs +++ b/raphtory-graphql/src/model/plugins/query_entry_point.rs @@ -1,9 +1,10 @@ -use crate::model::algorithms::RegisterFunction; use async_graphql::dynamic::Object; use dynamic_graphql::internal::{OutputTypeName, Register, Registry, ResolveOwned, TypeName}; use itertools::Itertools; use std::{collections::HashMap, sync::MutexGuard}; +use super::RegisterFunction; + pub trait QueryEntryPoint<'a>: Register + TypeName + OutputTypeName + ResolveOwned<'a> + Sync { diff --git a/raphtory-graphql/src/model/plugins/query_plugin.rs b/raphtory-graphql/src/model/plugins/query_plugin.rs index 889d713282..3319d33e6b 100644 --- a/raphtory-graphql/src/model/plugins/query_plugin.rs +++ b/raphtory-graphql/src/model/plugins/query_plugin.rs @@ -1,30 +1,26 @@ -use crate::model::{ - algorithms::{global_search::GlobalSearch, RegisterFunction}, - plugins::{entry_point::EntryPoint, operation::Operation}, -}; +use crate::model::plugins::{entry_point::EntryPoint, operation::Operation}; use async_graphql::{dynamic::FieldValue, Context}; use dynamic_graphql::internal::{OutputTypeName, Register, Registry, ResolveOwned, TypeName}; use once_cell::sync::Lazy; -use raphtory::{db::api::view::MaterializedGraph, vectors::vectorised_graph::VectorisedGraph}; use std::{ borrow::Cow, collections::HashMap, - sync::{Arc, Mutex, MutexGuard}, + sync::{Mutex, MutexGuard}, }; +use super::{operation::NoOpQuery, RegisterFunction}; + pub static QUERY_PLUGINS: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); -#[derive(Clone)] -pub struct QueryPlugin { - pub graphs: Arc>>, -} +#[derive(Clone, Default)] +pub struct QueryPlugin; impl<'a> EntryPoint<'a> for QueryPlugin { fn predefined_operations() -> HashMap<&'static str, RegisterFunction> { HashMap::from([( - "globalSearch", - Box::new(GlobalSearch::register_operation) as RegisterFunction, + "NoOps", + Box::new(NoOpQuery::register_operation) as RegisterFunction, )]) } diff --git a/raphtory-graphql/src/model/plugins/vector_algorithm_plugin.rs b/raphtory-graphql/src/model/plugins/vector_algorithm_plugin.rs deleted file mode 100644 index 0601d21c2b..0000000000 --- a/raphtory-graphql/src/model/plugins/vector_algorithm_plugin.rs +++ /dev/null @@ -1,59 +0,0 @@ -use crate::model::{ - algorithms::{similarity_search::SimilaritySearch, RegisterFunction}, - plugins::{entry_point::EntryPoint, operation::Operation}, -}; -use async_graphql::{dynamic::FieldValue, Context}; -use dynamic_graphql::internal::{OutputTypeName, Register, Registry, ResolveOwned, TypeName}; -use once_cell::sync::Lazy; -use raphtory::{db::api::view::MaterializedGraph, vectors::vectorised_graph::VectorisedGraph}; -use std::{ - borrow::Cow, - collections::HashMap, - sync::{Mutex, MutexGuard}, -}; - -pub static VECTOR_ALGO_PLUGINS: Lazy>> = - Lazy::new(|| Mutex::new(HashMap::new())); - -pub struct VectorAlgorithmPlugin { - pub graph: VectorisedGraph, -} - -impl From> for VectorAlgorithmPlugin { - fn from(graph: VectorisedGraph) -> Self { - Self { graph } - } -} - -impl<'a> EntryPoint<'a> for VectorAlgorithmPlugin { - fn predefined_operations() -> HashMap<&'static str, RegisterFunction> { - HashMap::from([( - "similaritySearch", - Box::new(SimilaritySearch::register_operation) as RegisterFunction, - )]) - } - - fn lock_plugins() -> MutexGuard<'static, HashMap> { - VECTOR_ALGO_PLUGINS.lock().unwrap() - } -} - -impl Register for VectorAlgorithmPlugin { - fn register(registry: Registry) -> Registry { - Self::register_operations(registry) - } -} - -impl TypeName for VectorAlgorithmPlugin { - fn get_type_name() -> Cow<'static, str> { - "VectorAlgorithmPlugin".into() - } -} - -impl OutputTypeName for VectorAlgorithmPlugin {} - -impl<'a> ResolveOwned<'a> for VectorAlgorithmPlugin { - fn resolve_owned(self, _ctx: &Context) -> dynamic_graphql::Result>> { - Ok(Some(FieldValue::owned_any(self))) - } -} diff --git a/raphtory-graphql/src/python/global_plugins.rs b/raphtory-graphql/src/python/global_plugins.rs deleted file mode 100644 index 52c793dd9a..0000000000 --- a/raphtory-graphql/src/python/global_plugins.rs +++ /dev/null @@ -1,77 +0,0 @@ -use crate::model::plugins::query_plugin::QueryPlugin; -use pyo3::{pyclass, pymethods, PyResult, Python}; -use raphtory::{ - db::api::view::DynamicGraph, - python::packages::vectors::{ - compute_embedding, translate_window, PyQuery, PyVectorisedGraph, PyWindow, - }, - vectors::{vectorised_cluster::VectorisedCluster, Document}, -}; - -/// A class for accessing graphs hosted in a Raphtory GraphQL server and running global search for -/// graph documents -#[pyclass(name = "GraphqlGraphs", module = "raphtory.graphql")] -pub struct PyGlobalPlugins(pub(crate) QueryPlugin); - -#[pymethods] -impl PyGlobalPlugins { - /// Return the top documents with the smallest cosine distance to `query` - /// - /// Arguments: - /// query (str): the text or the embedding to score against - /// limit (int): the maximum number of documents to return - /// window (Tuple[TimeInput, TimeInput], optional): the window where documents need to belong to in order to be considered - /// - /// Returns: - /// list[Document]: A list of documents - fn search_graph_documents( - &self, - py: Python, - query: PyQuery, - limit: usize, - window: PyWindow, - ) -> PyResult>> { - let docs = self.search_graph_documents_with_scores(py, query, limit, window)?; - Ok(docs.into_iter().map(|(doc, _)| doc).collect()) - } - - /// Same as `search_graph_documents` but it also returns the scores alongside the documents - /// - /// Arguments: - /// query (str): the text or the embedding to score against - /// limit (int): the maximum number of documents to return - /// window (Tuple[TimeInput, TimeInput], optional): the window where documents need to belong to in order to be considered - /// - /// Returns: - /// list[Tuple[Document, float]]: A list of documents and their scores - fn search_graph_documents_with_scores( - &self, - _py: Python, - query: PyQuery, - limit: usize, - window: PyWindow, - ) -> PyResult, f32)>> { - let window = translate_window(window); - let graphs = &self.0.graphs; - let cluster = VectorisedCluster::new(&graphs); - let graph_entry = graphs.iter().next(); - let (_, first_graph) = graph_entry - .expect("trying to search documents with no vectorised graphs on the server"); - let embedding = compute_embedding(first_graph, query)?; - let documents = cluster.search_graph_documents_with_scores(&embedding, limit, window); - Ok(documents - .into_iter() - .map(|(doc, score)| (doc.into_dynamic(), score)) - .collect()) - } - - /// Return the `VectorisedGraph` with name `name` or `None` if it doesn't exist - /// - /// Arguments: - /// name (str): the name of the graph - /// Returns: - /// Optional[VectorisedGraph]: the graph if it exists - fn get(&self, name: &str) -> Option { - self.0.graphs.get(name).map(|graph| graph.clone().into()) - } -} diff --git a/raphtory-graphql/src/python/mod.rs b/raphtory-graphql/src/python/mod.rs index c2576f3e93..d380bcb044 100644 --- a/raphtory-graphql/src/python/mod.rs +++ b/raphtory-graphql/src/python/mod.rs @@ -10,7 +10,6 @@ use raphtory::{db::api::view::MaterializedGraph, python::utils::errors::adapt_er use serde_json::{Map, Number, Value as JsonValue}; pub mod client; -pub mod global_plugins; pub mod pymodule; pub mod server; diff --git a/raphtory-graphql/src/python/pymodule.rs b/raphtory-graphql/src/python/pymodule.rs index 8a16d967b8..59041eaa29 100644 --- a/raphtory-graphql/src/python/pymodule.rs +++ b/raphtory-graphql/src/python/pymodule.rs @@ -4,14 +4,12 @@ use crate::python::{ remote_node::PyRemoteNode, PyEdgeAddition, PyNodeAddition, PyUpdate, }, decode_graph, encode_graph, - global_plugins::PyGlobalPlugins, server::{running_server::PyRunningGraphServer, server::PyGraphServer}, }; use pyo3::prelude::*; pub fn base_graphql_module(py: Python<'_>) -> Result, PyErr> { let graphql_module = PyModule::new(py, "graphql")?; - graphql_module.add_class::()?; graphql_module.add_class::()?; graphql_module.add_class::()?; graphql_module.add_class::()?; diff --git a/raphtory-graphql/src/python/server/server.rs b/raphtory-graphql/src/python/server/server.rs index 13eef300cc..34672c7282 100644 --- a/raphtory-graphql/src/python/server/server.rs +++ b/raphtory-graphql/src/python/server/server.rs @@ -1,39 +1,23 @@ use crate::{ config::{app_config::AppConfigBuilder, auth_config::PUBLIC_KEY_DECODING_ERR_MSG}, - model::{ - algorithms::document::Document as GqlDocument, - plugins::{entry_point::EntryPoint, query_plugin::QueryPlugin}, - }, - python::{ - adapt_graphql_value, - global_plugins::PyGlobalPlugins, - server::{ - running_server::PyRunningGraphServer, take_server_ownership, wait_server, BridgeCommand, - }, + python::server::{ + running_server::PyRunningGraphServer, take_server_ownership, wait_server, BridgeCommand, }, GraphServer, }; -use async_graphql::dynamic::{Field, FieldFuture, FieldValue, InputValue, Object, TypeRef}; -use dynamic_graphql::internal::{Registry, TypeName}; -use itertools::intersperse; use pyo3::{ exceptions::{PyAttributeError, PyException, PyValueError}, prelude::*, - types::{IntoPyDict, PyFunction, PyList}, - IntoPyObjectExt, + types::PyFunction, }; use raphtory::{ - db::api::view::DynamicGraph, - python::{packages::vectors::TemplateConfig, types::wrappers::document::PyDocument}, + python::packages::vectors::TemplateConfig, vectors::{ - embeddings::openai_embedding, - template::{ - DocumentTemplate, DEFAULT_EDGE_TEMPLATE, DEFAULT_GRAPH_TEMPLATE, DEFAULT_NODE_TEMPLATE, - }, - Document, EmbeddingFunction, + embeddings::{openai_embedding, EmbeddingFunction}, + template::{DocumentTemplate, DEFAULT_EDGE_TEMPLATE, DEFAULT_NODE_TEMPLATE}, }, }; -use std::{collections::HashMap, path::PathBuf, sync::Arc, thread}; +use std::{path::PathBuf, sync::Arc, thread}; /// A class for defining and running a Raphtory GraphQL server /// @@ -60,16 +44,11 @@ impl<'py> IntoPyObject<'py> for GraphServer { } } -fn template_from_python( - graphs: TemplateConfig, - nodes: TemplateConfig, - edges: TemplateConfig, -) -> Option { - if graphs.is_disabled() && nodes.is_disabled() && edges.is_disabled() { +fn template_from_python(nodes: TemplateConfig, edges: TemplateConfig) -> Option { + if nodes.is_disabled() && edges.is_disabled() { None } else { Some(DocumentTemplate { - graph_template: graphs.get_template_or(DEFAULT_GRAPH_TEMPLATE), node_template: nodes.get_template_or(DEFAULT_NODE_TEMPLATE), edge_template: edges.get_template_or(DEFAULT_EDGE_TEMPLATE), }) @@ -85,85 +64,14 @@ impl PyGraphServer { slf: PyRefMut, cache: String, embedding: F, - graphs: TemplateConfig, nodes: TemplateConfig, edges: TemplateConfig, ) -> PyResult { - let global_template = template_from_python(graphs, nodes, edges); + let global_template = template_from_python(nodes, edges); let server = take_server_ownership(slf)?; let cache = PathBuf::from(cache); - Ok(server.set_embeddings(embedding, &cache, global_template)) - } - - fn with_generic_document_search_function< - 'a, - E: EntryPoint<'a> + 'static, - F: Fn(&E, Python) -> PyObject + Send + Sync + 'static, - >( - slf: PyRefMut, - name: String, - input: HashMap, - function: Py, - adapter: F, - ) -> PyResult { - let input_mapper = HashMap::from([ - ("str", TypeRef::named_nn(TypeRef::STRING)), - ("int", TypeRef::named_nn(TypeRef::INT)), - ("float", TypeRef::named_nn(TypeRef::FLOAT)), - ]); - - let input_values = input - .into_iter() - .map(|(name, type_name)| { - let type_ref = input_mapper.get(&type_name.as_str()).cloned(); - type_ref - .map(|type_ref| InputValue::new(name, type_ref)) - .ok_or_else(|| { - let valid_types = input_mapper.keys().map(|key| key.to_owned()); - let valid_types_string: String = intersperse(valid_types, ", ").collect(); - let msg = format!("types in input have to be one of: {valid_types_string}"); - PyAttributeError::new_err(msg) - }) - }) - .collect::>>()?; - - // FIXME: this should return a result! - let register_function = |name: &str, registry: Registry, parent: Object| { - let registry = registry.register::(); - let output_type = TypeRef::named_nn_list_nn(GqlDocument::get_type_name()); - let mut field = Field::new(name, output_type, move |ctx| { - let documents: Vec> = Python::with_gil(|py| { - let entry_point = adapter(ctx.parent_value.downcast_ref().unwrap(), py); - let kw_args: HashMap<&str, PyObject> = ctx - .args - .iter() - .map(|(name, value)| (name.as_str(), adapt_graphql_value(&value, py))) - .collect(); - let py_kw_args = kw_args.into_py_dict(py).unwrap(); - let result = function - .call(py, (entry_point,), Some(&py_kw_args)) - .unwrap(); - let list = result.downcast_bound::(py).unwrap(); - let py_documents = list.iter().map(|doc| doc.extract::().unwrap()); - py_documents.map(|doc| doc.into()).collect() - }); - - let gql_documents = documents - .into_iter() - .map(|doc| FieldValue::owned_any(GqlDocument::from(doc))); - - FieldFuture::Value(Some(FieldValue::list(gql_documents))) - }); - for input_value in input_values { - field = field.argument(input_value); - } - let parent = parent.field(field); - (registry, parent) - }; - E::lock_plugins().insert(name, Box::new(register_function)); - - let new_server = take_server_ownership(slf)?; - Ok(new_server) + let rt = tokio::runtime::Runtime::new().unwrap(); + Ok(rt.block_on(server.set_embeddings(embedding, &cache, global_template))?) } } @@ -236,31 +144,27 @@ impl PyGraphServer { /// Arguments: /// cache (str): the directory to use as cache for the embeddings. /// embedding (Callable, optional): the embedding function to translate documents to embeddings. - /// graphs (bool | str): if graphs have to be embedded or not or the custom template to use if a str is provided. Defaults to True. /// nodes (bool | str): if nodes have to be embedded or not or the custom template to use if a str is provided. Defaults to True. /// edges (bool | str): if edges have to be embedded or not or the custom template to use if a str is provided. Defaults to True. /// /// Returns: /// GraphServer: A new server object with embeddings setup. #[pyo3( - signature = (cache, embedding = None, graphs = TemplateConfig::Bool(true), nodes = TemplateConfig::Bool(true), edges = TemplateConfig::Bool(true)) + signature = (cache, embedding = None, nodes = TemplateConfig::Bool(true), edges = TemplateConfig::Bool(true)) )] fn set_embeddings( slf: PyRefMut, cache: String, embedding: Option>, - graphs: TemplateConfig, nodes: TemplateConfig, edges: TemplateConfig, ) -> PyResult { match embedding { Some(embedding) => { let embedding: Arc = Arc::new(embedding); - Self::set_generic_embeddings(slf, cache, embedding, graphs, nodes, edges) - } - None => { - Self::set_generic_embeddings(slf, cache, openai_embedding, graphs, nodes, edges) + Self::set_generic_embeddings(slf, cache, embedding, nodes, edges) } + None => Self::set_generic_embeddings(slf, cache, openai_embedding, nodes, edges), } } @@ -268,59 +172,28 @@ impl PyGraphServer { /// /// Arguments: /// graph_names (list[str]): the names of the graphs to vectorise. All by default. - /// graphs (bool | str): if graphs have to be embedded or not or the custom template to use if a str is provided. Defaults to True. /// nodes (bool | str): if nodes have to be embedded or not or the custom template to use if a str is provided. Defaults to True. /// edges (bool | str): if edges have to be embedded or not or the custom template to use if a str is provided. Defaults to True. /// /// Returns: /// GraphServer: A new server object containing the vectorised graphs. #[pyo3( - signature = (graph_names, graphs = TemplateConfig::Bool(true), nodes = TemplateConfig::Bool(true), edges = TemplateConfig::Bool(true)) + signature = (graph_names, nodes = TemplateConfig::Bool(true), edges = TemplateConfig::Bool(true)) )] fn with_vectorised_graphs( slf: PyRefMut, graph_names: Vec, // TODO: support more models by just providing a string, e.g. "openai", here and in the VectorisedGraph API - graphs: TemplateConfig, nodes: TemplateConfig, edges: TemplateConfig, ) -> PyResult { - let template = - template_from_python(graphs, nodes, edges).ok_or(PyAttributeError::new_err( - "some of graph_template, node_template, edge_template has to be set", - ))?; + let template = template_from_python(nodes, edges).ok_or(PyAttributeError::new_err( + "node_template and/or edge_template has to be set", + ))?; let server = take_server_ownership(slf)?; Ok(server.with_vectorised_graphs(graph_names, template)) } - /// Register a function in the GraphQL schema for document search among all the graphs. - /// - /// The function needs to take a `GraphqlGraphs` object as the first argument followed by a - /// pre-defined set of keyword arguments. Supported types are `str`, `int`, and `float`. - /// They have to be specified using the `input` parameter as a dict where the keys are the - /// names of the parameters and the values are the types, expressed as strings. - /// - /// Arguments: - /// name (str): the name of the function in the GraphQL schema. - /// input (dict[str, str]): the keyword arguments expected by the function. - /// function (Callable): the function to run. - /// - /// Returns: - /// GraphServer: A new server object with the function registered - pub fn with_global_search_function( - slf: PyRefMut, - name: String, - input: HashMap, - function: Py, - ) -> PyResult { - let adapter = |entry_point: &QueryPlugin, py: Python| { - PyGlobalPlugins(entry_point.clone()) - .into_py_any(py) - .unwrap() - }; - PyGraphServer::with_generic_document_search_function(slf, name, input, function, adapter) - } - /// Start the server and return a handle to it. /// /// Arguments: @@ -340,31 +213,28 @@ impl PyGraphServer { timeout_ms: u64, ) -> PyResult { let (sender, receiver) = crossbeam_channel::bounded::(1); - let server = take_server_ownership(slf)?; - let cloned_sender = sender.clone(); + let server = take_server_ownership(slf)?; + let join_handle = thread::spawn(move || { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap() - .block_on(async move { - let handler = server.start_with_port(port); - let running_server = handler.await?; - let tokio_sender = running_server._get_sender().clone(); - tokio::task::spawn_blocking(move || { - match receiver.recv().expect("Failed to wait for cancellation") { - BridgeCommand::StopServer => tokio_sender - .blocking_send(()) - .expect("Failed to send cancellation signal"), - BridgeCommand::StopListening => (), - } - }); - let result = running_server.wait().await; - _ = cloned_sender.send(BridgeCommand::StopListening); - result - }) + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async move { + let handler = server.start_with_port(port); + let running_server = handler.await?; + let tokio_sender = running_server._get_sender().clone(); + tokio::task::spawn_blocking(move || { + match receiver.recv().expect("Failed to wait for cancellation") { + BridgeCommand::StopServer => tokio_sender + .blocking_send(()) + .expect("Failed to send cancellation signal"), + BridgeCommand::StopListening => (), + } + }); + let result = running_server.wait().await; + _ = cloned_sender.send(BridgeCommand::StopListening); + result + }) }); let mut server = PyRunningGraphServer::new(join_handle, sender, port)?; diff --git a/raphtory-graphql/src/server.rs b/raphtory-graphql/src/server.rs index 41cf3b0df8..8be9c25240 100644 --- a/raphtory-graphql/src/server.rs +++ b/raphtory-graphql/src/server.rs @@ -21,12 +21,14 @@ use poem::{ middleware::{Cors, CorsEndpoint}, EndpointExt, Route, Server, }; -use raphtory::vectors::{template::DocumentTemplate, EmbeddingFunction}; +use raphtory::{ + core::utils::errors::GraphResult, + vectors::{cache::VectorCache, embeddings::EmbeddingFunction, template::DocumentTemplate}, +}; use serde_json::json; use std::{ - fs, + fs::create_dir_all, path::{Path, PathBuf}, - sync::Arc, }; use thiserror::Error; use tokio::{ @@ -89,7 +91,7 @@ impl GraphServer { config_path: Option, ) -> IoResult { if !work_dir.exists() { - fs::create_dir_all(&work_dir)?; + create_dir_all(&work_dir)?; } let config = load_config(app_config, config_path).map_err(|err| ServerError::ConfigError(err))?; @@ -103,21 +105,19 @@ impl GraphServer { self } - pub fn set_embeddings( + pub async fn set_embeddings( mut self, embedding: F, - cache: &Path, // TODO: maybe now that we are storing vectors we could bin the cache!!! + cache: &Path, // or maybe it could be in a standard location like /tmp/raphtory/embedding_cache global_template: Option, - ) -> Self { - let cache = Some(PathBuf::from(cache).into()).into(); + ) -> GraphResult { self.data.embedding_conf = Some(EmbeddingConf { - function: Arc::new(embedding), - cache, + cache: VectorCache::on_disk(cache, embedding)?, // TODO: better do this lazily, actually do it when running the server global_template, individual_templates: Default::default(), }); - self + Ok(self) } /// Vectorise a subset of the graphs of the server. @@ -343,9 +343,10 @@ mod server_tests { use chrono::prelude::*; use raphtory::{ prelude::{AdditionOps, Graph, StableEncode, NO_PROPS}, - vectors::{template::DocumentTemplate, Embedding, EmbeddingResult}, + vectors::{embeddings::EmbeddingResult, template::DocumentTemplate, Embedding}, }; use raphtory_api::core::utils::logging::global_info_logger; + use tempfile::tempdir; use tokio::time::{sleep, Duration}; use tracing::info; @@ -384,9 +385,11 @@ mod server_tests { node_template: Some("{{ name }}".to_owned()), ..Default::default() }; - let cache = Path::new("/tmp/graph-cache"); + let cache_dir = tempdir().unwrap(); let handler = server - .set_embeddings(failing_embedding, cache, Some(template)) + .set_embeddings(failing_embedding, cache_dir.path(), Some(template)) + .await + .unwrap() .start_with_port(0); sleep(Duration::from_secs(5)).await; handler.await.unwrap().stop().await diff --git a/raphtory/Cargo.toml b/raphtory/Cargo.toml index 8a31708615..d78b946fba 100644 --- a/raphtory/Cargo.toml +++ b/raphtory/Cargo.toml @@ -16,11 +16,11 @@ homepage.workspace = true [dependencies] raphtory-api = { path = "../raphtory-api", version = "0.15.1" } -arrow-ipc = {workspace = true} -arrow-array = {workspace = true, features = ["chrono-tz"]} -arrow-schema = {workspace = true} -arrow-buffer = {workspace = true} -arrow-data ={ workspace = true } +arrow-ipc = { workspace = true } +arrow-array = { workspace = true, features = ["chrono-tz"] } +arrow-schema = { workspace = true } +arrow-buffer = { workspace = true } +arrow-data = { workspace = true } hashbrown = { workspace = true } chrono = { workspace = true } itertools = { workspace = true } @@ -71,6 +71,10 @@ async-openai = { workspace = true, optional = true } bincode = { workspace = true, optional = true } minijinja = { workspace = true, optional = true } minijinja-contrib = { workspace = true, optional = true } +arroy = { workspace = true, optional = true } +heed = { workspace = true, optional = true } +sysinfo = { workspace = true, optional = true } +moka = { workspace = true, optional = true } # python binding optional dependencies pyo3 = { workspace = true, optional = true } @@ -84,7 +88,7 @@ arrow-json = { workspace = true, optional = true } memmap2 = { workspace = true, optional = true } tempfile = { workspace = true, optional = true } pometry-storage = { workspace = true, optional = true } -pyo3-arrow = {workspace = true, optional = true} +pyo3-arrow = { workspace = true, optional = true } prost = { workspace = true, optional = true } prost-types = { workspace = true, optional = true } @@ -100,8 +104,8 @@ pretty_assertions = { workspace = true } quickcheck = { workspace = true } quickcheck_macros = { workspace = true } tempfile = { workspace = true } -tokio = { workspace = true } # for vector testing -dotenv = { workspace = true } # for vector testing +tokio = { workspace = true } # for vector testing +dotenv = { workspace = true } # for vector testing streaming-stats = { workspace = true } proptest = { workspace = true } polars-core = { workspace = true, features = ["fmt"] } @@ -136,6 +140,11 @@ vectors = [ "dep:bincode", "dep:minijinja", "dep:minijinja-contrib", + "dep:arroy", + "dep:heed", + "dep:sysinfo", + "dep:moka", + "dep:tempfile", # also used for the storage feature ] # Enables generating the pyo3 python bindings diff --git a/raphtory/src/core/mod.rs b/raphtory/src/core/mod.rs index 650d61918f..645c1a12bf 100644 --- a/raphtory/src/core/mod.rs +++ b/raphtory/src/core/mod.rs @@ -57,33 +57,6 @@ pub mod utils; use crate::core::prop_array::PropArray; pub use raphtory_api::core::*; -#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Hash, Default)] -pub enum Lifespan { - Interval { - start: i64, - end: i64, - }, - Event { - time: i64, - }, - #[default] - Inherited, -} - -/// struct containing all the necessary information to allow Raphtory creating a document and -/// storing it -#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash, Default)] -pub struct DocumentInput { - pub content: String, - pub life: Lifespan, -} - -impl Display for DocumentInput { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.write_str(&self.content) - } -} - /// Denotes the types of properties allowed to be stored in the graph. #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] pub enum Prop { @@ -894,16 +867,6 @@ impl From for Value { } } -impl From for Value { - fn from(lifespan: Lifespan) -> Self { - match lifespan { - Lifespan::Interval { start, end } => json!({ "start": start, "end": end }), - Lifespan::Event { time } => json!({ "time": time }), - Lifespan::Inherited => Value::String("inherited".to_string()), - } - } -} - pub fn sort_comparable_props(props: Vec<&Prop>) -> Vec<&Prop> { // Filter out non-comparable props let mut comparable_props: Vec<_> = props diff --git a/raphtory/src/core/utils/errors.rs b/raphtory/src/core/utils/errors.rs index afc52bb84b..53f56aeaf3 100644 --- a/raphtory/src/core/utils/errors.rs +++ b/raphtory/src/core/utils/errors.rs @@ -233,14 +233,12 @@ pub enum GraphError { #[from] source: zip::result::ZipError, }, - #[cfg(feature = "vectors")] - #[error("bincode operation failed")] - BincodeError { - #[from] - source: bincode::Error, - }, - + #[error("Arroy error: {0}")] + ArroyError(#[from] arroy::Error), + #[cfg(feature = "vectors")] + #[error("Heed error: {0}")] + HeedError(#[from] heed::Error), #[cfg(feature = "arrow")] #[error("Failed to load graph: {0}")] LoadFailure(String), @@ -273,6 +271,10 @@ pub enum GraphError { source: Box, }, + #[cfg(feature = "vectors")] + #[error("The path {0} does not contain a vector DB")] + VectorDbDoesntExist(String), + #[cfg(feature = "search")] #[error("Index operation failed")] QueryError { diff --git a/raphtory/src/python/packages/vectors.rs b/raphtory/src/python/packages/vectors.rs index 644d41bcba..e40eb4740e 100644 --- a/raphtory/src/python/packages/vectors.rs +++ b/raphtory/src/python/packages/vectors.rs @@ -1,5 +1,5 @@ use crate::{ - core::utils::{errors::GraphError, time::IntoTime}, + core::utils::time::IntoTime, db::{ api::view::{DynamicGraph, IntoDynamic, MaterializedGraph, StaticGraphViewOps}, graph::{edge::EdgeView, node::NodeView}, @@ -10,13 +10,13 @@ use crate::{ utils::{execute_async_task, PyNodeRef, PyTime}, }, vectors::{ - template::{ - DocumentTemplate, DEFAULT_EDGE_TEMPLATE, DEFAULT_GRAPH_TEMPLATE, DEFAULT_NODE_TEMPLATE, - }, + cache::VectorCache, + embeddings::{EmbeddingFunction, EmbeddingResult}, + template::{DocumentTemplate, DEFAULT_EDGE_TEMPLATE, DEFAULT_NODE_TEMPLATE}, vector_selection::DynamicVectorSelection, vectorisable::Vectorisable, - vectorised_graph::{DynamicVectorisedGraph, VectorisedGraph}, - Document, DocumentEntity, Embedding, EmbeddingFunction, EmbeddingResult, + vectorised_graph::VectorisedGraph, + Document, DocumentEntity, Embedding, }, }; use futures_util::future::BoxFuture; @@ -27,6 +27,8 @@ use pyo3::{ types::{PyFunction, PyList}, }; +type DynamicVectorisedGraph = VectorisedGraph; + pub type PyWindow = Option<(PyTime, PyTime)>; pub fn translate_window(window: PyWindow) -> Option<(i64, i64)> { @@ -40,14 +42,17 @@ pub enum PyQuery { } impl PyQuery { - async fn into_embedding( + fn into_embedding( self, - embedding: &E, + graph: &VectorisedGraph, ) -> PyResult { match self { Self::Raw(query) => { - let result = embedding.call(vec![query]).await; - Ok(result.map_err(GraphError::from)?.remove(0)) + let cache = graph.cache.clone(); + let result = Ok(execute_async_task(move || async move { + cache.get_single(query).await + })?); + result } Self::Computed(embedding) => Ok(embedding), } @@ -83,13 +88,8 @@ impl Document { entity, content, embedding, - life, } = self; let entity = match entity { - DocumentEntity::Graph { name, graph } => DocumentEntity::Graph { - name, - graph: graph.into_dynamic(), - }, // TODO: define a common method node/edge.into_dynamic for NodeView, as this code is duplicated in model/graph/node.rs and model/graph/edge.rs DocumentEntity::Node(node) => DocumentEntity::Node(NodeView { base_graph: node.base_graph.into_dynamic(), @@ -107,7 +107,6 @@ impl Document { entity, content, embedding, - life, } } } @@ -152,47 +151,29 @@ impl PyGraphView { /// /// Args: /// embedding (Callable[[list], list]): the embedding function to translate documents to embeddings - /// cache (str, optional): the file to be used as a cache to avoid calling the embedding function - /// overwrite_cache (bool): whether or not to overwrite the cache if there are new embeddings. Defaults to False. - /// graph (bool | str): if the graph has to be embedded or not or the custom template to use if a str is provided. Defaults to True. /// nodes (bool | str): if nodes have to be embedded or not or the custom template to use if a str is provided. Defaults to True. /// edges (bool | str): if edges have to be embedded or not or the custom template to use if a str is provided. Defaults to True. - /// graph_name (str, optional): the name of the graph /// verbose (bool): whether or not to print logs reporting the progress. Defaults to False. /// /// Returns: /// VectorisedGraph: A VectorisedGraph with all the documents/embeddings computed and with an initial empty selection - #[pyo3(signature = (embedding, cache = None, overwrite_cache = false, graph = TemplateConfig::Bool(true), nodes = TemplateConfig::Bool(true), edges = TemplateConfig::Bool(true), graph_name = None, verbose = false))] + #[pyo3(signature = (embedding, nodes = TemplateConfig::Bool(true), edges = TemplateConfig::Bool(true), verbose = false))] fn vectorise( &self, embedding: Bound, - cache: Option, - overwrite_cache: bool, - graph: TemplateConfig, nodes: TemplateConfig, edges: TemplateConfig, - graph_name: Option, verbose: bool, ) -> PyResult { let template = DocumentTemplate { - graph_template: graph.get_template_or(DEFAULT_GRAPH_TEMPLATE), node_template: nodes.get_template_or(DEFAULT_NODE_TEMPLATE), edge_template: edges.get_template_or(DEFAULT_EDGE_TEMPLATE), }; let embedding = embedding.unbind(); - let cache = cache.map(|cache| cache.into()).into(); + let cache = VectorCache::in_memory(embedding); let graph = self.graph.clone(); execute_async_task(move || async move { - Ok(graph - .vectorise( - Box::new(embedding), - cache, - overwrite_cache, - template, - graph_name, - verbose, - ) - .await?) + Ok(graph.vectorise(cache, template, None, verbose).await?) }) } } @@ -236,46 +217,11 @@ impl<'py> IntoPyObject<'py> for DynamicVectorSelection { /// over those documents #[pymethods] impl PyVectorisedGraph { - /// Save the embeddings present in this graph to `file` so they can be further used in a call to `vectorise` - fn save_embeddings(&self, file: String) { - self.0.save_embeddings(file.into()); - } - /// Return an empty selection of documents fn empty_selection(&self) -> DynamicVectorSelection { self.0.empty_selection() } - /// Return all the graph level documents - /// - /// Returns: - /// list[Document]: list of graph level documents - pub fn get_graph_documents(&self) -> Vec> { - self.0.get_graph_documents() - } - - /// Search the top scoring documents according to `query` with no more than `limit` documents - /// - /// Args: - /// query (str | list): the text or the embedding to score against - /// limit (int): the maximum number of documents to search - /// window (Tuple[int | str, int | str], optional): the window where documents need to belong to in order to be considered - /// - /// Returns: - /// VectorSelection: The vector selection resulting from the search - #[pyo3(signature = (query, limit, window=None))] - pub fn documents_by_similarity( - &self, - query: PyQuery, - limit: usize, - window: PyWindow, - ) -> PyResult { - let embedding = compute_embedding(&self.0, query)?; - Ok(self - .0 - .documents_by_similarity(&embedding, limit, translate_window(window))) - } - /// Search the top scoring entities according to `query` with no more than `limit` entities /// /// Args: @@ -292,10 +238,10 @@ impl PyVectorisedGraph { limit: usize, window: PyWindow, ) -> PyResult { - let embedding = compute_embedding(&self.0, query)?; + let embedding = query.into_embedding(&self.0)?; Ok(self .0 - .entities_by_similarity(&embedding, limit, translate_window(window))) + .entities_by_similarity(&embedding, limit, translate_window(window))?) } /// Search the top scoring nodes according to `query` with no more than `limit` nodes @@ -314,10 +260,10 @@ impl PyVectorisedGraph { limit: usize, window: PyWindow, ) -> PyResult { - let embedding = compute_embedding(&self.0, query)?; + let embedding = query.into_embedding(&self.0)?; Ok(self .0 - .nodes_by_similarity(&embedding, limit, translate_window(window))) + .nodes_by_similarity(&embedding, limit, translate_window(window))?) } /// Search the top scoring edges according to `query` with no more than `limit` edges @@ -336,10 +282,10 @@ impl PyVectorisedGraph { limit: usize, window: PyWindow, ) -> PyResult { - let embedding = compute_embedding(&self.0, query)?; + let embedding = query.into_embedding(&self.0)?; Ok(self .0 - .edges_by_similarity(&embedding, limit, translate_window(window))) + .edges_by_similarity(&embedding, limit, translate_window(window))?) } } @@ -378,17 +324,16 @@ impl PyVectorSelection { /// /// Returns: /// list[Document]: list of documents in the current selection - fn get_documents(&self) -> Vec> { - // TODO: review if I can simplify this - self.0.get_documents() + fn get_documents(&self) -> PyResult>> { + Ok(self.0.get_documents()?) } /// Return the documents alongside their scores present in the current selection /// /// Returns: /// list[Tuple[Document, float]]: list of documents and scores - fn get_documents_with_scores(&self) -> Vec<(Document, f32)> { - self.0.get_documents_with_scores() + fn get_documents_with_scores(&self) -> PyResult, f32)>> { + Ok(self.0.get_documents_with_scores()?) } /// Add all the documents associated with the `nodes` to the current selection @@ -446,38 +391,6 @@ impl PyVectorSelection { self_.0.expand(hops, translate_window(window)) } - /// Add the top `limit` adjacent documents with higher score for `query` to the selection - /// - /// The expansion algorithm is a loop with two steps on each iteration: - /// 1. All the documents 1 hop away of some of the documents included on the selection (and - /// not already selected) are marked as candidates. - /// 2. Those candidates are added to the selection in descending order according to the - /// similarity score obtained against the `query`. - /// - /// This loops goes on until the current selection reaches a total of `limit` documents or - /// until no more documents are available - /// - /// Args: - /// query (str | list): the text or the embedding to score against - /// limit (int): the number of documents to add - /// window (Tuple[int | str, int | str], optional): the window where documents need to belong to in order to be considered - /// - /// Returns: - /// None: - #[pyo3(signature = (query, limit, window=None))] - fn expand_documents_by_similarity( - mut self_: PyRefMut<'_, Self>, - query: PyQuery, - limit: usize, - window: PyWindow, - ) -> PyResult<()> { - let embedding = compute_embedding(&self_.0.graph, query)?; - self_ - .0 - .expand_documents_by_similarity(&embedding, limit, translate_window(window)); - Ok(()) - } - /// Add the top `limit` adjacent entities with higher score for `query` to the selection /// /// The expansion algorithm is a loop with two steps on each iteration: @@ -503,10 +416,10 @@ impl PyVectorSelection { limit: usize, window: PyWindow, ) -> PyResult<()> { - let embedding = compute_embedding(&self_.0.graph, query)?; + let embedding = query.into_embedding(&self_.0.graph)?; self_ .0 - .expand_entities_by_similarity(&embedding, limit, translate_window(window)); + .expand_entities_by_similarity(&embedding, limit, translate_window(window))?; Ok(()) } @@ -528,10 +441,10 @@ impl PyVectorSelection { limit: usize, window: PyWindow, ) -> PyResult<()> { - let embedding = compute_embedding(&self_.0.graph, query)?; + let embedding = query.into_embedding(&self_.0.graph)?; self_ .0 - .expand_nodes_by_similarity(&embedding, limit, translate_window(window)); + .expand_nodes_by_similarity(&embedding, limit, translate_window(window))?; Ok(()) } @@ -553,22 +466,14 @@ impl PyVectorSelection { limit: usize, window: PyWindow, ) -> PyResult<()> { - let embedding = compute_embedding(&self_.0.graph, query)?; + let embedding = query.into_embedding(&self_.0.graph)?; self_ .0 - .expand_edges_by_similarity(&embedding, limit, translate_window(window)); + .expand_edges_by_similarity(&embedding, limit, translate_window(window))?; Ok(()) } } -pub fn compute_embedding( - vectors: &VectorisedGraph, - query: PyQuery, -) -> PyResult { - let embedding = vectors.embedding.clone(); - execute_async_task(move || async move { query.into_embedding(embedding.as_ref()).await }) -} - impl EmbeddingFunction for Py { fn call(&self, texts: Vec) -> BoxFuture<'static, EmbeddingResult>> { let embedding_function = Python::with_gil(|py| self.clone_ref(py)); diff --git a/raphtory/src/python/types/wrappers/document.rs b/raphtory/src/python/types/wrappers/document.rs index 3f4f912768..a7a6788bc7 100644 --- a/raphtory/src/python/types/wrappers/document.rs +++ b/raphtory/src/python/types/wrappers/document.rs @@ -1,37 +1,9 @@ use crate::{ - core::Lifespan, db::api::view::DynamicGraph, - python::{ - graph::views::graph_view::PyGraphView, - types::repr::{Repr, StructReprBuilder}, - }, + python::types::repr::{Repr, StructReprBuilder}, vectors::{Document, DocumentEntity, Embedding}, }; -use pyo3::{prelude::*, types::PyNone, IntoPyObjectExt}; - -impl<'py> IntoPyObject<'py> for Lifespan { - type Target = PyAny; - type Output = Bound<'py, PyAny>; - type Error = PyErr; - - fn into_pyobject(self, py: Python<'py>) -> Result { - Ok(match self { - Lifespan::Inherited => PyNone::get(py).to_owned().into_any(), - Lifespan::Event { time } => time.into_pyobject(py)?.into_any(), - Lifespan::Interval { start, end } => (start, end).into_pyobject(py)?.into_any(), - }) - } -} - -impl Repr for Lifespan { - fn repr(&self) -> String { - match self { - Lifespan::Interval { start, end } => (start, end).repr(), - Lifespan::Event { time } => time.repr(), - Lifespan::Inherited => "None".to_string(), - } - } -} +use pyo3::{prelude::*, IntoPyObjectExt}; /// A Document /// @@ -68,7 +40,6 @@ impl PyDocument { #[getter] fn entity(&self, py: Python) -> PyResult { match &self.0.entity { - DocumentEntity::Graph { graph, .. } => graph.clone().into_py_any(py), DocumentEntity::Node(entity) => entity.clone().into_py_any(py), DocumentEntity::Edge(entity) => entity.clone().into_py_any(py), } @@ -82,15 +53,6 @@ impl PyDocument { fn embedding(&self) -> PyEmbedding { PyEmbedding(self.0.embedding.clone()) } - - /// the life span - /// - /// Returns: - /// Optional[Union[int | Tuple[int, int]]]: - #[getter] - fn life(&self) -> Lifespan { - self.0.life - } } #[pyclass(name = "Embedding", module = "raphtory.vectors", frozen)] @@ -114,17 +76,12 @@ impl Repr for PyDocument { fn repr(&self) -> String { let repr = StructReprBuilder::new("Document"); let with_entity = match &self.0.entity { - DocumentEntity::Graph { graph, .. } => { - let graph = graph.clone(); - repr.add_field("entity", PyGraphView { graph }) - } DocumentEntity::Node(node) => repr.add_field("entity", node), DocumentEntity::Edge(edge) => repr.add_field("entity", edge), }; with_entity .add_field("content", &self.content()) .add_field("embedding", &self.embedding()) - .add_field("life", &self.life()) .finish() } } diff --git a/raphtory/src/vectors/cache.rs b/raphtory/src/vectors/cache.rs new file mode 100644 index 0000000000..bdd7b086f7 --- /dev/null +++ b/raphtory/src/vectors/cache.rs @@ -0,0 +1,265 @@ +use crate::{core::utils::errors::GraphResult, vectors::Embedding}; +use futures_util::StreamExt; +use heed::{types::SerdeBincode, Database, Env, EnvOpenOptions}; +use moka::sync::Cache; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::{ + collections::{HashMap, VecDeque}, + hash::{DefaultHasher, Hash, Hasher}, + path::Path, + sync::Arc, +}; + +use super::embeddings::EmbeddingFunction; + +const MAX_DISK_ITEMS: usize = 1_000_000; +const MAX_VECTOR_DIM: usize = 8960; +const MAX_TEXT_LENGTH: usize = 200_000; + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct CacheEntry { + key: String, + value: Embedding, +} +type VectorDb = Database, SerdeBincode>; + +enum VectorStore { + Mem(RwLock>), + Disk { env: Env, db: VectorDb }, +} + +impl VectorStore { + fn in_memory() -> Self { + Self::Mem(Default::default()) + } + fn on_disk(path: &Path) -> GraphResult { + let _ = std::fs::create_dir_all(path); + let page_size = 16384; + let max_size = + (MAX_DISK_ITEMS * (MAX_VECTOR_DIM * 4 + MAX_TEXT_LENGTH)) / page_size * page_size; + + let env = unsafe { EnvOpenOptions::new().map_size(max_size).open(&path) }?; + let mut wtxn = env.write_txn().unwrap(); + let db: VectorDb = env.create_database(&mut wtxn, None)?; + wtxn.commit()?; + Ok(Self::Disk { env, db }) + } + + fn get_disk_keys(&self) -> GraphResult> { + match self { + VectorStore::Mem(_) => Ok(vec![]), + VectorStore::Disk { env, db } => { + let rtxn = env.read_txn()?; + let iter = db.iter(&rtxn)?; + let result: Result, heed::Error> = + iter.map(|result| result.map(|(id, _)| id)).collect(); + Ok(result?) // TODO: simplify this?, use into inside of the map? + } + } + } + + fn get(&self, key: &u64) -> Option { + match self { + VectorStore::Mem(store) => store.read().get(key).cloned(), + VectorStore::Disk { env, db } => { + let rtxn = env.read_txn().ok()?; + db.get(&rtxn, key).ok()? + } + } + } + + fn insert(&self, key: u64, value: CacheEntry) { + match self { + VectorStore::Mem(store) => { + store.write().insert(key, value); + } + VectorStore::Disk { env, db } => { + if let Ok(mut wtxn) = env.write_txn() { + let _ = db.put(&mut wtxn, &key, &value); + let _ = wtxn.commit(); + } + } + } + } + + fn remove(&self, key: &u64) { + match self { + VectorStore::Mem(store) => { + store.write().remove(key); + } + VectorStore::Disk { env, db } => { + // this is a bit dangerous, because if delete ops fail and insert ops succeed, + // the cache might explode in size, but that is very unlikely to happen + if let Ok(mut wtxn) = env.write_txn() { + let _ = db.delete(&mut wtxn, key); + let _ = wtxn.commit(); + } + } + } + } +} + +#[derive(Clone)] +pub struct VectorCache { + store: Arc, + cache: Cache, + function: Arc, +} + +impl VectorCache { + pub fn in_memory(function: impl EmbeddingFunction + 'static) -> Self { + Self { + store: VectorStore::in_memory().into(), + cache: Cache::new(10), + function: Arc::new(function), + } + } + + pub fn on_disk(path: &Path, function: impl EmbeddingFunction + 'static) -> GraphResult { + let store: Arc<_> = VectorStore::on_disk(path)?.into(); + let cloned = store.clone(); + + let cache: Cache = Cache::builder() + .max_capacity(MAX_DISK_ITEMS as u64) + .eviction_listener(move |key: Arc, _value: (), _cause| cloned.remove(key.as_ref())) + .build(); + + for key in store.get_disk_keys()? { + cache.insert(key, ()); + } + + Ok(Self { + store, + cache, + function: Arc::new(function), + }) + } + + fn get(&self, text: &str) -> Option { + let hash = hash(text); + self.cache.get(&hash)?; + let entry = self.store.get(&hash)?; + if &entry.key == text { + Some(entry.value) + } else { + None + } + } + + fn insert(&self, text: String, vector: Embedding) { + let hash = hash(&text); + let entry = CacheEntry { + key: text, + value: vector, + }; + self.store.insert(hash, entry); + self.cache.insert(hash, ()); + } + + pub(super) async fn get_embeddings( + &self, + texts: Vec, + ) -> GraphResult + '_> { + // TODO: review, turned this into a vec only to make compute_embeddings work + let mut results: Vec<_> = futures_util::stream::iter(texts) + .then(|text| async move { + match self.get(&text) { + Some(cached) => (text, Some(cached)), + None => (text, None), + } + }) + .collect() + .await; + let misses: Vec<_> = results + .iter_mut() + .filter_map(|(text, vector)| match vector { + Some(_) => None, + None => Some(text.clone()), + }) + .collect(); + let mut fresh_vectors: VecDeque<_> = if misses.len() > 0 { + self.function.call(misses).await?.into() + } else { + vec![].into() + }; + let embeddings = results.into_iter().map(move |(text, vector)| match vector { + Some(vector) => vector, + None => { + let vector = fresh_vectors.pop_front().unwrap(); + self.insert(text, vector.clone()); + vector + } + }); + Ok(embeddings) + } + + pub async fn get_single(&self, text: String) -> GraphResult { + let mut embeddings = self.get_embeddings(vec![text]).await?; + Ok(embeddings.next().unwrap()) + } +} + +fn hash(text: &str) -> u64 { + let mut hasher = DefaultHasher::new(); + text.hash(&mut hasher); + hasher.finish() +} + +#[cfg(test)] +mod cache_tests { + use tempfile::tempdir; + + use crate::vectors::{embeddings::EmbeddingResult, Embedding}; + + use super::VectorCache; + + async fn placeholder_embedding(texts: Vec) -> EmbeddingResult> { + dbg!(texts); + todo!() + } + + async fn test_abstract_cache(cache: VectorCache) { + let vector_a: Embedding = [1.0].into(); + let vector_b: Embedding = [0.5].into(); + + assert_eq!(cache.get("a"), None); + assert_eq!(cache.get("b"), None); + + cache.insert("a".to_owned(), vector_a.clone()); + assert_eq!(cache.get("a"), Some(vector_a.clone())); + assert_eq!(cache.get("b"), None); + + cache.insert("b".to_owned(), vector_b.clone()); + assert_eq!(cache.get("a"), Some(vector_a)); + assert_eq!(cache.get("b"), Some(vector_b)); + } + + #[tokio::test] + async fn test_empty_request() { + let cache = VectorCache::in_memory(placeholder_embedding); + let result: Vec<_> = cache.get_embeddings(vec![]).await.unwrap().collect(); + assert_eq!(result, vec![]); + } + + #[tokio::test] + async fn test_cache() { + test_abstract_cache(VectorCache::in_memory(placeholder_embedding)).await; + let dir = tempdir().unwrap(); + test_abstract_cache(VectorCache::on_disk(dir.path(), placeholder_embedding).unwrap()).await; + } + + #[tokio::test] + async fn test_on_disk_cache() { + let vector: Embedding = [1.0].into(); + let dir = tempdir().unwrap(); + + { + let cache = VectorCache::on_disk(dir.path(), placeholder_embedding).unwrap(); + cache.insert("a".to_owned(), vector.clone()); + } // here the heed env gets closed + + let loaded_from_disk = VectorCache::on_disk(dir.path(), placeholder_embedding).unwrap(); + assert_eq!(loaded_from_disk.get("a"), Some(vector)) + } +} diff --git a/raphtory/src/vectors/db.rs b/raphtory/src/vectors/db.rs new file mode 100644 index 0000000000..a7a40247d5 --- /dev/null +++ b/raphtory/src/vectors/db.rs @@ -0,0 +1,273 @@ +use std::{ + collections::HashSet, + ops::Deref, + path::{Path, PathBuf}, + sync::{Arc, OnceLock}, +}; + +use arroy::{distances::Cosine, Database as ArroyDatabase, Reader, Writer}; +use futures_util::StreamExt; +use rand::{rngs::StdRng, SeedableRng}; +use sysinfo::System; +use tempfile::TempDir; + +use crate::{ + core::utils::errors::{GraphError, GraphResult}, + db::api::view::StaticGraphViewOps, +}; + +use super::{ + entity_ref::{EntityRef, IntoDbId}, + Embedding, +}; + +const LMDB_MAX_SIZE: usize = 1024 * 1024 * 1024 * 1024; // 1TB + +#[derive(Clone)] +pub(super) struct NodeDb(pub(super) VectorDb); + +impl Deref for NodeDb { + type Target = VectorDb; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl EntityDb for NodeDb { + fn from_vector_db(db: VectorDb) -> Self { + Self(db) + } + + fn get_db(&self) -> &VectorDb { + &self.0 + } + + fn into_entity_ref(id: u32) -> EntityRef { + EntityRef::Node(id) + } + + fn view_has_entity(entity: &EntityRef, view: &G) -> bool { + view.has_node(entity.as_node_gid(view).unwrap()) + } + + fn all_valid_entities(view: G) -> impl Iterator { + view.nodes().into_iter().map(|node| node.into_db_id()) + } +} + +#[derive(Clone)] +pub(super) struct EdgeDb(pub(super) VectorDb); + +impl Deref for EdgeDb { + type Target = VectorDb; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl EntityDb for EdgeDb { + fn from_vector_db(db: VectorDb) -> Self { + Self(db) + } + + fn get_db(&self) -> &VectorDb { + &self.0 + } + + fn into_entity_ref(id: u32) -> EntityRef { + EntityRef::Edge(id) + } + + fn view_has_entity(entity: &EntityRef, view: &G) -> bool { + let (src, dst) = entity.as_edge_gids(view).unwrap(); + view.has_edge(src, dst) // TODO: there should be a quicker way of chking of some edge exist by pid + } + + fn all_valid_entities(view: G) -> impl Iterator { + view.edges().into_iter().map(|edge| edge.into_db_id()) + } +} + +pub(super) trait EntityDb: Sized { + fn from_vector_db(db: VectorDb) -> Self; + fn get_db(&self) -> &VectorDb; + fn into_entity_ref(id: u32) -> EntityRef; + fn view_has_entity(entity: &EntityRef, view: &G) -> bool; + fn all_valid_entities(view: G) -> impl Iterator + 'static; + + async fn from_vectors( + vectors: impl futures_util::Stream> + Send, + path: Option, + ) -> GraphResult { + let db = VectorDb::from_vectors(vectors, path).await?; + Ok(Self::from_vector_db(db)) + } + + fn from_path(path: &Path) -> GraphResult { + VectorDb::from_path(path).map(Self::from_vector_db) + } + + fn top_k( + &self, + query: &Embedding, + k: usize, + view: Option, + filter: Option>, + ) -> GraphResult> { + let candidates: Option>> = match (view, filter) { + (None, None) => None, + (view, Some(filter)) => Some(Box::new( + filter + .into_iter() + .filter(move |entity| { + view.as_ref() + .map_or(true, |view| Self::view_has_entity(entity, view)) + }) + .map(|entity| entity.id()), + )), + (Some(view), None) => Some(Box::new(Self::all_valid_entities(view))), + }; + self.top_k_with_candidates(query, k, candidates) + } + + fn top_k_with_candidates( + &self, + query: &Embedding, + k: usize, + candidates: Option>, + ) -> GraphResult> { + let db = self.get_db(); + let rtxn = db.env.read_txn()?; + let vectors = match Reader::open(&rtxn, 0, db.vectors) { + Ok(reader) => { + let mut query_builder = reader.nns(k); + let candidates = candidates.map(|filter| roaring::RoaringBitmap::from_iter(filter)); + let query_builder = if let Some(filter) = &candidates { + query_builder.candidates(filter) + } else { + &query_builder + }; + query_builder.by_vector(&rtxn, query.as_ref())? + } + Err(arroy::Error::MissingMetadata(_)) => vec![], // this just means the db is empty + Err(error) => return Err(error.into()), + }; + Ok(vectors + .into_iter() + // for arroy, distance = (1.0 - score) / 2.0, where score is cosine: [-1, 1] + .map(|(id, distance)| (Self::into_entity_ref(id), 1.0 - 2.0 * distance))) + } +} + +#[derive(Clone)] +pub(crate) struct VectorDb { + pub(crate) vectors: ArroyDatabase, + pub(crate) env: heed::Env, + pub(crate) _tempdir: Option>, // do we really need this, is the file open not enough + pub(crate) dimensions: OnceLock, +} + +impl VectorDb { + pub(super) fn insert_vector(&self, id: usize, embedding: &Embedding) -> GraphResult<()> { + let mut wtxn = self.env.write_txn()?; + + let dimensions = self.dimensions.get_or_init(|| embedding.len()); + let writer = Writer::::new(self.vectors, 0, *dimensions); + writer.add_item(&mut wtxn, id as u32, embedding.as_ref())?; + + let mut rng = StdRng::from_entropy(); + writer.builder(&mut rng).build(&mut wtxn)?; + + wtxn.commit()?; + Ok(()) + } + + pub(super) fn get_id(&self, id: u32) -> GraphResult> { + let rtxn = self.env.read_txn()?; + let reader = Reader::open(&rtxn, 0, self.vectors)?; + let vector = reader.item_vector(&rtxn, id)?; + Ok(vector.map(|vector| vector.into())) + } + + fn from_path(path: &Path) -> GraphResult { + let env = open_env(path)?; + let rtxn = env.read_txn()?; + let db: ArroyDatabase = env + .open_database(&rtxn, None)? + .ok_or_else(|| GraphError::VectorDbDoesntExist(path.display().to_string()))?; + let first_vector = Reader::open(&rtxn, 0, db) + .ok() + .and_then(|reader| reader.iter(&rtxn).ok()?.next()?.ok()); + let dimensions = if let Some((_, vector)) = first_vector { + vector.len().into() + } else { + OnceLock::new() + }; + rtxn.commit()?; + Ok(Self { + vectors: db, + env, + _tempdir: None, + dimensions, + }) + } + + async fn from_vectors( + vectors: impl futures_util::Stream> + Send, + path: Option, + ) -> GraphResult { + let (env, tempdir) = match path { + Some(path) => { + std::fs::create_dir_all(&path)?; + (open_env(&path)?, None) + } + None => { + let tempdir = tempfile::tempdir()?; + (open_env(tempdir.path())?, Some(tempdir.into())) + } + }; + + let mut wtxn = env.write_txn()?; + let db: ArroyDatabase = env.create_database(&mut wtxn, None)?; + + futures_util::pin_mut!(vectors); + let first_vector = vectors.next().await; + let dimensions = if let Some(Ok((first_id, first_vector))) = first_vector { + let dimensions = first_vector.len(); + let writer = Writer::::new(db, 0, dimensions); + + writer.add_item(&mut wtxn, first_id, &first_vector)?; + while let Some(result) = vectors.next().await { + let (id, vector) = result?; + writer.add_item(&mut wtxn, id, &vector)?; + } + + // TODO: review this -> You can specify the number of trees to use or specify None. + let mut rng = StdRng::seed_from_u64(42); + writer + .builder(&mut rng) + .available_memory(System::new().total_memory() as usize / 2) + .build(&mut wtxn)?; + dimensions.into() + } else { + OnceLock::new() + }; + + wtxn.commit()?; + + Ok(Self { + vectors: db, + env, + _tempdir: tempdir.into(), + dimensions, + }) + } +} + +fn open_env(path: &Path) -> heed::Result { + unsafe { + heed::EnvOpenOptions::new() + .map_size(LMDB_MAX_SIZE) + .open(path) + } +} diff --git a/raphtory/src/vectors/document_ref.rs b/raphtory/src/vectors/document_ref.rs deleted file mode 100644 index e23bd744e4..0000000000 --- a/raphtory/src/vectors/document_ref.rs +++ /dev/null @@ -1,134 +0,0 @@ -use crate::{ - db::api::view::StaticGraphViewOps, - vectors::{entity_id::EntityId, template::DocumentTemplate, Document, Embedding, Lifespan}, -}; -use serde::{Deserialize, Serialize}; -use std::hash::{Hash, Hasher}; - -use super::DocumentEntity; - -/// this struct contains the minimum amount of information need to regenerate a document using a -/// template and to quickly apply windows over them -#[derive(Clone, Debug, Serialize, Deserialize)] -pub(crate) struct DocumentRef { - pub(crate) entity_id: EntityId, - index: usize, - pub(crate) embedding: Embedding, - pub(crate) life: Lifespan, -} - -impl Hash for DocumentRef { - fn hash(&self, state: &mut H) { - match &self.entity_id { - EntityId::Graph { .. } => (), - EntityId::Node { id } => id.hash(state), - EntityId::Edge { src, dst } => { - src.hash(state); - dst.hash(state); - } - }; - state.write_usize(self.index); - } -} - -impl PartialEq for DocumentRef { - fn eq(&self, other: &Self) -> bool { - self.entity_id == other.entity_id && self.index == other.index - } -} - -impl Eq for DocumentRef {} - -impl DocumentRef { - pub fn new(entity_id: EntityId, index: usize, embedding: Embedding, life: Lifespan) -> Self { - Self { - entity_id, - index, - embedding, - life, - } - } - #[allow(dead_code)] - pub fn id(&self) -> (EntityId, usize) { - (self.entity_id.clone(), self.index) - } - - // TODO: review -> does window really need to be an Option - /// This function expects a graph with a window that matches the one provided in `window` - pub fn exists_on_window(&self, graph: Option<&G>, window: &Option<(i64, i64)>) -> bool - where - G: StaticGraphViewOps, - { - match self.life { - Lifespan::Event { time } => { - self.entity_exists_in_graph(graph) - && window - .map(|(start, end)| start <= time && time < end) - .unwrap_or(true) - } - Lifespan::Interval { - start: doc_start, - end: doc_end, - } => { - self.entity_exists_in_graph(graph) - && window - .map(|(start, end)| doc_end > start && doc_start < end) - .unwrap_or(true) - } - Lifespan::Inherited => self.entity_exists_in_graph(graph), - } - } - - fn entity_exists_in_graph(&self, graph: Option<&G>) -> bool { - match &self.entity_id { - EntityId::Graph { .. } => true, // TODO: maybe consider dead a graph with no entities - EntityId::Node { id } => graph.map(|g| g.has_node(id)).unwrap_or(true), - EntityId::Edge { src, dst } => graph.map(|g| g.has_edge(src, dst)).unwrap_or(true), - // TODO: Edge should probably contain a layer filter that we can pass to has_edge() - } - } - - pub fn regenerate(&self, original_graph: &G, template: &DocumentTemplate) -> Document - where - G: StaticGraphViewOps, - { - // FIXME: there is a problem here. We need to use the original graph so the number of - // documents is the same and the index is therefore consistent. However, we want to return - // the document using the windowed values for the properties of the entities - let (entity, content) = match &self.entity_id { - EntityId::Graph { name } => ( - DocumentEntity::Graph { - name: name.clone(), - graph: original_graph.clone(), - }, - template - .graph(original_graph) - .nth(self.index) - .unwrap() - .content, - ), - EntityId::Node { id } => ( - DocumentEntity::Node(original_graph.node(id).unwrap()), - template - .node((&&original_graph).node(id).unwrap()) - .nth(self.index) - .unwrap() - .content, - ), - EntityId::Edge { src, dst } => ( - DocumentEntity::Edge(original_graph.edge(src, dst).unwrap()), - template - .edge(original_graph.edge(src, dst).unwrap().as_ref()) - .nth(self.index) - .unwrap() - .content, - ), - }; - Document { - entity, - content, - embedding: self.embedding.clone(), - life: self.life, - } - } -} diff --git a/raphtory/src/vectors/embedding_cache.rs b/raphtory/src/vectors/embedding_cache.rs deleted file mode 100644 index 55bb32954e..0000000000 --- a/raphtory/src/vectors/embedding_cache.rs +++ /dev/null @@ -1,73 +0,0 @@ -use crate::vectors::Embedding; -use parking_lot::RwLock; -use std::{ - collections::{hash_map::DefaultHasher, HashMap}, - fs::{create_dir_all, File}, - hash::{Hash, Hasher}, - io::{BufReader, BufWriter}, - path::PathBuf, -}; - -pub type CacheStore = HashMap; - -#[derive(Debug)] -pub struct EmbeddingCache { - cache: RwLock, // TODO: double check that we really need a RwLock !! - path: PathBuf, -} - -impl From for EmbeddingCache { - fn from(path: PathBuf) -> Self { - let inner_cache = Self::try_reading_from_disk(&path).unwrap_or(HashMap::new()); - let cache = RwLock::new(inner_cache); - Self { cache, path } - } -} - -impl From for EmbeddingCache { - fn from(path: String) -> Self { - PathBuf::from(path).into() - } -} - -impl EmbeddingCache { - pub(crate) fn new(path: PathBuf) -> Self { - let cache = RwLock::new(CacheStore::new()); - Self { cache, path } - } - - fn try_reading_from_disk(path: &PathBuf) -> Option { - let file = File::open(&path).ok()?; - let mut reader = BufReader::new(file); - bincode::deserialize_from(&mut reader).ok() - } - - pub(crate) fn get_embedding(&self, text: &str) -> Option { - let hash = Self::hash_text(text); - self.cache.read().get(&hash).cloned() - } - - pub(crate) fn upsert_embedding(&self, text: &str, embedding: Embedding) { - let hash = Self::hash_text(text); - self.cache.write().insert(hash, embedding); - } - - fn hash_text(text: &str) -> u64 { - let mut hasher = DefaultHasher::new(); - text.hash(&mut hasher); - hasher.finish() - } - - // TODO: remove entries that weren't read in the last usage - pub(crate) fn dump_to_disk(&self) { - self.path.parent().iter().for_each(|parent_path| { - create_dir_all(parent_path).expect("Impossible to use cache dir"); - }); - // TODO: print helpful error if the path is a directory, maybe when creating the cache - // instead of here to save the embedding model to be called - let file = File::create(&self.path).expect("Couldn't create file to store embedding cache"); - let mut writer = BufWriter::new(file); - bincode::serialize_into::<_, CacheStore>(&mut writer, &self.cache.read()) - .expect("Couldn't serialize embedding cache"); - } -} diff --git a/raphtory/src/vectors/embeddings.rs b/raphtory/src/vectors/embeddings.rs index d684446bf9..944bb47e2f 100644 --- a/raphtory/src/vectors/embeddings.rs +++ b/raphtory/src/vectors/embeddings.rs @@ -1,11 +1,39 @@ -use crate::vectors::Embedding; +use std::{future::Future, ops::Deref, pin::Pin, sync::Arc}; + +use crate::{core::utils::errors::GraphResult, vectors::Embedding}; use async_openai::{ types::{CreateEmbeddingRequest, EmbeddingInput}, Client, }; +use futures_util::{future::BoxFuture, Stream, StreamExt}; use tracing::info; -use super::EmbeddingResult; +use super::cache::VectorCache; + +const CHUNK_SIZE: usize = 1000; + +pub(crate) type EmbeddingError = Box; +pub type EmbeddingResult = Result; + +pub trait EmbeddingFunction: Send + Sync { + fn call(&self, texts: Vec) -> BoxFuture<'static, EmbeddingResult>>; +} + +impl EmbeddingFunction for T +where + T: Fn(Vec) -> F + Send + Sync, + F: Future>> + Send + 'static, +{ + fn call(&self, texts: Vec) -> BoxFuture<'static, EmbeddingResult>> { + Box::pin(self(texts)) + } +} + +impl EmbeddingFunction for Arc { + fn call(&self, texts: Vec) -> BoxFuture<'static, EmbeddingResult>> { + Box::pin(self.deref().call(texts)) + } +} pub async fn openai_embedding(texts: Vec) -> EmbeddingResult> { info!("computing embeddings for {} texts", texts.len()); @@ -26,32 +54,30 @@ pub async fn openai_embedding(texts: Vec) -> EmbeddingResult) -> Vec { -// info!("computing embeddings for {} texts", texts.len()); -// Python::with_gil(|py| { -// let sentence_transformers = py.import("sentence_transformers")?; -// let locals = [("sentence_transformers", sentence_transformers)].into_py_dict(py); -// locals.set_item("texts", texts); -// -// let pyarray: &PyArray2 = py -// .eval( -// &format!( -// "sentence_transformers.SentenceTransformer('thenlper/gte-small').encode(texts)" -// ), -// Some(locals), -// None, -// )? -// .extract()?; -// -// let readonly = pyarray.readonly(); -// let chunks = readonly.as_slice().unwrap().chunks(384).into_iter(); -// let embeddings = chunks -// .map(|chunk| chunk.iter().copied().collect_vec()) -// .collect_vec(); -// -// Ok::>, Box>(embeddings) -// }) -// .unwrap() -// } +pub(super) fn compute_embeddings<'a, I>( + documents: I, + cache: &'a VectorCache, +) -> impl Stream> + Send + 'a +where + I: Iterator + Send + 'a, +{ + futures_util::stream::iter(documents) + .chunks(CHUNK_SIZE) + .then(|chunk| async { + let texts = chunk.iter().map(|(_, text)| text.clone()).collect(); + let stream: Pin> + Send>> = + match cache.get_embeddings(texts).await { + Ok(embeddings) => { + let embedded: Vec<_> = chunk + .into_iter() + .zip(embeddings) + .map(|((id, _), vector)| Ok((id, vector))) + .collect(); // TODO: do I really need this collect? + Box::pin(futures_util::stream::iter(embedded)) + } + Err(error) => Box::pin(futures_util::stream::iter([Err(error)])), + }; + stream + }) + .flatten() +} diff --git a/raphtory/src/vectors/entity_id.rs b/raphtory/src/vectors/entity_id.rs deleted file mode 100644 index b710119f4f..0000000000 --- a/raphtory/src/vectors/entity_id.rs +++ /dev/null @@ -1,75 +0,0 @@ -use crate::{ - db::graph::{edge::EdgeView, node::NodeView}, - prelude::{EdgeViewOps, GraphViewOps, NodeViewOps}, -}; -use raphtory_api::core::entities::GID; -use serde::{Deserialize, Serialize}; -use std::fmt::{Display, Formatter}; - -#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)] -pub(crate) enum EntityId { - Graph { name: Option }, - Node { id: GID }, - Edge { src: GID, dst: GID }, -} - -impl EntityId { - pub(crate) fn for_graph(name: Option) -> Self { - Self::Graph { name } - } - - pub(crate) fn from_node<'graph, G: GraphViewOps<'graph>>(node: NodeView) -> Self { - Self::Node { id: node.id() } - } - - pub(crate) fn from_edge<'graph, G: GraphViewOps<'graph>>(edge: EdgeView) -> Self { - Self::Edge { - src: edge.src().id(), - dst: edge.dst().id(), - } - } - - #[allow(dead_code)] - pub(crate) fn is_graph(&self) -> bool { - match self { - EntityId::Graph { .. } => true, - EntityId::Node { .. } => false, - EntityId::Edge { .. } => false, - } - } - - pub(crate) fn is_node(&self) -> bool { - match self { - EntityId::Graph { .. } => false, - EntityId::Node { .. } => true, - EntityId::Edge { .. } => false, - } - } - - pub(crate) fn is_edge(&self) -> bool { - match self { - EntityId::Graph { .. } => false, - EntityId::Node { .. } => false, - EntityId::Edge { .. } => true, - } - } -} - -impl Display for EntityId { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - EntityId::Graph { name } => { - let graph_name = name.clone().unwrap_or("_unnamed".to_owned()); - f.write_str(&format!("graph:{graph_name}")) - } - EntityId::Node { id } => f.write_str(&id.to_str()), - EntityId::Edge { src, dst } => { - f.write_str(&src.to_str()) - .expect("src ID couldn't be serialized"); - f.write_str("-") - .expect("edge ID separator couldn't be serialized"); - f.write_str(&dst.to_str()) - } - } - } -} diff --git a/raphtory/src/vectors/entity_ref.rs b/raphtory/src/vectors/entity_ref.rs new file mode 100644 index 0000000000..55fdac1d40 --- /dev/null +++ b/raphtory/src/vectors/entity_ref.rs @@ -0,0 +1,90 @@ +use either::Either; +use raphtory_api::core::entities::GID; + +use crate::db::{ + api::{storage::graph::edges::edge_storage_ops::EdgeStorageOps, view::StaticGraphViewOps}, + graph::{edge::EdgeView, node::NodeView}, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(super) enum EntityRef { + Node(u32), + Edge(u32), +} + +impl From> for EntityRef { + fn from(value: NodeView) -> Self { + EntityRef::Node(value.into_db_id()) + } +} + +impl From> for EntityRef { + fn from(value: EdgeView) -> Self { + EntityRef::Edge(value.into_db_id()) + } +} + +impl EntityRef { + pub(super) fn id(&self) -> u32 { + match self { + EntityRef::Node(id) => *id, + EntityRef::Edge(id) => *id, + } + } + + pub(super) fn resolve_entity( + &self, + graph: &G, + ) -> Option, EdgeView>> { + match self.resolve_entity_gids(graph) { + Either::Left(node) => Some(Either::Left(graph.node(node)?)), + Either::Right((src, dst)) => Some(Either::Right(graph.edge(src, dst)?)), + } + } + + pub(super) fn as_node_view(&self, graph: &G) -> Option> { + self.resolve_entity(graph)?.left() + } + + pub(super) fn as_edge_view(&self, graph: &G) -> Option> { + self.resolve_entity(graph)?.right() + } + + pub(super) fn as_node_gid(&self, graph: &G) -> Option { + self.resolve_entity_gids(graph).left() + } + + pub(super) fn as_edge_gids(&self, graph: &G) -> Option<(GID, GID)> { + self.resolve_entity_gids(graph).right() + } + + fn resolve_entity_gids(&self, graph: &G) -> Either { + match self { + EntityRef::Node(id) => { + Either::Left(graph.node_id(raphtory_api::core::entities::VID((*id) as usize))) + } + EntityRef::Edge(id) => { + let edge = graph.core_edge((*id as usize).into()); + let src = graph.node_id(edge.src()); + let dst = graph.node_id(edge.dst()); + Either::Right((src, dst)) + } + } + } +} + +pub(super) trait IntoDbId { + fn into_db_id(self) -> u32; +} + +impl IntoDbId for NodeView { + fn into_db_id(self) -> u32 { + self.node.index() as u32 + } +} + +impl IntoDbId for EdgeView { + fn into_db_id(self) -> u32 { + self.edge.pid().0 as u32 + } +} diff --git a/raphtory/src/vectors/graph_entity.rs b/raphtory/src/vectors/graph_entity.rs deleted file mode 100644 index 51a6574e6f..0000000000 --- a/raphtory/src/vectors/graph_entity.rs +++ /dev/null @@ -1,107 +0,0 @@ -use crate::{ - db::{ - api::view::StaticGraphViewOps, - graph::{edge::EdgeView, node::NodeView}, - }, - prelude::{GraphViewOps, NodeViewOps}, -}; -use itertools::{chain, Itertools}; -use std::{collections::HashSet, fmt::Display}; - -pub trait GraphEntity: Sized { - fn generate_property_list( - &self, - time_fmt: &F, - filter_out: Vec<&str>, - force_static: Vec<&str>, - ) -> String - where - F: Fn(i64) -> D, - D: Display; -} - -impl<'graph, G: GraphViewOps<'graph>> GraphEntity for NodeView { - fn generate_property_list( - &self, - time_fmt: &F, - filter_out: Vec<&str>, - force_static: Vec<&str>, - ) -> String - where - F: Fn(i64) -> D, - D: Display, - { - let time_fmt = |time: i64| time_fmt(time).to_string(); - let missing = || "missing".to_owned(); - let min_time_fmt = self.earliest_time().map(time_fmt).unwrap_or_else(missing); - let min_time = format!("earliest activity: {}", min_time_fmt); - let max_time_fmt = self.latest_time().map(time_fmt).unwrap_or_else(missing); - let max_time = format!("latest activity: {}", max_time_fmt); - - let temporal_props = self - .properties() - .temporal() - .iter() - .filter(|(key, _)| !filter_out.contains(&key.as_ref())) - .filter(|(key, _)| !force_static.contains(&key.as_ref())) - .filter(|(_, v)| { - // the history of the temporal prop has more than one value - v.values() - .into_iter() - .map(|prop| prop.to_string()) - .unique() - .collect_vec() - .len() - > 1 - }) - .collect_vec(); - - let temporal_keys: HashSet<_> = temporal_props.iter().map(|(key, _)| key).collect(); - let temporal_props = temporal_props.iter().map(|(key, value)| { - let time_value_pairs = value.iter().map(|(k, v)| (k, v.to_string())); - let events = - time_value_pairs - .unique_by(|(_, value)| value.clone()) - .map(|(time, value)| { - let time = time_fmt(time); - format!("{key} changed to {value} at {time}") - }); - Itertools::intersperse(events, "\n".to_owned()).collect() - }); - - let prop_storage = self.properties(); - - let static_props = prop_storage - .keys() - .filter(|key| !filter_out.contains(&key.as_ref())) - .filter(|key| !temporal_keys.contains(key)) - .map(|key| { - let prop = prop_storage.get(&key).unwrap().to_string(); - let key = key.to_string(); - format!("{key}: {prop}") - }); - - let props = chain!(static_props, temporal_props).sorted_by(|a, b| a.len().cmp(&b.len())); - // We sort by length so when cutting out the tail of the document we don't remove small properties - - let lines = chain!([min_time, max_time], props); - Itertools::intersperse(lines, "\n".to_owned()).collect() - } -} - -impl GraphEntity for EdgeView { - // FIXME: implement this and remove underscore prefix from the parameter names - fn generate_property_list( - &self, - _time_fmt: &F, - _filter_out: Vec<&str>, - _force_static: Vec<&str>, - ) -> String - where - F: Fn(i64) -> D, - D: Display, - { - // TODO: not needed yet - "".to_owned() - } -} diff --git a/raphtory/src/vectors/mod.rs b/raphtory/src/vectors/mod.rs index d82805ab50..5ac02871eb 100644 --- a/raphtory/src/vectors/mod.rs +++ b/raphtory/src/vectors/mod.rs @@ -1,32 +1,26 @@ -use crate::{ - core::{DocumentInput, Lifespan}, - db::{ - api::view::StaticGraphViewOps, - graph::{edge::EdgeView, node::NodeView}, - }, +use crate::db::{ + api::view::StaticGraphViewOps, + graph::{edge::EdgeView, node::NodeView}, }; -use futures_util::future::BoxFuture; -use std::{error, future::Future, ops::Deref, sync::Arc}; +use std::sync::Arc; +pub mod cache; pub mod datetimeformat; -mod document_ref; -pub mod embedding_cache; +mod db; pub mod embeddings; -mod entity_id; -mod similarity_search_utils; +mod entity_ref; pub mod splitting; +mod storage; pub mod template; +mod utils; pub mod vector_selection; -mod vector_storage; pub mod vectorisable; -pub mod vectorised_cluster; pub mod vectorised_graph; pub type Embedding = Arc<[f32]>; #[derive(Debug, Clone)] pub enum DocumentEntity { - Graph { name: Option, graph: G }, Node(NodeView), Edge(EdgeView), } @@ -36,68 +30,19 @@ pub struct Document { pub entity: DocumentEntity, pub content: String, pub embedding: Embedding, - pub life: Lifespan, -} - -impl Lifespan { - #![allow(dead_code)] - pub(crate) fn event(time: i64) -> Self { - Self::Event { time } - } -} - -impl From for DocumentInput { - fn from(value: String) -> Self { - Self { - content: value, - life: Lifespan::Inherited, - } - } -} - -impl From<&str> for DocumentInput { - fn from(value: &str) -> Self { - Self { - content: value.to_owned(), - life: Lifespan::Inherited, - } - } -} - -pub(crate) type EmbeddingError = Box; -pub type EmbeddingResult = Result; - -pub trait EmbeddingFunction: Send + Sync { - fn call(&self, texts: Vec) -> BoxFuture<'static, EmbeddingResult>>; -} - -impl EmbeddingFunction for T -where - T: Fn(Vec) -> F + Send + Sync, - F: Future>> + Send + 'static, -{ - fn call(&self, texts: Vec) -> BoxFuture<'static, EmbeddingResult>> { - Box::pin(self(texts)) - } -} - -impl EmbeddingFunction for Arc { - fn call(&self, texts: Vec) -> BoxFuture<'static, EmbeddingResult>> { - Box::pin(self.deref().call(texts)) - } } #[cfg(test)] mod vector_tests { - use super::*; + use std::{fs::remove_dir_all, path::PathBuf}; + + use super::{embeddings::EmbeddingResult, *}; use crate::{ core::Prop, - prelude::{AdditionOps, Graph, GraphViewOps}, - vectors::{embeddings::openai_embedding, vectorisable::Vectorisable}, + prelude::*, + vectors::{cache::VectorCache, embeddings::openai_embedding, vectorisable::Vectorisable}, }; - use dotenv::dotenv; use itertools::Itertools; - use std::fs::remove_file; use template::DocumentTemplate; use tokio; @@ -116,7 +61,6 @@ mod vector_tests { fn custom_template() -> DocumentTemplate { DocumentTemplate { - graph_template: None, node_template: Some( "{{ name}} is a {{ node_type }} aged {{ properties.age }}".to_owned(), ), @@ -133,66 +77,41 @@ mod vector_tests { let g = Graph::new(); g.add_node(0, "test", NO_PROPS, None).unwrap(); - // the following succeeds with no cache set up - g.vectorise( - Box::new(fake_embedding), - None.into(), - true, - template.clone(), - None, - false, - ) - .await - .unwrap(); - - let path = "/tmp/raphtory/very/deep/path/embedding-cache-test"; - let _ = remove_file(path); + let path = PathBuf::from("/tmp/raphtory/very/deep/path/embedding-cache-test"); + let _ = remove_dir_all(&path); // the following creates the embeddings, and store them on the cache - g.vectorise( - Box::new(fake_embedding), - Some(path.to_owned().into()).into(), - true, - template.clone(), - None, - false, - ) - .await - .unwrap(); + { + let cache = VectorCache::on_disk(&path, fake_embedding).unwrap(); + g.vectorise(cache, template.clone(), None, false) + .await + .unwrap(); + } // the cache gets dropped here and the heed env released // the following uses the embeddings from the cache, so it doesn't call the panicking // embedding, which would make the test fail - g.vectorise( - Box::new(panicking_embedding), - Some(path.to_owned().into()).into(), - true, - template, - None, - false, - ) - .await - .unwrap(); + let cache = VectorCache::on_disk(&path, panicking_embedding).unwrap(); + g.vectorise(cache, template, None, false).await.unwrap(); } #[tokio::test] async fn test_empty_graph() { let template = custom_template(); let g = Graph::new(); - let cache = Some("/tmp/raphtory/vector-cache-lotr-test".to_owned().into()).into(); - let vectors = g - .vectorise(Box::new(fake_embedding), cache, true, template, None, false) - .await - .unwrap(); + let cache = VectorCache::in_memory(fake_embedding); + let vectors = g.vectorise(cache, template, None, false).await.unwrap(); let embedding: Embedding = fake_embedding(vec!["whatever".to_owned()]) .await .unwrap() .remove(0); - - let mut selection = vectors.documents_by_similarity(&embedding, 10, None); - selection.expand_documents_by_similarity(&embedding, 10, None); + let mut selection = vectors + .entities_by_similarity(&embedding, 10, None) + .unwrap(); + selection + .expand_entities_by_similarity(&embedding, 10, None) + .unwrap(); selection.expand(2, None); - let docs = selection.get_documents(); - + let docs = selection.get_documents().unwrap(); assert!(docs.is_empty()) } @@ -208,14 +127,9 @@ mod vector_tests { .unwrap(); let template = custom_template(); - let doc: DocumentInput = template - .node(g.node("Frodo").unwrap()) - .next() - .unwrap() - .into(); - let content = doc.content; + let content: String = template.node(g.node("Frodo").unwrap()).unwrap(); let expected_content = "Frodo is a hobbit aged 30"; - assert_eq!(content, expected_content); + assert_eq!(&content, expected_content); } #[test] @@ -225,141 +139,13 @@ mod vector_tests { .unwrap(); let template = custom_template(); - let doc: DocumentInput = template + let content: String = template .edge(g.edge("Frodo", "Gandalf").unwrap().as_ref()) - .next() - .unwrap() - .into(); - let content = doc.content; + .unwrap(); let expected_content = "Frodo appeared with Gandalf in lines: 0"; - assert_eq!(content, expected_content); + assert_eq!(&content, expected_content); } - // const FAKE_DOCUMENTS: [&str; 3] = ["doc1", "doc2", "doc3"]; - // struct FakeMultiDocumentTemplate; - - // impl DocumentTemplate for FakeMultiDocumentTemplate { - // fn graph(&self, graph: &G) -> Box> { - // DefaultTemplate.graph(graph) - // } - - // fn node(&self, _node: &NodeView) -> Box> { - // Box::new( - // Vec::from(FAKE_DOCUMENTS) - // .into_iter() - // .map(|text| text.into()), - // ) - // } - // fn edge(&self, _edge: EdgeView<&G, &G>) -> Box> { - // Box::new(std::iter::empty()) - // } - // } - - // #[tokio::test] - // async fn test_vector_store_with_multi_embedding() { - // let g = Graph::new(); - // g.add_node(0, "test", NO_PROPS, None).unwrap(); - - // let vectors = g - // .vectorise_with_template( - // Box::new(fake_embedding), - // Some(PathBuf::from("/tmp/raphtory/vector-cache-multi-test")), - // true, - // FakeMultiDocumentTemplate, - // false, - // ) - // .await; - - // let embedding = fake_embedding(vec!["whatever".to_owned()]).await.remove(0); - - // let mut selection = vectors.search_documents(&embedding, 1, None); - // selection.expand_documents_by_similarity(&embedding, 9, None); - // let docs = selection.get_documents(); - // assert_eq!(docs.len(), 3); - // // all documents are present in the result - // for doc_content in FAKE_DOCUMENTS { - // assert!( - // docs.iter().any(|doc| match doc { - // Document::Node { content, name, .. } => - // content == doc_content && name == "test", - // _ => false, - // }), - // "document {doc_content:?} is not present in the result: {docs:?}" - // ); - // } - // } - - // struct FakeTemplateWithIntervals; - - // impl DocumentTemplate for FakeTemplateWithIntervals { - // fn graph(&self, graph: &G) -> Box> { - // DefaultTemplate.graph(graph) - // } - - // fn node(&self, _node: &NodeView) -> Box> { - // let doc_event_20: DocumentInput = DocumentInput { - // content: "event at 20".to_owned(), - // life: Lifespan::Event { time: 20 }, - // }; - - // let doc_interval_30_40: DocumentInput = DocumentInput { - // content: "interval from 30 to 40".to_owned(), - // life: Lifespan::Interval { start: 30, end: 40 }, - // }; - // Box::new(vec![doc_event_20, doc_interval_30_40].into_iter()) - // } - // fn edge(&self, _edge: &EdgeView) -> Box> { - // Box::new(std::iter::empty()) - // } - // } - - // #[tokio::test] - // async fn test_vector_store_with_window() { - // let g = Graph::new(); - // g.add_node(0, "test", NO_PROPS, None).unwrap(); - // g.add_edge(40, "test", "test", NO_PROPS, None).unwrap(); - - // let vectors = g - // .vectorise_with_template( - // Box::new(fake_embedding), - // Some(PathBuf::from("/tmp/raphtory/vector-cache-window-test")), - // true, - // FakeTemplateWithIntervals, - // false, - // ) - // .await; - - // let embedding = fake_embedding(vec!["whatever".to_owned()]).await.remove(0); - // let mut selection = vectors.search_documents(&embedding, 1, None); - // selection.expand_documents_by_similarity(&embedding, 9, None); - // let docs = selection.get_documents(); - // assert_eq!(docs.len(), 2); - - // let mut selection = vectors.search_documents(&embedding, 1, Some((-10, 25))); - // selection.expand_documents_by_similarity(&embedding, 9, Some((-10, 25))); - // let docs = selection.get_documents(); - // assert!( - // match &docs[..] { - // [Document::Node { name, content, .. }] => - // name == "test" && content == "event at 20", - // _ => false, - // }, - // "{docs:?} has the wrong content" - // ); - - // let mut selection = vectors.search_documents(&embedding, 1, Some((35, 100))); - // selection.expand_documents_by_similarity(&embedding, 9, Some((35, 100))); - // let docs = selection.get_documents(); - // assert!( - // match &docs[..] { - // [Document::Node { name, content, .. }] => - // name == "test" && content == "interval from 30 to 40", - // _ => false, - // }, - // "{docs:?} has the wrong content" - // ); - // } - #[ignore = "this test needs an OpenAI API key to run"] #[tokio::test] async fn test_vector_store() { @@ -398,56 +184,49 @@ mod vector_tests { ) .unwrap(); - dotenv().ok(); + dotenv::dotenv().ok(); + let cache = VectorCache::in_memory(openai_embedding); let vectors = g - .vectorise( - Box::new(openai_embedding), - Some("/tmp/raphtory/vector-cache-lotr-test".to_owned().into()).into(), - true, - template, - None, - false, - ) + .vectorise(cache.clone(), template, None, false) .await .unwrap(); - let embedding = openai_embedding(vec!["Find a magician".to_owned()]) - .await - .unwrap() - .remove(0); + let query = "Find a magician".to_owned(); + let embedding = cache.get_single(query).await.unwrap(); let docs = vectors .nodes_by_similarity(&embedding, 1, None) - .get_documents(); + .unwrap() + .get_documents() + .unwrap(); // TODO: use the ids instead in all of these cases assert!(docs[0].content.contains("Gandalf is a wizard")); - let embedding = openai_embedding(vec!["Find a young person".to_owned()]) - .await - .unwrap() - .remove(0); + let query = "Find a young person".to_owned(); + let embedding = cache.get_single(query).await.unwrap(); let docs = vectors .nodes_by_similarity(&embedding, 1, None) - .get_documents(); + .unwrap() + .get_documents() + .unwrap(); assert!(docs[0].content.contains("Frodo is a hobbit")); // this fails when using gte-small // with window! - let embedding = openai_embedding(vec!["Find a young person".to_owned()]) - .await - .unwrap() - .remove(0); + let query = "Find a young person".to_owned(); + let embedding = cache.get_single(query).await.unwrap(); let docs = vectors .nodes_by_similarity(&embedding, 1, Some((1, 3))) - .get_documents(); - assert!(!docs[0].content.contains("Frodo is a hobbit")); // this fails when using gte-small - - let embedding = openai_embedding(vec!["Has anyone appeared with anyone else?".to_owned()]) - .await .unwrap() - .remove(0); + .get_documents() + .unwrap(); + assert!(!docs[0].content.contains("Frodo is a hobbit")); // this fails when using gte-small + let query = "Has anyone appeared with anyone else?".to_owned(); + let embedding = cache.get_single(query).await.unwrap(); let docs = vectors .edges_by_similarity(&embedding, 1, None) - .get_documents(); + .unwrap() + .get_documents() + .unwrap(); assert!(docs[0].content.contains("Frodo appeared with Gandalf")); } } diff --git a/raphtory/src/vectors/similarity_search_utils.rs b/raphtory/src/vectors/similarity_search_utils.rs deleted file mode 100644 index 03bc058ea3..0000000000 --- a/raphtory/src/vectors/similarity_search_utils.rs +++ /dev/null @@ -1,61 +0,0 @@ -use crate::vectors::{document_ref::DocumentRef, Embedding}; -use itertools::Itertools; - -use super::entity_id::EntityId; - -pub(crate) fn score_documents<'a, I>( - query: &'a Embedding, - documents: I, -) -> impl Iterator + 'a -where - I: IntoIterator + 'a, -{ - documents.into_iter().map(|doc| { - let score = cosine(query, &doc.embedding); - (doc, score) - }) -} - -/// the caller is responsible for filtering out empty document vectors -pub(crate) fn score_document_groups_by_highest<'a, I>( - query: &'a Embedding, - documents: I, -) -> impl Iterator), f32)> + 'a -where - I: IntoIterator)> + 'a, -{ - documents.into_iter().map(|group| { - let scores = group.1.iter().map(|doc| cosine(query, &doc.embedding)); - let highest_score = scores.max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); - (group, highest_score) - }) -} - -/// Returns the top k nodes in descending order -pub(crate) fn find_top_k<'a, I, T>(elements: I, k: usize) -> impl Iterator + 'a -where - I: Iterator + 'a, - T: 'static, -{ - // TODO: add optimization for when this is used -> don't maintain more candidates than the max number of documents to return !!! - elements - .sorted_by(|(_, score1), (_, score2)| score1.partial_cmp(&score2).unwrap().reverse()) - // We use reverse because default sorting is ascending but we want it descending - .take(k) -} - -fn cosine(vector1: &Embedding, vector2: &Embedding) -> f32 { - assert_eq!(vector1.len(), vector2.len()); - - let dot_product: f32 = vector1.iter().zip(vector2.iter()).map(|(x, y)| x * y).sum(); - let x_length: f32 = vector1.iter().map(|x| x * x).sum(); - let y_length: f32 = vector2.iter().map(|y| y * y).sum(); - // TODO: store the length of the vector as well so we don't need to recompute it - // Vectors are already normalized for ada but nor for all the models: - // see: https://platform.openai.com/docs/guides/embeddings/which-distance-function-should-i-use - - let normalized = dot_product / (x_length.sqrt() * y_length.sqrt()); - assert!(normalized <= 1.001); - assert!(normalized >= -1.001); - normalized -} diff --git a/raphtory/src/vectors/storage.rs b/raphtory/src/vectors/storage.rs new file mode 100644 index 0000000000..2740f6a376 --- /dev/null +++ b/raphtory/src/vectors/storage.rs @@ -0,0 +1,60 @@ +use serde::{Deserialize, Serialize}; +use std::{ + fs::File, + path::{Path, PathBuf}, +}; + +use crate::{ + core::utils::errors::{GraphError, GraphResult}, + db::api::view::StaticGraphViewOps, +}; + +use super::{ + cache::VectorCache, + db::{EdgeDb, EntityDb, NodeDb}, + template::DocumentTemplate, + vectorised_graph::VectorisedGraph, +}; + +#[derive(Serialize, Deserialize)] +pub(super) struct VectorMeta { + pub(super) template: DocumentTemplate, +} + +impl VectorMeta { + pub(super) fn write_to_path(&self, path: &Path) -> Result<(), GraphError> { + let file = File::create(meta_path(path))?; + serde_json::to_writer(file, self)?; + Ok(()) + } +} + +impl VectorisedGraph { + pub fn read_from_path(path: &Path, graph: G, cache: VectorCache) -> GraphResult { + let meta_string = std::fs::read_to_string(meta_path(path))?; + let meta: VectorMeta = serde_json::from_str(&meta_string)?; + + let node_db = NodeDb::from_path(&node_vectors_path(path))?; + let edge_db = EdgeDb::from_path(&edge_vectors_path(path))?; + + Ok(VectorisedGraph { + template: meta.template, + source_graph: graph, + cache, + node_db, + edge_db, + }) + } +} + +fn meta_path(path: &Path) -> PathBuf { + path.join("meta") +} + +pub(super) fn node_vectors_path(path: &Path) -> PathBuf { + path.join("nodes") +} + +pub(super) fn edge_vectors_path(path: &Path) -> PathBuf { + path.join("edges") +} diff --git a/raphtory/src/vectors/template.rs b/raphtory/src/vectors/template.rs index 4abb4e5501..2a6b1b802c 100644 --- a/raphtory/src/vectors/template.rs +++ b/raphtory/src/vectors/template.rs @@ -1,6 +1,6 @@ use super::datetimeformat::datetimeformat; use crate::{ - core::{DocumentInput, Prop}, + core::Prop, db::{ api::properties::TemporalPropertyView, graph::{edge::EdgeView, node::NodeView}, @@ -11,7 +11,7 @@ use minijinja::{ value::{Enumerator, Object}, Environment, Template, Value, }; -use raphtory_api::core::storage::arc_str::ArcStr; +use raphtory_api::core::storage::arc_str::{ArcStr, OptionAsStr}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::error; @@ -98,38 +98,6 @@ impl<'graph, G: GraphViewOps<'graph>> From> for NodeTemplateContext } } -#[derive(Serialize)] -struct GraphTemplateContext { - properties: Value, - constant_properties: Value, - temporal_properties: Value, -} - -// FIXME: boilerplate for the properties -impl<'graph, G: GraphViewOps<'graph>> From for GraphTemplateContext { - fn from(value: G) -> Self { - Self { - properties: value - .properties() - .iter() - .map(|(key, value)| (key.to_string(), value.clone())) - .collect(), - constant_properties: value - .properties() - .constant() - .iter() - .map(|(key, value)| (key.to_string(), value.clone())) - .collect(), - temporal_properties: value - .properties() - .temporal() - .iter() - .map(|(key, prop)| (key.to_string(), Into::::into(prop))) - .collect(), - } - } -} - // FIXME: this is eagerly allocating a lot of stuff, we should implement Object instead for Prop impl From for Value { fn from(value: Prop) -> Self { @@ -159,61 +127,29 @@ impl From for Value { #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct DocumentTemplate { - pub graph_template: Option, pub node_template: Option, pub edge_template: Option, } -fn empty_iter() -> Box + Send> { - Box::new(std::iter::empty()) -} - impl DocumentTemplate { - pub(crate) fn graph<'graph, G: GraphViewOps<'graph>>( - &self, - graph: G, - ) -> Box + Send> { - match &self.graph_template { - Some(template) => { - // TODO: create the environment only once and store it on the DocumentTemplate struct - let mut env = Environment::new(); - let template = build_template(&mut env, template); - match template.render(GraphTemplateContext::from(graph)) { - Ok(mut document) => { - truncate(&mut document); - Box::new(std::iter::once(document.into())) - } - Err(error) => { - error!("Template render failed for a node, skipping: {error}"); - empty_iter() - } - } - } - None => empty_iter(), - } - } - /// A function that translate a node into an iterator of documents pub(crate) fn node<'graph, G: GraphViewOps<'graph>>( &self, node: NodeView, - ) -> Box + Send> { - match &self.node_template { - Some(template) => { - let mut env = Environment::new(); - let template = build_template(&mut env, template); - match template.render(NodeTemplateContext::from(node)) { - Ok(mut document) => { - truncate(&mut document); - Box::new(std::iter::once(document.into())) - } - Err(error) => { - error!("Template render failed for a node, skipping: {error}"); - empty_iter() - } - } + ) -> Option { + let template = self.node_template.as_str()?; + let mut env = Environment::new(); + let template = build_template(&mut env, template); + match template.render(NodeTemplateContext::from(node.clone())) { + Ok(mut document) => { + truncate(&mut document); + Some(document) + } + Err(error) => { + let node = node.name(); + error!("Template render failed for a node {node}, skipping: {error}"); + None } - None => empty_iter(), } } @@ -221,23 +157,21 @@ impl DocumentTemplate { pub(crate) fn edge<'graph, G: GraphViewOps<'graph>>( &self, edge: EdgeView, - ) -> Box + Send> { - match &self.edge_template { - Some(template) => { - let mut env = Environment::new(); - let template = build_template(&mut env, template); - match template.render(EdgeTemplateContext::from(edge)) { - Ok(mut document) => { - truncate(&mut document); - Box::new(std::iter::once(document.into())) - } - Err(error) => { - error!("Template render failed for an edge, skipping: {error}"); - empty_iter() - } - } + ) -> Option { + let template = self.edge_template.as_str()?; + let mut env = Environment::new(); + let template = build_template(&mut env, template); + match template.render(EdgeTemplateContext::from(edge.clone())) { + Ok(mut document) => { + truncate(&mut document); + Some(document) + } + Err(error) => { + let src = edge.src().name(); + let dst = edge.dst().name(); + error!("Template render failed for edge {src}->{dst}, skipping: {error}"); + None } - None => empty_iter(), } } } @@ -307,17 +241,6 @@ pub const DEFAULT_EDGE_TEMPLATE: &str = - {{ time|datetimeformat }} {% endfor %}"; -pub const DEFAULT_GRAPH_TEMPLATE: &str = "Graph with the following properties: -{% for (key, value) in constant_properties|items %} -{{ key }}: {{ value }} -{% endfor %} -{% for (key, values) in temporal_properties|items %} -{{ key }}: -{% for (time, value) in values %} - - changed to {{ value }} at {{ time|datetimeformat }} -{% endfor %} -{% endfor %}"; - #[cfg(test)] mod template_tests { use indoc::indoc; @@ -351,12 +274,10 @@ mod template_tests { let template = DocumentTemplate { node_template: Some(DEFAULT_NODE_TEMPLATE.to_owned()), - graph_template: Some(DEFAULT_GRAPH_TEMPLATE.to_owned()), edge_template: Some(DEFAULT_EDGE_TEMPLATE.to_owned()), }; - let mut docs = template.node(graph.node("node1").unwrap()); - let rendered = docs.next().unwrap().content; + let rendered = template.node(graph.node("node1").unwrap()).unwrap(); let expected = indoc! {" Node node1 has the following properties: key1: value1 @@ -367,22 +288,15 @@ mod template_tests { "}; assert_eq!(&rendered, expected); - let mut docs = template.edge(graph.edge("node1", "node2").unwrap()); - let rendered = docs.next().unwrap().content; + let rendered = template + .edge(graph.edge("node1", "node2").unwrap()) + .unwrap(); let expected = indoc! {" There is an edge from node1 to node2 with events at: - Jan 1 1970 00:00 - Jan 1 1970 00:01 "}; assert_eq!(&rendered, expected); - - let mut docs = template.graph(graph); - let rendered = docs.next().unwrap().content; - let expected = indoc! {" - Graph with the following properties: - name: test-name - "}; - assert_eq!(&rendered, expected); } #[test] @@ -424,12 +338,10 @@ mod template_tests { "}; let template = DocumentTemplate { node_template: Some(node_template.to_owned()), - graph_template: None, edge_template: None, }; - let mut docs = template.node(graph.node("node1").unwrap()); - let rendered = docs.next().unwrap().content; + let rendered = template.node(graph.node("node1").unwrap()).unwrap(); let expected = indoc! {" node node1 is an unknown entity with the following properties: temp_test: @@ -442,8 +354,7 @@ mod template_tests { "}; assert_eq!(&rendered, expected); - let mut docs = template.node(graph.node("node2").unwrap()); - let rendered = docs.next().unwrap().content; + let rendered = template.node(graph.node("node2").unwrap()).unwrap(); let expected = indoc! {" node node2 is a person with the following properties: const_test: const_test_value "}; @@ -462,12 +373,10 @@ mod template_tests { "{{ (temporal_properties.temp|first).time|datetimeformat(format=\"long\") }}"; let template = DocumentTemplate { node_template: Some(node_template.to_owned()), - graph_template: None, edge_template: None, }; - let mut docs = template.node(graph.node("node1").unwrap()); - let rendered = docs.next().unwrap().content; + let rendered = template.node(graph.node("node1").unwrap()).unwrap(); let expected = "September 9 2024 09:08:01"; assert_eq!(&rendered, expected); } diff --git a/raphtory/src/vectors/utils.rs b/raphtory/src/vectors/utils.rs new file mode 100644 index 0000000000..7367486d47 --- /dev/null +++ b/raphtory/src/vectors/utils.rs @@ -0,0 +1,23 @@ +use crate::{ + db::{api::view::StaticGraphViewOps, graph::views::window_graph::WindowedGraph}, + prelude::TimeOps, +}; +use itertools::Itertools; + +/// Returns the top k docs in descending order +pub(crate) fn find_top_k<'a, I, T>(elements: I, k: usize) -> impl Iterator + 'a +where + I: Iterator + 'a, + T: 'static, +{ + elements + .sorted_by(|(_, score1), (_, score2)| score2.partial_cmp(score1).unwrap()) // desc ordering, thus the invertion + .take(k) +} + +pub(super) fn apply_window( + graph: &G, + window: Option<(i64, i64)>, +) -> Option> { + window.map(|(start, end)| graph.window(start, end)) +} diff --git a/raphtory/src/vectors/vector_selection.rs b/raphtory/src/vectors/vector_selection.rs index 7d6c6bdfa1..0299a6b840 100644 --- a/raphtory/src/vectors/vector_selection.rs +++ b/raphtory/src/vectors/vector_selection.rs @@ -1,11 +1,9 @@ -use itertools::{chain, Itertools}; -use std::{ - collections::{HashMap, HashSet}, - ops::Deref, -}; +use either::Either; +use itertools::Itertools; +use std::{collections::HashSet, usize}; use crate::{ - core::entities::nodes::node_ref::AsNodeRef, + core::{entities::nodes::node_ref::AsNodeRef, utils::errors::GraphResult}, db::{ api::view::{DynamicGraph, StaticGraphViewOps}, graph::{edge::EdgeView, node::NodeView}, @@ -14,11 +12,11 @@ use crate::{ }; use super::{ - document_ref::DocumentRef, - entity_id::EntityId, - similarity_search_utils::{find_top_k, score_document_groups_by_highest, score_documents}, + db::EntityDb, + entity_ref::EntityRef, + utils::{apply_window, find_top_k}, vectorised_graph::VectorisedGraph, - Document, Embedding, + Document, DocumentEntity, Embedding, }; #[derive(Clone, Copy)] @@ -28,113 +26,139 @@ enum ExpansionPath { Both, } +impl ExpansionPath { + fn includes_nodes(&self) -> bool { + matches!(self, Self::Nodes | Self::Both) + } + + fn includes_edges(&self) -> bool { + matches!(self, Self::Edges | Self::Both) + } +} + +#[derive(Debug, Clone)] +struct Selected(Vec<(EntityRef, f32)>); + +impl From> for Selected { + fn from(value: Vec<(EntityRef, f32)>) -> Self { + Self(value) + } +} + +impl Selected { + fn extend(&mut self, extension: impl IntoIterator) { + self.extend_with_limit(extension, usize::MAX); + } + + fn extend_with_limit( + &mut self, + extension: impl IntoIterator, + limit: usize, + ) { + let selection_set: HashSet = + HashSet::from_iter(self.0.iter().map(|(doc, _)| doc.clone())); + let new_docs = extension + .into_iter() + .unique_by(|(entity, _)| *entity) + .filter(|(entity, _)| !selection_set.contains(entity)) + .take(limit); + self.0.extend(new_docs); + } + + fn iter(&self) -> impl Iterator { + self.0.iter() + } + + fn len(&self) -> usize { + self.0.len() + } +} + pub type DynamicVectorSelection = VectorSelection; #[derive(Clone)] pub struct VectorSelection { pub(crate) graph: VectorisedGraph, - selected_docs: Vec<(DocumentRef, f32)>, + selected: Selected, // FIXME: this is a bit error prone, might contain duplicates } impl VectorSelection { - pub(crate) fn new(graph: VectorisedGraph) -> Self { + pub(crate) fn empty(graph: VectorisedGraph) -> Self { Self { graph, - selected_docs: vec![], + selected: vec![].into(), } } - pub(crate) fn new_with_preselection( - graph: VectorisedGraph, - docs: Vec<(DocumentRef, f32)>, - ) -> Self { + pub(super) fn new(graph: VectorisedGraph, docs: Vec<(EntityRef, f32)>) -> Self { Self { graph, - selected_docs: docs, + selected: docs.into(), } } /// Return the nodes present in the current selection pub fn nodes(&self) -> Vec> { - self.selected_docs + let g = &self.graph.source_graph; + self.selected .iter() - .unique_by(|(doc, _)| &doc.entity_id) - .filter_map(|(doc, _)| match &doc.entity_id { - EntityId::Node { id } => self.graph.source_graph.node(id), - _ => None, - }) - .collect_vec() + .filter_map(|(id, _)| id.as_node_view(g)) + .collect() } /// Return the edges present in the current selection pub fn edges(&self) -> Vec> { - self.selected_docs + let g = &self.graph.source_graph; + self.selected .iter() - .unique_by(|(doc, _)| &doc.entity_id) - .filter_map(|(doc, _)| match &doc.entity_id { - EntityId::Edge { src, dst } => self.graph.source_graph.edge(src, dst), - _ => None, - }) - .collect_vec() + .filter_map(|(id, _)| id.as_edge_view(g)) + .collect() } /// Return the documents present in the current selection - pub fn get_documents(&self) -> Vec> { - self.get_documents_with_scores() + pub fn get_documents(&self) -> GraphResult>> { + Ok(self + .get_documents_with_scores()? .into_iter() .map(|(doc, _)| doc) - .collect_vec() + .collect()) } /// Return the documents alongside their scores present in the current selection - pub fn get_documents_with_scores(&self) -> Vec<(Document, f32)> { - self.selected_docs + pub fn get_documents_with_scores(&self) -> GraphResult, f32)>> { + self.selected .iter() - .map(|(doc, score)| { - ( - doc.regenerate(&self.graph.source_graph, &self.graph.template), - *score, - ) - }) - .collect_vec() + .map(|(entity, score)| self.regenerate_doc(*entity).map(|doc| (doc, *score))) + .collect() } - /// Add all the documents associated with the `nodes` to the current selection + /// Add all `nodes` to the current selection /// /// Documents added by this call are assumed to have a score of 0. + /// If any node id doesn't exist it will be ignored /// /// # Arguments /// * nodes - a list of the node ids or nodes to add pub fn add_nodes(&mut self, nodes: Vec) { - let node_documents = self.graph.node_documents.read(); - let node_docs = nodes + let new_docs = nodes .into_iter() - .flat_map(|id| { - let node = self.graph.source_graph.node(id); - let opt = node.map(|node| node_documents.get(&EntityId::from_node(node))); - opt.flatten().unwrap_or(&self.graph.empty_vec) - }) - .map(|doc| (doc.clone(), 0.0)); - self.selected_docs = extend_selection(self.selected_docs.clone(), node_docs, usize::MAX); + .filter_map(|id| Some(self.graph.source_graph.node(id)?.into())); + self.selected.extend(new_docs.map(|doc| (doc, 0.0))); } - /// Add all the documents associated with the `edges` to the current selection + /// Add all `edges` to the current selection /// /// Documents added by this call are assumed to have a score of 0. + /// If any edge doesn't exist it will be ignored /// /// # Arguments /// * edges - a list of the edge ids or edges to add pub fn add_edges(&mut self, edges: Vec<(V, V)>) { - let edge_documents = self.graph.edge_documents.read(); - let edge_docs = edges + let new_docs = edges .into_iter() - .flat_map(|(src, dst)| { - let edge = self.graph.source_graph.edge(src, dst); - let opt = edge.map(|edge| edge_documents.get(&EntityId::from_edge(edge))); - opt.flatten().unwrap_or(&self.graph.empty_vec) - }) - .map(|doc| (doc.clone(), 0.0)); - self.selected_docs = extend_selection(self.selected_docs.clone(), edge_docs, usize::MAX); + .filter_map(|(src, dst)| Some(self.graph.source_graph.edge(src, dst)?.into())); + // self.extend_selection_with_refs(new_docs); + self.selected.extend(new_docs.map(|doc| (doc, 0.0))); } /// Append all the documents in `selection` to the current selection @@ -145,12 +169,7 @@ impl VectorSelection { /// # Returns /// The selection with the new documents pub fn append(&mut self, selection: &Self) -> &Self { - self.selected_docs = extend_selection( - self.selected_docs.clone(), - selection.selected_docs.clone().into_iter(), - usize::MAX, - ); - + self.selected.extend(selection.selected.iter().cloned()); self } @@ -165,64 +184,15 @@ impl VectorSelection { /// * hops - the number of hops to carry out the expansion /// * window - the window where documents need to belong to in order to be considered pub fn expand(&mut self, hops: usize, window: Option<(i64, i64)>) { - match window { - None => self.expand_with_window(hops, window, &self.graph.source_graph.clone()), - Some((start, end)) => { - let windowed_graph = self.graph.source_graph.window(start, end); - self.expand_with_window(hops, window, &windowed_graph) - } + let nodes = self.get_nodes_in_context(window, false); + let edges = self.get_edges_in_context(window, false); + let docs = nodes.into_iter().chain(edges).map(|entity| (entity, 0.0)); + self.selected.extend(docs); + if hops > 1 { + self.expand(hops - 1, window); } } - fn expand_with_window( - &mut self, - hops: usize, - window: Option<(i64, i64)>, - windowed_graph: &W, - ) { - let node_documents = self.graph.node_documents.read(); - let edge_documents = self.graph.edge_documents.read(); - for _ in 0..hops { - let context = self - .selected_docs - .iter() - .flat_map(|(doc, _)| { - self.get_context( - doc, - node_documents.deref(), - edge_documents.deref(), - windowed_graph, - window, - ) - }) - .map(|doc| (doc.clone(), 0.0)); - self.selected_docs = extend_selection(self.selected_docs.clone(), context, usize::MAX); - } - } - - /// Add the top `limit` adjacent documents with higher score for `query` to the selection - /// - /// The expansion algorithm is a loop with two steps on each iteration: - /// 1. All the documents 1 hop away of some of the documents included on the selection (and - /// not already selected) are marked as candidates. - /// 2. Those candidates are added to the selection in descending order according to the - /// similarity score obtained against the `query`. - /// - /// This loops goes on until the number of new documents reaches a total of `limit` - /// documents or until no more documents are available - /// - /// # Arguments - /// * query - the embedding to score against - /// * window - the window where documents need to belong to in order to be considered - pub fn expand_documents_by_similarity( - &mut self, - query: &Embedding, - limit: usize, - window: Option<(i64, i64)>, - ) { - self.expand_documents_by_similarity_with_path(query, limit, window, ExpansionPath::Both) - } - /// Add the top `limit` adjacent entities with higher score for `query` to the selection /// /// The expansion algorithm is a loop with two steps on each iteration: @@ -242,8 +212,8 @@ impl VectorSelection { query: &Embedding, limit: usize, window: Option<(i64, i64)>, - ) { - self.expand_entities_by_similarity_with_path(query, limit, window, ExpansionPath::Both) + ) -> GraphResult<()> { + self.expand_by_similarity(query, limit, window, ExpansionPath::Both) } /// Add the top `limit` adjacent nodes with higher score for `query` to the selection @@ -259,8 +229,8 @@ impl VectorSelection { query: &Embedding, limit: usize, window: Option<(i64, i64)>, - ) { - self.expand_entities_by_similarity_with_path(query, limit, window, ExpansionPath::Nodes) + ) -> GraphResult<()> { + self.expand_by_similarity(query, limit, window, ExpansionPath::Nodes) } /// Add the top `limit` adjacent edges with higher score for `query` to the selection @@ -276,391 +246,147 @@ impl VectorSelection { query: &Embedding, limit: usize, window: Option<(i64, i64)>, - ) { - self.expand_entities_by_similarity_with_path(query, limit, window, ExpansionPath::Edges) + ) -> GraphResult<()> { + self.expand_by_similarity(query, limit, window, ExpansionPath::Edges) } - fn expand_documents_by_similarity_with_path( + fn expand_by_similarity( &mut self, query: &Embedding, limit: usize, window: Option<(i64, i64)>, path: ExpansionPath, - ) { - match window { - None => self.expand_documents_by_similarity_with_path_and_window( - query, - limit, - window, - &self.graph.source_graph.clone(), - path, - ), - Some((start, end)) => { - let windowed_graph = self.graph.source_graph.window(start, end); - self.expand_documents_by_similarity_with_path_and_window( - query, - limit, - window, - &windowed_graph, - path, - ) - } + ) -> GraphResult<()> { + let g = &self.graph.source_graph; + let view = apply_window(g, window); + let initial_size = self.selected.len(); + + let nodes: Box> = if path.includes_nodes() { + let jump = matches!(path, ExpansionPath::Nodes); + let filter = self.get_nodes_in_context(window, jump); + let nodes = self + .graph + .node_db + .top_k(query, limit, view.clone(), Some(filter))?; + Box::new(nodes) + } else { + Box::new(std::iter::empty()) + }; + + let edges: Box> = if path.includes_edges() { + let jump = matches!(path, ExpansionPath::Edges); + let filter = self.get_edges_in_context(window, jump); + let edges = self.graph.edge_db.top_k(query, limit, view, Some(filter))?; + Box::new(edges) + } else { + Box::new(std::iter::empty()) + }; + + let docs = find_top_k(nodes.chain(edges), limit).collect::>(); // collect to remove lifetime + self.selected.extend_with_limit(docs, limit); + + let increment = self.selected.len() - initial_size; + if increment > 0 && increment < limit { + self.expand_by_similarity(query, limit, window, path)? } + Ok(()) } - /// this function only exists so that we can make the type of graph generic - fn expand_documents_by_similarity_with_path_and_window( - &mut self, - query: &Embedding, - limit: usize, - window: Option<(i64, i64)>, - windowed_graph: &W, - path: ExpansionPath, - ) { - let node_documents = self.graph.node_documents.read(); - let edge_documents = self.graph.edge_documents.read(); - - // let mut selected_docs = self.selected_docs.clone(); - let total_limit = self.selected_docs.len() + limit; - - while self.selected_docs.len() < total_limit { - let remaining = total_limit - self.selected_docs.len(); - let candidates = self - .selected_docs - .iter() - .flat_map(|(doc, _)| { - self.get_context( - doc, - node_documents.deref(), - edge_documents.deref(), - windowed_graph, - window, - ) - }) - .flat_map(|doc| match (path, doc.entity_id.clone()) { - // this is to hope from node->node or edge->edge - (ExpansionPath::Nodes, EntityId::Edge { .. }) - | (ExpansionPath::Edges, EntityId::Node { .. }) => self.get_context( - doc, - node_documents.deref(), - edge_documents.deref(), - windowed_graph, - window, - ), - _ => Box::new(std::iter::once(doc)), - }) - .filter(|doc| match path { - ExpansionPath::Both => true, - ExpansionPath::Nodes => doc.entity_id.is_node(), - ExpansionPath::Edges => doc.entity_id.is_edge(), - }); - - let scored_candidates = score_documents(query, candidates.cloned()); - let top_sorted_candidates = find_top_k(scored_candidates, usize::MAX); - self.selected_docs = extend_selection( - self.selected_docs.clone(), - top_sorted_candidates, - total_limit, - ); - - let new_remaining = total_limit - self.selected_docs.len(); - if new_remaining == remaining { - break; // TODO: try to move this to the top condition - } - } - } - - fn expand_entities_by_similarity_with_path( - &mut self, - query: &Embedding, - limit: usize, - window: Option<(i64, i64)>, - path: ExpansionPath, - ) { + fn get_nodes_in_context(&self, window: Option<(i64, i64)>, jump: bool) -> HashSet { match window { - None => self.expand_entities_by_similarity_with_path_and_window( - query, - limit, - window, - &self.graph.source_graph.clone(), - path, - ), - Some((start, end)) => { - let windowed_graph = self.graph.source_graph.window(start, end); - self.expand_entities_by_similarity_with_path_and_window( - query, - limit, - window, - &windowed_graph, - path, - ) - } + Some((start, end)) => self + .get_nodes_in_context_for_view(&self.graph.source_graph.window(start, end), jump), + None => self.get_nodes_in_context_for_view(&self.graph.source_graph, jump), } } - /// this function only exists so that we can make the type of graph generic - fn expand_entities_by_similarity_with_path_and_window( - &mut self, - query: &Embedding, - limit: usize, - window: Option<(i64, i64)>, - windowed_graph: &W, - path: ExpansionPath, - ) { - let total_entity_limit = self.get_selected_entity_len() + limit; - - while self.get_selected_entity_len() < total_entity_limit { - let remaining = total_entity_limit - self.get_selected_entity_len(); - - let candidates: Box)>> = match path { - ExpansionPath::Both => { - let node_doc_groups = self.selected_docs.iter().flat_map(|(doc, _)| { - self.get_nodes_in_context(doc, windowed_graph, window.clone()) - }); - let edge_doc_groups = self.selected_docs.iter().flat_map(|(doc, _)| { - self.get_edges_in_context(doc, windowed_graph, window) - }); - - Box::new(chain!(node_doc_groups, edge_doc_groups)) - } - ExpansionPath::Nodes => { - let groups = self.selected_docs.iter().flat_map(|(doc, _)| { - self.get_nodes_in_context(doc, windowed_graph, window) - }); - Box::new(groups) - } - ExpansionPath::Edges => { - let groups = self.selected_docs.iter().flat_map(|(doc, _)| { - self.get_edges_in_context(doc, windowed_graph, window) - }); - Box::new(groups) - } - }; - - let scored_candidates = score_document_groups_by_highest(query, candidates); - - let top_sorted_candidates = find_top_k(scored_candidates, usize::MAX); - self.selected_docs = - self.extend_selection_with_groups(top_sorted_candidates, total_entity_limit); - - let new_remaining = total_entity_limit - self.get_selected_entity_len(); - if new_remaining == remaining { - break; // TODO: try to move this to the top condition - } + fn get_edges_in_context(&self, window: Option<(i64, i64)>, jump: bool) -> HashSet { + match window { + Some((start, end)) => self + .get_edges_in_context_for_view(&self.graph.source_graph.window(start, end), jump), + None => self.get_edges_in_context_for_view(&self.graph.source_graph, jump), } } - fn get_selected_entity_id_set(&self) -> HashSet { - HashSet::from_iter(self.selected_docs.iter().map(|doc| doc.0.entity_id.clone())) + fn get_nodes_in_context_for_view( + &self, + v: &W, + jump: bool, + ) -> HashSet { + let iter = self.selected.iter(); + iter.flat_map(|(e, _)| e.get_adjacent_nodes(v, jump)) + .collect() } - fn get_selected_entity_len(&self) -> usize { - self.get_selected_entity_id_set().len() + fn get_edges_in_context_for_view( + &self, + v: &W, + jump: bool, + ) -> HashSet { + let iter = self.selected.iter(); + iter.flat_map(|(e, _)| e.get_adjacent_edges(v, jump)) + .collect() } - // this might return the document used as input, uniqueness need to be check outside of this - fn get_context<'a, W: StaticGraphViewOps>( - &'a self, - document: &DocumentRef, - node_documents: &'a HashMap>, - edge_documents: &'a HashMap>, - windowed_graph: &'a W, - window: Option<(i64, i64)>, - ) -> Box + 'a> { - match &document.entity_id { - EntityId::Graph { .. } => Box::new(std::iter::empty()), - EntityId::Node { id } => { - let self_docs = node_documents - .get(&document.entity_id) - .unwrap_or(&self.graph.empty_vec); - match windowed_graph.node(id) { - None => Box::new(std::iter::empty()), - Some(node) => { - let edges = node.edges(); - let edge_docs = edges.into_iter().flat_map(|edge| { - let edge_id = EntityId::from_edge(edge); - edge_documents - .get(&edge_id) - .unwrap_or(&self.graph.empty_vec) - }); - Box::new( - chain!(self_docs, edge_docs).filter(move |doc| { - doc.exists_on_window(Some(windowed_graph), &window) - }), - ) - } - } - } - EntityId::Edge { src, dst } => { - let self_docs = edge_documents - .get(&document.entity_id) - .unwrap_or(&self.graph.empty_vec); - match windowed_graph.edge(src, dst) { - None => Box::new(std::iter::empty()), - Some(edge) => { - let src_id = EntityId::from_node(edge.src()); - let dst_id = EntityId::from_node(edge.dst()); - let src_docs = node_documents.get(&src_id).unwrap_or(&self.graph.empty_vec); - let dst_docs = node_documents.get(&dst_id).unwrap_or(&self.graph.empty_vec); - Box::new( - chain!(self_docs, src_docs, dst_docs).filter(move |doc| { - doc.exists_on_window(Some(windowed_graph), &window) - }), - ) - } - } - } + fn regenerate_doc(&self, entity: EntityRef) -> GraphResult> { + match entity.resolve_entity(&self.graph.source_graph).unwrap() { + Either::Left(node) => Ok(Document { + entity: DocumentEntity::Node(node.clone()), + content: self.graph.template.node(node).unwrap(), + embedding: self.graph.node_db.get_id(entity.id())?.unwrap(), + }), + Either::Right(edge) => Ok(Document { + entity: DocumentEntity::Edge(edge.clone()), + content: self.graph.template.edge(edge).unwrap(), + embedding: self.graph.edge_db.get_id(entity.id())?.unwrap(), + }), } } +} - fn nodes_into_document_groups<'a, W: StaticGraphViewOps>( - &'a self, - nodes: impl Iterator> + 'static, - windowed_graph: &'a W, - window: Option<(i64, i64)>, - ) -> Box)> + 'a> { - let groups = nodes - .map(move |node| { - let entity_id = EntityId::from_node(node); - self.graph - .node_documents - .read() - .get(&entity_id) - .map(|group| { - let docs = group - .iter() - .filter(|doc| doc.exists_on_window(Some(windowed_graph), &window)) - .cloned() - .collect_vec(); - (entity_id, docs) - }) - }) - .flatten() - .filter(|(_, docs)| docs.len() > 0); - Box::new(groups) - } - - fn edges_into_document_groups<'a, W: StaticGraphViewOps>( - &'a self, - edges: impl Iterator> + 'a, - windowed_graph: &'a W, - window: Option<(i64, i64)>, - ) -> Box)> + 'a> { - let groups = edges - .map(move |edge| { - let entity_id = EntityId::from_edge(edge); - self.graph - .edge_documents - .read() - .get(&entity_id) - .map(|group| { - let docs = group - .iter() - .filter(|doc| doc.exists_on_window(Some(windowed_graph), &window)) - .cloned() - .collect_vec(); - (entity_id, docs) - }) - }) - .flatten() - .filter(|(_, docs)| docs.len() > 0); - Box::new(groups) - } - - fn get_nodes_in_context<'a, W: StaticGraphViewOps>( - &'a self, - document: &'a DocumentRef, - windowed_graph: &'a W, - window: Option<(i64, i64)>, - ) -> Box)> + 'a> { - match &document.entity_id { - EntityId::Graph { .. } => Box::new(std::iter::empty()), - EntityId::Node { id } => match windowed_graph.node(id) { - None => Box::new(std::iter::empty()), - Some(node) => { - let nodes = node.neighbours().iter(); // TODO: make nodes_into_document_groups more flexible - self.nodes_into_document_groups(nodes, windowed_graph, window) - } - }, - EntityId::Edge { src, dst } => match windowed_graph.edge(src, dst) { - None => Box::new(std::iter::empty()), - Some(edge) => { - let nodes = [edge.src(), edge.dst()].into_iter(); - self.nodes_into_document_groups(nodes, windowed_graph, window) +// TODO: I could make get_neighbour_nodes rely on get_neighbour_edges and viceversa, reusing some code +impl EntityRef { + fn get_adjacent_nodes( + &self, + view: &G, + jump: bool, + ) -> impl Iterator { + let nodes: Box>> = + if let Some(node) = self.as_node_view(view) { + if jump { + let docs = node.neighbours().into_iter(); + Box::new(docs) + } else { + Box::new(std::iter::empty()) } - }, - } + } else if let Some(edge) = self.as_edge_view(view) { + Box::new([edge.src(), edge.dst()].into_iter()) + } else { + Box::new(std::iter::empty()) + }; + nodes.map(|node| node.into()) } - fn get_edges_in_context<'a, W: StaticGraphViewOps>( - &'a self, - document: &DocumentRef, - windowed_graph: &'a W, - window: Option<(i64, i64)>, - ) -> Box)> + 'a> { - match &document.entity_id { - EntityId::Graph { .. } => Box::new(std::iter::empty()), - EntityId::Node { id } => match windowed_graph.node(id) { - None => Box::new(std::iter::empty()), - Some(node) => { - let edges = node.edges(); - self.edges_into_document_groups(edges.into_iter(), windowed_graph, window) - } - }, - EntityId::Edge { src, dst } => match windowed_graph.edge(src, dst) { - None => Box::new(std::iter::empty()), - Some(edge) => { + fn get_adjacent_edges( + &self, + view: &G, + jump: bool, + ) -> impl Iterator { + let edges: Box>> = + if let Some(node) = self.as_node_view(view) { + let docs = node.edges().into_iter(); + Box::new(docs) + } else if let Some(edge) = self.as_edge_view(view) { + if jump { let src_edges = edge.src().edges().into_iter(); let dst_edges = edge.dst().edges().into_iter(); - let edges = chain!(src_edges, dst_edges); - self.edges_into_document_groups(edges, windowed_graph, window) + Box::new(src_edges.chain(dst_edges)) + } else { + Box::new(std::iter::empty()) } - }, - } - } - - /// this is a wrapper around `extend_selection` for adding in full entities - fn extend_selection_with_groups<'a, I>( - &self, - extension: I, - total_entity_limit: usize, - ) -> Vec<(DocumentRef, f32)> - where - I: IntoIterator), f32)>, - { - let entity_set = self.get_selected_entity_id_set(); - let entity_extension_size = total_entity_limit - self.get_selected_entity_len(); - let new_unique_entities = extension - .into_iter() - .unique_by(|((entity_id, _), _score)| entity_id.clone()) - .filter(|((entity_id, _), _score)| !entity_set.contains(entity_id)) - .take(entity_extension_size); - let documents_to_add = new_unique_entities - .flat_map(|((_, docs), score)| docs.into_iter().map(move |doc| (doc.clone(), score))); - extend_selection(self.selected_docs.clone(), documents_to_add, usize::MAX) + } else { + Box::new(std::iter::empty()) + }; + edges.map(|edge| edge.into()) } } - -/// this function assumes that extension might contain duplicates and might contain elements -/// already present in selection, and returns a sequence with no repetitions and preserving the -/// elements in selection in the same indexes -fn extend_selection( - selection: Vec<(DocumentRef, f32)>, - extension: I, - new_total_size: usize, -) -> Vec<(DocumentRef, f32)> -where - I: IntoIterator, -{ - let selection_set: HashSet = - HashSet::from_iter(selection.iter().map(|(doc, _)| doc.clone())); - let new_docs = extension - .into_iter() - .unique_by(|(doc, _)| doc.clone()) - .filter(|(doc, _)| !selection_set.contains(doc)); - selection - .into_iter() - .chain(new_docs) - .take(new_total_size) - .collect_vec() -} diff --git a/raphtory/src/vectors/vector_storage.rs b/raphtory/src/vectors/vector_storage.rs deleted file mode 100644 index 50c9738aff..0000000000 --- a/raphtory/src/vectors/vector_storage.rs +++ /dev/null @@ -1,65 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::{ - collections::HashMap, - fs::File, - io::{BufReader, BufWriter}, - path::Path, - sync::Arc, -}; - -use crate::{core::utils::errors::GraphError, db::api::view::StaticGraphViewOps}; - -use super::{ - document_ref::DocumentRef, embedding_cache::EmbeddingCache, entity_id::EntityId, - template::DocumentTemplate, vectorised_graph::VectorisedGraph, EmbeddingFunction, -}; - -#[derive(Serialize, Deserialize)] -struct VectorStorage { - template: DocumentTemplate, - graph_documents: Vec, - node_documents: HashMap>, - edge_documents: HashMap>, -} - -impl VectorisedGraph { - pub fn read_from_path( - path: &Path, - graph: G, - embedding: Arc, - cache_storage: Arc>, - ) -> Option { - // TODO: return Result instead of Option - let file = File::open(&path).ok()?; - let mut reader = BufReader::new(file); - let VectorStorage { - template, - graph_documents, - node_documents, - edge_documents, - } = bincode::deserialize_from(&mut reader).ok()?; - - Some(VectorisedGraph::new( - graph, - template, - embedding, - cache_storage, - Arc::new(graph_documents.into()), - Arc::new(node_documents.into()), - Arc::new(edge_documents.into()), - )) - } - - pub fn write_to_path(&self, path: &Path) -> Result<(), GraphError> { - let storage = VectorStorage { - template: self.template.clone(), - graph_documents: self.graph_documents.read().clone(), - node_documents: self.node_documents.read().clone(), - edge_documents: self.edge_documents.read().clone(), - }; - let file = File::create(path)?; - let mut writer = BufWriter::new(file); - bincode::serialize_into(&mut writer, &storage)?; - Ok(()) - } -} diff --git a/raphtory/src/vectors/vectorisable.rs b/raphtory/src/vectors/vectorisable.rs index d7e1dc9b42..2ea2003d50 100644 --- a/raphtory/src/vectors/vectorisable.rs +++ b/raphtory/src/vectors/vectorisable.rs @@ -1,29 +1,20 @@ use crate::{ core::utils::errors::GraphResult, - db::{ - api::view::{internal::IntoDynamic, StaticGraphViewOps}, - graph::{edge::EdgeView, node::NodeView}, - }, + db::api::view::{internal::IntoDynamic, StaticGraphViewOps}, vectors::{ - document_ref::DocumentRef, embedding_cache::EmbeddingCache, entity_id::EntityId, - template::DocumentTemplate, vectorised_graph::VectorisedGraph, EmbeddingFunction, Lifespan, + db::EntityDb, embeddings::compute_embeddings, template::DocumentTemplate, + vectorised_graph::VectorisedGraph, }, }; use async_trait::async_trait; -use itertools::Itertools; -use parking_lot::RwLock; -use std::{collections::HashMap, sync::Arc}; +use std::path::Path; use tracing::info; -const CHUNK_SIZE: usize = 1000; - -#[derive(Clone, Debug)] -struct IndexedDocumentInput { - entity_id: EntityId, - content: String, - index: usize, - life: Lifespan, -} +use super::{ + cache::VectorCache, + db::{EdgeDb, NodeDb}, + storage::{edge_vectors_path, node_vectors_path, VectorMeta}, +}; #[async_trait] pub trait Vectorisable { @@ -40,11 +31,9 @@ pub trait Vectorisable { /// A VectorisedGraph with all the documents/embeddings computed and with an initial empty selection async fn vectorise( &self, - embedding: Box, - cache: Arc>, - overwrite_cache: bool, + cache: VectorCache, template: DocumentTemplate, - graph_name: Option, + path: Option<&Path>, verbose: bool, ) -> GraphResult>; } @@ -53,230 +42,48 @@ pub trait Vectorisable { impl Vectorisable for G { async fn vectorise( &self, - embedding: Box, - cache: Arc>, - overwrite_cache: bool, + cache: VectorCache, template: DocumentTemplate, - graph_name: Option, + path: Option<&Path>, verbose: bool, ) -> GraphResult> { - let graph_docs = indexed_docs_for_graph(self, graph_name, &template); - - let nodes = self.nodes().collect().into_iter(); - let nodes_docs = nodes.flat_map(|node| indexed_docs_for_node(node, &template)); - - let edges = self.edges().collect().into_iter(); - let edges_docs = edges.flat_map(|edge| indexed_docs_for_edge(edge, &template)); - - if verbose { - info!("computing embeddings for graph"); - } - let graph_refs = compute_entity_embeddings(graph_docs, embedding.as_ref(), &cache).await?; - if verbose { info!("computing embeddings for nodes"); } - let node_refs = compute_embedding_groups(nodes_docs, embedding.as_ref(), &cache).await?; + let nodes = self.nodes(); + let node_docs = nodes + .iter() + .filter_map(|node| template.node(node).map(|doc| (node.node.0 as u32, doc))); + let node_path = path.map(node_vectors_path); + let node_vectors = compute_embeddings(node_docs, &cache); + let node_db = NodeDb::from_vectors(node_vectors, node_path).await?; if verbose { info!("computing embeddings for edges"); } - let edge_refs = compute_embedding_groups(edges_docs, embedding.as_ref(), &cache).await?; - - if overwrite_cache { - cache.iter().for_each(|cache| cache.dump_to_disk()); + let edges = self.edges(); + let edge_docs = edges.iter().filter_map(|edge| { + template + .edge(edge) + .map(|doc| (edge.edge.pid().0 as u32, doc)) + }); + let edge_path = path.map(edge_vectors_path); + let edge_vectors = compute_embeddings(edge_docs, &cache); + let edge_db = EdgeDb::from_vectors(edge_vectors, edge_path).await?; + + if let Some(path) = path { + let meta = VectorMeta { + template: template.clone(), + }; + meta.write_to_path(path)?; } - Ok(VectorisedGraph::new( - self.clone(), + Ok(VectorisedGraph { + source_graph: self.clone(), template, - embedding.into(), - cache.into(), - RwLock::new(graph_refs).into(), - RwLock::new(node_refs).into(), - RwLock::new(edge_refs).into(), - )) - } -} - -pub(crate) async fn vectorise_graph( - graph: &G, - graph_name: Option, - template: &DocumentTemplate, - embedding: &Arc, - cache_storage: &Option, -) -> GraphResult> { - let docs = indexed_docs_for_graph(graph, graph_name, template); - compute_entity_embeddings(docs, embedding.as_ref(), &cache_storage).await -} - -pub(crate) async fn vectorise_node( - node: NodeView, - template: &DocumentTemplate, - embedding: &Arc, - cache_storage: &Option, -) -> GraphResult> { - let docs = indexed_docs_for_node(node, template); - compute_entity_embeddings(docs, embedding.as_ref(), &cache_storage).await -} - -pub(crate) async fn vectorise_edge( - edge: EdgeView, - template: &DocumentTemplate, - embedding: &Arc, - cache_storage: &Option, -) -> GraphResult> { - let docs = indexed_docs_for_edge(edge, template); - compute_entity_embeddings(docs, embedding.as_ref(), &cache_storage).await -} - -fn indexed_docs_for_graph<'a, G: StaticGraphViewOps>( - graph: &'a G, - name: Option, - template: &DocumentTemplate, -) -> impl Iterator + Send + 'a { - template - .graph(graph) - .enumerate() - .map(move |(index, doc)| IndexedDocumentInput { - entity_id: EntityId::for_graph(name.clone()), - content: doc.content, - index, - life: doc.life, - }) -} - -fn indexed_docs_for_node( - node: NodeView, - template: &DocumentTemplate, -) -> impl Iterator + Send { - template - .node(node.clone()) - .enumerate() - .map(move |(index, doc)| IndexedDocumentInput { - entity_id: EntityId::from_node(node.clone()), - content: doc.content, - index, - life: doc.life, - }) -} - -fn indexed_docs_for_edge( - edge: EdgeView, - template: &DocumentTemplate, -) -> impl Iterator + Send { - template - .edge(edge.clone()) - .enumerate() - .map(move |(index, doc)| IndexedDocumentInput { - entity_id: EntityId::from_edge(edge.clone()), - content: doc.content, - index, - life: doc.life, + cache, + node_db, + edge_db, }) -} - -async fn compute_entity_embeddings( - documents: I, - embedding: &dyn EmbeddingFunction, - cache: &Option, -) -> GraphResult> -where - I: Iterator + Send, -{ - let map = compute_embedding_groups(documents, embedding, cache).await?; - Ok(map - .into_iter() - .next() - .map(|(_, refs)| refs) - .unwrap_or_else(|| vec![])) // there should be only one value here, TODO: check that's true -} - -async fn compute_embedding_groups( - documents: I, - embedding: &dyn EmbeddingFunction, - cache: &Option, -) -> GraphResult>> -where - I: Iterator, -{ - let mut embedding_groups: HashMap> = HashMap::new(); - let mut buffer = Vec::with_capacity(CHUNK_SIZE); - - for document in documents { - buffer.push(document); - if buffer.len() >= CHUNK_SIZE { - insert_chunk(&mut embedding_groups, &buffer, embedding, cache).await?; - buffer.clear(); - } - } - if buffer.len() > 0 { - insert_chunk(&mut embedding_groups, &buffer, embedding, cache).await?; } - Ok(embedding_groups) -} - -async fn insert_chunk( - embedding_groups: &mut HashMap>, - buffer: &Vec, - embedding: &dyn EmbeddingFunction, - cache: &Option, -) -> GraphResult<()> { - let doc_refs = compute_chunk(&buffer, embedding, cache).await?; - for doc in doc_refs { - match embedding_groups.get_mut(&doc.entity_id) { - Some(group) => group.push(doc), - None => { - embedding_groups.insert(doc.entity_id.clone(), vec![doc]); - } - } - } - Ok(()) -} - -async fn compute_chunk( - documents: &Vec, - embedding: &dyn EmbeddingFunction, - cache: &Option, -) -> GraphResult> { - let mut misses = vec![]; - let mut embedded = vec![]; - match cache { - Some(cache) => { - for doc in documents { - let embedding = cache.get_embedding(&doc.content); - match embedding { - Some(embedding) => embedded.push(DocumentRef::new( - doc.entity_id.clone(), - doc.index, - embedding, - doc.life, - )), - None => misses.push(doc), - } - } - } - None => misses = documents.iter().collect(), - }; - - let texts = misses.iter().map(|doc| doc.content.clone()).collect_vec(); - let embeddings = if texts.is_empty() { - vec![] - } else { - embedding.call(texts).await? - }; - - for (doc, embedding) in misses.into_iter().zip(embeddings) { - if let Some(cache) = cache { - cache.upsert_embedding(&doc.content, embedding.clone()) - }; - embedded.push(DocumentRef::new( - doc.entity_id.clone(), - doc.index, - embedding, - doc.life, - )); - } - - Ok(embedded) } diff --git a/raphtory/src/vectors/vectorised_cluster.rs b/raphtory/src/vectors/vectorised_cluster.rs deleted file mode 100644 index b59f7c8382..0000000000 --- a/raphtory/src/vectors/vectorised_cluster.rs +++ /dev/null @@ -1,61 +0,0 @@ -use crate::{ - db::api::view::StaticGraphViewOps, - prelude::Graph, - vectors::{ - entity_id::EntityId, - similarity_search_utils::{find_top_k, score_documents}, - vectorised_graph::VectorisedGraph, - Document, Embedding, - }, -}; -use itertools::Itertools; -use std::collections::HashMap; - -pub struct VectorisedCluster<'a, G: StaticGraphViewOps> { - graphs: &'a HashMap>, -} - -impl<'a, G: StaticGraphViewOps> VectorisedCluster<'a, G> { - pub fn new(graphs: &'a HashMap>) -> Self { - Self { graphs } - } - - pub fn search_graph_documents( - &self, - query: &Embedding, - limit: usize, - window: Option<(i64, i64)>, - ) -> Vec> { - self.search_graph_documents_with_scores(query, limit, window) - .into_iter() - .map(|(document, _score)| document) - .collect_vec() - } - - pub fn search_graph_documents_with_scores( - &self, - query: &Embedding, - limit: usize, - window: Option<(i64, i64)>, - ) -> Vec<(Document, f32)> { - let documents = self - .graphs - .iter() - .flat_map(|(_name, graph)| graph.graph_documents.read().clone()) - .filter(|doc| doc.exists_on_window::(None, &window)) - .collect_vec(); - let scored_documents = score_documents(query, documents); - let top_k = find_top_k(scored_documents, limit); - - top_k - .map(|(doc, score)| match &doc.entity_id { - EntityId::Graph { name } => { - let name = name.clone().unwrap(); - let graph = self.graphs.get(&name).unwrap(); - (doc.regenerate(&graph.source_graph, &graph.template), score) - } - _ => panic!("got document that is not related to any graph"), - }) - .collect_vec() - } -} diff --git a/raphtory/src/vectors/vectorised_graph.rs b/raphtory/src/vectors/vectorised_graph.rs index a762cbf714..aa2c32aff7 100644 --- a/raphtory/src/vectors/vectorised_graph.rs +++ b/raphtory/src/vectors/vectorised_graph.rs @@ -1,186 +1,63 @@ use crate::{ - // core::entities::nodes::node_ref::AsNodeRef, core::{entities::nodes::node_ref::AsNodeRef, utils::errors::GraphResult}, db::api::view::{DynamicGraph, IntoDynamic, StaticGraphViewOps}, - prelude::*, - vectors::{ - document_ref::DocumentRef, - embedding_cache::EmbeddingCache, - entity_id::EntityId, - similarity_search_utils::{find_top_k, score_documents}, - template::DocumentTemplate, - Embedding, EmbeddingFunction, - }, + vectors::{template::DocumentTemplate, utils::find_top_k, Embedding}, }; -use async_trait::async_trait; -use itertools::{chain, Itertools}; -use parking_lot::RwLock; -use std::{collections::HashMap, ops::Deref, path::PathBuf, sync::Arc}; use super::{ - similarity_search_utils::score_document_groups_by_highest, + cache::VectorCache, + db::{EdgeDb, EntityDb, NodeDb}, + utils::apply_window, vector_selection::VectorSelection, - vectorisable::{vectorise_edge, vectorise_graph, vectorise_node}, - Document, }; +#[derive(Clone)] pub struct VectorisedGraph { pub(crate) source_graph: G, pub(crate) template: DocumentTemplate, - pub(crate) embedding: Arc, - pub(crate) cache_storage: Arc>, - // it is not the end of the world but we are storing the entity id twice - pub(crate) graph_documents: Arc>>, - pub(crate) node_documents: Arc>>>, // TODO: replace with FxHashMap - pub(crate) edge_documents: Arc>>>, - pub(crate) empty_vec: Vec, -} - -// This has to be here so it is shared between python and graphql -pub type DynamicVectorisedGraph = VectorisedGraph; - -#[async_trait] -impl Clone for VectorisedGraph { - fn clone(&self) -> Self { - Self::new( - self.source_graph.clone(), - self.template.clone(), - self.embedding.clone(), - self.cache_storage.clone(), - self.graph_documents.clone(), - self.node_documents.clone(), - self.edge_documents.clone(), - ) - } + pub(crate) cache: VectorCache, + pub(super) node_db: NodeDb, + pub(super) edge_db: EdgeDb, } impl VectorisedGraph { - pub fn into_dynamic(&self) -> VectorisedGraph { - VectorisedGraph::new( - self.source_graph.clone().into_dynamic(), - self.template.clone(), - self.embedding.clone(), - self.cache_storage.clone(), - self.graph_documents.clone(), - self.node_documents.clone(), - self.edge_documents.clone(), - ) + pub fn into_dynamic(self) -> VectorisedGraph { + VectorisedGraph { + source_graph: self.source_graph.clone().into_dynamic(), + template: self.template, + cache: self.cache, + node_db: self.node_db, + edge_db: self.edge_db, + } } } impl VectorisedGraph { - pub(crate) fn new( - graph: G, - template: DocumentTemplate, - embedding: Arc, - cache_storage: Arc>, - graph_documents: Arc>>, - node_documents: Arc>>>, - edge_documents: Arc>>>, - ) -> Self { - Self { - source_graph: graph, - template, - embedding, - cache_storage, - graph_documents, - node_documents, - edge_documents, - empty_vec: vec![], - } - } - pub async fn update_node(&self, node: T) -> GraphResult<()> { if let Some(node) = self.source_graph.node(node) { - let entity_id = EntityId::from_node(node.clone()); - let refs = vectorise_node( - node, - &self.template, - &self.embedding, - self.cache_storage.as_ref(), - ) - .await?; - self.node_documents.write().insert(entity_id, refs); + let id = node.node.index(); + if let Some(doc) = self.template.node(node) { + let vector = self.cache.get_single(doc).await?; + self.node_db.insert_vector(id, &vector)?; + } } Ok(()) } pub async fn update_edge(&self, src: T, dst: T) -> GraphResult<()> { if let Some(edge) = self.source_graph.edge(src, dst) { - let entity_id = EntityId::from_edge(edge.clone()); - let refs = vectorise_edge( - edge, - &self.template, - &self.embedding, - self.cache_storage.as_ref(), - ) - .await?; - self.edge_documents.write().insert(entity_id, refs); + let id = edge.edge.pid().0; + if let Some(doc) = self.template.edge(edge) { + let vector = self.cache.get_single(doc).await?; + self.edge_db.insert_vector(id, &vector)?; + } } Ok(()) } - pub async fn update_graph(&self, graph_name: Option) -> GraphResult<()> { - let refs = vectorise_graph( - &self.source_graph, - graph_name, - &self.template, - &self.embedding, - self.cache_storage.as_ref(), - ) - .await?; - *self.graph_documents.write() = refs; - Ok(()) - } - - /// Save the embeddings present in this graph to `file` so they can be further used in a call to `vectorise` - pub fn save_embeddings(&self, file: PathBuf) { - let cache = EmbeddingCache::new(file); - let node_documents = self.node_documents.read(); - let edge_documents = self.edge_documents.read(); - chain!(node_documents.iter(), edge_documents.iter()).for_each(|(_, group)| { - group.iter().for_each(|doc| { - let original = doc.regenerate(&self.source_graph, &self.template); - cache.upsert_embedding(&original.content, doc.embedding.clone()); - }) - }); - cache.dump_to_disk(); - } - /// Return an empty selection of documents pub fn empty_selection(&self) -> VectorSelection { - VectorSelection::new(self.clone()) - } - - /// Return all the graph level documents - pub fn get_graph_documents(&self) -> Vec> { - self.graph_documents - .read() - .iter() - .map(|doc| doc.regenerate(&self.source_graph, &self.template)) - .collect_vec() - } - - /// Search the top scoring documents according to `query` with no more than `limit` documents - /// - /// # Arguments - /// * query - the embedding to score against - /// * limit - the maximum number of documents to search - /// * window - the window where documents need to belong to in order to be considered - /// - /// # Returns - /// The vector selection resulting from the search - pub fn documents_by_similarity( - &self, - query: &Embedding, - limit: usize, - window: Option<(i64, i64)>, - ) -> VectorSelection { - let node_documents = self.node_documents.read(); - let edge_documents = self.edge_documents.read(); - let joined = chain!(node_documents.iter(), edge_documents.iter()); - let docs = self.search_top_documents(joined, query, limit, window); - VectorSelection::new_with_preselection(self.clone(), docs) + VectorSelection::empty(self.clone()) } /// Search the top scoring entities according to `query` with no more than `limit` entities @@ -197,12 +74,12 @@ impl VectorisedGraph { query: &Embedding, limit: usize, window: Option<(i64, i64)>, - ) -> VectorSelection { - let node_documents = self.node_documents.read(); - let edge_documents = self.edge_documents.read(); - let joined = chain!(node_documents.iter(), edge_documents.iter()); - let docs = self.search_top_document_groups(joined, query, limit, window); - VectorSelection::new_with_preselection(self.clone(), docs) + ) -> GraphResult> { + let view = apply_window(&self.source_graph, window); + let nodes = self.node_db.top_k(query, limit, view.clone(), None)?; + let edges = self.edge_db.top_k(query, limit, view, None)?; + let docs = find_top_k(nodes.chain(edges), limit).collect(); + Ok(VectorSelection::new(self.clone(), docs)) } /// Search the top scoring nodes according to `query` with no more than `limit` nodes @@ -219,10 +96,10 @@ impl VectorisedGraph { query: &Embedding, limit: usize, window: Option<(i64, i64)>, - ) -> VectorSelection { - let node_documents = self.node_documents.read(); - let docs = self.search_top_document_groups(node_documents.deref(), query, limit, window); - VectorSelection::new_with_preselection(self.clone(), docs) + ) -> GraphResult> { + let view = apply_window(&self.source_graph, window); + let docs = self.node_db.top_k(query, limit, view, None)?; + Ok(VectorSelection::new(self.clone(), docs.collect())) } /// Search the top scoring edges according to `query` with no more than `limit` edges @@ -239,83 +116,9 @@ impl VectorisedGraph { query: &Embedding, limit: usize, window: Option<(i64, i64)>, - ) -> VectorSelection { - let edge_documents = self.edge_documents.read(); - let docs = self.search_top_document_groups(edge_documents.deref(), query, limit, window); - VectorSelection::new_with_preselection(self.clone(), docs) - } - - fn search_top_documents<'a, I>( - &self, - document_groups: I, - query: &Embedding, - limit: usize, - window: Option<(i64, i64)>, - ) -> Vec<(DocumentRef, f32)> - where - I: IntoIterator)> + 'a, - { - let all_documents = document_groups - .into_iter() - .flat_map(|(_, embeddings)| embeddings); - - let window_docs: Box> = match window { - None => Box::new(all_documents), - Some((start, end)) => { - let windowed_graph = self.source_graph.window(start, end); - let filtered = all_documents.filter(move |document| { - document.exists_on_window(Some(&windowed_graph), &window) - }); - Box::new(filtered) - } - }; - - let scored_docs = score_documents(query, window_docs.cloned()); // TODO: try to remove this clone - let top_documents = find_top_k(scored_docs, limit); - top_documents.collect() - } - - fn search_top_document_groups<'a, I>( - &self, - document_groups: I, - query: &Embedding, - limit: usize, - window: Option<(i64, i64)>, - ) -> Vec<(DocumentRef, f32)> - where - I: IntoIterator)> + 'a, - { - let window_docs: Box)>> = match window { - None => Box::new( - document_groups - .into_iter() - .map(|(id, docs)| (id.clone(), docs.clone())), - // TODO: filter empty vectors here? what happens if the user inputs an empty list as the doc prop - ), - Some((start, end)) => { - let windowed_graph = self.source_graph.window(start, end); - let filtered = document_groups - .into_iter() - .map(move |(entity_id, docs)| { - let filtered_dcos = docs - .iter() - .filter(|doc| doc.exists_on_window(Some(&windowed_graph), &window)) - .cloned() - .collect_vec(); - (entity_id.clone(), filtered_dcos) - }) - .filter(|(_, docs)| docs.len() > 0); - Box::new(filtered) - } - }; - - let scored_docs = score_document_groups_by_highest(query, window_docs); - - // let scored_docs = score_documents(query, window_docs.cloned()); // TODO: try to remove this clone - let top_documents = find_top_k(scored_docs, limit); - - top_documents - .flat_map(|((_, docs), score)| docs.into_iter().map(move |doc| (doc, score))) - .collect() + ) -> GraphResult> { + let view = apply_window(&self.source_graph, window); + let docs = self.edge_db.top_k(query, limit, view, None)?; + Ok(VectorSelection::new(self.clone(), docs.collect())) } }