Skip to content

Commit 3265c33

Browse files
committed
Implement inbound RPC.partial dispatch (step 3)
Replaces the stub in GossipRouter.processPartialMessageExtension with the full flow: drop RPCs missing topicID or groupID, then delegate to PartialMessagesAdapterImpl which gets-or-creates the GroupState (with DoS cap enforcement) and calls handler.onIncomingRpc with the live peerStates map.
1 parent 6adf574 commit 3265c33

5 files changed

Lines changed: 356 additions & 7 deletions

File tree

docs/partial-messages.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ Mirror this checklist in issue #435.
364364
`PublishAction<PeerState>` (with `nextPeerState`),
365365
`PublishActionsFn<PeerState>`, `PartialMessagesPeerFeedback`, and
366366
`GroupState` container with TTL + DoS caps. No routing yet.
367-
- [ ] **Step 3** — Inbound `RPC.partial` dispatch: replace the stub at
367+
- [x] **Step 3** — Inbound `RPC.partial` dispatch: replace the stub at
368368
`GossipRouter.kt:476` with the full flow (validate caps, create/update
369369
group state, call `onIncomingRpc`).
370370
- [ ] **Step 4** — Outbound `publishPartial(...)` on the `Gossip` facade;

libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -525,12 +525,19 @@ open class GossipRouter(
525525
partialMessagesExtension: Rpc.PartialMessagesExtension,
526526
receivedFrom: PeerHandler
527527
) {
528-
logger.trace(
529-
"Processing partial message extension message {} from {}",
530-
partialMessagesExtension.toString(),
531-
receivedFrom.peerId
532-
)
533-
// TODO: implement partial message handling (https://github.com/libp2p/jvm-libp2p/issues/435)
528+
val topic = partialMessagesExtension.topicID
529+
if (!partialMessagesExtension.hasTopicID() || topic.isEmpty()) {
530+
logger.debug("Dropping partial message from {}: missing topicID", receivedFrom.peerId)
531+
return
532+
}
533+
534+
if (!partialMessagesExtension.hasGroupID() || partialMessagesExtension.groupID.isEmpty) {
535+
logger.debug("Dropping partial message from {}: missing groupID", receivedFrom.peerId)
536+
return
537+
}
538+
539+
logger.trace("Processing partial message extension for topic {} from {}", topic, receivedFrom.peerId)
540+
partialMessages?.onIncomingRpc(topic, receivedFrom.peerId, partialMessagesExtension)
534541
}
535542

536543
override fun broadcastInbound(msgs: List<PubsubMessage>, receivedFrom: PeerHandler) {

libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesAdapter.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package io.libp2p.pubsub.gossip.partialmessages
22

33
import io.libp2p.core.PeerId
44
import io.libp2p.pubsub.Topic
5+
import pubsub.pb.Rpc
56

67
/**
78
* Type-erased view of the partial-messages subsystem used by [io.libp2p.pubsub.gossip.GossipRouter].
@@ -12,6 +13,7 @@ internal interface PartialMessagesAdapter {
1213
fun onPeerDisconnected(peer: PeerId)
1314
fun onTopicUnsubscribed(topic: Topic)
1415
fun onHeartbeat()
16+
fun onIncomingRpc(topic: Topic, from: PeerId, rpc: Rpc.PartialMessagesExtension)
1517
}
1618

1719
/**
@@ -31,4 +33,10 @@ internal class PartialMessagesAdapterImpl<PeerState>(
3133
override fun onPeerDisconnected(peer: PeerId) = stateStore.onPeerDisconnected(peer)
3234
override fun onTopicUnsubscribed(topic: Topic) = stateStore.onTopicUnsubscribed(topic)
3335
override fun onHeartbeat() = stateStore.onHeartbeat()
36+
37+
override fun onIncomingRpc(topic: Topic, from: PeerId, rpc: Rpc.PartialMessagesExtension) {
38+
val groupId = rpc.groupID.toByteArray().toGroupId()
39+
val groupState = stateStore.getOrCreatePeerGroup(topic, groupId, from) ?: return
40+
handler.onIncomingRpc(from, groupState.peerStates, rpc, feedback)
41+
}
3442
}
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
package io.libp2p.pubsub.gossip.extensions
2+
3+
import com.google.protobuf.ByteString
4+
import io.libp2p.core.PeerId
5+
import io.libp2p.pubsub.PubsubProtocol
6+
import io.libp2p.pubsub.Topic
7+
import io.libp2p.pubsub.gossip.GossipExtension
8+
import io.libp2p.pubsub.gossip.GossipTestsBase
9+
import io.libp2p.pubsub.gossip.partialmessages.PartialMessagesHandler
10+
import io.libp2p.pubsub.gossip.partialmessages.PartialMessagesPeerFeedback
11+
import org.assertj.core.api.Assertions.assertThat
12+
import org.junit.jupiter.api.Test
13+
import pubsub.pb.Rpc
14+
import java.util.concurrent.CopyOnWriteArrayList
15+
16+
private const val TIMEOUT_MS = 500L
17+
18+
class PartialMessagesInboundRpcTest : GossipTestsBase() {
19+
20+
private val topicId = "test-topic"
21+
private val groupIdBytes = "group-1".toByteArray()
22+
23+
/** Records each [onIncomingRpc] call for assertion in tests. */
24+
data class IncomingCall(val from: PeerId, val rpc: Rpc.PartialMessagesExtension)
25+
26+
private val incomingCalls = CopyOnWriteArrayList<IncomingCall>()
27+
28+
private val capturingHandler: PartialMessagesHandler<Unit> =
29+
object : PartialMessagesHandler<Unit> {
30+
override fun onIncomingRpc(
31+
from: PeerId,
32+
peerStates: Map<PeerId, Unit>,
33+
rpc: Rpc.PartialMessagesExtension,
34+
feedback: PartialMessagesPeerFeedback
35+
) {
36+
incomingCalls += IncomingCall(from, rpc)
37+
}
38+
39+
override fun onEmitGossip(
40+
topic: Topic,
41+
groupId: ByteArray,
42+
gossipPeers: Collection<PeerId>,
43+
peerStates: Map<PeerId, Unit>,
44+
feedback: PartialMessagesPeerFeedback
45+
) {}
46+
}
47+
48+
private fun newTest() = TwoRoutersTest(
49+
protocol = PubsubProtocol.Gossip_V_1_3,
50+
enabledGossipExtensions = listOf(GossipExtension.PARTIAL_MESSAGES),
51+
partialMessagesHandler = capturingHandler,
52+
)
53+
54+
private fun partialRpcWith(
55+
topicId: String? = this.topicId,
56+
groupId: ByteArray? = groupIdBytes
57+
): Rpc.RPC {
58+
val ext = Rpc.PartialMessagesExtension.newBuilder().apply {
59+
if (topicId != null) setTopicID(topicId)
60+
if (groupId != null) setGroupID(ByteString.copyFrom(groupId))
61+
}.build()
62+
return Rpc.RPC.newBuilder().setPartial(ext).build()
63+
}
64+
65+
private fun controlExtensionsWithPartial(): Rpc.RPC =
66+
Rpc.RPC.newBuilder().setControl(
67+
Rpc.ControlMessage.newBuilder().setExtensions(
68+
Rpc.ControlExtensions.newBuilder().setPartialMessages(true)
69+
)
70+
).build()
71+
72+
// Drains any currently queued messages from the mock router's outbox
73+
// so later assertions start from a clean slate.
74+
private fun TwoRoutersTest.flushRouter() =
75+
gossipRouter.submitOnEventThread {}.join()
76+
77+
@Test
78+
fun `valid partial RPC after ControlExtensions dispatches to handler`() {
79+
val test = newTest()
80+
test.flushRouter()
81+
82+
test.mockRouter.sendToSingle(controlExtensionsWithPartial())
83+
test.mockRouter.sendToSingle(partialRpcWith())
84+
test.flushRouter()
85+
86+
assertThat(incomingCalls).hasSize(1)
87+
assertThat(incomingCalls[0].rpc.topicID).isEqualTo(topicId)
88+
assertThat(incomingCalls[0].rpc.groupID.toByteArray()).isEqualTo(groupIdBytes)
89+
}
90+
91+
@Test
92+
fun `partial RPC without prior ControlExtensions is ignored`() {
93+
val test = newTest()
94+
test.flushRouter()
95+
96+
test.mockRouter.sendToSingle(partialRpcWith())
97+
test.flushRouter()
98+
99+
assertThat(incomingCalls).isEmpty()
100+
}
101+
102+
@Test
103+
fun `partial RPC with missing topicID is dropped`() {
104+
val test = newTest()
105+
test.flushRouter()
106+
107+
test.mockRouter.sendToSingle(controlExtensionsWithPartial())
108+
test.mockRouter.sendToSingle(partialRpcWith(topicId = null))
109+
test.flushRouter()
110+
111+
assertThat(incomingCalls).isEmpty()
112+
}
113+
114+
@Test
115+
fun `partial RPC with empty topicID is dropped`() {
116+
val test = newTest()
117+
test.flushRouter()
118+
119+
test.mockRouter.sendToSingle(controlExtensionsWithPartial())
120+
test.mockRouter.sendToSingle(partialRpcWith(topicId = ""))
121+
test.flushRouter()
122+
123+
assertThat(incomingCalls).isEmpty()
124+
}
125+
126+
@Test
127+
fun `partial RPC with missing groupID is dropped`() {
128+
val test = newTest()
129+
test.flushRouter()
130+
131+
test.mockRouter.sendToSingle(controlExtensionsWithPartial())
132+
test.mockRouter.sendToSingle(partialRpcWith(groupId = null))
133+
test.flushRouter()
134+
135+
assertThat(incomingCalls).isEmpty()
136+
}
137+
138+
@Test
139+
fun `partial RPC with empty groupID is dropped`() {
140+
val test = newTest()
141+
test.flushRouter()
142+
143+
test.mockRouter.sendToSingle(controlExtensionsWithPartial())
144+
test.mockRouter.sendToSingle(partialRpcWith(groupId = ByteArray(0)))
145+
test.flushRouter()
146+
147+
assertThat(incomingCalls).isEmpty()
148+
}
149+
150+
@Test
151+
fun `partial RPC when extension is disabled is ignored`() {
152+
val test = TwoRoutersTest(
153+
protocol = PubsubProtocol.Gossip_V_1_3,
154+
enabledGossipExtensions = listOf(),
155+
)
156+
test.flushRouter()
157+
158+
test.mockRouter.sendToSingle(controlExtensionsWithPartial())
159+
test.mockRouter.sendToSingle(partialRpcWith())
160+
test.flushRouter()
161+
162+
assertThat(incomingCalls).isEmpty()
163+
}
164+
165+
@Test
166+
fun `multiple valid partial RPCs for different groups all dispatched`() {
167+
val test = newTest()
168+
test.flushRouter()
169+
170+
test.mockRouter.sendToSingle(controlExtensionsWithPartial())
171+
test.mockRouter.sendToSingle(partialRpcWith(groupId = "g1".toByteArray()))
172+
test.mockRouter.sendToSingle(partialRpcWith(groupId = "g2".toByteArray()))
173+
test.flushRouter()
174+
175+
assertThat(incomingCalls).hasSize(2)
176+
}
177+
}
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package io.libp2p.pubsub.gossip.partialmessages
2+
3+
import com.google.protobuf.ByteString
4+
import io.libp2p.core.PeerId
5+
import io.libp2p.pubsub.Topic
6+
import org.assertj.core.api.Assertions.assertThat
7+
import org.junit.jupiter.api.BeforeEach
8+
import org.junit.jupiter.api.Test
9+
import pubsub.pb.Rpc
10+
11+
class PartialMessagesAdapterImplTest {
12+
13+
private val topic = "test-topic"
14+
private val groupIdBytes = "group-1".toByteArray()
15+
private lateinit var peer1: PeerId
16+
private lateinit var peer2: PeerId
17+
private lateinit var capturedCalls: MutableList<IncomingRpcCall<String>>
18+
private lateinit var adapter: PartialMessagesAdapterImpl<String>
19+
20+
data class IncomingRpcCall<S>(
21+
val from: PeerId,
22+
val peerStates: Map<PeerId, S>,
23+
val rpc: Rpc.PartialMessagesExtension
24+
)
25+
26+
private fun makeHandler() = object : PartialMessagesHandler<String> {
27+
override fun onIncomingRpc(
28+
from: PeerId,
29+
peerStates: Map<PeerId, String>,
30+
rpc: Rpc.PartialMessagesExtension,
31+
feedback: PartialMessagesPeerFeedback
32+
) {
33+
capturedCalls += IncomingRpcCall(from, peerStates, rpc)
34+
}
35+
36+
override fun onEmitGossip(
37+
topic: Topic,
38+
groupId: ByteArray,
39+
gossipPeers: Collection<PeerId>,
40+
peerStates: Map<PeerId, String>,
41+
feedback: PartialMessagesPeerFeedback
42+
) {}
43+
}
44+
45+
@BeforeEach
46+
fun setup() {
47+
peer1 = PeerId.random()
48+
peer2 = PeerId.random()
49+
capturedCalls = mutableListOf()
50+
adapter = PartialMessagesAdapterImpl(
51+
handler = makeHandler(),
52+
stateStore = PartialGroupStateStore(groupTtlHeartbeats = 3),
53+
feedback = NopPartialMessagesFeedback
54+
)
55+
}
56+
57+
private fun buildRpc(
58+
topicId: String = topic,
59+
groupId: ByteArray = groupIdBytes,
60+
partialMessage: ByteArray? = null,
61+
partsMetadata: ByteArray? = null
62+
): Rpc.PartialMessagesExtension =
63+
Rpc.PartialMessagesExtension.newBuilder()
64+
.setTopicID(topicId)
65+
.setGroupID(ByteString.copyFrom(groupId))
66+
.apply {
67+
if (partialMessage != null) setPartialMessage(ByteString.copyFrom(partialMessage))
68+
if (partsMetadata != null) setPartsMetadata(ByteString.copyFrom(partsMetadata))
69+
}
70+
.build()
71+
72+
@Test
73+
fun `dispatches valid RPC to handler`() {
74+
val rpc = buildRpc()
75+
76+
adapter.onIncomingRpc(topic, peer1, rpc)
77+
78+
assertThat(capturedCalls).hasSize(1)
79+
assertThat(capturedCalls[0].from).isEqualTo(peer1)
80+
assertThat(capturedCalls[0].rpc).isEqualTo(rpc)
81+
}
82+
83+
@Test
84+
fun `peerStates map is empty on first RPC for a fresh group`() {
85+
adapter.onIncomingRpc(topic, peer1, buildRpc())
86+
87+
assertThat(capturedCalls[0].peerStates).isEmpty()
88+
}
89+
90+
@Test
91+
fun `second RPC for the same group reuses the same peerStates object`() {
92+
adapter.onIncomingRpc(topic, peer1, buildRpc())
93+
adapter.onIncomingRpc(topic, peer2, buildRpc())
94+
95+
assertThat(capturedCalls).hasSize(2)
96+
// Both calls receive the same live GroupState.peerStates reference
97+
assertThat(capturedCalls[0].peerStates).isSameAs(capturedCalls[1].peerStates)
98+
}
99+
100+
@Test
101+
fun `optional partialMessage and partsMetadata are forwarded to handler`() {
102+
val rpc = buildRpc(
103+
partialMessage = byteArrayOf(1, 2, 3),
104+
partsMetadata = byteArrayOf(0xFF.toByte())
105+
)
106+
107+
adapter.onIncomingRpc(topic, peer1, rpc)
108+
109+
assertThat(capturedCalls[0].rpc.partialMessage.toByteArray()).isEqualTo(byteArrayOf(1, 2, 3))
110+
assertThat(capturedCalls[0].rpc.partsMetadata.toByteArray()).isEqualTo(byteArrayOf(0xFF.toByte()))
111+
}
112+
113+
@Test
114+
fun `handler not called when per-topic DoS cap is exceeded`() {
115+
val store = PartialGroupStateStore<String>(
116+
groupTtlHeartbeats = 3,
117+
peerInitiatedGroupLimitPerTopic = 1
118+
)
119+
val capped = PartialMessagesAdapterImpl(
120+
handler = makeHandler(),
121+
stateStore = store,
122+
feedback = NopPartialMessagesFeedback
123+
)
124+
125+
capped.onIncomingRpc(topic, peer1, buildRpc(groupId = "g1".toByteArray()))
126+
capped.onIncomingRpc(topic, peer1, buildRpc(groupId = "g2".toByteArray()))
127+
128+
assertThat(capturedCalls).hasSize(1)
129+
}
130+
131+
@Test
132+
fun `handler not called when per-peer DoS cap is exceeded`() {
133+
val store = PartialGroupStateStore<String>(
134+
groupTtlHeartbeats = 3,
135+
peerInitiatedGroupLimitPerTopicPerPeer = 1
136+
)
137+
val capped = PartialMessagesAdapterImpl(
138+
handler = makeHandler(),
139+
stateStore = store,
140+
feedback = NopPartialMessagesFeedback
141+
)
142+
143+
capped.onIncomingRpc(topic, peer1, buildRpc(groupId = "g1".toByteArray()))
144+
capped.onIncomingRpc(topic, peer1, buildRpc(groupId = "g2".toByteArray()))
145+
146+
assertThat(capturedCalls).hasSize(1)
147+
}
148+
149+
@Test
150+
fun `different topics create independent groups`() {
151+
adapter.onIncomingRpc("topic-a", peer1, buildRpc(topicId = "topic-a"))
152+
adapter.onIncomingRpc("topic-b", peer1, buildRpc(topicId = "topic-b"))
153+
154+
assertThat(capturedCalls).hasSize(2)
155+
assertThat(capturedCalls[0].peerStates).isNotSameAs(capturedCalls[1].peerStates)
156+
}
157+
}

0 commit comments

Comments
 (0)