diff --git a/docs/partial-messages.md b/docs/partial-messages.md index 4c715a7b..d8d73a03 100644 --- a/docs/partial-messages.md +++ b/docs/partial-messages.md @@ -364,10 +364,10 @@ Mirror this checklist in issue #435. `PublishAction` (with `nextPeerState`), `PublishActionsFn`, `PartialMessagesPeerFeedback`, and `GroupState` container with TTL + DoS caps. No routing yet. -- [ ] **Step 3** — Inbound `RPC.partial` dispatch: replace the stub at +- [x] **Step 3** — Inbound `RPC.partial` dispatch: replace the stub at `GossipRouter.kt:476` with the full flow (validate caps, create/update group state, call `onIncomingRpc`). -- [ ] **Step 4** — Outbound `publishPartial(...)` on the `Gossip` facade; +- [x] **Step 4** — Outbound `publishPartial(...)` on the `Gossip` facade; route through `GossipRpcPartsQueue` (do **not** bypass — PR #433 got this wrong). Enforce the "omit `partialMessage` when peer supports but didn't request" MUST. diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt index a72b93cc..0fee5364 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt @@ -147,11 +147,22 @@ abstract class AbstractRouter( override fun onPeerActive(peer: PeerHandler) { val partsQueue = pendingRpcParts.getQueue(peer) subscribedTopics.forEach { - partsQueue.addSubscribe(it) + enqueueSubscribe(partsQueue, it) } flushPending(peer) } + /** + * Enqueues a subscribe announcement for [topic] onto [partsQueue]. + * + * The default implementation emits a bare subscribe with no per-topic options. + * Subclasses (e.g. GossipRouter) override this to attach per-topic options + * such as partial-message flags. + */ + protected open fun enqueueSubscribe(partsQueue: RpcPartsQueue, topic: Topic) { + partsQueue.addSubscribe(topic) + } + protected open fun notifyMalformedMessage(peer: PeerHandler) {} protected open fun notifyUnseenMessage(peer: PeerHandler, msg: PubsubMessage) {} protected open fun notifyNonSubscribedMessage(peer: PeerHandler, msg: Rpc.Message) {} @@ -172,7 +183,17 @@ abstract class AbstractRouter( } try { - val subscriptions = msg.subscriptionsList.map { PubsubSubscription(it.topicid, it.subscribe) } + val subscriptions = msg.subscriptionsList.map { + // Per partial-messages spec: flags MUST be ignored on subscribe=false, and the + // receiving side coerces supportsSendingPartial := requestsPartial || supportsSendingPartial. + // The coercion rule is also applied on the outbound side by GossipRouter. + PubsubSubscription( + topic = it.topicid, + subscribe = it.subscribe, + requestsPartial = it.subscribe && it.requestsPartial, + supportsSendingPartial = it.subscribe && (it.supportsSendingPartial || it.requestsPartial) + ) + } subscriptionFilter .filterIncomingSubscriptions(subscriptions, peersTopics.getByFirst(peer)) .forEach { handleMessageSubscriptions(peer, it) } @@ -301,7 +322,20 @@ abstract class AbstractRouter( } } - private fun handleMessageSubscriptions(peer: PeerHandler, msg: PubsubSubscription) { + /** + * Applies a single filtered inbound subscription to the router's state. + * + * Called once per `SubOpts` on the pubsub event loop, after + * [SubscriptionFilter.filterIncomingSubscriptions] has run. Subclasses may + * override to react to subscription state changes (for example, to track + * per-topic capability flags). Overrides MUST call `super` so that + * [peersTopics] stays in sync. + * + * [msg] carries the protocol-level flags already normalised by the caller: + * for `subscribe=false` frames, extension flags are zeroed before reaching + * this method. + */ + protected open fun handleMessageSubscriptions(peer: PeerHandler, msg: PubsubSubscription) { if (msg.subscribe) { peersTopics.add(peer, msg.topic) } else { @@ -319,7 +353,7 @@ abstract class AbstractRouter( } protected open fun subscribe(topic: Topic) { - activePeers.forEach { pendingRpcParts.getQueue(it).addSubscribe(topic) } + activePeers.forEach { enqueueSubscribe(pendingRpcParts.getQueue(it), topic) } subscribedTopics += topic } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubRouter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubRouter.kt index c960fdb6..cce99b60 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubRouter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubRouter.kt @@ -13,7 +13,12 @@ typealias Topic = String typealias MessageId = WBytes typealias PubsubMessageFactory = (Rpc.Message) -> PubsubMessage -data class PubsubSubscription(val topic: Topic, val subscribe: Boolean) +data class PubsubSubscription( + val topic: Topic, + val subscribe: Boolean, + val requestsPartial: Boolean = false, + val supportsSendingPartial: Boolean = false +) interface PubsubMessage { val protobufMessage: Rpc.Message diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/RpcPartsQueue.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/RpcPartsQueue.kt index 11af5f8d..05c6a623 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/RpcPartsQueue.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/RpcPartsQueue.kt @@ -9,14 +9,23 @@ interface RpcPartsQueue { fun addPublish(message: Rpc.Message) fun addSubscribe(topic: Topic) { - addSubscription(topic, SubscriptionStatus.Subscribed) + addSubscribe(topic, requestsPartial = false, supportsSendingPartial = false) + } + + fun addSubscribe(topic: Topic, requestsPartial: Boolean, supportsSendingPartial: Boolean) { + addSubscription(topic, SubscriptionStatus.Subscribed, requestsPartial, supportsSendingPartial) } fun addUnsubscribe(topic: Topic) { - addSubscription(topic, SubscriptionStatus.Unsubscribed) + addSubscription(topic, SubscriptionStatus.Unsubscribed, requestsPartial = false, supportsSendingPartial = false) } - fun addSubscription(topic: Topic, status: SubscriptionStatus) + fun addSubscription( + topic: Topic, + status: SubscriptionStatus, + requestsPartial: Boolean, + supportsSendingPartial: Boolean + ) fun takeMerged(): List } @@ -38,11 +47,20 @@ open class DefaultRpcPartsQueue : RpcPartsQueue { } } - protected data class SubscriptionPart(val topic: Topic, val status: RpcPartsQueue.SubscriptionStatus) : AbstractPart { + protected data class SubscriptionPart( + val topic: Topic, + val status: RpcPartsQueue.SubscriptionStatus, + val requestsPartial: Boolean = false, + val supportsSendingPartial: Boolean = false + ) : AbstractPart { override fun appendToBuilder(builder: Rpc.RPC.Builder) { - builder.addSubscriptionsBuilder().apply { - setTopicid(topic) - setSubscribe(status == RpcPartsQueue.SubscriptionStatus.Subscribed) + val subBuilder = builder.addSubscriptionsBuilder() + subBuilder.topicid = topic + subBuilder.subscribe = status == RpcPartsQueue.SubscriptionStatus.Subscribed + // Per spec: partial flags MUST NOT be sent on unsubscribe (subscribe=false). + if (status == RpcPartsQueue.SubscriptionStatus.Subscribed) { + if (requestsPartial) subBuilder.requestsPartial = true + if (supportsSendingPartial) subBuilder.supportsSendingPartial = true } } } @@ -57,8 +75,13 @@ open class DefaultRpcPartsQueue : RpcPartsQueue { addPart(PublishPart(message)) } - override fun addSubscription(topic: Topic, status: RpcPartsQueue.SubscriptionStatus) { - addPart(SubscriptionPart(topic, status)) + override fun addSubscription( + topic: Topic, + status: RpcPartsQueue.SubscriptionStatus, + requestsPartial: Boolean, + supportsSendingPartial: Boolean + ) { + addPart(SubscriptionPart(topic, status, requestsPartial, supportsSendingPartial)) } override fun takeMerged(): List { diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt index 39100f10..7628345d 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt @@ -10,7 +10,9 @@ import io.libp2p.core.multistream.ProtocolDescriptor import io.libp2p.core.pubsub.PubsubApi import io.libp2p.pubsub.PubsubApiImpl import io.libp2p.pubsub.PubsubProtocol +import io.libp2p.pubsub.Topic import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder +import io.libp2p.pubsub.gossip.partialmessages.PublishActionsFn import io.netty.channel.ChannelHandler import org.slf4j.LoggerFactory import java.util.concurrent.CompletableFuture @@ -32,6 +34,20 @@ class Gossip @JvmOverloads constructor( return router.score.getCachedScore(peerId) } + /** + * Queues outbound [pubsub.pb.Rpc.PartialMessagesExtension] RPCs for [topic]/[groupId] + * by invoking the client's [actionsFn] on the current group state. + * + * Submits to the pubsub event thread; the returned future completes when the RPCs + * have been enqueued and flushed. + */ + fun publishPartial( + topic: Topic, + groupId: ByteArray, + actionsFn: PublishActionsFn<*> + ): CompletableFuture = + router.submitOnEventThread { router.publishPartial(topic, groupId, actionsFn) } + override val protocolDescriptor = when (router.protocol) { PubsubProtocol.Gossip_V_1_3 -> { diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt index bdfe6905..3497882b 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt @@ -7,6 +7,9 @@ import io.libp2p.core.pubsub.ValidationResult import io.libp2p.etc.types.* import io.libp2p.etc.util.P2PService import io.libp2p.pubsub.* +import io.libp2p.pubsub.gossip.partialmessages.PartialMessagesAdapter +import io.libp2p.pubsub.gossip.partialmessages.PublishActionsFn +import io.libp2p.pubsub.gossip.partialmessages.toGroupId import org.slf4j.LoggerFactory import pubsub.pb.Rpc import java.time.Duration @@ -134,6 +137,56 @@ open class GossipRouter( override val pendingRpcParts = PendingRpcPartsMap { DefaultGossipRpcPartsQueue(params) } val gossipExtensionsState = GossipExtensionsState(gossipExtensionsConfig) + val partialSubscriptionState = PartialSubscriptionState() + internal var partialMessages: PartialMessagesAdapter? = null + + /** + * Local per-topic subscription options that affect outbound subscribe announcements. + * Accessed only on the pubsub event loop. + */ + private val localTopicPartialFlags: MutableMap = mutableMapOf() + + /** + * Configures the partial-messages flags advertised on this node's subscribe + * announcements for [topic]. Must be called before [subscribe] for the flags + * to take effect on the initial announcement; a subsequent call will affect + * later re-announcements (e.g. on new peer activation). + * + * Per spec, the send-side also applies the coercion + * `supportsSendingPartial := requestsPartial || supportsSendingPartial`. + */ + fun setTopicPartialFlags(topic: Topic, requestsPartial: Boolean, supportsSendingPartial: Boolean) { + runOnEventThread { + val coerced = PartialSubFlags.coerce(requestsPartial, supportsSendingPartial) + if (coerced == PartialSubFlags.NONE) { + localTopicPartialFlags -= topic + } else { + localTopicPartialFlags[topic] = coerced + } + } + } + + /** + * Queues outbound [pubsub.pb.Rpc.PartialMessagesExtension] RPCs for [topic]/[groupId] + * by invoking the client's [actionsFn] on the current group state. + * + * Must be called on the pubsub event thread. + */ + fun publishPartial(topic: Topic, groupId: ByteArray, actionsFn: PublishActionsFn<*>) { + val adapter = partialMessages ?: return + val gid = groupId.toGroupId() + + fun peerRequestsPartial(peerId: PeerId) = + partialSubscriptionState.peerRequestsPartial(topic, peerId) + + fun enqueue(peerId: PeerId, partialMessage: ByteArray?, partsMetadata: ByteArray?) { + val peerHandler = activePeers.find { it.peerId == peerId } ?: return + pendingRpcParts.getQueue(peerHandler).addPartialMessage(topic, groupId, partialMessage, partsMetadata) + } + + adapter.publishPartial(topic, gid, actionsFn, ::peerRequestsPartial, ::enqueue) + flushAllPending() + } private fun setBackOff(peer: PeerHandler, topic: Topic) = setBackOff(peer, topic, params.pruneBackoff.toMillis()) private fun setBackOff(peer: PeerHandler, topic: Topic, delay: Long) { @@ -161,9 +214,28 @@ open class GossipRouter( acceptRequestsWhitelist -= peer pendingRpcParts.popQueue(peer) // discard them gossipExtensionsState.onPeerDisconnected(peer.peerId) + partialSubscriptionState.onPeerDisconnected(peer.peerId) super.onPeerDisconnected(peer) } + override fun enqueueSubscribe(partsQueue: RpcPartsQueue, topic: Topic) { + val flags = localTopicPartialFlags[topic] ?: PartialSubFlags.NONE + partsQueue.addSubscribe(topic, flags.requestsPartial, flags.supportsSendingPartial) + } + + override fun handleMessageSubscriptions(peer: PeerHandler, msg: PubsubSubscription) { + super.handleMessageSubscriptions(peer, msg) + if (msg.subscribe) { + partialSubscriptionState.setPeerFlags( + msg.topic, + peer.peerId, + PartialSubFlags(msg.requestsPartial, msg.supportsSendingPartial) + ) + } else { + partialSubscriptionState.removePeerFlags(msg.topic, peer.peerId) + } + } + override fun onPeerActive(peer: PeerHandler) { super.onPeerActive(peer) eventBroadcaster.notifyConnected(peer.peerId, peer.getRemoteAddress()) @@ -477,12 +549,19 @@ open class GossipRouter( partialMessagesExtension: Rpc.PartialMessagesExtension, receivedFrom: PeerHandler ) { - logger.trace( - "Processing partial message extension message {} from {}", - partialMessagesExtension.toString(), - receivedFrom.peerId - ) - // TODO: implement partial message handling (https://github.com/libp2p/jvm-libp2p/issues/435) + val topic = partialMessagesExtension.topicID + if (!partialMessagesExtension.hasTopicID() || topic.isEmpty()) { + logger.debug("Dropping partial message from {}: missing topicID", receivedFrom.peerId) + return + } + + if (!partialMessagesExtension.hasGroupID() || partialMessagesExtension.groupID.isEmpty) { + logger.debug("Dropping partial message from {}: missing groupID", receivedFrom.peerId) + return + } + + logger.trace("Processing partial message extension for topic {} from {}", topic, receivedFrom.peerId) + partialMessages?.onIncomingRpc(topic, receivedFrom.peerId, partialMessagesExtension) } override fun broadcastInbound(msgs: List, receivedFrom: PeerHandler) { @@ -615,6 +694,8 @@ open class GossipRouter( super.unsubscribe(topic) mesh[topic]?.copy()?.forEach { prune(it, topic) } mesh -= topic + localTopicPartialFlags -= topic + partialSubscriptionState.removeTopic(topic) } private fun catchingHeartbeat() { diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt index 32e5c908..72b581f3 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt @@ -29,6 +29,8 @@ interface GossipRpcPartsQueue : RpcPartsQueue { // TODO Need to check if we should handle when control extension and extension messages could be separated by split (https://github.com/libp2p/jvm-libp2p/issues/440) fun addControlExtensions(ctrlMessage: Rpc.ControlExtensions) + + fun addPartialMessage(topic: Topic, groupId: ByteArray, partialMessage: ByteArray?, partsMetadata: ByteArray?) } /** @@ -90,6 +92,23 @@ open class DefaultGossipRpcPartsQueue( } } + // Not a data class: ByteArray fields break equals/hashCode in data classes. + protected class PartialMessagePart( + val topic: Topic, + val groupId: ByteArray, + val partialMessage: ByteArray?, + val partsMetadata: ByteArray? + ) : AbstractPart { + override fun appendToBuilder(builder: Rpc.RPC.Builder) { + val pmBuilder = Rpc.PartialMessagesExtension.newBuilder() + .setTopicID(topic) + .setGroupID(groupId.toProtobuf()) + partialMessage?.let { pmBuilder.setPartialMessage(it.toProtobuf()) } + partsMetadata?.let { pmBuilder.setPartsMetadata(it.toProtobuf()) } + builder.setPartial(pmBuilder.build()) + } + } + override fun addIHave(messageId: MessageId, topic: Topic) { addPart(IHavePart(messageId, topic)) } @@ -114,6 +133,10 @@ open class DefaultGossipRpcPartsQueue( addPart(ControlExtensionPart(ctrlMessage)) } + override fun addPartialMessage(topic: Topic, groupId: ByteArray, partialMessage: ByteArray?, partsMetadata: ByteArray?) { + addPart(PartialMessagePart(topic, groupId, partialMessage, partsMetadata)) + } + override fun takeMerged(): List { val ret = mutableListOf() var partIdx = 0 @@ -126,10 +149,12 @@ open class DefaultGossipRpcPartsQueue( var iWantCount = params.maxIWantMessageIds ?: Int.MAX_VALUE var graftCount = params.maxGraftMessages ?: Int.MAX_VALUE var pruneCount = params.maxPruneMessages ?: Int.MAX_VALUE + // proto field `partial` is optional (not repeated): at most 1 per RPC + var partialCount = 1 while (partIdx < parts.size && publishCount > 0 && subscriptionCount > 0 && iHaveCount > 0 && - iWantCount > 0 && graftCount > 0 && pruneCount > 0 + iWantCount > 0 && graftCount > 0 && pruneCount > 0 && partialCount > 0 ) { val part = parts[partIdx++] when (part) { @@ -139,6 +164,7 @@ open class DefaultGossipRpcPartsQueue( is IWantPart -> iWantCount-- is GraftPart -> graftCount-- is PrunePart -> pruneCount-- + is PartialMessagePart -> partialCount-- } part.appendToBuilder(builder) diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/PartialSubscriptionState.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/PartialSubscriptionState.kt new file mode 100644 index 00000000..ae4b7daf --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/PartialSubscriptionState.kt @@ -0,0 +1,89 @@ +package io.libp2p.pubsub.gossip + +import io.libp2p.core.PeerId +import io.libp2p.pubsub.Topic + +data class PartialSubFlags( + val requestsPartial: Boolean, + val supportsSendingPartial: Boolean +) { + companion object { + val NONE = PartialSubFlags(requestsPartial = false, supportsSendingPartial = false) + + /** + * Applies the partial-messages spec coercion + * `supportsSendingPartial := requestsPartial || supportsSendingPartial`. + * + * Per the spec, this rule MUST be applied by both the sender (when + * advertising flags outbound) and the receiver (when parsing inbound + * `SubOpts`). Callers are expected to have already zeroed the flags + * for `subscribe=false` frames before calling this helper. + */ + fun coerce(requestsPartial: Boolean, supportsSendingPartial: Boolean): PartialSubFlags = + PartialSubFlags( + requestsPartial = requestsPartial, + supportsSendingPartial = supportsSendingPartial || requestsPartial + ) + } +} + +/** + * Per-topic, per-peer partial-messages subscription state. + * + * Tracks, for each `(topic, peer)`, the remote peer's `requestsPartial` / + * `supportsSendingPartial` flags as most recently announced via a subscribe + * `SubOpts`. Unsubscribes and peer disconnects drop the corresponding state. + * + * NOT thread-safe: accessed only on the pubsub event loop. + */ +class PartialSubscriptionState { + + private val byTopic: MutableMap> = mutableMapOf() + + /** + * Stores [flags] for `(topic, peer)`. + * + * Passing [PartialSubFlags.NONE] (or any equivalent `PartialSubFlags(false, false)`) + * is treated as a removal: the peer's entry is dropped and, if it was the + * last peer for the topic, the topic entry is GC'd. This keeps the snapshot + * invariant "present ⇔ non-default flags". + */ + fun setPeerFlags(topic: Topic, peer: PeerId, flags: PartialSubFlags) { + if (flags == PartialSubFlags.NONE) { + removePeerFlags(topic, peer) + return + } + byTopic.getOrPut(topic) { mutableMapOf() }[peer] = flags + } + + fun removePeerFlags(topic: Topic, peer: PeerId) { + val peers = byTopic[topic] ?: return + peers.remove(peer) + if (peers.isEmpty()) byTopic.remove(topic) + } + + fun removeTopic(topic: Topic) { + byTopic.remove(topic) + } + + fun onPeerDisconnected(peer: PeerId) { + val emptied = mutableListOf() + for ((topic, peers) in byTopic) { + peers.remove(peer) + if (peers.isEmpty()) emptied += topic + } + emptied.forEach { byTopic.remove(it) } + } + + fun peerFlags(topic: Topic, peer: PeerId): PartialSubFlags = + byTopic[topic]?.get(peer) ?: PartialSubFlags.NONE + + fun peerRequestsPartial(topic: Topic, peer: PeerId) = + peerFlags(topic, peer).requestsPartial + + fun peerSupportsSendingPartial(topic: Topic, peer: PeerId) = + peerFlags(topic, peer).supportsSendingPartial + + internal fun snapshot(): Map> = + byTopic.mapValues { (_, v) -> v.toMap() } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipRouterBuilder.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipRouterBuilder.kt index 214d4b06..0d3cc0a9 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipRouterBuilder.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/builders/GossipRouterBuilder.kt @@ -5,6 +5,7 @@ import io.libp2p.core.pubsub.ValidationResult import io.libp2p.etc.types.lazyVar import io.libp2p.pubsub.* import io.libp2p.pubsub.gossip.* +import io.libp2p.pubsub.gossip.partialmessages.* import java.util.* import java.util.concurrent.Executors import java.util.concurrent.ScheduledExecutorService @@ -40,6 +41,13 @@ open class GossipRouterBuilder( }, val gossipRouterEventListeners: MutableList = mutableListOf(), val enabledGossipExtensions: List = mutableListOf(), + + /** + * Client-supplied handler for the partial-messages extension. + * Required when [GossipExtension.PARTIAL_MESSAGES] is enabled; a build-time + * error is thrown if the extension is enabled without a handler. + */ + var partialMessagesHandler: PartialMessagesHandler<*>? = null, ) { var seenCache: SeenCache> by lazyVar { TTLSeenCache(SimpleSeenCache(), params.seenTTL, currentTimeSupplier) } @@ -73,12 +81,28 @@ open class GossipRouterBuilder( ) router.eventBroadcaster.listeners += gossipRouterEventListeners + router.partialMessages = buildPartialMessagesAdapter() return router } + @Suppress("UNCHECKED_CAST") + private fun buildPartialMessagesAdapter(): PartialMessagesAdapter? { + val handler = partialMessagesHandler ?: return null + return PartialMessagesAdapterImpl( + handler = handler as PartialMessagesHandler, + stateStore = PartialGroupStateStore(), + feedback = NopPartialMessagesFeedback, + ) + } + open fun build(): GossipRouter { if (disposed) throw RuntimeException("The builder was already used") disposed = true + if (enabledGossipExtensions.contains(GossipExtension.PARTIAL_MESSAGES) && partialMessagesHandler == null) { + throw IllegalStateException( + "GossipExtension.PARTIAL_MESSAGES is enabled but no partialMessagesHandler was provided" + ) + } return createGossipRouter() } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialGroupStateStore.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialGroupStateStore.kt new file mode 100644 index 00000000..e3433410 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialGroupStateStore.kt @@ -0,0 +1,174 @@ +package io.libp2p.pubsub.gossip.partialmessages + +import io.libp2p.core.PeerId +import io.libp2p.pubsub.Topic +import org.slf4j.LoggerFactory + +private val logger = LoggerFactory.getLogger(PartialGroupStateStore::class.java) + +const val DEFAULT_GROUP_TTL_HEARTBEATS = 5 +const val DEFAULT_PEER_INITIATED_GROUP_LIMIT_PER_TOPIC = 255 +const val DEFAULT_PEER_INITIATED_GROUP_LIMIT_PER_TOPIC_PER_PEER = 8 + +/** + * Stable, value-based identity for a partial-messages group ID. + * + * Wraps a raw [ByteArray] so it can be used as a [HashMap] key with + * content equality rather than reference equality. + */ +class GroupId(val bytes: ByteArray) { + override fun equals(other: Any?): Boolean = + other is GroupId && bytes.contentEquals(other.bytes) + override fun hashCode(): Int = bytes.contentHashCode() + override fun toString(): String = bytes.joinToString("") { "%02x".format(it) } +} + +fun ByteArray.toGroupId(): GroupId = GroupId(this) + +/** + * Per-(topic, groupId) state container. + * + * [peerStates] is mutable and updated as parts arrive. + * [ttlInHeartbeats] is decremented each heartbeat and reset on [PartialGroupStateStore.resetTtl]. + * [initiatingPeer] is non-null iff [peerInitiated] is true. + * + * NOT thread-safe: accessed only on the pubsub event loop. + */ +class GroupState( + var ttlInHeartbeats: Int, + val peerInitiated: Boolean, + val initiatingPeer: PeerId? +) { + val peerStates: MutableMap = mutableMapOf() +} + +/** + * Stores and manages per-(topic, groupId) [GroupState] entries for the partial-messages + * extension. + * + * DoS caps (matching go-libp2p defaults): + * - [peerInitiatedGroupLimitPerTopic]: max peer-initiated groups across all peers per topic. + * - [peerInitiatedGroupLimitPerTopicPerPeer]: max peer-initiated groups per (topic, peer). + * + * NOT thread-safe: all access must be serialised on the pubsub event loop. + */ +class PartialGroupStateStore( + val groupTtlHeartbeats: Int = DEFAULT_GROUP_TTL_HEARTBEATS, + val peerInitiatedGroupLimitPerTopic: Int = DEFAULT_PEER_INITIATED_GROUP_LIMIT_PER_TOPIC, + val peerInitiatedGroupLimitPerTopicPerPeer: Int = DEFAULT_PEER_INITIATED_GROUP_LIMIT_PER_TOPIC_PER_PEER +) { + private val groups: HashMap>> = hashMapOf() + + fun getGroup(topic: Topic, groupId: GroupId): GroupState? = + groups[topic]?.get(groupId) + + /** + * Returns the group for (topic, groupId), creating it as a locally-initiated group + * if absent. Resets the TTL if the group already exists. + */ + fun getOrCreateLocalGroup(topic: Topic, groupId: GroupId): GroupState { + val topicGroups = groups.getOrPut(topic) { hashMapOf() } + val existing = topicGroups[groupId] + if (existing != null) { + existing.ttlInHeartbeats = groupTtlHeartbeats + return existing + } + return GroupState( + ttlInHeartbeats = groupTtlHeartbeats, + peerInitiated = false, + initiatingPeer = null + ).also { topicGroups[groupId] = it } + } + + /** + * Returns the group for (topic, groupId), creating it as a peer-initiated group if absent. + * Returns null and drops the RPC if either DoS cap would be exceeded. + */ + fun getOrCreatePeerGroup(topic: Topic, groupId: GroupId, peer: PeerId): GroupState? { + val topicGroups = groups.getOrPut(topic) { hashMapOf() } + val existing = topicGroups[groupId] + if (existing != null) return existing + + val totalPeerInitiated = topicGroups.values.count { it.peerInitiated } + if (totalPeerInitiated >= peerInitiatedGroupLimitPerTopic) { + logger.debug( + "Dropping peer-initiated group {} from {}: per-topic cap {} reached for topic {}", + groupId, + peer, + peerInitiatedGroupLimitPerTopic, + topic + ) + return null + } + + val peerTotal = topicGroups.values.count { it.initiatingPeer == peer } + if (peerTotal >= peerInitiatedGroupLimitPerTopicPerPeer) { + logger.debug( + "Dropping peer-initiated group {} from {}: per-peer cap {} reached for topic {}", + groupId, + peer, + peerInitiatedGroupLimitPerTopicPerPeer, + topic + ) + return null + } + + return GroupState( + ttlInHeartbeats = groupTtlHeartbeats, + peerInitiated = true, + initiatingPeer = peer + ).also { topicGroups[groupId] = it } + } + + /** Resets the TTL for (topic, groupId). Called by publishPartial. */ + fun resetTtl(topic: Topic, groupId: GroupId) { + groups[topic]?.get(groupId)?.let { it.ttlInHeartbeats = groupTtlHeartbeats } + } + + /** Returns a read-only snapshot of all groups for [topic]. */ + fun groupsForTopic(topic: Topic): Map> = + groups[topic] ?: emptyMap() + + /** + * Decrements TTLs and garbage-collects expired groups (TTL ≤ 0) and + * groups whose peerStates map has become empty. + */ + fun onHeartbeat() { + val topicIter = groups.entries.iterator() + while (topicIter.hasNext()) { + val (_, topicGroups) = topicIter.next() + val groupIter = topicGroups.entries.iterator() + while (groupIter.hasNext()) { + val (_, group) = groupIter.next() + group.ttlInHeartbeats-- + if (group.ttlInHeartbeats <= 0 || group.peerStates.isEmpty()) { + groupIter.remove() + } + } + if (topicGroups.isEmpty()) topicIter.remove() + } + } + + /** + * Removes [peer] from all group peerStates; garbage-collects groups that + * become empty as a result. + */ + fun onPeerDisconnected(peer: PeerId) { + val topicIter = groups.entries.iterator() + while (topicIter.hasNext()) { + val (_, topicGroups) = topicIter.next() + val groupIter = topicGroups.entries.iterator() + while (groupIter.hasNext()) { + val (_, group) = groupIter.next() + group.peerStates.remove(peer) + if (group.peerStates.isEmpty()) groupIter.remove() + } + if (topicGroups.isEmpty()) topicIter.remove() + } + } + + /** Drops all group state for [topic] (called when we unsubscribe). */ + fun onTopicUnsubscribed(topic: Topic) { + groups.remove(topic) + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesAdapter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesAdapter.kt new file mode 100644 index 00000000..2f3e0737 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesAdapter.kt @@ -0,0 +1,86 @@ +package io.libp2p.pubsub.gossip.partialmessages + +import io.libp2p.core.PeerId +import io.libp2p.pubsub.Topic +import org.slf4j.LoggerFactory +import pubsub.pb.Rpc + +private val logger = LoggerFactory.getLogger(PartialMessagesAdapterImpl::class.java) + +/** + * Type-erased view of the partial-messages subsystem used by [io.libp2p.pubsub.gossip.GossipRouter]. + * + * All methods are called on the pubsub event thread. + */ +internal interface PartialMessagesAdapter { + fun onPeerDisconnected(peer: PeerId) + fun onTopicUnsubscribed(topic: Topic) + fun onHeartbeat() + fun onIncomingRpc(topic: Topic, from: PeerId, rpc: Rpc.PartialMessagesExtension) + + /** + * Executes the client's [PublishActionsFn], updates group state, and enqueues + * outbound [Rpc.PartialMessagesExtension] RPCs via [enqueueFn]. + * + * [peerRequestsPartial] is used to enforce the spec MUST: omit [PublishAction.partialMessage] + * when the peer supports but did not request partial messages. + */ + fun publishPartial( + topic: Topic, + groupId: GroupId, + actionsFn: PublishActionsFn<*>, + peerRequestsPartial: (PeerId) -> Boolean, + enqueueFn: (PeerId, ByteArray?, ByteArray?) -> Unit + ) +} + +/** + * Bridges [GossipRouter] (which has no [PeerState] type parameter) to the typed + * [PartialMessagesHandler] and [PartialGroupStateStore]. + * + * Created once in [io.libp2p.pubsub.gossip.builders.GossipRouterBuilder] with an + * unchecked cast that is safe because [PeerState] is captured and used consistently + * throughout the lifetime of this object. + */ +internal class PartialMessagesAdapterImpl( + val handler: PartialMessagesHandler, + val stateStore: PartialGroupStateStore, + val feedback: PartialMessagesPeerFeedback +) : PartialMessagesAdapter { + + override fun onPeerDisconnected(peer: PeerId) = stateStore.onPeerDisconnected(peer) + override fun onTopicUnsubscribed(topic: Topic) = stateStore.onTopicUnsubscribed(topic) + override fun onHeartbeat() = stateStore.onHeartbeat() + + override fun onIncomingRpc(topic: Topic, from: PeerId, rpc: Rpc.PartialMessagesExtension) { + val groupId = rpc.groupID.toByteArray().toGroupId() + val groupState = stateStore.getOrCreatePeerGroup(topic, groupId, from) ?: return + handler.onIncomingRpc(from, groupState.peerStates, rpc, feedback) + } + + @Suppress("UNCHECKED_CAST") + override fun publishPartial( + topic: Topic, + groupId: GroupId, + actionsFn: PublishActionsFn<*>, + peerRequestsPartial: (PeerId) -> Boolean, + enqueueFn: (PeerId, ByteArray?, ByteArray?) -> Unit + ) { + val typedFn = actionsFn as PublishActionsFn + val groupState = stateStore.getOrCreateLocalGroup(topic, groupId) + for ((peerId, action) in typedFn.decide(groupState.peerStates, peerRequestsPartial)) { + if (action.error != null) { + logger.debug("Skipping partial publish to {}: {}", peerId, action.error.message) + continue + } + // Spec MUST: omit partialMessage if peer supports but didn't request + val effectivePartialMessage = if (peerRequestsPartial(peerId)) action.partialMessage else null + if (effectivePartialMessage != null || action.partsMetadata != null) { + enqueueFn(peerId, effectivePartialMessage, action.partsMetadata) + } + if (action.nextPeerState != null) { + groupState.peerStates[peerId] = action.nextPeerState + } + } + } +} diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesHandler.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesHandler.kt new file mode 100644 index 00000000..e210c4d9 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesHandler.kt @@ -0,0 +1,51 @@ +package io.libp2p.pubsub.gossip.partialmessages + +import io.libp2p.core.PeerId +import io.libp2p.pubsub.Topic +import pubsub.pb.Rpc + +/** + * Client-supplied handler for the partial-messages extension. + * + * Both callbacks run on the pubsub event thread and MUST be fast and non-blocking. + * Dispatch heavy work (decoding, KZG validation) to a separate executor. + * + * @param PeerState opaque per-(topic, groupId, peerId) state that the library + * stores and passes back; the library never interprets it. + */ +interface PartialMessagesHandler { + + /** + * Called on every inbound [Rpc.PartialMessagesExtension] RPC. + * + * Any of [rpc].partialMessage and [rpc].partsMetadata may be absent; all + * four combinations are valid wire messages. + * + * [peerStates] reflects the current state for this (topic, groupId) pair across + * all peers. The map is a live view — do not retain a reference outside this call. + */ + fun onIncomingRpc( + from: PeerId, + peerStates: Map, + rpc: Rpc.PartialMessagesExtension, + feedback: PartialMessagesPeerFeedback + ) + + /** + * Called once per locally-initiated group during the gossipsub heartbeat for + * gossip targets that are partial-capable on [topic]. + * + * The client typically responds by calling [io.libp2p.pubsub.gossip.Gossip.publishPartial] + * for the same (topic, groupId). + * + * [peerStates] reflects the current state for this group across all peers. + * The map is a live view — do not retain a reference outside this call. + */ + fun onEmitGossip( + topic: Topic, + groupId: ByteArray, + gossipPeers: Collection, + peerStates: Map, + feedback: PartialMessagesPeerFeedback + ) +} diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesPeerFeedback.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesPeerFeedback.kt new file mode 100644 index 00000000..5c3e4caf --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesPeerFeedback.kt @@ -0,0 +1,14 @@ +package io.libp2p.pubsub.gossip.partialmessages + +import io.libp2p.core.PeerId +import io.libp2p.pubsub.Topic + +enum class FeedbackKind { USEFUL, INVALID, IGNORED } + +interface PartialMessagesPeerFeedback { + fun reportFeedback(topic: Topic, peer: PeerId, kind: FeedbackKind) +} + +internal object NopPartialMessagesFeedback : PartialMessagesPeerFeedback { + override fun reportFeedback(topic: Topic, peer: PeerId, kind: FeedbackKind) {} +} diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PublishActions.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PublishActions.kt new file mode 100644 index 00000000..c0b54c1a --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/partialmessages/PublishActions.kt @@ -0,0 +1,32 @@ +package io.libp2p.pubsub.gossip.partialmessages + +import io.libp2p.core.PeerId + +/** + * Encodes what the library should send to one peer for a single + * [io.libp2p.pubsub.gossip.Gossip.publishPartial] call. + * + * [nextPeerState] is applied atomically by the library per peer after the + * send; null means "leave the existing state unchanged". + */ +data class PublishAction( + val partialMessage: ByteArray? = null, + val partsMetadata: ByteArray? = null, + val nextPeerState: PeerState? = null, + val error: Throwable? = null +) + +/** + * Decision function supplied by the client to [io.libp2p.pubsub.gossip.Gossip.publishPartial]. + * + * [decide] is called on the pubsub event thread with the current peer state map + * and a predicate for checking whether a peer requested partial for the topic. + * It must return a sequence of (peerId, action) pairs — one per peer that + * should receive an outbound [pubsub.pb.Rpc.PartialMessagesExtension] RPC. + */ +fun interface PublishActionsFn { + fun decide( + peerStates: Map, + peerRequestsPartial: (PeerId) -> Boolean + ): Sequence>> +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterBuilderTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterBuilderTest.kt index 224b5a7a..f0f6135d 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterBuilderTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRouterBuilderTest.kt @@ -1,11 +1,21 @@ package io.libp2p.pubsub.gossip +import io.libp2p.core.PeerId +import io.libp2p.pubsub.Topic import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder +import io.libp2p.pubsub.gossip.partialmessages.PartialMessagesHandler +import io.libp2p.pubsub.gossip.partialmessages.PartialMessagesPeerFeedback import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test +import pubsub.pb.Rpc class GossipRouterBuilderTest { + private val nopHandler: PartialMessagesHandler = object : PartialMessagesHandler { + override fun onIncomingRpc(from: PeerId, peerStates: Map, rpc: Rpc.PartialMessagesExtension, feedback: PartialMessagesPeerFeedback) {} + override fun onEmitGossip(topic: Topic, groupId: ByteArray, gossipPeers: Collection, peerStates: Map, feedback: PartialMessagesPeerFeedback) {} + } + @Test fun `builds GossipRouter with both extensions disabled by default`() { val router = GossipRouterBuilder().build() @@ -36,6 +46,7 @@ class GossipRouterBuilderTest { GossipExtension.TEST_EXTENSION, GossipExtension.PARTIAL_MESSAGES, ) + .apply { partialMessagesHandler = nopHandler } .build() val localSupport = router.gossipExtensionsState.localExtensionSupport diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipTestsBase.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipTestsBase.kt index 1917310e..b14cb227 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipTestsBase.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipTestsBase.kt @@ -8,6 +8,8 @@ import io.libp2p.pubsub.* import io.libp2p.pubsub.DeterministicFuzz.Companion.createGossipFuzzRouterFactory import io.libp2p.pubsub.DeterministicFuzz.Companion.createMockFuzzRouterFactory import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder +import io.libp2p.pubsub.gossip.partialmessages.PartialMessagesHandler +import io.libp2p.pubsub.gossip.partialmessages.PartialMessagesPeerFeedback import io.netty.handler.logging.LogLevel import pubsub.pb.Rpc @@ -15,6 +17,28 @@ abstract class GossipTestsBase { protected val GossipScore.testPeerScores get() = (this as DefaultGossipScore).peerScores + /** + * No-op [PartialMessagesHandler] for use in tests that enable the partial-messages + * extension but don't exercise handler behaviour. + */ + protected val nopPartialMessagesHandler: PartialMessagesHandler = + object : PartialMessagesHandler { + override fun onIncomingRpc( + from: PeerId, + peerStates: Map, + rpc: pubsub.pb.Rpc.PartialMessagesExtension, + feedback: PartialMessagesPeerFeedback + ) {} + + override fun onEmitGossip( + topic: Topic, + groupId: ByteArray, + gossipPeers: Collection, + peerStates: Map, + feedback: PartialMessagesPeerFeedback + ) {} + } + protected fun newProtoMessage(topic: Topic, seqNo: Long, data: ByteArray) = Rpc.Message.newBuilder() .addTopicIDs(topic) @@ -63,8 +87,8 @@ abstract class GossipTestsBase { val scoreParams: GossipScoreParams = GossipScoreParams(), val mockRouterFactory: DeterministicFuzzRouterFactory = createMockFuzzRouterFactory(), val protocol: PubsubProtocol = PubsubProtocol.Gossip_V_1_1, - val enabledGossipExtensions: List = listOf(GossipExtension.TEST_EXTENSION) - + val enabledGossipExtensions: List = listOf(GossipExtension.TEST_EXTENSION), + val partialMessagesHandler: PartialMessagesHandler<*>? = null, ) { val fuzz = DeterministicFuzz() val gossipRouterBuilderFactory = { @@ -72,7 +96,8 @@ abstract class GossipTestsBase { protocol = protocol, params = coreParams, scoreParams = scoreParams, - enabledGossipExtensions = enabledGossipExtensions + enabledGossipExtensions = enabledGossipExtensions, + partialMessagesHandler = partialMessagesHandler, ) } val router1 = fuzz.createTestRouter(createGossipFuzzRouterFactory(gossipRouterBuilderFactory)) diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/PartialSubscriptionStateTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/PartialSubscriptionStateTest.kt new file mode 100644 index 00000000..60a05c2a --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/PartialSubscriptionStateTest.kt @@ -0,0 +1,153 @@ +package io.libp2p.pubsub.gossip + +import io.libp2p.core.PeerId +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test + +class PartialSubscriptionStateTest { + + private lateinit var state: PartialSubscriptionState + private lateinit var peer1: PeerId + private lateinit var peer2: PeerId + private lateinit var peer3: PeerId + + private val topicA = "topic-a" + private val topicB = "topic-b" + + @BeforeEach + fun setup() { + state = PartialSubscriptionState() + peer1 = PeerId.random() + peer2 = PeerId.random() + peer3 = PeerId.random() + } + + @Test + fun `unknown peer returns NONE`() { + assertThat(state.peerFlags(topicA, peer1)).isEqualTo(PartialSubFlags.NONE) + assertThat(state.peerRequestsPartial(topicA, peer1)).isFalse() + assertThat(state.peerSupportsSendingPartial(topicA, peer1)).isFalse() + } + + @Test + fun `setPeerFlags stores and peerFlags reads back`() { + val flags = PartialSubFlags(requestsPartial = true, supportsSendingPartial = true) + state.setPeerFlags(topicA, peer1, flags) + + assertThat(state.peerFlags(topicA, peer1)).isEqualTo(flags) + assertThat(state.peerRequestsPartial(topicA, peer1)).isTrue() + assertThat(state.peerSupportsSendingPartial(topicA, peer1)).isTrue() + } + + @Test + fun `setPeerFlags with NONE removes entry`() { + state.setPeerFlags(topicA, peer1, PartialSubFlags(requestsPartial = true, supportsSendingPartial = false)) + state.setPeerFlags(topicA, peer1, PartialSubFlags.NONE) + + assertThat(state.peerFlags(topicA, peer1)).isEqualTo(PartialSubFlags.NONE) + assertThat(state.snapshot()).doesNotContainKey(topicA) + } + + @Test + fun `setPeerFlags overwrites previous flags for same peer and topic`() { + state.setPeerFlags(topicA, peer1, PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + state.setPeerFlags(topicA, peer1, PartialSubFlags(requestsPartial = false, supportsSendingPartial = true)) + + assertThat(state.peerFlags(topicA, peer1)).isEqualTo( + PartialSubFlags(requestsPartial = false, supportsSendingPartial = true) + ) + } + + @Test + fun `removePeerFlags drops the peer's entry and GCs empty topic`() { + val flags = PartialSubFlags(requestsPartial = true, supportsSendingPartial = true) + state.setPeerFlags(topicA, peer1, flags) + + state.removePeerFlags(topicA, peer1) + + assertThat(state.peerFlags(topicA, peer1)).isEqualTo(PartialSubFlags.NONE) + assertThat(state.snapshot()).doesNotContainKey(topicA) + } + + @Test + fun `removePeerFlags on unknown peer or topic is a no-op`() { + state.removePeerFlags(topicA, peer1) // nothing stored yet + assertThat(state.snapshot()).isEmpty() + + state.setPeerFlags(topicA, peer1, PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + state.removePeerFlags(topicB, peer1) // topic mismatch + state.removePeerFlags(topicA, peer2) // peer mismatch + + assertThat(state.peerFlags(topicA, peer1)).isEqualTo( + PartialSubFlags(requestsPartial = true, supportsSendingPartial = true) + ) + } + + @Test + fun `removeTopic drops all peers for that topic, leaves other topics intact`() { + state.setPeerFlags(topicA, peer1, PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + state.setPeerFlags(topicA, peer2, PartialSubFlags(requestsPartial = false, supportsSendingPartial = true)) + state.setPeerFlags(topicB, peer1, PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + + state.removeTopic(topicA) + + assertThat(state.peerFlags(topicA, peer1)).isEqualTo(PartialSubFlags.NONE) + assertThat(state.peerFlags(topicA, peer2)).isEqualTo(PartialSubFlags.NONE) + assertThat(state.peerFlags(topicB, peer1)).isEqualTo( + PartialSubFlags(requestsPartial = true, supportsSendingPartial = true) + ) + } + + @Test + fun `onPeerDisconnected clears the peer across all topics, leaves other peers intact`() { + state.setPeerFlags(topicA, peer1, PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + state.setPeerFlags(topicA, peer2, PartialSubFlags(requestsPartial = false, supportsSendingPartial = true)) + state.setPeerFlags(topicB, peer1, PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + state.setPeerFlags(topicB, peer3, PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + + state.onPeerDisconnected(peer1) + + assertThat(state.peerFlags(topicA, peer1)).isEqualTo(PartialSubFlags.NONE) + assertThat(state.peerFlags(topicB, peer1)).isEqualTo(PartialSubFlags.NONE) + assertThat(state.peerFlags(topicA, peer2)).isEqualTo( + PartialSubFlags(requestsPartial = false, supportsSendingPartial = true) + ) + assertThat(state.peerFlags(topicB, peer3)).isEqualTo( + PartialSubFlags(requestsPartial = true, supportsSendingPartial = true) + ) + } + + @Test + fun `onPeerDisconnected GCs topics that become empty`() { + state.setPeerFlags(topicA, peer1, PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + state.setPeerFlags(topicB, peer2, PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + + state.onPeerDisconnected(peer1) + + assertThat(state.snapshot()).doesNotContainKey(topicA) + assertThat(state.snapshot()).containsKey(topicB) + } + + @Test + fun `onPeerDisconnected on unknown peer is a no-op`() { + state.setPeerFlags(topicA, peer1, PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + + state.onPeerDisconnected(peer2) + + assertThat(state.peerFlags(topicA, peer1)).isEqualTo( + PartialSubFlags(requestsPartial = true, supportsSendingPartial = true) + ) + } + + @Test + fun `peer independence on same topic`() { + val flags1 = PartialSubFlags(requestsPartial = true, supportsSendingPartial = true) + val flags2 = PartialSubFlags(requestsPartial = false, supportsSendingPartial = true) + state.setPeerFlags(topicA, peer1, flags1) + state.setPeerFlags(topicA, peer2, flags2) + + assertThat(state.peerFlags(topicA, peer1)).isEqualTo(flags1) + assertThat(state.peerFlags(topicA, peer2)).isEqualTo(flags2) + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/GossipExtensionsMessageHandlingTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/GossipExtensionsMessageHandlingTest.kt index c39f4d51..0a78acd8 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/GossipExtensionsMessageHandlingTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/GossipExtensionsMessageHandlingTest.kt @@ -138,7 +138,8 @@ class GossipExtensionsMessageHandlingTest : GossipTestsBase() { enabledGossipExtensions = listOf( GossipExtension.TEST_EXTENSION, GossipExtension.PARTIAL_MESSAGES - ) + ), + partialMessagesHandler = nopPartialMessagesHandler, ) val receivedMessage = test.mockRouter.waitForMessage( @@ -198,6 +199,7 @@ class GossipExtensionsMessageHandlingTest : GossipTestsBase() { val test = TwoRoutersTest( protocol = PubsubProtocol.Gossip_V_1_3, enabledGossipExtensions = listOf(GossipExtension.PARTIAL_MESSAGES), + partialMessagesHandler = nopPartialMessagesHandler, // Creating GossipScoreParams with behaviourPenaltyWeight (peer bad behavior affecting // score). Here we are not interested if the weight is "correct". What we want to see if // that a peer is penalized for sending more than one ControlExtensions message. diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/PartialMessagesInboundRpcTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/PartialMessagesInboundRpcTest.kt new file mode 100644 index 00000000..682f7282 --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/PartialMessagesInboundRpcTest.kt @@ -0,0 +1,177 @@ +package io.libp2p.pubsub.gossip.extensions + +import com.google.protobuf.ByteString +import io.libp2p.core.PeerId +import io.libp2p.pubsub.PubsubProtocol +import io.libp2p.pubsub.Topic +import io.libp2p.pubsub.gossip.GossipExtension +import io.libp2p.pubsub.gossip.GossipTestsBase +import io.libp2p.pubsub.gossip.partialmessages.PartialMessagesHandler +import io.libp2p.pubsub.gossip.partialmessages.PartialMessagesPeerFeedback +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import pubsub.pb.Rpc +import java.util.concurrent.CopyOnWriteArrayList + +private const val TIMEOUT_MS = 500L + +class PartialMessagesInboundRpcTest : GossipTestsBase() { + + private val topicId = "test-topic" + private val groupIdBytes = "group-1".toByteArray() + + /** Records each [onIncomingRpc] call for assertion in tests. */ + data class IncomingCall(val from: PeerId, val rpc: Rpc.PartialMessagesExtension) + + private val incomingCalls = CopyOnWriteArrayList() + + private val capturingHandler: PartialMessagesHandler = + object : PartialMessagesHandler { + override fun onIncomingRpc( + from: PeerId, + peerStates: Map, + rpc: Rpc.PartialMessagesExtension, + feedback: PartialMessagesPeerFeedback + ) { + incomingCalls += IncomingCall(from, rpc) + } + + override fun onEmitGossip( + topic: Topic, + groupId: ByteArray, + gossipPeers: Collection, + peerStates: Map, + feedback: PartialMessagesPeerFeedback + ) {} + } + + private fun newTest() = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3, + enabledGossipExtensions = listOf(GossipExtension.PARTIAL_MESSAGES), + partialMessagesHandler = capturingHandler, + ) + + private fun partialRpcWith( + topicId: String? = this.topicId, + groupId: ByteArray? = groupIdBytes + ): Rpc.RPC { + val ext = Rpc.PartialMessagesExtension.newBuilder().apply { + if (topicId != null) setTopicID(topicId) + if (groupId != null) setGroupID(ByteString.copyFrom(groupId)) + }.build() + return Rpc.RPC.newBuilder().setPartial(ext).build() + } + + private fun controlExtensionsWithPartial(): Rpc.RPC = + Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().setExtensions( + Rpc.ControlExtensions.newBuilder().setPartialMessages(true) + ) + ).build() + + // Drains any currently queued messages from the mock router's outbox + // so later assertions start from a clean slate. + private fun TwoRoutersTest.flushRouter() = + gossipRouter.submitOnEventThread {}.join() + + @Test + fun `valid partial RPC after ControlExtensions dispatches to handler`() { + val test = newTest() + test.flushRouter() + + test.mockRouter.sendToSingle(controlExtensionsWithPartial()) + test.mockRouter.sendToSingle(partialRpcWith()) + test.flushRouter() + + assertThat(incomingCalls).hasSize(1) + assertThat(incomingCalls[0].rpc.topicID).isEqualTo(topicId) + assertThat(incomingCalls[0].rpc.groupID.toByteArray()).isEqualTo(groupIdBytes) + } + + @Test + fun `partial RPC without prior ControlExtensions is ignored`() { + val test = newTest() + test.flushRouter() + + test.mockRouter.sendToSingle(partialRpcWith()) + test.flushRouter() + + assertThat(incomingCalls).isEmpty() + } + + @Test + fun `partial RPC with missing topicID is dropped`() { + val test = newTest() + test.flushRouter() + + test.mockRouter.sendToSingle(controlExtensionsWithPartial()) + test.mockRouter.sendToSingle(partialRpcWith(topicId = null)) + test.flushRouter() + + assertThat(incomingCalls).isEmpty() + } + + @Test + fun `partial RPC with empty topicID is dropped`() { + val test = newTest() + test.flushRouter() + + test.mockRouter.sendToSingle(controlExtensionsWithPartial()) + test.mockRouter.sendToSingle(partialRpcWith(topicId = "")) + test.flushRouter() + + assertThat(incomingCalls).isEmpty() + } + + @Test + fun `partial RPC with missing groupID is dropped`() { + val test = newTest() + test.flushRouter() + + test.mockRouter.sendToSingle(controlExtensionsWithPartial()) + test.mockRouter.sendToSingle(partialRpcWith(groupId = null)) + test.flushRouter() + + assertThat(incomingCalls).isEmpty() + } + + @Test + fun `partial RPC with empty groupID is dropped`() { + val test = newTest() + test.flushRouter() + + test.mockRouter.sendToSingle(controlExtensionsWithPartial()) + test.mockRouter.sendToSingle(partialRpcWith(groupId = ByteArray(0))) + test.flushRouter() + + assertThat(incomingCalls).isEmpty() + } + + @Test + fun `partial RPC when extension is disabled is ignored`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3, + enabledGossipExtensions = listOf(), + ) + test.flushRouter() + + test.mockRouter.sendToSingle(controlExtensionsWithPartial()) + test.mockRouter.sendToSingle(partialRpcWith()) + test.flushRouter() + + assertThat(incomingCalls).isEmpty() + } + + @Test + fun `multiple valid partial RPCs for different groups all dispatched`() { + val test = newTest() + test.flushRouter() + + test.mockRouter.sendToSingle(controlExtensionsWithPartial()) + test.mockRouter.sendToSingle(partialRpcWith(groupId = "g1".toByteArray())) + test.mockRouter.sendToSingle(partialRpcWith(groupId = "g2".toByteArray())) + test.flushRouter() + + assertThat(incomingCalls).hasSize(2) + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/PartialMessagesOutboundRpcTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/PartialMessagesOutboundRpcTest.kt new file mode 100644 index 00000000..9fd02ace --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/PartialMessagesOutboundRpcTest.kt @@ -0,0 +1,172 @@ +package io.libp2p.pubsub.gossip.extensions + +import com.google.protobuf.ByteString +import io.libp2p.core.PeerId +import io.libp2p.pubsub.PubsubProtocol +import io.libp2p.pubsub.gossip.GossipExtension +import io.libp2p.pubsub.gossip.GossipTestsBase +import io.libp2p.pubsub.gossip.partialmessages.PublishAction +import io.libp2p.pubsub.gossip.partialmessages.PublishActionsFn +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import pubsub.pb.Rpc + +private const val TIMEOUT_MS = 500L + +class PartialMessagesOutboundRpcTest : GossipTestsBase() { + + private val topicId = "test-topic" + private val groupIdBytes = "group-1".toByteArray() + + private fun newTest() = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3, + enabledGossipExtensions = listOf(GossipExtension.PARTIAL_MESSAGES), + partialMessagesHandler = nopPartialMessagesHandler, + ) + + private fun controlExtensionsWithPartial(): Rpc.RPC = + Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().setExtensions( + Rpc.ControlExtensions.newBuilder().setPartialMessages(true) + ) + ).build() + + private fun subscribeRpc( + topic: String, + requestsPartial: Boolean, + supportsSendingPartial: Boolean + ): Rpc.RPC = + Rpc.RPC.newBuilder().addSubscriptions( + Rpc.RPC.SubOpts.newBuilder() + .setTopicid(topic) + .setSubscribe(true) + .setRequestsPartial(requestsPartial) + .setSupportsSendingPartial(supportsSendingPartial) + ).build() + + private fun TwoRoutersTest.flushRouter() = + gossipRouter.submitOnEventThread {}.join() + + private fun TwoRoutersTest.peerIdOfMockRouter(): PeerId = router2.peerId + + @Test + fun `publishPartial delivers partial RPC to peer that requested partial`() { + val test = newTest() + + test.mockRouter.sendToSingle(controlExtensionsWithPartial()) + test.mockRouter.sendToSingle(subscribeRpc(topicId, requestsPartial = true, supportsSendingPartial = true)) + test.flushRouter() + + val payload = byteArrayOf(1, 2, 3) + val meta = byteArrayOf(0xAA.toByte()) + val peerId = test.peerIdOfMockRouter() + + val actionsFn = PublishActionsFn { _, _ -> + sequenceOf(peerId to PublishAction(partialMessage = payload, partsMetadata = meta)) + } + + test.gossipRouter.publishPartial(topicId, groupIdBytes, actionsFn) + + val received = test.mockRouter.waitForMessage({ it.hasPartial() }, TIMEOUT_MS) + assertThat(received.partial.topicID).isEqualTo(topicId) + assertThat(received.partial.groupID).isEqualTo(ByteString.copyFrom(groupIdBytes)) + assertThat(received.partial.partialMessage.toByteArray()).isEqualTo(payload) + assertThat(received.partial.partsMetadata.toByteArray()).isEqualTo(meta) + } + + @Test + fun `publishPartial omits partialMessage when peer supports but did not request`() { + val test = newTest() + + // Peer supports sending partial but did NOT request partial messages + test.mockRouter.sendToSingle(controlExtensionsWithPartial()) + test.mockRouter.sendToSingle(subscribeRpc(topicId, requestsPartial = false, supportsSendingPartial = true)) + test.flushRouter() + + val payload = byteArrayOf(1, 2, 3) + val meta = byteArrayOf(0xAA.toByte()) + val peerId = test.peerIdOfMockRouter() + + val actionsFn = PublishActionsFn { _, _ -> + sequenceOf(peerId to PublishAction(partialMessage = payload, partsMetadata = meta)) + } + + test.gossipRouter.publishPartial(topicId, groupIdBytes, actionsFn) + + val received = test.mockRouter.waitForMessage({ it.hasPartial() }, TIMEOUT_MS) + // partsMetadata is present; partialMessage MUST be absent (spec MUST) + assertThat(received.partial.hasPartialMessage()).isFalse() + assertThat(received.partial.partsMetadata.toByteArray()).isEqualTo(meta) + } + + @Test + fun `publishPartial sends nothing when actionsFn returns empty sequence`() { + val test = newTest() + + test.mockRouter.sendToSingle(controlExtensionsWithPartial()) + test.mockRouter.sendToSingle(subscribeRpc(topicId, requestsPartial = true, supportsSendingPartial = true)) + test.flushRouter() + + val actionsFn = PublishActionsFn { _, _ -> emptySequence() } + + test.gossipRouter.publishPartial(topicId, groupIdBytes, actionsFn) + test.flushRouter() + + assertThat(test.mockRouter.inboundMessages.none { it.hasPartial() }).isTrue() + } + + @Test + fun `publishPartial sends nothing when adapter is not configured`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3, + enabledGossipExtensions = listOf(), + ) + test.flushRouter() + + val peerId = test.peerIdOfMockRouter() + val actionsFn = PublishActionsFn { _, _ -> + sequenceOf(peerId to PublishAction(partsMetadata = byteArrayOf(1))) + } + + test.gossipRouter.publishPartial(topicId, groupIdBytes, actionsFn) + test.flushRouter() + + assertThat(test.mockRouter.inboundMessages.none { it.hasPartial() }).isTrue() + } + + @Test + fun `publishPartial two groups produce two separate RPCs`() { + val test = newTest() + + test.mockRouter.sendToSingle(controlExtensionsWithPartial()) + test.mockRouter.sendToSingle(subscribeRpc(topicId, requestsPartial = true, supportsSendingPartial = true)) + test.flushRouter() + + val peerId = test.peerIdOfMockRouter() + val groupA = "group-a".toByteArray() + val groupB = "group-b".toByteArray() + + test.gossipRouter.publishPartial( + topicId, + groupA, + PublishActionsFn { _, _ -> sequenceOf(peerId to PublishAction(partsMetadata = byteArrayOf(1))) } + ) + test.gossipRouter.publishPartial( + topicId, + groupB, + PublishActionsFn { _, _ -> sequenceOf(peerId to PublishAction(partsMetadata = byteArrayOf(2))) } + ) + + val rpc1 = test.mockRouter.waitForMessage({ it.hasPartial() }, TIMEOUT_MS) + val rpc2 = test.mockRouter.waitForMessage({ it.hasPartial() }, TIMEOUT_MS) + + val groupIds = setOf( + rpc1.partial.groupID.toByteArray().toList(), + rpc2.partial.groupID.toByteArray().toList() + ) + assertThat(groupIds).containsExactlyInAnyOrder( + groupA.toList(), + groupB.toList() + ) + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/PartialSubscriptionWireTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/PartialSubscriptionWireTest.kt new file mode 100644 index 00000000..9c3c9ac5 --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/PartialSubscriptionWireTest.kt @@ -0,0 +1,205 @@ +package io.libp2p.pubsub.gossip.extensions + +import io.libp2p.core.PeerId +import io.libp2p.pubsub.PubsubProtocol +import io.libp2p.pubsub.Topic +import io.libp2p.pubsub.gossip.GossipExtension +import io.libp2p.pubsub.gossip.GossipTestsBase +import io.libp2p.pubsub.gossip.PartialSubFlags +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import pubsub.pb.Rpc + +class PartialSubscriptionWireTest : GossipTestsBase() { + + private val topicA = "topic-a" + private val topicB = "topic-b" + + private fun newTest() = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3, + enabledGossipExtensions = listOf(GossipExtension.PARTIAL_MESSAGES), + partialMessagesHandler = nopPartialMessagesHandler, + ) + + private fun Rpc.RPC.firstSubscribeFor(topic: String): Rpc.RPC.SubOpts? = + subscriptionsList.firstOrNull { it.topicid == topic && it.subscribe } + + private fun Rpc.RPC.firstUnsubscribeFor(topic: String): Rpc.RPC.SubOpts? = + subscriptionsList.firstOrNull { it.topicid == topic && !it.subscribe } + + /** + * Reads `partialSubscriptionState.peerFlags` on the pubsub event loop so the + * test thread establishes a happens-before with any pending event-loop + * mutations. The state container is documented as not thread-safe; direct + * access from the test thread risks `ConcurrentModificationException` and + * stale reads. + */ + private fun TwoRoutersTest.peerFlagsOnEventLoop(topic: Topic, peer: PeerId): PartialSubFlags = + gossipRouter.submitOnEventThread { + gossipRouter.partialSubscriptionState.peerFlags(topic, peer) + }.join() + + private fun TwoRoutersTest.snapshotPartialStateOnEventLoop(): Map> = + gossipRouter.submitOnEventThread { + gossipRouter.partialSubscriptionState.snapshot() + }.join() + + @Test + fun `outbound subscribe carries configured partial flags with send-side coercion`() { + val test = newTest() + + test.gossipRouter.setTopicPartialFlags(topicA, requestsPartial = true, supportsSendingPartial = false) + test.gossipRouter.subscribe(topicA) + + val received = test.mockRouter.waitForMessage({ it.firstSubscribeFor(topicA) != null }) + val sub = received.firstSubscribeFor(topicA)!! + assertThat(sub.requestsPartial).isTrue() + // spec coercion: supportsSendingPartial := requestsPartial || supportsSendingPartial + assertThat(sub.supportsSendingPartial).isTrue() + } + + @Test + fun `outbound subscribe with only supportsSendingPartial carries only that flag`() { + val test = newTest() + + test.gossipRouter.setTopicPartialFlags(topicA, requestsPartial = false, supportsSendingPartial = true) + test.gossipRouter.subscribe(topicA) + + val received = test.mockRouter.waitForMessage({ it.firstSubscribeFor(topicA) != null }) + val sub = received.firstSubscribeFor(topicA)!! + assertThat(sub.requestsPartial).isFalse() + assertThat(sub.supportsSendingPartial).isTrue() + } + + @Test + fun `outbound subscribe without configured flags has both flags absent`() { + val test = newTest() + + test.gossipRouter.subscribe(topicA) + + val received = test.mockRouter.waitForMessage({ it.firstSubscribeFor(topicA) != null }) + val sub = received.firstSubscribeFor(topicA)!! + assertThat(sub.hasRequestsPartial()).isFalse() + assertThat(sub.hasSupportsSendingPartial()).isFalse() + } + + @Test + fun `outbound unsubscribe never carries partial flags`() { + val test = newTest() + + test.gossipRouter.setTopicPartialFlags(topicA, requestsPartial = true, supportsSendingPartial = true) + test.gossipRouter.subscribe(topicA) + test.mockRouter.waitForMessage({ it.firstSubscribeFor(topicA) != null }) + + test.gossipRouter.unsubscribe(topicA) + + val received = test.mockRouter.waitForMessage({ it.firstUnsubscribeFor(topicA) != null }) + val unsub = received.firstUnsubscribeFor(topicA)!! + assertThat(unsub.hasRequestsPartial()).isFalse() + assertThat(unsub.hasSupportsSendingPartial()).isFalse() + } + + @Test + fun `inbound subscribe with requestsPartial only stores coerced flags`() { + val test = newTest() + + val rpc = subscribeRpc(topicA, requestsPartial = true, supportsSendingPartial = false) + test.mockRouter.sendToSingle(rpc) + + val peerId = test.router2.peerId + // Receive-side coercion: supportsSendingPartial := requestsPartial || supportsSendingPartial + assertThat(test.peerFlagsOnEventLoop(topicA, peerId)) + .isEqualTo(PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + } + + @Test + fun `inbound subscribe with supportsSendingPartial only stores that flag verbatim`() { + val test = newTest() + + val rpc = subscribeRpc(topicA, requestsPartial = false, supportsSendingPartial = true) + test.mockRouter.sendToSingle(rpc) + + assertThat(test.peerFlagsOnEventLoop(topicA, test.router2.peerId)) + .isEqualTo(PartialSubFlags(requestsPartial = false, supportsSendingPartial = true)) + } + + @Test + fun `inbound subscribe with both flags false leaves state empty`() { + val test = newTest() + + val rpc = subscribeRpc(topicA, requestsPartial = false, supportsSendingPartial = false) + test.mockRouter.sendToSingle(rpc) + + assertThat(test.peerFlagsOnEventLoop(topicA, test.router2.peerId)) + .isEqualTo(PartialSubFlags.NONE) + assertThat(test.snapshotPartialStateOnEventLoop()).doesNotContainKey(topicA) + } + + @Test + fun `inbound unsubscribe ignores flags and clears any prior peer state`() { + val test = newTest() + val peerId = test.router2.peerId + + test.mockRouter.sendToSingle(subscribeRpc(topicA, requestsPartial = true, supportsSendingPartial = true)) + assertThat(test.peerFlagsOnEventLoop(topicA, peerId)) + .isEqualTo(PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + + // Unsubscribe with malicious flags set: flags MUST be ignored, state MUST be cleared. + val unsub = Rpc.RPC.newBuilder().addSubscriptions( + Rpc.RPC.SubOpts.newBuilder() + .setTopicid(topicA) + .setSubscribe(false) + .setRequestsPartial(true) + .setSupportsSendingPartial(true) + ).build() + test.mockRouter.sendToSingle(unsub) + + assertThat(test.peerFlagsOnEventLoop(topicA, peerId)) + .isEqualTo(PartialSubFlags.NONE) + } + + @Test + fun `peer disconnect clears stored partial subscription state`() { + val test = newTest() + val peerId = test.router2.peerId + + test.mockRouter.sendToSingle(subscribeRpc(topicA, requestsPartial = true, supportsSendingPartial = true)) + assertThat(test.peerFlagsOnEventLoop(topicA, peerId)) + .isEqualTo(PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + + test.connection.disconnect() + + assertThat(test.peerFlagsOnEventLoop(topicA, peerId)) + .isEqualTo(PartialSubFlags.NONE) + } + + @Test + fun `local unsubscribe clears stored partial subscription state for that topic`() { + val test = newTest() + val peerId = test.router2.peerId + + test.gossipRouter.subscribe(topicA) + test.mockRouter.sendToSingle(subscribeRpc(topicA, requestsPartial = true, supportsSendingPartial = true)) + test.mockRouter.sendToSingle(subscribeRpc(topicB, requestsPartial = true, supportsSendingPartial = true)) + + assertThat(test.peerFlagsOnEventLoop(topicA, peerId)) + .isEqualTo(PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + + test.gossipRouter.unsubscribe(topicA) + + assertThat(test.peerFlagsOnEventLoop(topicA, peerId)) + .isEqualTo(PartialSubFlags.NONE) + // Other topic state preserved + assertThat(test.peerFlagsOnEventLoop(topicB, peerId)) + .isEqualTo(PartialSubFlags(requestsPartial = true, supportsSendingPartial = true)) + } + + private fun subscribeRpc(topic: String, requestsPartial: Boolean, supportsSendingPartial: Boolean): Rpc.RPC = + Rpc.RPC.newBuilder().addSubscriptions( + Rpc.RPC.SubOpts.newBuilder() + .setTopicid(topic) + .setSubscribe(true) + .setRequestsPartial(requestsPartial) + .setSupportsSendingPartial(supportsSendingPartial) + ).build() +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialGroupStateStoreTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialGroupStateStoreTest.kt new file mode 100644 index 00000000..7d1c0cdb --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialGroupStateStoreTest.kt @@ -0,0 +1,245 @@ +package io.libp2p.pubsub.gossip.partialmessages + +import io.libp2p.core.PeerId +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test + +class PartialGroupStateStoreTest { + + private lateinit var store: PartialGroupStateStore + private lateinit var peer1: PeerId + private lateinit var peer2: PeerId + + private val topicA = "topic-a" + private val topicB = "topic-b" + private val groupId1 = "group-1".toByteArray().toGroupId() + private val groupId2 = "group-2".toByteArray().toGroupId() + + @BeforeEach + fun setup() { + store = PartialGroupStateStore(groupTtlHeartbeats = 3) + peer1 = PeerId.random() + peer2 = PeerId.random() + } + + // --- GroupId equality --- + + @Test + fun `GroupId equality is content-based`() { + val a = "abc".toByteArray().toGroupId() + val b = "abc".toByteArray().toGroupId() + val c = "xyz".toByteArray().toGroupId() + assertThat(a).isEqualTo(b) + assertThat(a).isNotEqualTo(c) + assertThat(a.hashCode()).isEqualTo(b.hashCode()) + } + + @Test + fun `GroupId works as HashMap key`() { + val map = HashMap() + map["abc".toByteArray().toGroupId()] = 42 + assertThat(map["abc".toByteArray().toGroupId()]).isEqualTo(42) + } + + // --- local groups --- + + @Test + fun `getOrCreateLocalGroup creates a new group`() { + val group = store.getOrCreateLocalGroup(topicA, groupId1) + assertThat(group.peerInitiated).isFalse() + assertThat(group.initiatingPeer).isNull() + assertThat(group.ttlInHeartbeats).isEqualTo(3) + assertThat(store.getGroup(topicA, groupId1)).isSameAs(group) + } + + @Test + fun `getOrCreateLocalGroup resets TTL on existing group`() { + val group = store.getOrCreateLocalGroup(topicA, groupId1) + group.ttlInHeartbeats = 1 + store.getOrCreateLocalGroup(topicA, groupId1) + assertThat(group.ttlInHeartbeats).isEqualTo(3) + } + + @Test + fun `getOrCreateLocalGroup returns same object on repeated calls`() { + val g1 = store.getOrCreateLocalGroup(topicA, groupId1) + val g2 = store.getOrCreateLocalGroup(topicA, groupId1) + assertThat(g1).isSameAs(g2) + } + + // --- peer groups --- + + @Test + fun `getOrCreatePeerGroup creates a peer-initiated group`() { + val group = store.getOrCreatePeerGroup(topicA, groupId1, peer1) + assertThat(group).isNotNull() + assertThat(group!!.peerInitiated).isTrue() + assertThat(group.initiatingPeer).isEqualTo(peer1) + assertThat(group.ttlInHeartbeats).isEqualTo(3) + } + + @Test + fun `getOrCreatePeerGroup returns existing group`() { + val g1 = store.getOrCreatePeerGroup(topicA, groupId1, peer1) + val g2 = store.getOrCreatePeerGroup(topicA, groupId1, peer1) + assertThat(g1).isSameAs(g2) + } + + @Test + fun `per-topic cap rejects new peer-initiated groups`() { + val smallCapStore = PartialGroupStateStore( + peerInitiatedGroupLimitPerTopic = 2 + ) + val g1id = "g1".toByteArray().toGroupId() + val g2id = "g2".toByteArray().toGroupId() + val g3id = "g3".toByteArray().toGroupId() + + assertThat(smallCapStore.getOrCreatePeerGroup(topicA, g1id, peer1)).isNotNull() + assertThat(smallCapStore.getOrCreatePeerGroup(topicA, g2id, peer1)).isNotNull() + assertThat(smallCapStore.getOrCreatePeerGroup(topicA, g3id, peer1)).isNull() + } + + @Test + fun `per-topic cap does not count local-initiated groups`() { + val smallCapStore = PartialGroupStateStore( + peerInitiatedGroupLimitPerTopic = 1 + ) + smallCapStore.getOrCreateLocalGroup(topicA, "local1".toByteArray().toGroupId()) + smallCapStore.getOrCreateLocalGroup(topicA, "local2".toByteArray().toGroupId()) + + // Only 0 peer-initiated groups, so cap not reached + assertThat(smallCapStore.getOrCreatePeerGroup(topicA, groupId1, peer1)).isNotNull() + // Now cap reached (1 peer-initiated) + assertThat(smallCapStore.getOrCreatePeerGroup(topicA, groupId2, peer2)).isNull() + } + + @Test + fun `per-peer cap rejects new peer-initiated groups for that peer`() { + val smallCapStore = PartialGroupStateStore( + peerInitiatedGroupLimitPerTopicPerPeer = 2 + ) + val g1id = "g1".toByteArray().toGroupId() + val g2id = "g2".toByteArray().toGroupId() + val g3id = "g3".toByteArray().toGroupId() + + assertThat(smallCapStore.getOrCreatePeerGroup(topicA, g1id, peer1)).isNotNull() + assertThat(smallCapStore.getOrCreatePeerGroup(topicA, g2id, peer1)).isNotNull() + assertThat(smallCapStore.getOrCreatePeerGroup(topicA, g3id, peer1)).isNull() + + // peer2 should still be allowed (different peer) + assertThat(smallCapStore.getOrCreatePeerGroup(topicA, g3id, peer2)).isNotNull() + } + + @Test + fun `per-peer cap is per-topic — other topics are unaffected`() { + val smallCapStore = PartialGroupStateStore( + peerInitiatedGroupLimitPerTopicPerPeer = 1 + ) + assertThat(smallCapStore.getOrCreatePeerGroup(topicA, groupId1, peer1)).isNotNull() + assertThat(smallCapStore.getOrCreatePeerGroup(topicA, groupId2, peer1)).isNull() + assertThat(smallCapStore.getOrCreatePeerGroup(topicB, groupId1, peer1)).isNotNull() + } + + // --- TTL and heartbeat GC --- + + @Test + fun `onHeartbeat decrements TTL`() { + val group = store.getOrCreateLocalGroup(topicA, groupId1) + group.peerStates[peer1] = "state" // prevent GC by empty-peerStates rule + store.onHeartbeat() + assertThat(group.ttlInHeartbeats).isEqualTo(2) + } + + @Test + fun `onHeartbeat removes group when TTL reaches zero`() { + val group = store.getOrCreateLocalGroup(topicA, groupId1) + group.peerStates[peer1] = "state" // prevent GC by empty-peerStates rule + repeat(3) { store.onHeartbeat() } + assertThat(store.getGroup(topicA, groupId1)).isNull() + } + + @Test + fun `onHeartbeat removes group when peerStates is empty`() { + val group = store.getOrCreateLocalGroup(topicA, groupId1) + group.peerStates[peer1] = "state" + group.peerStates.remove(peer1) + store.onHeartbeat() + assertThat(store.getGroup(topicA, groupId1)).isNull() + } + + @Test + fun `onHeartbeat does not remove group with non-empty peerStates before TTL`() { + val group = store.getOrCreateLocalGroup(topicA, groupId1) + group.peerStates[peer1] = "state" + store.onHeartbeat() + assertThat(store.getGroup(topicA, groupId1)).isSameAs(group) + } + + @Test + fun `resetTtl refreshes TTL for a group`() { + val group = store.getOrCreateLocalGroup(topicA, groupId1) + group.peerStates[peer1] = "state" // prevent GC by empty-peerStates rule + repeat(2) { store.onHeartbeat() } + assertThat(group.ttlInHeartbeats).isEqualTo(1) + store.resetTtl(topicA, groupId1) + assertThat(group.ttlInHeartbeats).isEqualTo(3) + } + + // --- peer disconnect --- + + @Test + fun `onPeerDisconnected removes peer from all group peerStates`() { + val group1 = store.getOrCreateLocalGroup(topicA, groupId1) + val group2 = store.getOrCreateLocalGroup(topicB, groupId2) + group1.peerStates[peer1] = "state1" + group1.peerStates[peer2] = "state2" + group2.peerStates[peer1] = "state3" + + store.onPeerDisconnected(peer1) + + assertThat(group1.peerStates).containsOnlyKeys(peer2) + assertThat(group2.peerStates).isEmpty() + } + + @Test + fun `onPeerDisconnected GCs groups whose peerStates become empty`() { + val group = store.getOrCreateLocalGroup(topicA, groupId1) + group.peerStates[peer1] = "only-state" + + store.onPeerDisconnected(peer1) + + assertThat(store.getGroup(topicA, groupId1)).isNull() + assertThat(store.groupsForTopic(topicA)).isEmpty() + } + + // --- topic unsubscribe --- + + @Test + fun `onTopicUnsubscribed removes all groups for that topic`() { + store.getOrCreateLocalGroup(topicA, groupId1) + store.getOrCreateLocalGroup(topicA, groupId2) + store.getOrCreateLocalGroup(topicB, groupId1) + + store.onTopicUnsubscribed(topicA) + + assertThat(store.groupsForTopic(topicA)).isEmpty() + assertThat(store.groupsForTopic(topicB)).isNotEmpty() + } + + // --- groupsForTopic --- + + @Test + fun `groupsForTopic returns empty map for unknown topic`() { + assertThat(store.groupsForTopic("unknown-topic")).isEmpty() + } + + @Test + fun `groupsForTopic returns all groups for the topic`() { + store.getOrCreateLocalGroup(topicA, groupId1) + store.getOrCreatePeerGroup(topicA, groupId2, peer1) + + assertThat(store.groupsForTopic(topicA)).hasSize(2) + assertThat(store.groupsForTopic(topicB)).isEmpty() + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesAdapterImplTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesAdapterImplTest.kt new file mode 100644 index 00000000..61c3b95b --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/partialmessages/PartialMessagesAdapterImplTest.kt @@ -0,0 +1,277 @@ +package io.libp2p.pubsub.gossip.partialmessages + +import com.google.protobuf.ByteString +import io.libp2p.core.PeerId +import io.libp2p.pubsub.Topic +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import pubsub.pb.Rpc + +class PartialMessagesAdapterImplTest { + + private val topic = "test-topic" + private val groupIdBytes = "group-1".toByteArray() + private lateinit var peer1: PeerId + private lateinit var peer2: PeerId + private lateinit var capturedCalls: MutableList> + private lateinit var adapter: PartialMessagesAdapterImpl + + data class IncomingRpcCall( + val from: PeerId, + val peerStates: Map, + val rpc: Rpc.PartialMessagesExtension + ) + + private fun makeHandler() = object : PartialMessagesHandler { + override fun onIncomingRpc( + from: PeerId, + peerStates: Map, + rpc: Rpc.PartialMessagesExtension, + feedback: PartialMessagesPeerFeedback + ) { + capturedCalls += IncomingRpcCall(from, peerStates, rpc) + } + + override fun onEmitGossip( + topic: Topic, + groupId: ByteArray, + gossipPeers: Collection, + peerStates: Map, + feedback: PartialMessagesPeerFeedback + ) {} + } + + @BeforeEach + fun setup() { + peer1 = PeerId.random() + peer2 = PeerId.random() + capturedCalls = mutableListOf() + adapter = PartialMessagesAdapterImpl( + handler = makeHandler(), + stateStore = PartialGroupStateStore(groupTtlHeartbeats = 3), + feedback = NopPartialMessagesFeedback + ) + } + + private fun buildRpc( + topicId: String = topic, + groupId: ByteArray = groupIdBytes, + partialMessage: ByteArray? = null, + partsMetadata: ByteArray? = null + ): Rpc.PartialMessagesExtension = + Rpc.PartialMessagesExtension.newBuilder() + .setTopicID(topicId) + .setGroupID(ByteString.copyFrom(groupId)) + .apply { + if (partialMessage != null) setPartialMessage(ByteString.copyFrom(partialMessage)) + if (partsMetadata != null) setPartsMetadata(ByteString.copyFrom(partsMetadata)) + } + .build() + + @Test + fun `dispatches valid RPC to handler`() { + val rpc = buildRpc() + + adapter.onIncomingRpc(topic, peer1, rpc) + + assertThat(capturedCalls).hasSize(1) + assertThat(capturedCalls[0].from).isEqualTo(peer1) + assertThat(capturedCalls[0].rpc).isEqualTo(rpc) + } + + @Test + fun `peerStates map is empty on first RPC for a fresh group`() { + adapter.onIncomingRpc(topic, peer1, buildRpc()) + + assertThat(capturedCalls[0].peerStates).isEmpty() + } + + @Test + fun `second RPC for the same group reuses the same peerStates object`() { + adapter.onIncomingRpc(topic, peer1, buildRpc()) + adapter.onIncomingRpc(topic, peer2, buildRpc()) + + assertThat(capturedCalls).hasSize(2) + // Both calls receive the same live GroupState.peerStates reference + assertThat(capturedCalls[0].peerStates).isSameAs(capturedCalls[1].peerStates) + } + + @Test + fun `optional partialMessage and partsMetadata are forwarded to handler`() { + val rpc = buildRpc( + partialMessage = byteArrayOf(1, 2, 3), + partsMetadata = byteArrayOf(0xFF.toByte()) + ) + + adapter.onIncomingRpc(topic, peer1, rpc) + + assertThat(capturedCalls[0].rpc.partialMessage.toByteArray()).isEqualTo(byteArrayOf(1, 2, 3)) + assertThat(capturedCalls[0].rpc.partsMetadata.toByteArray()).isEqualTo(byteArrayOf(0xFF.toByte())) + } + + @Test + fun `handler not called when per-topic DoS cap is exceeded`() { + val store = PartialGroupStateStore( + groupTtlHeartbeats = 3, + peerInitiatedGroupLimitPerTopic = 1 + ) + val capped = PartialMessagesAdapterImpl( + handler = makeHandler(), + stateStore = store, + feedback = NopPartialMessagesFeedback + ) + + capped.onIncomingRpc(topic, peer1, buildRpc(groupId = "g1".toByteArray())) + capped.onIncomingRpc(topic, peer1, buildRpc(groupId = "g2".toByteArray())) + + assertThat(capturedCalls).hasSize(1) + } + + @Test + fun `handler not called when per-peer DoS cap is exceeded`() { + val store = PartialGroupStateStore( + groupTtlHeartbeats = 3, + peerInitiatedGroupLimitPerTopicPerPeer = 1 + ) + val capped = PartialMessagesAdapterImpl( + handler = makeHandler(), + stateStore = store, + feedback = NopPartialMessagesFeedback + ) + + capped.onIncomingRpc(topic, peer1, buildRpc(groupId = "g1".toByteArray())) + capped.onIncomingRpc(topic, peer1, buildRpc(groupId = "g2".toByteArray())) + + assertThat(capturedCalls).hasSize(1) + } + + @Test + fun `different topics create independent groups`() { + adapter.onIncomingRpc("topic-a", peer1, buildRpc(topicId = "topic-a")) + adapter.onIncomingRpc("topic-b", peer1, buildRpc(topicId = "topic-b")) + + assertThat(capturedCalls).hasSize(2) + assertThat(capturedCalls[0].peerStates).isNotSameAs(capturedCalls[1].peerStates) + } + + // ---- publishPartial ---- + + @Test + fun `publishPartial enqueues RPC for peer that requests partial`() { + val enqueued = mutableListOf>() + val payload = byteArrayOf(1, 2, 3) + val meta = byteArrayOf(0xAA.toByte()) + val actionsFn = PublishActionsFn { _, _ -> + sequenceOf(peer1 to PublishAction(partialMessage = payload, partsMetadata = meta)) + } + + adapter.publishPartial( + topic = topic, + groupId = groupIdBytes.toGroupId(), + actionsFn = actionsFn, + peerRequestsPartial = { true }, + enqueueFn = { p, pm, meta2 -> enqueued += Triple(p, pm, meta2) } + ) + + assertThat(enqueued).hasSize(1) + assertThat(enqueued[0].first).isEqualTo(peer1) + assertThat(enqueued[0].second).isEqualTo(payload) + assertThat(enqueued[0].third).isEqualTo(meta) + } + + @Test + fun `publishPartial omits partialMessage when peerRequestsPartial is false`() { + val enqueued = mutableListOf>() + val payload = byteArrayOf(1, 2, 3) + val meta = byteArrayOf(0xAA.toByte()) + val actionsFn = PublishActionsFn { _, _ -> + sequenceOf(peer1 to PublishAction(partialMessage = payload, partsMetadata = meta)) + } + + adapter.publishPartial( + topic = topic, + groupId = groupIdBytes.toGroupId(), + actionsFn = actionsFn, + peerRequestsPartial = { false }, + enqueueFn = { p, pm, meta2 -> enqueued += Triple(p, pm, meta2) } + ) + + assertThat(enqueued).hasSize(1) + assertThat(enqueued[0].second).isNull() + assertThat(enqueued[0].third).isEqualTo(meta) + } + + @Test + fun `publishPartial skips peer when action contains an error`() { + val enqueued = mutableListOf() + val actionsFn = PublishActionsFn { _, _ -> + sequenceOf(peer1 to PublishAction(error = RuntimeException("oops"))) + } + + adapter.publishPartial( + topic = topic, + groupId = groupIdBytes.toGroupId(), + actionsFn = actionsFn, + peerRequestsPartial = { true }, + enqueueFn = { p, _, _ -> enqueued += p } + ) + + assertThat(enqueued).isEmpty() + } + + @Test + fun `publishPartial does not call enqueueFn when both partialMessage and partsMetadata are null`() { + val enqueued = mutableListOf() + val actionsFn = PublishActionsFn { _, _ -> + sequenceOf(peer1 to PublishAction()) + } + + adapter.publishPartial( + topic = topic, + groupId = groupIdBytes.toGroupId(), + actionsFn = actionsFn, + peerRequestsPartial = { true }, + enqueueFn = { p, _, _ -> enqueued += p } + ) + + assertThat(enqueued).isEmpty() + } + + @Test + fun `publishPartial stores nextPeerState in group`() { + val actionsFn = PublishActionsFn { _, _ -> + sequenceOf(peer1 to PublishAction(partsMetadata = byteArrayOf(1), nextPeerState = "state-for-peer1")) + } + + adapter.publishPartial( + topic = topic, + groupId = groupIdBytes.toGroupId(), + actionsFn = actionsFn, + peerRequestsPartial = { true }, + enqueueFn = { _, _, _ -> } + ) + + val group = adapter.stateStore.getGroup(topic, groupIdBytes.toGroupId()) + assertThat(group?.peerStates?.get(peer1)).isEqualTo("state-for-peer1") + } + + @Test + fun `publishPartial provides peerRequestsPartial predicate to decide`() { + val predicateCapture = mutableListOf() + val actionsFn = PublishActionsFn { _, peerRequestsPartial -> + predicateCapture += peerRequestsPartial(peer1) + emptySequence() + } + + adapter.publishPartial( + topic = topic, + groupId = groupIdBytes.toGroupId(), + actionsFn = actionsFn, + peerRequestsPartial = { it == peer1 }, + enqueueFn = { _, _, _ -> } + ) + + assertThat(predicateCapture).containsExactly(true) + } +}