@@ -37,6 +37,51 @@ impl AsRef<str> for ColumnIdentifier {
3737 }
3838}
3939
40+ /// This defines a struct `SendMessage` in Rust. It has two properties:
41+ /// `message_direction` and `send_message`. The `message_direction` property
42+ /// is the identifier for the direction of the message. The `send_message`
43+ /// property is the function that determines which messages to send from a
44+ /// vertex to its neighbors.
45+ pub struct SendMessage {
46+ /// `message_direction` is the identifier for the direction of the message.
47+ pub message_direction : Expr ,
48+ /// `send_message` is the function that determines which messages to send from a
49+ /// vertex to its neighbors.
50+ pub send_message : Expr ,
51+ }
52+
53+ impl SendMessage {
54+ /// The function creates a new instance of the `SendMessage` struct with the
55+ /// specified message direction and send message expression.
56+ ///
57+ /// Arguments:
58+ ///
59+ /// * `message_direction`: An enum that specifies whether the message should be sent
60+ /// to the source vertex or the destination vertex of an edge.
61+ /// * `send_message`: `send_message` is an expression that represents the message
62+ /// that will be sent from a vertex to its neighbors during the Pregel computation.
63+ /// It can be any valid Rust expression that evaluates to a DataFrame.
64+ ///
65+ /// Returns:
66+ ///
67+ /// A new instance of the `SendMessage` struct.
68+ pub fn new ( message_direction : MessageReceiver , send_message : Expr ) -> Self {
69+ // We make this in this manner because we want to use the `src.id` and `edge.dst` columns
70+ // in the send_messages function. This is because how polars works, when joining DataFrames,
71+ // it will keep only the left-hand side of the joins, thus, we need to use the `src.id` and
72+ // `edge.dst` columns to get the correct vertex IDs.
73+ let message_direction = match message_direction {
74+ MessageReceiver :: Src => Pregel :: src ( ColumnIdentifier :: Id ) ,
75+ MessageReceiver :: Dst => Pregel :: edge ( ColumnIdentifier :: Dst ) ,
76+ } ;
77+ // Now we create the `SendMessage` struct with everything set up.
78+ SendMessage {
79+ message_direction,
80+ send_message,
81+ }
82+ }
83+ }
84+
4085/// The Pregel struct represents a Pregel computation with various parameters and
4186/// expressions.
4287///
@@ -90,11 +135,13 @@ pub struct Pregel {
90135 /// `initial_message` is an expression that defines the initial message that
91136 /// each vertex in the graph will receive before the computation starts.
92137 initial_message : Expr ,
93- /// `send_messages` is a tuple containing two expressions. The first expression
94- /// determines whether the message will go from Src to Dst or vice-versa. The
95- /// second expression represents the message sending function that determines
96- /// which messages to send from a vertex to its neighbors.
97- send_messages : ( Expr , Expr ) ,
138+ /// The `send_messages` property is a vector of `SendMessage` structs that represent
139+ /// the message sending functions. The `SendMessage` struct contains two expressions.
140+ /// The first expression represents the message sending function that determines whether
141+ /// the message will go from Src to Dst or vice-versa. The second expression represents
142+ /// the message sending function that determines which messages to send from a
143+ /// vertex to its neighbors.
144+ send_messages : Vec < SendMessage > ,
98145 /// `aggregate_messages` is an expression that defines how messages sent to a
99146 /// vertex should be aggregated. In Pregel, messages are sent from one vertex
100147 /// to another and can be aggregated before being processed by the receiving
@@ -162,11 +209,13 @@ pub struct PregelBuilder {
162209 /// `initial_message` is an expression that defines the initial message that
163210 /// each vertex in the graph will receive before the computation starts.
164211 initial_message : Expr ,
165- /// `send_messages` is a tuple containing two expressions. The first expression
166- /// determines whether the message will go from Src to Dst or vice-versa. The
167- /// second expression represents the message sending function that determines
168- /// which messages to send from a vertex to its neighbors.
169- send_messages : ( Expr , Expr ) ,
212+ /// The `send_messages` property is a vector of `SendMessage` structs that represent
213+ /// the message sending functions. The `SendMessage` struct contains two expressions.
214+ /// The first expression represents the message sending function that determines whether
215+ /// the message will go from Src to Dst or vice-versa. The second expression represents
216+ /// the message sending function that determines which messages to send from a
217+ /// vertex to its neighbors.
218+ send_messages : Vec < SendMessage > ,
170219 /// `aggregate_messages` is an expression that defines how messages sent to a
171220 /// vertex should be aggregated. In Pregel, messages are sent from one vertex
172221 /// to another and can be aggregated before being processed by the receiving
@@ -219,7 +268,7 @@ impl PregelBuilder {
219268 max_iterations : 10 ,
220269 vertex_column : ColumnIdentifier :: Custom ( "aux" ) ,
221270 initial_message : Default :: default ( ) ,
222- send_messages : ( Default :: default ( ) , Default :: default ( ) ) ,
271+ send_messages : Default :: default ( ) ,
223272 aggregate_messages : Default :: default ( ) ,
224273 v_prog : Default :: default ( ) ,
225274 }
@@ -289,7 +338,45 @@ impl PregelBuilder {
289338 }
290339
291340 /// This function sets the message sending behavior for a Pregel computation in
292- /// Rust.
341+ /// Rust. Chaining this method allows for multiple message sending behaviors to be
342+ /// specified for a single Pregel computation.
343+ ///
344+ /// # Examples
345+ ///
346+ /// ```rust
347+ /// use polars::prelude::*;
348+ /// use pregel_rs::graph_frame::GraphFrame;
349+ /// use pregel_rs::pregel::ColumnIdentifier::{Custom, Dst, Id, Src};
350+ /// use pregel_rs::pregel::{MessageReceiver, Pregel, PregelBuilder};
351+ /// use std::error::Error;
352+ ///
353+ /// // Simple example of a Pregel algorithm where we chain several `send_messages` calls. In
354+ /// // this example, we send a message to the source of an edge and then to the destination of
355+ /// // the same edge. It has no real use case, but it demonstrates how to chain multiple calls.
356+ /// fn main() -> Result<(), Box<dyn Error>> {
357+ /// let edges = df![
358+ /// Src.as_ref() => [0, 1, 1, 2, 2, 3],
359+ /// Dst.as_ref() => [1, 0, 3, 1, 3, 2],
360+ /// ]?;
361+ ///
362+ /// let vertices = df![
363+ /// Id.as_ref() => [0, 1, 2, 3],
364+ /// Custom("value").as_ref() => [3, 6, 2, 1],
365+ /// ]?;
366+ ///
367+ /// let pregel = PregelBuilder::new(GraphFrame::new(vertices, edges)?)
368+ /// .max_iterations(4)
369+ /// .with_vertex_column(Custom("aux"))
370+ /// .initial_message(lit(0))
371+ /// .send_messages(MessageReceiver::Src, lit(1))
372+ /// .send_messages(MessageReceiver::Dst, lit(-1))
373+ /// .aggregate_messages(Pregel::msg(None).sum())
374+ /// .v_prog(Pregel::msg(None) + lit(1))
375+ /// .build();
376+ ///
377+ /// Ok(println!("{:?}", pregel.run()))
378+ /// }
379+ /// ```
293380 ///
294381 /// Arguments:
295382 ///
@@ -298,8 +385,7 @@ impl PregelBuilder {
298385 /// computation.
299386 /// * `send_messages`: `send_messages` is a parameter of type `Expr`. It is used to
300387 /// specify the function that will be applied to each vertex in the graph to send
301- /// messages to its neighboring vertices. The `send_messages` function takes two
302- /// arguments: the first argument is the vertex ID of the current vertex, and
388+ /// messages to its neighboring vertices.
303389 ///
304390 /// Returns:
305391 ///
@@ -308,16 +394,7 @@ impl PregelBuilder {
308394 /// multiple methods can be called on the same struct instance in a single
309395 /// expression.
310396 pub fn send_messages ( mut self , to : MessageReceiver , send_messages : Expr ) -> Self {
311- // We make this in this manner because we want to use the `src.id` and `edge.dst` columns
312- // in the send_messages function. This is because how polars works, when joining dataframes,
313- // it will keep only the left-hand side of the joins, thus, we need to use the `src.id` and
314- // `edge.dst` columns to get the correct vertex IDs.
315- let to = match to {
316- MessageReceiver :: Src => Pregel :: src ( ColumnIdentifier :: Id ) ,
317- MessageReceiver :: Dst => Pregel :: edge ( ColumnIdentifier :: Dst ) ,
318- } ;
319- // Now we can set the send_messages field of the struct to the provided expression.
320- self . send_messages = ( to, send_messages) ;
397+ self . send_messages . push ( SendMessage :: new ( to, send_messages) ) ;
321398 self
322399 }
323400
@@ -563,11 +640,22 @@ impl Pregel {
563640 // We create a tuple where we store the column names of the `send_messages` DataFrame. We use
564641 // the `alias` method to ensure that the column names are properly qualified. We also
565642 // do the same for the `aggregate_messages` Expr. And the same with the `v_prog` Expr.
566- let ( send_messages_ids, send_messages_msg) = self . send_messages ;
567- let ( send_messages_ids, send_messages_msg) = (
568- send_messages_ids. alias ( & Self :: alias ( & ColumnIdentifier :: Msg , ColumnIdentifier :: Id ) ) ,
569- send_messages_msg. alias ( ColumnIdentifier :: Pregel . as_ref ( ) ) ,
570- ) ;
643+ let ( mut send_messages_ids, mut send_messages_msg) : ( Vec < Expr > , Vec < Expr > ) = self
644+ . send_messages
645+ . iter ( )
646+ . map ( |send_message| {
647+ let message_direction = & send_message. message_direction ;
648+ let send_message_expr = & send_message. send_message ;
649+ (
650+ message_direction
651+ . to_owned ( )
652+ . alias ( & Self :: alias ( & ColumnIdentifier :: Msg , ColumnIdentifier :: Id ) ) ,
653+ send_message_expr
654+ . to_owned ( )
655+ . alias ( ColumnIdentifier :: Pregel . as_ref ( ) ) ,
656+ )
657+ } )
658+ . unzip ( ) ;
571659 let aggregate_messages = self
572660 . aggregate_messages
573661 . alias ( ColumnIdentifier :: Pregel . as_ref ( ) ) ;
@@ -628,14 +716,12 @@ impl Pregel {
628716 // are computed by performing an aggregation on the `triplets_df` DataFrame. The aggregation
629717 // is performed on the `msg` column of the `triplets_df` DataFrame, and the aggregation
630718 // function is the one set by the user at the initialization of the model.
631- let sends_messages_ids_df = & send_messages_ids;
632- let send_messages_msg_df = & send_messages_msg;
719+ let send_messages = & mut send_messages_ids; // we create a mutable reference to the `send_messages_ids` Vector
720+ let send_messages_msg_df = & mut send_messages_msg; // we create a mutable reference to the `send_messages_msg` Vector
721+ send_messages. append ( send_messages_msg_df) ; // we append the `send_messages_msg` Vector to the `send_messages` Vector
633722 let aggregate_messages_df = & aggregate_messages;
634723 let message_df = triplets_df
635- . select ( vec ! [
636- sends_messages_ids_df. to_owned( ) ,
637- send_messages_msg_df. to_owned( ) ,
638- ] )
724+ . select ( send_messages)
639725 . groupby ( [ Self :: msg ( Some ( ColumnIdentifier :: Id ) ) ] )
640726 . agg ( [ aggregate_messages_df. to_owned ( ) ] ) ;
641727 // We Compute the new values for the vertices. Note that we have to check for possibly
@@ -685,29 +771,52 @@ impl Pregel {
685771#[ cfg( test) ]
686772mod tests {
687773 use crate :: graph_frame:: GraphFrame ;
688- use crate :: pregel:: ColumnIdentifier :: { Custom , Dst , Id , Src } ;
689- use crate :: pregel:: { MessageReceiver , Pregel , PregelBuilder } ;
774+ use crate :: pregel:: { ColumnIdentifier , MessageReceiver , Pregel , PregelBuilder , SendMessage } ;
690775 use polars:: prelude:: * ;
691776 use std:: error:: Error ;
692777
693- fn pagerank_builder ( iterations : u8 ) -> Result < Pregel , Box < dyn Error > > {
694- let edges = df ! [
695- Src . as_ref( ) => [ 0 , 0 , 1 , 2 , 3 , 4 , 4 , 4 ] ,
696- Dst . as_ref( ) => [ 1 , 2 , 2 , 3 , 3 , 1 , 2 , 3 ] ,
697- ] ?;
778+ fn pagerank_graph ( ) -> Result < GraphFrame , String > {
779+ let edges = match df ! [
780+ ColumnIdentifier :: Src . as_ref( ) => [ 0 , 0 , 1 , 2 , 3 , 4 , 4 , 4 ] ,
781+ ColumnIdentifier :: Dst . as_ref( ) => [ 1 , 2 , 2 , 3 , 3 , 1 , 2 , 3 ] ,
782+ ] {
783+ Ok ( edges) => edges,
784+ Err ( _) => return Err ( String :: from ( "Error creating the edges DataFrame" ) ) ,
785+ } ;
698786
699- let vertices = GraphFrame :: from_edges ( edges. clone ( ) ) ?. out_degrees ( ) ?;
787+ let graph = match GraphFrame :: from_edges ( edges. clone ( ) ) {
788+ Ok ( graph) => graph,
789+ Err ( _) => return Err ( String :: from ( "Error creating the vertices DataFrame" ) ) ,
790+ } ;
791+
792+ let vertices = match graph. out_degrees ( ) {
793+ Ok ( vertices) => vertices,
794+ Err ( _) => {
795+ return Err ( String :: from (
796+ "Error creating the vertices out degree DataFrame" ,
797+ ) )
798+ }
799+ } ;
700800
801+ match GraphFrame :: new ( vertices, edges) {
802+ Ok ( graph) => Ok ( graph) ,
803+ Err ( _) => Err ( String :: from ( "Error creating the graph" ) ) ,
804+ }
805+ }
806+
807+ fn pagerank_builder ( iterations : u8 ) -> Result < Pregel , Box < dyn Error > > {
808+ let graph = pagerank_graph ( ) ?;
701809 let damping_factor = 0.85 ;
702- let num_vertices: f64 = vertices. column ( Id . as_ref ( ) ) ?. len ( ) as f64 ;
810+ let num_vertices: f64 = graph . vertices . column ( ColumnIdentifier :: Id . as_ref ( ) ) ?. len ( ) as f64 ;
703811
704- Ok ( PregelBuilder :: new ( GraphFrame :: new ( vertices , edges ) ? )
812+ Ok ( PregelBuilder :: new ( graph )
705813 . max_iterations ( iterations)
706- . with_vertex_column ( Custom ( "rank" ) )
814+ . with_vertex_column ( ColumnIdentifier :: Custom ( "rank" ) )
707815 . initial_message ( lit ( 1.0 / num_vertices) )
708816 . send_messages (
709817 MessageReceiver :: Dst ,
710- Pregel :: src ( Custom ( "rank" ) ) / Pregel :: src ( Custom ( "out_degree" ) ) ,
818+ Pregel :: src ( ColumnIdentifier :: Custom ( "rank" ) )
819+ / Pregel :: src ( ColumnIdentifier :: Custom ( "out_degree" ) ) ,
711820 )
712821 . aggregate_messages ( Pregel :: msg ( None ) . sum ( ) )
713822 . v_prog (
@@ -777,16 +886,16 @@ mod tests {
777886
778887 fn max_value_graph ( ) -> Result < GraphFrame , String > {
779888 let edges = match df ! [
780- Src . as_ref( ) => [ 0 , 1 , 1 , 2 , 2 , 3 ] ,
781- Dst . as_ref( ) => [ 1 , 0 , 3 , 1 , 3 , 2 ] ,
889+ ColumnIdentifier :: Src . as_ref( ) => [ 0 , 1 , 1 , 2 , 2 , 3 ] ,
890+ ColumnIdentifier :: Dst . as_ref( ) => [ 1 , 0 , 3 , 1 , 3 , 2 ] ,
782891 ] {
783892 Ok ( edges) => edges,
784893 Err ( _) => return Err ( String :: from ( "Error creating the edges DataFrame" ) ) ,
785894 } ;
786895
787896 let vertices = match df ! [
788- Id . as_ref( ) => [ 0 , 1 , 2 , 3 ] ,
789- Custom ( "value" ) . as_ref( ) => [ 3 , 6 , 2 , 1 ] ,
897+ ColumnIdentifier :: Id . as_ref( ) => [ 0 , 1 , 2 , 3 ] ,
898+ ColumnIdentifier :: Custom ( "value" ) . as_ref( ) => [ 3 , 6 , 2 , 1 ] ,
790899 ] {
791900 Ok ( vertices) => vertices,
792901 Err ( _) => return Err ( String :: from ( "Error creating the vertices DataFrame" ) ) ,
@@ -802,14 +911,17 @@ mod tests {
802911 Ok ( Pregel {
803912 graph : max_value_graph ( ) ?,
804913 max_iterations : iterations,
805- vertex_column : Custom ( "max_value" ) ,
806- initial_message : col ( Custom ( "value" ) . as_ref ( ) ) ,
807- send_messages : (
808- Pregel :: edge ( MessageReceiver :: into ( MessageReceiver :: Dst ) ) ,
809- Pregel :: src ( Custom ( "max_value " ) ) ,
810- ) ,
914+ vertex_column : ColumnIdentifier :: Custom ( "max_value" ) ,
915+ initial_message : col ( ColumnIdentifier :: Custom ( "value" ) . as_ref ( ) ) ,
916+ send_messages : vec ! [ SendMessage :: new (
917+ MessageReceiver :: Dst ,
918+ Pregel :: src( ColumnIdentifier :: Custom ( "value " ) ) ,
919+ ) ] ,
811920 aggregate_messages : Pregel :: msg ( None ) . max ( ) ,
812- v_prog : max_exprs ( [ col ( Custom ( "max_value" ) . as_ref ( ) ) , Pregel :: msg ( None ) ] ) ,
921+ v_prog : max_exprs ( [
922+ col ( ColumnIdentifier :: Custom ( "max_value" ) . as_ref ( ) ) ,
923+ Pregel :: msg ( None ) ,
924+ ] ) ,
813925 } )
814926 }
815927
@@ -874,7 +986,7 @@ mod tests {
874986 // useful to test the Pregel model.
875987 match PregelBuilder :: new ( graph)
876988 . max_iterations ( 4 )
877- . with_vertex_column ( Custom ( "does_not_matter " ) )
989+ . with_vertex_column ( ColumnIdentifier :: Custom ( "aux " ) )
878990 . initial_message ( lit ( 0 ) ) // we pass the Undefined state to all vertices
879991 . send_messages ( MessageReceiver :: Src , lit ( 0 ) )
880992 . aggregate_messages ( lit ( 0 ) )
@@ -886,4 +998,42 @@ mod tests {
886998 Err ( _) => Err ( String :: from ( "Error running the algorithm" ) ) ,
887999 }
8881000 }
1001+
1002+ #[ test]
1003+ fn test_send_messages_src_dst ( ) -> Result < ( ) , String > {
1004+ let graph = pagerank_graph ( ) ?;
1005+
1006+ let pregel = match PregelBuilder :: new ( graph)
1007+ . max_iterations ( 4 )
1008+ . with_vertex_column ( ColumnIdentifier :: Custom ( "aux" ) )
1009+ . initial_message ( lit ( 0 ) )
1010+ . send_messages ( MessageReceiver :: Src , lit ( 1 ) )
1011+ . send_messages ( MessageReceiver :: Dst , lit ( -1 ) )
1012+ . aggregate_messages ( Pregel :: msg ( None ) . sum ( ) )
1013+ . v_prog ( Pregel :: msg ( None ) + lit ( 1 ) )
1014+ . build ( )
1015+ . run ( )
1016+ {
1017+ Ok ( pregel) => pregel,
1018+ Err ( _) => return Err ( String :: from ( "Error running pregel" ) ) ,
1019+ } ;
1020+
1021+ let sorted_pregel = match pregel. sort ( & [ "id" ] , false ) {
1022+ Ok ( sorted_pregel) => sorted_pregel,
1023+ Err ( _) => return Err ( String :: from ( "Error sorting the DataFrame" ) ) ,
1024+ } ;
1025+
1026+ let ans = match sorted_pregel. column ( "aux" ) {
1027+ Ok ( ans) => ans,
1028+ Err ( _) => return Err ( String :: from ( "Error retrieving the column" ) ) ,
1029+ } ;
1030+
1031+ let expected = Series :: new ( "aux" , [ 3 , 2 , 2 , 2 , 4 ] ) ;
1032+
1033+ if ans. eq ( & expected) {
1034+ Ok ( ( ) )
1035+ } else {
1036+ Err ( String :: from ( "The resulting DataFrame is not correct" ) )
1037+ }
1038+ }
8891039}
0 commit comments