Skip to content

Commit acb63d1

Browse files
committed
send_messages can now be chained
1 parent 2a383cc commit acb63d1

1 file changed

Lines changed: 209 additions & 59 deletions

File tree

src/pregel.rs

Lines changed: 209 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,51 @@ impl AsRef<str> for ColumnIdentifier {
3737
}
3838
}
3939

40+
/// This defines a struct `SendMessage` in Rust. It has two properties:
41+
/// `message_direction` and `send_message`. The `message_direction` property
42+
/// is the identifier for the direction of the message. The `send_message`
43+
/// property is the function that determines which messages to send from a
44+
/// vertex to its neighbors.
45+
pub struct SendMessage {
46+
/// `message_direction` is the identifier for the direction of the message.
47+
pub message_direction: Expr,
48+
/// `send_message` is the function that determines which messages to send from a
49+
/// vertex to its neighbors.
50+
pub send_message: Expr,
51+
}
52+
53+
impl SendMessage {
54+
/// The function creates a new instance of the `SendMessage` struct with the
55+
/// specified message direction and send message expression.
56+
///
57+
/// Arguments:
58+
///
59+
/// * `message_direction`: An enum that specifies whether the message should be sent
60+
/// to the source vertex or the destination vertex of an edge.
61+
/// * `send_message`: `send_message` is an expression that represents the message
62+
/// that will be sent from a vertex to its neighbors during the Pregel computation.
63+
/// It can be any valid Rust expression that evaluates to a DataFrame.
64+
///
65+
/// Returns:
66+
///
67+
/// A new instance of the `SendMessage` struct.
68+
pub fn new(message_direction: MessageReceiver, send_message: Expr) -> Self {
69+
// We make this in this manner because we want to use the `src.id` and `edge.dst` columns
70+
// in the send_messages function. This is because how polars works, when joining DataFrames,
71+
// it will keep only the left-hand side of the joins, thus, we need to use the `src.id` and
72+
// `edge.dst` columns to get the correct vertex IDs.
73+
let message_direction = match message_direction {
74+
MessageReceiver::Src => Pregel::src(ColumnIdentifier::Id),
75+
MessageReceiver::Dst => Pregel::edge(ColumnIdentifier::Dst),
76+
};
77+
// Now we create the `SendMessage` struct with everything set up.
78+
SendMessage {
79+
message_direction,
80+
send_message,
81+
}
82+
}
83+
}
84+
4085
/// The Pregel struct represents a Pregel computation with various parameters and
4186
/// expressions.
4287
///
@@ -90,11 +135,13 @@ pub struct Pregel {
90135
/// `initial_message` is an expression that defines the initial message that
91136
/// each vertex in the graph will receive before the computation starts.
92137
initial_message: Expr,
93-
/// `send_messages` is a tuple containing two expressions. The first expression
94-
/// determines whether the message will go from Src to Dst or vice-versa. The
95-
/// second expression represents the message sending function that determines
96-
/// which messages to send from a vertex to its neighbors.
97-
send_messages: (Expr, Expr),
138+
/// The `send_messages` property is a vector of `SendMessage` structs that represent
139+
/// the message sending functions. The `SendMessage` struct contains two expressions.
140+
/// The first expression represents the message sending function that determines whether
141+
/// the message will go from Src to Dst or vice-versa. The second expression represents
142+
/// the message sending function that determines which messages to send from a
143+
/// vertex to its neighbors.
144+
send_messages: Vec<SendMessage>,
98145
/// `aggregate_messages` is an expression that defines how messages sent to a
99146
/// vertex should be aggregated. In Pregel, messages are sent from one vertex
100147
/// to another and can be aggregated before being processed by the receiving
@@ -162,11 +209,13 @@ pub struct PregelBuilder {
162209
/// `initial_message` is an expression that defines the initial message that
163210
/// each vertex in the graph will receive before the computation starts.
164211
initial_message: Expr,
165-
/// `send_messages` is a tuple containing two expressions. The first expression
166-
/// determines whether the message will go from Src to Dst or vice-versa. The
167-
/// second expression represents the message sending function that determines
168-
/// which messages to send from a vertex to its neighbors.
169-
send_messages: (Expr, Expr),
212+
/// The `send_messages` property is a vector of `SendMessage` structs that represent
213+
/// the message sending functions. The `SendMessage` struct contains two expressions.
214+
/// The first expression represents the message sending function that determines whether
215+
/// the message will go from Src to Dst or vice-versa. The second expression represents
216+
/// the message sending function that determines which messages to send from a
217+
/// vertex to its neighbors.
218+
send_messages: Vec<SendMessage>,
170219
/// `aggregate_messages` is an expression that defines how messages sent to a
171220
/// vertex should be aggregated. In Pregel, messages are sent from one vertex
172221
/// to another and can be aggregated before being processed by the receiving
@@ -219,7 +268,7 @@ impl PregelBuilder {
219268
max_iterations: 10,
220269
vertex_column: ColumnIdentifier::Custom("aux"),
221270
initial_message: Default::default(),
222-
send_messages: (Default::default(), Default::default()),
271+
send_messages: Default::default(),
223272
aggregate_messages: Default::default(),
224273
v_prog: Default::default(),
225274
}
@@ -289,7 +338,45 @@ impl PregelBuilder {
289338
}
290339

291340
/// This function sets the message sending behavior for a Pregel computation in
292-
/// Rust.
341+
/// Rust. Chaining this method allows for multiple message sending behaviors to be
342+
/// specified for a single Pregel computation.
343+
///
344+
/// # Examples
345+
///
346+
/// ```rust
347+
/// use polars::prelude::*;
348+
/// use pregel_rs::graph_frame::GraphFrame;
349+
/// use pregel_rs::pregel::ColumnIdentifier::{Custom, Dst, Id, Src};
350+
/// use pregel_rs::pregel::{MessageReceiver, Pregel, PregelBuilder};
351+
/// use std::error::Error;
352+
///
353+
/// // Simple example of a Pregel algorithm where we chain several `send_messages` calls. In
354+
/// // this example, we send a message to the source of an edge and then to the destination of
355+
/// // the same edge. It has no real use case, but it demonstrates how to chain multiple calls.
356+
/// fn main() -> Result<(), Box<dyn Error>> {
357+
/// let edges = df![
358+
/// Src.as_ref() => [0, 1, 1, 2, 2, 3],
359+
/// Dst.as_ref() => [1, 0, 3, 1, 3, 2],
360+
/// ]?;
361+
///
362+
/// let vertices = df![
363+
/// Id.as_ref() => [0, 1, 2, 3],
364+
/// Custom("value").as_ref() => [3, 6, 2, 1],
365+
/// ]?;
366+
///
367+
/// let pregel = PregelBuilder::new(GraphFrame::new(vertices, edges)?)
368+
/// .max_iterations(4)
369+
/// .with_vertex_column(Custom("aux"))
370+
/// .initial_message(lit(0))
371+
/// .send_messages(MessageReceiver::Src, lit(1))
372+
/// .send_messages(MessageReceiver::Dst, lit(-1))
373+
/// .aggregate_messages(Pregel::msg(None).sum())
374+
/// .v_prog(Pregel::msg(None) + lit(1))
375+
/// .build();
376+
///
377+
/// Ok(println!("{:?}", pregel.run()))
378+
/// }
379+
/// ```
293380
///
294381
/// Arguments:
295382
///
@@ -298,8 +385,7 @@ impl PregelBuilder {
298385
/// computation.
299386
/// * `send_messages`: `send_messages` is a parameter of type `Expr`. It is used to
300387
/// specify the function that will be applied to each vertex in the graph to send
301-
/// messages to its neighboring vertices. The `send_messages` function takes two
302-
/// arguments: the first argument is the vertex ID of the current vertex, and
388+
/// messages to its neighboring vertices.
303389
///
304390
/// Returns:
305391
///
@@ -308,16 +394,7 @@ impl PregelBuilder {
308394
/// multiple methods can be called on the same struct instance in a single
309395
/// expression.
310396
pub fn send_messages(mut self, to: MessageReceiver, send_messages: Expr) -> Self {
311-
// We make this in this manner because we want to use the `src.id` and `edge.dst` columns
312-
// in the send_messages function. This is because how polars works, when joining dataframes,
313-
// it will keep only the left-hand side of the joins, thus, we need to use the `src.id` and
314-
// `edge.dst` columns to get the correct vertex IDs.
315-
let to = match to {
316-
MessageReceiver::Src => Pregel::src(ColumnIdentifier::Id),
317-
MessageReceiver::Dst => Pregel::edge(ColumnIdentifier::Dst),
318-
};
319-
// Now we can set the send_messages field of the struct to the provided expression.
320-
self.send_messages = (to, send_messages);
397+
self.send_messages.push(SendMessage::new(to, send_messages));
321398
self
322399
}
323400

@@ -563,11 +640,22 @@ impl Pregel {
563640
// We create a tuple where we store the column names of the `send_messages` DataFrame. We use
564641
// the `alias` method to ensure that the column names are properly qualified. We also
565642
// do the same for the `aggregate_messages` Expr. And the same with the `v_prog` Expr.
566-
let (send_messages_ids, send_messages_msg) = self.send_messages;
567-
let (send_messages_ids, send_messages_msg) = (
568-
send_messages_ids.alias(&Self::alias(&ColumnIdentifier::Msg, ColumnIdentifier::Id)),
569-
send_messages_msg.alias(ColumnIdentifier::Pregel.as_ref()),
570-
);
643+
let (mut send_messages_ids, mut send_messages_msg): (Vec<Expr>, Vec<Expr>) = self
644+
.send_messages
645+
.iter()
646+
.map(|send_message| {
647+
let message_direction = &send_message.message_direction;
648+
let send_message_expr = &send_message.send_message;
649+
(
650+
message_direction
651+
.to_owned()
652+
.alias(&Self::alias(&ColumnIdentifier::Msg, ColumnIdentifier::Id)),
653+
send_message_expr
654+
.to_owned()
655+
.alias(ColumnIdentifier::Pregel.as_ref()),
656+
)
657+
})
658+
.unzip();
571659
let aggregate_messages = self
572660
.aggregate_messages
573661
.alias(ColumnIdentifier::Pregel.as_ref());
@@ -628,14 +716,12 @@ impl Pregel {
628716
// are computed by performing an aggregation on the `triplets_df` DataFrame. The aggregation
629717
// is performed on the `msg` column of the `triplets_df` DataFrame, and the aggregation
630718
// function is the one set by the user at the initialization of the model.
631-
let sends_messages_ids_df = &send_messages_ids;
632-
let send_messages_msg_df = &send_messages_msg;
719+
let send_messages = &mut send_messages_ids; // we create a mutable reference to the `send_messages_ids` Vector
720+
let send_messages_msg_df = &mut send_messages_msg; // we create a mutable reference to the `send_messages_msg` Vector
721+
send_messages.append(send_messages_msg_df); // we append the `send_messages_msg` Vector to the `send_messages` Vector
633722
let aggregate_messages_df = &aggregate_messages;
634723
let message_df = triplets_df
635-
.select(vec![
636-
sends_messages_ids_df.to_owned(),
637-
send_messages_msg_df.to_owned(),
638-
])
724+
.select(send_messages)
639725
.groupby([Self::msg(Some(ColumnIdentifier::Id))])
640726
.agg([aggregate_messages_df.to_owned()]);
641727
// We Compute the new values for the vertices. Note that we have to check for possibly
@@ -685,29 +771,52 @@ impl Pregel {
685771
#[cfg(test)]
686772
mod tests {
687773
use crate::graph_frame::GraphFrame;
688-
use crate::pregel::ColumnIdentifier::{Custom, Dst, Id, Src};
689-
use crate::pregel::{MessageReceiver, Pregel, PregelBuilder};
774+
use crate::pregel::{ColumnIdentifier, MessageReceiver, Pregel, PregelBuilder, SendMessage};
690775
use polars::prelude::*;
691776
use std::error::Error;
692777

693-
fn pagerank_builder(iterations: u8) -> Result<Pregel, Box<dyn Error>> {
694-
let edges = df![
695-
Src.as_ref() => [0, 0, 1, 2, 3, 4, 4, 4],
696-
Dst.as_ref() => [1, 2, 2, 3, 3, 1, 2, 3],
697-
]?;
778+
fn pagerank_graph() -> Result<GraphFrame, String> {
779+
let edges = match df![
780+
ColumnIdentifier::Src.as_ref() => [0, 0, 1, 2, 3, 4, 4, 4],
781+
ColumnIdentifier::Dst.as_ref() => [1, 2, 2, 3, 3, 1, 2, 3],
782+
] {
783+
Ok(edges) => edges,
784+
Err(_) => return Err(String::from("Error creating the edges DataFrame")),
785+
};
698786

699-
let vertices = GraphFrame::from_edges(edges.clone())?.out_degrees()?;
787+
let graph = match GraphFrame::from_edges(edges.clone()) {
788+
Ok(graph) => graph,
789+
Err(_) => return Err(String::from("Error creating the vertices DataFrame")),
790+
};
791+
792+
let vertices = match graph.out_degrees() {
793+
Ok(vertices) => vertices,
794+
Err(_) => {
795+
return Err(String::from(
796+
"Error creating the vertices out degree DataFrame",
797+
))
798+
}
799+
};
700800

801+
match GraphFrame::new(vertices, edges) {
802+
Ok(graph) => Ok(graph),
803+
Err(_) => Err(String::from("Error creating the graph")),
804+
}
805+
}
806+
807+
fn pagerank_builder(iterations: u8) -> Result<Pregel, Box<dyn Error>> {
808+
let graph = pagerank_graph()?;
701809
let damping_factor = 0.85;
702-
let num_vertices: f64 = vertices.column(Id.as_ref())?.len() as f64;
810+
let num_vertices: f64 = graph.vertices.column(ColumnIdentifier::Id.as_ref())?.len() as f64;
703811

704-
Ok(PregelBuilder::new(GraphFrame::new(vertices, edges)?)
812+
Ok(PregelBuilder::new(graph)
705813
.max_iterations(iterations)
706-
.with_vertex_column(Custom("rank"))
814+
.with_vertex_column(ColumnIdentifier::Custom("rank"))
707815
.initial_message(lit(1.0 / num_vertices))
708816
.send_messages(
709817
MessageReceiver::Dst,
710-
Pregel::src(Custom("rank")) / Pregel::src(Custom("out_degree")),
818+
Pregel::src(ColumnIdentifier::Custom("rank"))
819+
/ Pregel::src(ColumnIdentifier::Custom("out_degree")),
711820
)
712821
.aggregate_messages(Pregel::msg(None).sum())
713822
.v_prog(
@@ -777,16 +886,16 @@ mod tests {
777886

778887
fn max_value_graph() -> Result<GraphFrame, String> {
779888
let edges = match df![
780-
Src.as_ref() => [0, 1, 1, 2, 2, 3],
781-
Dst.as_ref() => [1, 0, 3, 1, 3, 2],
889+
ColumnIdentifier::Src.as_ref() => [0, 1, 1, 2, 2, 3],
890+
ColumnIdentifier::Dst.as_ref() => [1, 0, 3, 1, 3, 2],
782891
] {
783892
Ok(edges) => edges,
784893
Err(_) => return Err(String::from("Error creating the edges DataFrame")),
785894
};
786895

787896
let vertices = match df![
788-
Id.as_ref() => [0, 1, 2, 3],
789-
Custom("value").as_ref() => [3, 6, 2, 1],
897+
ColumnIdentifier::Id.as_ref() => [0, 1, 2, 3],
898+
ColumnIdentifier::Custom("value").as_ref() => [3, 6, 2, 1],
790899
] {
791900
Ok(vertices) => vertices,
792901
Err(_) => return Err(String::from("Error creating the vertices DataFrame")),
@@ -802,14 +911,17 @@ mod tests {
802911
Ok(Pregel {
803912
graph: max_value_graph()?,
804913
max_iterations: iterations,
805-
vertex_column: Custom("max_value"),
806-
initial_message: col(Custom("value").as_ref()),
807-
send_messages: (
808-
Pregel::edge(MessageReceiver::into(MessageReceiver::Dst)),
809-
Pregel::src(Custom("max_value")),
810-
),
914+
vertex_column: ColumnIdentifier::Custom("max_value"),
915+
initial_message: col(ColumnIdentifier::Custom("value").as_ref()),
916+
send_messages: vec![SendMessage::new(
917+
MessageReceiver::Dst,
918+
Pregel::src(ColumnIdentifier::Custom("value")),
919+
)],
811920
aggregate_messages: Pregel::msg(None).max(),
812-
v_prog: max_exprs([col(Custom("max_value").as_ref()), Pregel::msg(None)]),
921+
v_prog: max_exprs([
922+
col(ColumnIdentifier::Custom("max_value").as_ref()),
923+
Pregel::msg(None),
924+
]),
813925
})
814926
}
815927

@@ -874,7 +986,7 @@ mod tests {
874986
// useful to test the Pregel model.
875987
match PregelBuilder::new(graph)
876988
.max_iterations(4)
877-
.with_vertex_column(Custom("does_not_matter"))
989+
.with_vertex_column(ColumnIdentifier::Custom("aux"))
878990
.initial_message(lit(0)) // we pass the Undefined state to all vertices
879991
.send_messages(MessageReceiver::Src, lit(0))
880992
.aggregate_messages(lit(0))
@@ -886,4 +998,42 @@ mod tests {
886998
Err(_) => Err(String::from("Error running the algorithm")),
887999
}
8881000
}
1001+
1002+
#[test]
1003+
fn test_send_messages_src_dst() -> Result<(), String> {
1004+
let graph = pagerank_graph()?;
1005+
1006+
let pregel = match PregelBuilder::new(graph)
1007+
.max_iterations(4)
1008+
.with_vertex_column(ColumnIdentifier::Custom("aux"))
1009+
.initial_message(lit(0))
1010+
.send_messages(MessageReceiver::Src, lit(1))
1011+
.send_messages(MessageReceiver::Dst, lit(-1))
1012+
.aggregate_messages(Pregel::msg(None).sum())
1013+
.v_prog(Pregel::msg(None) + lit(1))
1014+
.build()
1015+
.run()
1016+
{
1017+
Ok(pregel) => pregel,
1018+
Err(_) => return Err(String::from("Error running pregel")),
1019+
};
1020+
1021+
let sorted_pregel = match pregel.sort(&["id"], false) {
1022+
Ok(sorted_pregel) => sorted_pregel,
1023+
Err(_) => return Err(String::from("Error sorting the DataFrame")),
1024+
};
1025+
1026+
let ans = match sorted_pregel.column("aux") {
1027+
Ok(ans) => ans,
1028+
Err(_) => return Err(String::from("Error retrieving the column")),
1029+
};
1030+
1031+
let expected = Series::new("aux", [3, 2, 2, 2, 4]);
1032+
1033+
if ans.eq(&expected) {
1034+
Ok(())
1035+
} else {
1036+
Err(String::from("The resulting DataFrame is not correct"))
1037+
}
1038+
}
8891039
}

0 commit comments

Comments
 (0)