@@ -8,11 +8,13 @@ use corro_types::{
88 base:: { CrsqlDbVersion , CrsqlSeq , Version } ,
99 broadcast:: { BiPayload , ChangeSource , ChangeV1 , Changeset , Timestamp } ,
1010 change:: ChunkedChanges ,
11+ config:: FollowBroadcast ,
1112 sqlite:: SqlitePoolError ,
1213} ;
1314use futures:: { Stream , StreamExt } ;
1415use metrics:: counter;
1516use quinn:: { RecvStream , SendStream } ;
17+ use rand:: { rngs:: OsRng , Rng } ;
1618use rusqlite:: { params_from_iter, Row , ToSql } ;
1719use speedy:: { Readable , Writable } ;
1820use tokio:: { sync:: mpsc, task:: block_in_place} ;
@@ -238,6 +240,8 @@ pub async fn read_follow_msg<R: Stream<Item = std::io::Result<BytesMut>> + Unpin
238240pub async fn recv_follow (
239241 agent : & Agent ,
240242 mut read : FramedRead < RecvStream , LengthDelimitedCodec > ,
243+ local_only : bool ,
244+ broadcast : Option < & FollowBroadcast > ,
241245) -> Result < Option < CrsqlDbVersion > , FollowError > {
242246 let mut last_db_version = None ;
243247 let tx_changes = agent. tx_changes ( ) ;
@@ -255,8 +259,17 @@ pub async fn recv_follow(
255259 "received changeset for version(s) {:?} and db_version {db_version:?}" ,
256260 changeset. versions( )
257261 ) ;
262+ let change_src = if local_only
263+ || broadcast
264+ . map ( |bcast| should_broadcast ( & changeset. actor_id , bcast) )
265+ . unwrap_or ( false )
266+ {
267+ ChangeSource :: Broadcast
268+ } else {
269+ ChangeSource :: Follow
270+ } ;
258271 tx_changes
259- . send ( ( changeset, ChangeSource :: Follow ) )
272+ . send ( ( changeset, change_src ) )
260273 . await
261274 . map_err ( |_| FollowError :: ChannelClosed ) ?;
262275 if let Some ( db_version) = db_version {
@@ -270,12 +283,20 @@ pub async fn recv_follow(
270283 Ok ( last_db_version)
271284}
272285
286+ fn should_broadcast ( actor_id : & ActorId , broadcast : & FollowBroadcast ) -> bool {
287+ match broadcast {
288+ FollowBroadcast :: ActorIds ( set) => set. contains ( actor_id) ,
289+ FollowBroadcast :: Percent ( percent) => OsRng . gen_range ( 0 ..100 ) < * percent,
290+ }
291+ }
292+
273293pub async fn follow (
274294 agent : & Agent ,
275295 mut tx : SendStream ,
276296 recv : RecvStream ,
277297 from : Option < CrsqlDbVersion > ,
278298 local_only : bool ,
299+ broadcast : Option < & FollowBroadcast > ,
279300) -> Result < Option < CrsqlDbVersion > , FollowError > {
280301 let mut codec = LengthDelimitedCodec :: builder ( )
281302 . max_frame_length ( 100 * 1_024 * 1_024 )
@@ -302,7 +323,7 @@ pub async fn follow(
302323 . new_codec ( ) ,
303324 ) ;
304325
305- recv_follow ( agent, framed) . await
326+ recv_follow ( agent, framed, local_only , broadcast ) . await
306327}
307328
308329#[ cfg( test) ]
0 commit comments