Skip to content

Commit ee0e1a9

Browse files
committed
add broadcast setting to follow settings to allow followers to broadcast under certain conditions (list of actors or percent of changes received)
1 parent 96da032 commit ee0e1a9

File tree

3 files changed

+59
-6
lines changed

3 files changed

+59
-6
lines changed

crates/corro-agent/src/agent/setup.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,16 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age
304304

305305
match transport.open_bi(addr).await {
306306
Ok((tx, rx)) => {
307-
match api::peer::follow::follow(&agent, tx, rx, Some(from), false).await {
307+
match api::peer::follow::follow(
308+
&agent,
309+
tx,
310+
rx,
311+
Some(from),
312+
false,
313+
follow.broadcast.as_ref(),
314+
)
315+
.await
316+
{
308317
Ok(dbv) => {
309318
info!("following terminated, last db version: {dbv:?}");
310319
last_from = dbv;

crates/corro-agent/src/api/peer/follow.rs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@ use corro_types::{
88
base::{CrsqlDbVersion, CrsqlSeq, Version},
99
broadcast::{BiPayload, ChangeSource, ChangeV1, Changeset, Timestamp},
1010
change::ChunkedChanges,
11+
config::FollowBroadcast,
1112
sqlite::SqlitePoolError,
1213
};
1314
use futures::{Stream, StreamExt};
1415
use metrics::counter;
1516
use quinn::{RecvStream, SendStream};
17+
use rand::{rngs::OsRng, Rng};
1618
use rusqlite::{params_from_iter, Row, ToSql};
1719
use speedy::{Readable, Writable};
1820
use tokio::{sync::mpsc, task::block_in_place};
@@ -238,6 +240,8 @@ pub async fn read_follow_msg<R: Stream<Item = std::io::Result<BytesMut>> + Unpin
238240
pub async fn recv_follow(
239241
agent: &Agent,
240242
mut read: FramedRead<RecvStream, LengthDelimitedCodec>,
243+
local_only: bool,
244+
broadcast: Option<&FollowBroadcast>,
241245
) -> Result<Option<CrsqlDbVersion>, FollowError> {
242246
let mut last_db_version = None;
243247
let tx_changes = agent.tx_changes();
@@ -255,8 +259,17 @@ pub async fn recv_follow(
255259
"received changeset for version(s) {:?} and db_version {db_version:?}",
256260
changeset.versions()
257261
);
262+
let change_src = if local_only
263+
|| broadcast
264+
.map(|bcast| should_broadcast(&changeset.actor_id, bcast))
265+
.unwrap_or(false)
266+
{
267+
ChangeSource::Broadcast
268+
} else {
269+
ChangeSource::Follow
270+
};
258271
tx_changes
259-
.send((changeset, ChangeSource::Follow))
272+
.send((changeset, change_src))
260273
.await
261274
.map_err(|_| FollowError::ChannelClosed)?;
262275
if let Some(db_version) = db_version {
@@ -270,12 +283,20 @@ pub async fn recv_follow(
270283
Ok(last_db_version)
271284
}
272285

286+
fn should_broadcast(actor_id: &ActorId, broadcast: &FollowBroadcast) -> bool {
287+
match broadcast {
288+
FollowBroadcast::ActorIds(set) => set.contains(actor_id),
289+
FollowBroadcast::Percent(percent) => OsRng.gen_range(0..100) < *percent,
290+
}
291+
}
292+
273293
pub async fn follow(
274294
agent: &Agent,
275295
mut tx: SendStream,
276296
recv: RecvStream,
277297
from: Option<CrsqlDbVersion>,
278298
local_only: bool,
299+
broadcast: Option<&FollowBroadcast>,
279300
) -> Result<Option<CrsqlDbVersion>, FollowError> {
280301
let mut codec = LengthDelimitedCodec::builder()
281302
.max_frame_length(100 * 1_024 * 1_024)
@@ -302,7 +323,7 @@ pub async fn follow(
302323
.new_codec(),
303324
);
304325

305-
recv_follow(agent, framed).await
326+
recv_follow(agent, framed, local_only, broadcast).await
306327
}
307328

308329
#[cfg(test)]

crates/corro-types/src/config.rs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
use std::net::{Ipv6Addr, SocketAddr, SocketAddrV6};
1+
use std::{
2+
collections::HashSet,
3+
net::{Ipv6Addr, SocketAddr, SocketAddrV6},
4+
};
25

36
use camino::Utf8PathBuf;
47
use corro_base_types::CrsqlDbVersion;
58
use serde::{Deserialize, Serialize};
69
use serde_with::{formats::PreferOne, serde_as, OneOrMany};
710

11+
use crate::actor::ActorId;
12+
813
pub const DEFAULT_GOSSIP_PORT: u16 = 4001;
914
const DEFAULT_GOSSIP_IDLE_TIMEOUT: u32 = 30;
1015

@@ -38,6 +43,15 @@ pub struct FollowConfig {
3843
pub addr: SocketAddr,
3944
#[serde(default)]
4045
pub from: FollowFrom,
46+
#[serde(default)]
47+
pub broadcast: Option<FollowBroadcast>,
48+
}
49+
50+
#[derive(Debug, Clone, Serialize, Deserialize)]
51+
#[serde(rename_all = "kebab-case")]
52+
pub enum FollowBroadcast {
53+
ActorIds(HashSet<ActorId>),
54+
Percent(u8),
4155
}
4256

4357
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
@@ -385,8 +399,17 @@ impl ConfigBuilder {
385399
self
386400
}
387401

388-
pub fn follow(mut self, addr: SocketAddr, from: FollowFrom) -> Self {
389-
self.follow = Some(FollowConfig { addr, from });
402+
pub fn follow(
403+
mut self,
404+
addr: SocketAddr,
405+
from: FollowFrom,
406+
broadcast: Option<FollowBroadcast>,
407+
) -> Self {
408+
self.follow = Some(FollowConfig {
409+
addr,
410+
from,
411+
broadcast,
412+
});
390413
self
391414
}
392415

0 commit comments

Comments
 (0)