Skip to content

feat(interactive): Refactoring input and output for builtin procedures #4556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flex/engines/graph_db/app/builtin/k_hop_neighbors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace gs {

results::CollectiveResults KNeighbors::Query(const GraphDBSession& sess,
std::string label_name,
int64_t vertex_id, int32_t k) {
std::string vertex_id, int32_t k) {
auto txn = sess.GetReadTransaction();
const Schema& schema_ = txn.schema();

Expand Down Expand Up @@ -117,7 +117,7 @@ results::CollectiveResults KNeighbors::Query(const GraphDBSession& sess,
->mutable_entry()
->mutable_element()
->mutable_object()
->set_i64(txn.GetVertexId(vertex_.first, vertex_.second).AsInt64());
->set_str(txn.GetVertexId(vertex_.first, vertex_.second).to_string());
}

txn.Commit();
Expand Down
5 changes: 3 additions & 2 deletions flex/engines/graph_db/app/builtin/k_hop_neighbors.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
#include "flex/engines/hqps_db/app/interactive_app_base.h"

namespace gs {
class KNeighbors : public CypherReadAppBase<std::string, int64_t, int32_t> {
class KNeighbors : public CypherReadAppBase<std::string, std::string, int32_t> {
public:
KNeighbors() {}
results::CollectiveResults Query(const GraphDBSession& sess,
std::string label_name, int64_t vertex_id,
std::string label_name,
std::string vertex_id,
int32_t hop_range) override;
};

Expand Down
167 changes: 78 additions & 89 deletions flex/engines/graph_db/app/builtin/shortest_path_among_three.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,53 @@
* limitations under the License.
*/
#include "flex/engines/graph_db/app/builtin/shortest_path_among_three.h"
#include "flex/engines/graph_db/runtime/common/rt_any.h"
#include "flex/engines/graph_db/runtime/common/types.h"

namespace gs {

void sink_shortest_path(const ReadTransaction& tx,
results::CollectiveResults& results,
const std::vector<std::pair<label_t, vid_t>>& nodes,
const std::vector<label_t>& edge_labels) {
for (size_t i = 0; i < nodes.size(); ++i) {
LOG(INFO) << "sink_shortest_path: " << static_cast<int>(nodes[i].first)
<< " " << nodes[i].second;
}
for (size_t i = 0; i < edge_labels.size(); ++i) {
LOG(INFO) << "sink_shortest_path: " << static_cast<int>(edge_labels[i]);
}
Schema schema_ = tx.schema();
std::string result_path;
auto path = results.add_results()
->mutable_record()
->add_columns()
->mutable_entry()
->mutable_element()
->mutable_graph_path();
CHECK(nodes.size() == edge_labels.size() + 1);
for (size_t i = 0; i < nodes.size(); ++i) {
auto vertex_in_path = path->add_path();
auto node = vertex_in_path->mutable_vertex();
node->mutable_label()->set_id(nodes[i].first);
node->set_id(
runtime::encode_unique_vertex_id(nodes[i].first, nodes[i].second));
if (i < edge_labels.size()) {
auto edge_in_path = path->add_path();
auto edge = edge_in_path->mutable_edge();
edge->mutable_src_label()->set_id(nodes[i].first);
edge->mutable_dst_label()->set_id(nodes[i + 1].first);
edge->mutable_label()->set_id(edge_labels[i]);
edge->set_id(runtime::encode_unique_edge_id(
edge_labels[i], nodes[i].second, nodes[i + 1].second));
edge->set_src_id(
runtime::encode_unique_vertex_id(nodes[i].first, nodes[i].second));
edge->set_dst_id(runtime::encode_unique_vertex_id(nodes[i + 1].first,
nodes[i + 1].second));
}
}
}

results::CollectiveResults ShortestPathAmongThree::Query(
const GraphDBSession& sess, std::string label_name1, std::string oid1_str,
std::string label_name2, std::string oid2_str, std::string label_name3,
Expand Down Expand Up @@ -60,63 +104,39 @@ results::CollectiveResults ShortestPathAmongThree::Query(
return {};
}
// get the three shortest paths
std::vector<std::pair<label_t, vid_t>> v1v2result_;
std::vector<std::pair<label_t, vid_t>> v2v3result_;
std::vector<std::pair<label_t, vid_t>> v1v3result_;
std::vector<std::pair<label_t, vid_t>> v1v2result_, v2v3result_, v1v3result_;
std::vector<label_t> v1v2edge_labels_, v2v3edge_labels_, v1v3edge_labels_;

bool find_flag = true;
if (!ShortestPath(txn, label_v1, index_v1, label_v2, index_v2, v1v2result_)) {
if (!ShortestPath(txn, label_v1, index_v1, label_v2, index_v2, v1v2result_,
v1v2edge_labels_)) {
find_flag = false;
}
if (find_flag &&
!ShortestPath(txn, label_v2, index_v2, label_v3, index_v3, v2v3result_)) {
if (find_flag && !ShortestPath(txn, label_v2, index_v2, label_v3, index_v3,
v2v3result_, v2v3edge_labels_)) {
find_flag = false;
}
if (find_flag &&
!ShortestPath(txn, label_v1, index_v1, label_v3, index_v3, v1v3result_)) {
if (find_flag && !ShortestPath(txn, label_v1, index_v1, label_v3, index_v3,
v1v3result_, v1v3edge_labels_)) {
find_flag = false;
}
std::string result_path = "";
if (find_flag) {
// connect the two shortest paths among three
std::vector<std::pair<label_t, vid_t>> TSP =
ConnectPath(v1v2result_, v2v3result_, v1v3result_);
results::CollectiveResults results;

// construct return result
for (auto it = TSP.begin(); it != TSP.end(); ++it) {
if (std::next(it) != TSP.end()) {
result_path +=
"(" + schema_.get_vertex_label_name(it->first) + "," +
std::to_string(txn.GetVertexId(it->first, it->second).AsInt64()) +
")" + "--";
} else {
result_path +=
"(" + schema_.get_vertex_label_name(it->first) + "," +
std::to_string(txn.GetVertexId(it->first, it->second).AsInt64()) +
")";
}
}
} else {
result_path = "no path find!";
if (find_flag) {
sink_shortest_path(txn, results, v1v2result_, v1v2edge_labels_);
sink_shortest_path(txn, results, v2v3result_, v2v3edge_labels_);
sink_shortest_path(txn, results, v1v3result_, v1v3edge_labels_);
}

// create result string
results::CollectiveResults results;
auto result = results.add_results();
result->mutable_record()
->add_columns()
->mutable_entry()
->mutable_element()
->mutable_object()
->set_str(result_path);
LOG(INFO) << "results: " << results.DebugString();

txn.Commit();
return results;
}

bool ShortestPathAmongThree::ShortestPath(
const gs::ReadTransaction& txn, label_t v1_l, vid_t v1_index, label_t v2_l,
vid_t v2_index, std::vector<std::pair<label_t, vid_t>>& result_) {
vid_t v2_index, std::vector<std::pair<label_t, vid_t>>& result_,
std::vector<label_t>& edge_labels) {
Schema schema_ = txn.schema();
label_t vertex_size_ = (int) schema_.vertex_label_num();
label_t edge_size_ = (int) schema_.edge_label_num();
Expand All @@ -128,18 +148,16 @@ bool ShortestPathAmongThree::ShortestPath(
}
};

std::unordered_map<std::pair<label_t, vid_t>, std::pair<label_t, vid_t>,
pair_hash>
std::unordered_map<std::pair<label_t, vid_t>,
std::tuple<label_t, vid_t, label_t>, pair_hash>
parent;
std::vector<label_t> nei_label_;
std::vector<vid_t> nei_index_;

parent[std::make_pair(v1_l, v1_index)] =
std::make_pair((label_t) UINT8_MAX, (vid_t) UINT32_MAX);
parent[std::make_pair(v1_l, v1_index)] = std::make_tuple(
(label_t) UINT8_MAX, (vid_t) UINT32_MAX, (label_t) UINT8_MAX);
nei_label_.push_back(v1_l);
nei_index_.push_back(v1_index);
std::unordered_set<std::pair<label_t, vid_t>, pair_hash> visit;
visit.insert(std::make_pair(v1_l, v1_index));

std::vector<label_t> next_nei_labels_;
std::vector<vid_t> next_nei_indexs_;
Expand All @@ -155,12 +173,11 @@ bool ShortestPathAmongThree::ShortestPath(
k); // 1.self_label 2.self_index 3.edge_label 4.nei_label
while (outedges.IsValid()) {
auto neighbor = outedges.GetNeighbor();
if (visit.find(std::make_pair(j, neighbor)) == visit.end()) {
if (parent.find(std::make_pair(j, neighbor)) == parent.end()) {
next_nei_labels_.push_back(j);
next_nei_indexs_.push_back(neighbor);
visit.insert(std::make_pair(j, neighbor));
parent[std::make_pair(j, neighbor)] =
std::make_pair(nei_label_[i], nei_index_[i]);
std::make_tuple(nei_label_[i], nei_index_[i], k);
if (std::make_pair(j, neighbor) ==
std::make_pair(v2_l, v2_index)) {
find = true;
Expand All @@ -177,12 +194,11 @@ bool ShortestPathAmongThree::ShortestPath(
k); // 1.self_label 2.self_index 3.edge_label 4.nei_label
while (inedges.IsValid()) {
auto neighbor = inedges.GetNeighbor();
if (visit.find(std::make_pair(j, neighbor)) == visit.end()) {
if (parent.find(std::make_pair(j, neighbor)) == parent.end()) {
next_nei_labels_.push_back(j);
next_nei_indexs_.push_back(neighbor);
visit.insert(std::make_pair(j, neighbor));
parent[std::make_pair(j, neighbor)] =
std::make_pair(nei_label_[i], nei_index_[i]);
std::make_tuple(nei_label_[i], nei_index_[i], k);
if (std::make_pair(j, neighbor) ==
std::make_pair(v2_l, v2_index)) {
find = true;
Expand All @@ -207,11 +223,16 @@ bool ShortestPathAmongThree::ShortestPath(
next_nei_indexs_.clear();
}
if (find) {
std::pair<label_t, vid_t> vertex_v = std::make_pair(v2_l, v2_index);
while (vertex_v !=
std::make_pair((label_t) UINT8_MAX, (vid_t) UINT32_MAX)) {
result_.push_back(vertex_v);
vertex_v = parent[vertex_v];
std::pair<label_t, vid_t> vertex_key = std::make_pair(v2_l, v2_index);
std::tuple<label_t, vid_t, label_t> vertex_value;
while (std::get<0>(vertex_key) != UINT8_MAX) {
result_.push_back(vertex_key);
vertex_value = parent[vertex_key];
vertex_key =
std::make_pair(std::get<0>(vertex_value), std::get<1>(vertex_value));
if (std::get<2>(vertex_value) != UINT8_MAX) {
edge_labels.push_back(std::get<2>(vertex_value));
}
}
std::reverse(result_.begin(), result_.end());
return true;
Expand All @@ -220,38 +241,6 @@ bool ShortestPathAmongThree::ShortestPath(
}
}

std::vector<std::pair<label_t, vid_t>> ShortestPathAmongThree::ConnectPath(
const std::vector<std::pair<label_t, vid_t>>& path1,
const std::vector<std::pair<label_t, vid_t>>& path2,
const std::vector<std::pair<label_t, vid_t>>& path3) {
std::vector<std::pair<label_t, vid_t>> TSP;
size_t v1v2size = path1.size();
size_t v2v3size = path2.size();
size_t v1v3size = path3.size();
if (v1v2size <= v2v3size && v1v3size <= v2v3size) {
for (size_t i = v1v2size; i > 0; i--) {
TSP.push_back(path1[i - 1]);
}
for (size_t i = 1; i < v1v3size; i++) {
TSP.push_back(path3[i]);
}
} else if (v1v2size <= v1v3size && v2v3size <= v1v3size) {
for (size_t i = 0; i < v1v2size; i++) {
TSP.push_back(path1[i]);
}
for (size_t i = 1; i < v2v3size; i++) {
TSP.push_back(path2[i]);
}
} else if (v2v3size <= v1v2size && v1v3size <= v1v2size) {
for (size_t i = 0; i < v2v3size; i++) {
TSP.push_back(path2[i]);
}
for (size_t i = v1v3size - 1; i > 0; i--) {
TSP.push_back(path3[i - 1]);
}
}
return TSP;
}
AppWrapper ShortestPathAmongThreeFactory::CreateApp(const GraphDB& db) {
return AppWrapper(new ShortestPathAmongThree(), NULL);
}
Expand Down
7 changes: 2 additions & 5 deletions flex/engines/graph_db/app/builtin/shortest_path_among_three.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,8 @@ class ShortestPathAmongThree
private:
bool ShortestPath(const gs::ReadTransaction& txn, label_t v1_l,
vid_t v1_index, label_t v2_l, vid_t v2_index,
std::vector<std::pair<label_t, vid_t>>& result_);
std::vector<std::pair<label_t, vid_t>> ConnectPath(
const std::vector<std::pair<label_t, vid_t>>& path1,
const std::vector<std::pair<label_t, vid_t>>& path2,
const std::vector<std::pair<label_t, vid_t>>& path3);
std::vector<std::pair<label_t, vid_t>>& result,
std::vector<label_t>& edge_labels);
};

class ShortestPathAmongThreeFactory : public AppFactoryBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_builtin_procedure(interactive_session, neo4j_session, create_modern_gra
create_modern_graph,
"k_neighbors",
'"person"',
"1L",
'"1"',
"2",
)

Expand Down
2 changes: 1 addition & 1 deletion flex/storages/metadata/graph_meta_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ const std::vector<PluginMeta>& get_builtin_plugin_metas() {
k_neighbors.creation_time = GetCurrentTimeStamp();
k_neighbors.update_time = GetCurrentTimeStamp();
k_neighbors.params.push_back({"label_name", PropertyType::kString, true});
k_neighbors.params.push_back({"oid", PropertyType::kInt64, false});
k_neighbors.params.push_back({"oid", PropertyType::kString, false});
k_neighbors.params.push_back({"k", PropertyType::kInt32, false});
k_neighbors.returns.push_back({"label_name", PropertyType::kString});
k_neighbors.returns.push_back({"vertex_oid", PropertyType::kInt64});
Expand Down
Loading