diff --git a/library/events/json_missing_dependencies_event.nim b/library/events/json_missing_dependencies_event.nim index ef7ae57..2b4a094 100644 --- a/library/events/json_missing_dependencies_event.nim +++ b/library/events/json_missing_dependencies_event.nim @@ -1,15 +1,15 @@ import std/json -import ./json_base_event, ../../src/[message] +import ./json_base_event, ../../src/[message], std/base64 type JsonMissingDependenciesEvent* = ref object of JsonEvent messageId*: SdsMessageID - missingDeps: seq[SdsMessageID] + missingDeps*: seq[HistoryEntry] channelId*: SdsChannelID proc new*( T: type JsonMissingDependenciesEvent, messageId: SdsMessageID, - missingDeps: seq[SdsMessageID], + missingDeps: seq[HistoryEntry], channelId: SdsChannelID, ): T = return JsonMissingDependenciesEvent( @@ -17,4 +17,15 @@ proc new*( ) method `$`*(jsonMissingDependencies: JsonMissingDependenciesEvent): string = - $(%*jsonMissingDependencies) + var node = newJObject() + node["eventType"] = %*jsonMissingDependencies.eventType + node["messageId"] = %*jsonMissingDependencies.messageId + node["channelId"] = %*jsonMissingDependencies.channelId + var missingDepsNode = newJArray() + for dep in jsonMissingDependencies.missingDeps: + var depNode = newJObject() + depNode["messageId"] = %*dep.messageId + depNode["retrievalHint"] = %*encode(dep.retrievalHint) + missingDepsNode.add(depNode) + node["missingDeps"] = missingDepsNode + $node diff --git a/library/ffi_types.nim b/library/ffi_types.nim index e2445a2..dbb531d 100644 --- a/library/ffi_types.nim +++ b/library/ffi_types.nim @@ -5,6 +5,10 @@ type SdsCallBack* = proc( callerRet: cint, msg: ptr cchar, len: csize_t, userData: pointer ) {.cdecl, gcsafe, raises: [].} +type SdsRetrievalHintProvider* = proc( + messageId: cstring, hint: ptr cstring, hintLen: ptr csize_t, userData: pointer +) {.cdecl, gcsafe, raises: [].} + const RET_OK*: cint = 0 const RET_ERR*: cint = 1 const RET_MISSING_CALLBACK*: cint = 2 diff --git a/library/libsds.h b/library/libsds.h index 886d3cb..0d9840e 100644 --- a/library/libsds.h +++ b/library/libsds.h @@ -20,6 +20,8 @@ extern "C" { typedef void (*SdsCallBack) (int callerRet, const char* msg, size_t len, void* userData); +typedef void (*SdsRetrievalHintProvider) (const char* messageId, char** hint, size_t* hintLen, void* userData); + // --- Core API Functions --- @@ -28,6 +30,8 @@ void* SdsNewReliabilityManager(SdsCallBack callback, void* userData); void SdsSetEventCallback(void* ctx, SdsCallBack callback, void* userData); +void SdsSetRetrievalHintProvider(void* ctx, SdsRetrievalHintProvider callback, void* userData); + int SdsCleanupReliabilityManager(void* ctx, SdsCallBack callback, void* userData); int SdsResetReliabilityManager(void* ctx, SdsCallBack callback, void* userData); diff --git a/library/libsds.nim b/library/libsds.nim index 53d2a4e..cd2c49b 100644 --- a/library/libsds.nim +++ b/library/libsds.nim @@ -82,7 +82,7 @@ proc onMessageSent(ctx: ptr SdsContext): MessageSentCallback = $JsonMessageSentEvent.new(messageId, channelId) proc onMissingDependencies(ctx: ptr SdsContext): MissingDependenciesCallback = - return proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} = + return proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = callEventCallback(ctx, "onMissingDependencies"): $JsonMissingDependenciesEvent.new(messageId, missingDeps, channelId) @@ -91,6 +91,22 @@ proc onPeriodicSync(ctx: ptr SdsContext): PeriodicSyncCallback = callEventCallback(ctx, "onPeriodicSync"): $JsonPeriodicSyncEvent.new() +proc onRetrievalHint(ctx: ptr SdsContext): RetrievalHintProvider = + return proc(messageId: SdsMessageID): seq[byte] {.gcsafe.} = + if isNil(ctx.retrievalHintProvider): + return @[] + + var hint: cstring + var hintLen: csize_t + cast[SdsRetrievalHintProvider](ctx.retrievalHintProvider)( + messageId.cstring, addr hint, addr hintLen, ctx.retrievalHintUserData + ) + + if not isNil(hint) and hintLen > 0: + result = newSeq[byte](hintLen) + copyMem(addr result[0], hint, hintLen) + deallocShared(hint) + ### End of not-exported components ################################################################################ @@ -153,6 +169,7 @@ proc SdsNewReliabilityManager( messageSentCb: onMessageSent(ctx), missingDependenciesCb: onMissingDependencies(ctx), periodicSyncCb: onPeriodicSync(ctx), + retrievalHintProvider: onRetrievalHint(ctx), ) let retCode = handleRequest( @@ -177,6 +194,13 @@ proc SdsSetEventCallback( ctx[].eventCallback = cast[pointer](callback) ctx[].eventUserData = userData +proc SdsSetRetrievalHintProvider( + ctx: ptr SdsContext, callback: SdsRetrievalHintProvider, userData: pointer +) {.dynlib, exportc.} = + initializeLibrary() + ctx[].retrievalHintProvider = cast[pointer](callback) + ctx[].retrievalHintUserData = userData + proc SdsCleanupReliabilityManager( ctx: ptr SdsContext, callback: SdsCallBack, userData: pointer ): cint {.dynlib, exportc.} = diff --git a/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim b/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim index fd5a615..19b9b3f 100644 --- a/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim +++ b/library/sds_thread/inter_thread_communication/requests/sds_lifecycle_request.nim @@ -40,6 +40,7 @@ proc createReliabilityManager( rm.setCallbacks( appCallbacks.messageReadyCb, appCallbacks.messageSentCb, appCallbacks.missingDependenciesCb, appCallbacks.periodicSyncCb, + appCallbacks.retrievalHintProvider, ) return ok(rm) diff --git a/library/sds_thread/inter_thread_communication/requests/sds_message_request.nim b/library/sds_thread/inter_thread_communication/requests/sds_message_request.nim index d41c15a..63e0704 100644 --- a/library/sds_thread/inter_thread_communication/requests/sds_message_request.nim +++ b/library/sds_thread/inter_thread_communication/requests/sds_message_request.nim @@ -1,4 +1,4 @@ -import std/[json, strutils, net, sequtils] +import std/[json, strutils, net, sequtils, base64] import chronos, chronicles, results import ../../../alloc @@ -17,7 +17,7 @@ type SdsMessageRequest* = object type SdsUnwrapResponse* = object message*: seq[byte] - missingDeps*: seq[SdsMessageID] + missingDeps*: seq[HistoryEntry] proc createShared*( T: type SdsMessageRequest, @@ -61,13 +61,23 @@ proc process*( of UNWRAP_MESSAGE: let messageBytes = self.message.toSeq() - let (unwrappedMessage, missingDeps, _) = unwrapReceivedMessage(rm[], messageBytes).valueOr: + let (unwrappedMessage, missingDeps, extractedChannelId) = unwrapReceivedMessage(rm[], messageBytes).valueOr: error "UNWRAP_MESSAGE failed", error = error return err("error processing UNWRAP_MESSAGE request: " & $error) let res = SdsUnwrapResponse(message: unwrappedMessage, missingDeps: missingDeps) # return the result as a json string - return ok($(%*(res))) + var node = newJObject() + node["message"] = %*res.message + node["channelId"] = %*extractedChannelId + var missingDepsNode = newJArray() + for dep in res.missingDeps: + var depNode = newJObject() + depNode["messageId"] = %*dep.messageId + depNode["retrievalHint"] = %*encode(dep.retrievalHint) + missingDepsNode.add(depNode) + node["missingDeps"] = missingDepsNode + return ok($node) return ok("") diff --git a/library/sds_thread/sds_thread.nim b/library/sds_thread/sds_thread.nim index 8f23840..f54b327 100644 --- a/library/sds_thread/sds_thread.nim +++ b/library/sds_thread/sds_thread.nim @@ -20,6 +20,8 @@ type SdsContext* = object userData*: pointer eventCallback*: pointer eventUserdata*: pointer + retrievalHintProvider*: pointer + retrievalHintUserData*: pointer running: Atomic[bool] # To control when the thread is running proc runSds(ctx: ptr SdsContext) {.async.} = diff --git a/src/message.nim b/src/message.nim index f23ad13..030a023 100644 --- a/src/message.nim +++ b/src/message.nim @@ -4,10 +4,14 @@ type SdsMessageID* = string SdsChannelID* = string + HistoryEntry* = object + messageId*: SdsMessageID + retrievalHint*: seq[byte] ## Optional hint for efficient retrieval (e.g., Waku message hash) + SdsMessage* = object messageId*: SdsMessageID lamportTimestamp*: int64 - causalHistory*: seq[SdsMessageID] + causalHistory*: seq[HistoryEntry] channelId*: SdsChannelID content*: seq[byte] bloomFilter*: seq[byte] diff --git a/src/protobuf.nim b/src/protobuf.nim index 1f6d600..8adbbb3 100644 --- a/src/protobuf.nim +++ b/src/protobuf.nim @@ -9,8 +9,13 @@ proc encode*(msg: SdsMessage): ProtoBuffer = pb.write(1, msg.messageId) pb.write(2, uint64(msg.lamportTimestamp)) - for hist in msg.causalHistory: - pb.write(3, hist) + for entry in msg.causalHistory: + var entryPb = initProtoBuffer() + entryPb.write(1, entry.messageId) + if entry.retrievalHint.len > 0: + entryPb.write(2, entry.retrievalHint) + entryPb.finish() + pb.write(3, entryPb.buffer) pb.write(4, msg.channelId) pb.write(5, msg.content) @@ -31,10 +36,24 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = return err(ProtobufError.missingRequiredField("lamportTimestamp")) msg.lamportTimestamp = int64(timestamp) - var causalHistory: seq[SdsMessageID] - let histResult = pb.getRepeatedField(3, causalHistory) - if histResult.isOk: - msg.causalHistory = causalHistory + # Handle both old and new causal history formats + var historyBuffers: seq[seq[byte]] + if pb.getRepeatedField(3, historyBuffers).isOk: + # New format: repeated HistoryEntry + for histBuffer in historyBuffers: + let entryPb = initProtoBuffer(histBuffer) + var entry: HistoryEntry + if not ?entryPb.getField(1, entry.messageId): + return err(ProtobufError.missingRequiredField("HistoryEntry.messageId")) + # retrievalHint is optional + discard entryPb.getField(2, entry.retrievalHint) + msg.causalHistory.add(entry) + else: + # Try old format: repeated string + var causalHistory: seq[SdsMessageID] + let histResult = pb.getRepeatedField(3, causalHistory) + if histResult.isOk: + msg.causalHistory = toCausalHistory(causalHistory) if not ?pb.getField(4, msg.channelId): return err(ProtobufError.missingRequiredField("channelId")) diff --git a/src/reliability.nim b/src/reliability.nim index a39fac3..f73d5de 100644 --- a/src/reliability.nim +++ b/src/reliability.nim @@ -24,10 +24,10 @@ proc newReliabilityManager*( proc isAcknowledged*( msg: UnacknowledgedMessage, - causalHistory: seq[SdsMessageID], + causalHistory: seq[HistoryEntry], rbf: Option[RollingBloomFilter], ): bool = - if msg.message.messageId in causalHistory: + if msg.message.messageId in causalHistory.getMessageIds(): return true if rbf.isSome(): @@ -112,7 +112,7 @@ proc wrapOutgoingMessage*( let msg = SdsMessage( messageId: messageId, lamportTimestamp: channel.lamportTimestamp, - causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory, channelId), + causalHistory: rm.getRecentHistoryEntries(rm.config.maxCausalHistory, channelId), channelId: channelId, content: message, bloomFilter: bfResult.get(), @@ -176,7 +176,7 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc proc unwrapReceivedMessage*( rm: ReliabilityManager, message: seq[byte] ): Result[ - tuple[message: seq[byte], missingDeps: seq[SdsMessageID], channelId: SdsChannelID], + tuple[message: seq[byte], missingDeps: seq[HistoryEntry], channelId: SdsChannelID], ReliabilityError, ] = ## Unwraps a received message and processes its reliability metadata. @@ -209,7 +209,7 @@ proc unwrapReceivedMessage*( if missingDeps.len == 0: var depsInBuffer = false for msgId, entry in channel.incomingBuffer.pairs(): - if msgId in msg.causalHistory: + if msgId in msg.causalHistory.getMessageIds(): depsInBuffer = true break # Check if any dependencies are still in incoming buffer @@ -224,7 +224,7 @@ proc unwrapReceivedMessage*( rm.onMessageReady(msg.messageId, channelId) else: channel.incomingBuffer[msg.messageId] = - IncomingMessage(message: msg, missingDeps: missingDeps.toHashSet()) + IncomingMessage(message: msg, missingDeps: missingDeps.getMessageIds().toHashSet()) if not rm.onMissingDependencies.isNil(): rm.onMissingDependencies(msg.messageId, missingDeps, channelId) @@ -271,6 +271,7 @@ proc setCallbacks*( onMessageSent: MessageSentCallback, onMissingDependencies: MissingDependenciesCallback, onPeriodicSync: PeriodicSyncCallback = nil, + onRetrievalHint: RetrievalHintProvider = nil ) = ## Sets the callback functions for various events in the ReliabilityManager. ## @@ -279,11 +280,13 @@ proc setCallbacks*( ## - onMessageSent: Callback function called when a message is confirmed as sent. ## - onMissingDependencies: Callback function called when a message has missing dependencies. ## - onPeriodicSync: Callback function called to notify about periodic sync + ## - onRetrievalHint: Callback function called to get a retrieval hint for a message ID. withLock rm.lock: rm.onMessageReady = onMessageReady rm.onMessageSent = onMessageSent rm.onMissingDependencies = onMissingDependencies rm.onPeriodicSync = onPeriodicSync + rm.onRetrievalHint = onRetrievalHint proc checkUnacknowledgedMessages( rm: ReliabilityManager, channelId: SdsChannelID diff --git a/src/reliability_utils.nim b/src/reliability_utils.nim index 28248da..90055bc 100644 --- a/src/reliability_utils.nim +++ b/src/reliability_utils.nim @@ -10,9 +10,11 @@ type proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} MissingDependenciesCallback* = proc( - messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID + messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID ) {.gcsafe.} + RetrievalHintProvider* = proc(messageId: SdsMessageID): seq[byte] {.gcsafe.} + PeriodicSyncCallback* = proc() {.gcsafe, raises: [].} AppCallbacks* = ref object @@ -20,6 +22,7 @@ type messageSentCb*: MessageSentCallback missingDependenciesCb*: MissingDependenciesCallback periodicSyncCb*: PeriodicSyncCallback + retrievalHintProvider*: RetrievalHintProvider ReliabilityConfig* = object bloomFilterCapacity*: int @@ -45,9 +48,10 @@ type onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} onMissingDependencies*: proc( - messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID + messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID ) {.gcsafe.} onPeriodicSync*: PeriodicSyncCallback + onRetrievalHint*: RetrievalHintProvider ReliabilityError* {.pure.} = enum reInvalidArgument @@ -120,32 +124,73 @@ proc updateLamportTimestamp*( error "Failed to update lamport timestamp", channelId = channelId, msgTs = msgTs, error = getCurrentExceptionMsg() -proc getRecentSdsMessageIDs*( +# Helper functions for HistoryEntry +proc newHistoryEntry*(messageId: SdsMessageID, retrievalHint: seq[byte] = @[]): HistoryEntry = + ## Creates a new HistoryEntry with optional retrieval hint + HistoryEntry(messageId: messageId, retrievalHint: retrievalHint) + +proc newHistoryEntry*(messageId: SdsMessageID, retrievalHint: string): HistoryEntry = + ## Creates a new HistoryEntry with string retrieval hint + HistoryEntry(messageId: messageId, retrievalHint: cast[seq[byte]](retrievalHint)) + +proc toCausalHistory*(messageIds: seq[SdsMessageID]): seq[HistoryEntry] = + ## Converts a sequence of message IDs to HistoryEntry sequence + result = newSeq[HistoryEntry](messageIds.len) + for i, msgId in messageIds: + result[i] = newHistoryEntry(msgId) + +proc getMessageIds*(causalHistory: seq[HistoryEntry]): seq[SdsMessageID] = + ## Extracts message IDs from HistoryEntry sequence + result = newSeq[SdsMessageID](causalHistory.len) + for i, entry in causalHistory: + result[i] = entry.messageId + +proc getRecentHistoryEntries*( rm: ReliabilityManager, n: int, channelId: SdsChannelID -): seq[SdsMessageID] = +): seq[HistoryEntry] = try: if channelId in rm.channels: let channel = rm.channels[channelId] - result = channel.messageHistory[max(0, channel.messageHistory.len - n) .. ^1] + let recentMessageIds = channel.messageHistory[max(0, channel.messageHistory.len - n) .. ^1] + if rm.onRetrievalHint.isNil(): + return toCausalHistory(recentMessageIds) + else: + for msgId in recentMessageIds: + let hint = rm.onRetrievalHint(msgId) + result.add(newHistoryEntry(msgId, hint)) else: result = @[] except Exception: - error "Failed to get recent message IDs", + error "Failed to get recent history entries", channelId = channelId, n = n, error = getCurrentExceptionMsg() result = @[] proc checkDependencies*( - rm: ReliabilityManager, deps: seq[SdsMessageID], channelId: SdsChannelID -): seq[SdsMessageID] = - var missingDeps: seq[SdsMessageID] = @[] + rm: ReliabilityManager, deps: seq[HistoryEntry], channelId: SdsChannelID +): seq[HistoryEntry] = + var missingDeps: seq[HistoryEntry] = @[] try: if channelId in rm.channels: let channel = rm.channels[channelId] - for depId in deps: - if depId notin channel.messageHistory: - missingDeps.add(depId) + for dep in deps: + if dep.messageId notin channel.messageHistory: + # If we have a retrieval hint provider and the original dep has no hint, get one + if not rm.onRetrievalHint.isNil() and dep.retrievalHint.len == 0: + let hint = rm.onRetrievalHint(dep.messageId) + missingDeps.add(newHistoryEntry(dep.messageId, hint)) + else: + missingDeps.add(dep) else: - missingDeps = deps + # Channel doesn't exist, all deps are missing + if not rm.onRetrievalHint.isNil(): + for dep in deps: + if dep.retrievalHint.len == 0: + let hint = rm.onRetrievalHint(dep.messageId) + missingDeps.add(newHistoryEntry(dep.messageId, hint)) + else: + missingDeps.add(dep) + else: + missingDeps = deps except Exception: error "Failed to check dependencies", channelId = channelId, error = getCurrentExceptionMsg() @@ -241,4 +286,4 @@ proc removeChannel*( except Exception: error "Failed to remove channel", channelId = channelId, msg = getCurrentExceptionMsg() - return err(ReliabilityError.reInternalError) + return err(ReliabilityError.reInternalError) \ No newline at end of file diff --git a/tests/test_reliability.nim b/tests/test_reliability.nim index 7b68a86..e9d44dc 100644 --- a/tests/test_reliability.nim +++ b/tests/test_reliability.nim @@ -99,7 +99,7 @@ suite "Reliability Mechanisms": messageReadyCount += 1, proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = messageSentCount += 1, - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} = + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = missingDepsCount += 1, ) @@ -112,7 +112,7 @@ suite "Reliability Mechanisms": let msg2 = SdsMessage( messageId: id2, lamportTimestamp: 2, - causalHistory: @[id1], # msg2 depends on msg1 + causalHistory: toCausalHistory(@[id1]), # msg2 depends on msg1 channelId: testChannel, content: @[byte(2)], bloomFilter: @[], @@ -121,7 +121,7 @@ suite "Reliability Mechanisms": let msg3 = SdsMessage( messageId: id3, lamportTimestamp: 3, - causalHistory: @[id1, id2], # msg3 depends on both msg1 and msg2 + causalHistory: toCausalHistory(@[id1, id2]), # msg3 depends on both msg1 and msg2 channelId: testChannel, content: @[byte(3)], bloomFilter: @[], @@ -141,8 +141,8 @@ suite "Reliability Mechanisms": check: missingDepsCount == 1 # Should trigger missing deps callback missingDeps3.len == 2 # Should be missing both msg1 and msg2 - id1 in missingDeps3 - id2 in missingDeps3 + id1 in missingDeps3.getMessageIds() + id2 in missingDeps3.getMessageIds() # Then try processing msg2 (which only depends on msg1) let unwrapResult2 = rm.unwrapReceivedMessage(serialized2.get()) @@ -152,7 +152,7 @@ suite "Reliability Mechanisms": check: missingDepsCount == 2 # Should have triggered another missing deps callback missingDeps2.len == 1 # Should only be missing msg1 - id1 in missingDeps2 + id1 in missingDeps2.getMessageIds() messageReadyCount == 0 # No messages should be ready yet # Mark first dependency (msg1) as met @@ -176,7 +176,7 @@ suite "Reliability Mechanisms": messageReadyCount += 1, proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = messageSentCount += 1, - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} = + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = missingDepsCount += 1, ) @@ -190,7 +190,7 @@ suite "Reliability Mechanisms": let msg2 = SdsMessage( messageId: "msg2", lamportTimestamp: rm.channels[testChannel].lamportTimestamp + 1, - causalHistory: @[id1], # Include our message in causal history + causalHistory: toCausalHistory(@[id1]), # Include our message in causal history channelId: testChannel, content: @[byte(2)], bloomFilter: @[] # Test with an empty bloom filter @@ -216,7 +216,7 @@ suite "Reliability Mechanisms": discard, proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = messageSentCount += 1, - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} = + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, ) @@ -251,6 +251,59 @@ suite "Reliability Mechanisms": check messageSentCount == 1 # Our message should be acknowledged via bloom filter + test "retrieval hints": + var messageReadyCount = 0 + var messageSentCount = 0 + var missingDepsCount = 0 + + rm.setCallbacks( + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = + messageReadyCount += 1, + proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = + messageSentCount += 1, + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = + missingDepsCount += 1, + nil, + proc(messageId: SdsMessageID): seq[byte] = + return cast[seq[byte]]("hint:" & messageId) + ) + + # Send a first message to populate history + let msg1 = @[byte(1)] + let id1 = "msg1" + let wrap1 = rm.wrapOutgoingMessage(msg1, id1, testChannel) + check wrap1.isOk() + + # Send a second message, which should have the first in its causal history + let msg2 = @[byte(2)] + let id2 = "msg2" + let wrap2 = rm.wrapOutgoingMessage(msg2, id2, testChannel) + check wrap2.isOk() + + # Check that the wrapped message contains the hint + let unwrappedMsg2 = deserializeMessage(wrap2.get()).get() + check unwrappedMsg2.causalHistory.len > 0 + check unwrappedMsg2.causalHistory[0].messageId == id1 + check unwrappedMsg2.causalHistory[0].retrievalHint == cast[seq[byte]]("hint:" & id1) + + # Create a message with a missing dependency + let msg3 = SdsMessage( + messageId: "msg3", + lamportTimestamp: 3, + causalHistory: toCausalHistory(@["missing-dep"]), + channelId: testChannel, + content: @[byte(3)], + bloomFilter: @[], + ) + let serialized3 = serializeMessage(msg3).get() + let unwrapResult3 = rm.unwrapReceivedMessage(serialized3) + check unwrapResult3.isOk() + let (_, missingDeps3, _) = unwrapResult3.get() + check missingDeps3.len == 1 + check missingDeps3[0].messageId == "missing-dep" + # The hint should be populated by the retrieval hint provider for missing dependencies + check missingDeps3[0].retrievalHint == cast[seq[byte]]("hint:missing-dep") + # Periodic task & Buffer management tests suite "Periodic Tasks & Buffer Management": var rm: ReliabilityManager @@ -273,7 +326,7 @@ suite "Periodic Tasks & Buffer Management": discard, proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = messageSentCount += 1, - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} = + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, ) @@ -291,7 +344,7 @@ suite "Periodic Tasks & Buffer Management": let ackMsg = SdsMessage( messageId: "ack1", lamportTimestamp: rm.channels[testChannel].lamportTimestamp + 1, - causalHistory: @["msg0", "msg2", "msg4"], + causalHistory: toCausalHistory(@["msg0", "msg2", "msg4"]), channelId: testChannel, content: @[byte(100)], bloomFilter: @[], @@ -328,7 +381,7 @@ suite "Periodic Tasks & Buffer Management": discard, proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = messageSentCount += 1, - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} = + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, ) @@ -377,7 +430,7 @@ suite "Periodic Tasks & Buffer Management": discard, proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} = + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, proc() {.gcsafe.} = syncCallCount += 1, @@ -420,7 +473,7 @@ suite "Special Cases Handling": let msgInvalid = SdsMessage( messageId: "invalid-bf", lamportTimestamp: 1, - causalHistory: @[], + causalHistory: toCausalHistory(@[]), channelId: testChannel, content: @[byte(1)], bloomFilter: @[1.byte, 2.byte, 3.byte] # Invalid filter data @@ -443,7 +496,7 @@ suite "Special Cases Handling": messageReadyCount += 1, proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = discard, - proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} = + proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = discard, ) @@ -451,7 +504,7 @@ suite "Special Cases Handling": let msg = SdsMessage( messageId: "dup-msg", lamportTimestamp: 1, - causalHistory: @[], + causalHistory: toCausalHistory(@[]), channelId: testChannel, content: @[byte(1)], bloomFilter: @[], @@ -601,7 +654,7 @@ suite "Multi-Channel ReliabilityManager Tests": readyMessageCount += 1, proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} = sentMessageCount += 1, - proc(messageId: SdsMessageID, deps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} = + proc(messageId: SdsMessageID, deps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} = missingDepsCount += 1 ) @@ -624,7 +677,7 @@ suite "Multi-Channel ReliabilityManager Tests": let ackMsg1 = SdsMessage( messageId: "ack1", lamportTimestamp: rm.channels[channel1].lamportTimestamp + 1, - causalHistory: @[msgId1], # Acknowledge msg1 + causalHistory: toCausalHistory(@[msgId1]), # Acknowledge msg1 channelId: channel1, content: @[byte(100)], bloomFilter: @[], @@ -633,7 +686,7 @@ suite "Multi-Channel ReliabilityManager Tests": let ackMsg2 = SdsMessage( messageId: "ack2", lamportTimestamp: rm.channels[channel2].lamportTimestamp + 1, - causalHistory: @[msgId2], # Acknowledge msg2 + causalHistory: toCausalHistory(@[msgId2]), # Acknowledge msg2 channelId: channel2, content: @[byte(101)], bloomFilter: @[],