Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions library/events/json_missing_dependencies_event.nim
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
import std/json
import ./json_base_event, ../../src/[message]
import ./json_base_event, ../../src/[message], std/base64

type JsonMissingDependenciesEvent* = ref object of JsonEvent
messageId*: SdsMessageID
missingDeps: seq[SdsMessageID]
missingDeps*: seq[HistoryEntry]
channelId*: SdsChannelID

proc new*(
T: type JsonMissingDependenciesEvent,
messageId: SdsMessageID,
missingDeps: seq[SdsMessageID],
missingDeps: seq[HistoryEntry],
channelId: SdsChannelID,
): T =
return JsonMissingDependenciesEvent(
eventType: "missing_dependencies", messageId: messageId, missingDeps: missingDeps, channelId: channelId
)

method `$`*(jsonMissingDependencies: JsonMissingDependenciesEvent): string =
$(%*jsonMissingDependencies)
var node = newJObject()
node["eventType"] = %*jsonMissingDependencies.eventType
node["messageId"] = %*jsonMissingDependencies.messageId
node["channelId"] = %*jsonMissingDependencies.channelId
var missingDepsNode = newJArray()
for dep in jsonMissingDependencies.missingDeps:
var depNode = newJObject()
depNode["messageId"] = %*dep.messageId
depNode["retrievalHint"] = %*encode(dep.retrievalHint)
missingDepsNode.add(depNode)
node["missingDeps"] = missingDepsNode
$node
4 changes: 4 additions & 0 deletions library/ffi_types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ type SdsCallBack* = proc(
callerRet: cint, msg: ptr cchar, len: csize_t, userData: pointer
) {.cdecl, gcsafe, raises: [].}

type SdsRetrievalHintProvider* = proc(
messageId: cstring, hint: ptr cstring, hintLen: ptr csize_t, userData: pointer
) {.cdecl, gcsafe, raises: [].}

const RET_OK*: cint = 0
const RET_ERR*: cint = 1
const RET_MISSING_CALLBACK*: cint = 2
Expand Down
4 changes: 4 additions & 0 deletions library/libsds.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ extern "C" {

typedef void (*SdsCallBack) (int callerRet, const char* msg, size_t len, void* userData);

typedef void (*SdsRetrievalHintProvider) (const char* messageId, char** hint, size_t* hintLen, void* userData);


// --- Core API Functions ---

Expand All @@ -28,6 +30,8 @@ void* SdsNewReliabilityManager(SdsCallBack callback, void* userData);

void SdsSetEventCallback(void* ctx, SdsCallBack callback, void* userData);

void SdsSetRetrievalHintProvider(void* ctx, SdsRetrievalHintProvider callback, void* userData);

int SdsCleanupReliabilityManager(void* ctx, SdsCallBack callback, void* userData);

int SdsResetReliabilityManager(void* ctx, SdsCallBack callback, void* userData);
Expand Down
26 changes: 25 additions & 1 deletion library/libsds.nim
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ proc onMessageSent(ctx: ptr SdsContext): MessageSentCallback =
$JsonMessageSentEvent.new(messageId, channelId)

proc onMissingDependencies(ctx: ptr SdsContext): MissingDependenciesCallback =
return proc(messageId: SdsMessageID, missingDeps: seq[SdsMessageID], channelId: SdsChannelID) {.gcsafe.} =
return proc(messageId: SdsMessageID, missingDeps: seq[HistoryEntry], channelId: SdsChannelID) {.gcsafe.} =
callEventCallback(ctx, "onMissingDependencies"):
$JsonMissingDependenciesEvent.new(messageId, missingDeps, channelId)

Expand All @@ -91,6 +91,22 @@ proc onPeriodicSync(ctx: ptr SdsContext): PeriodicSyncCallback =
callEventCallback(ctx, "onPeriodicSync"):
$JsonPeriodicSyncEvent.new()

proc onRetrievalHint(ctx: ptr SdsContext): RetrievalHintProvider =
return proc(messageId: SdsMessageID): seq[byte] {.gcsafe.} =
if isNil(ctx.retrievalHintProvider):
return @[]

var hint: cstring
var hintLen: csize_t
cast[SdsRetrievalHintProvider](ctx.retrievalHintProvider)(
messageId.cstring, addr hint, addr hintLen, ctx.retrievalHintUserData
)

if not isNil(hint) and hintLen > 0:
result = newSeq[byte](hintLen)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should always avoid the use of the implicit result

copyMem(addr result[0], hint, hintLen)
deallocShared(hint)

### End of not-exported components
################################################################################

Expand Down Expand Up @@ -153,6 +169,7 @@ proc SdsNewReliabilityManager(
messageSentCb: onMessageSent(ctx),
missingDependenciesCb: onMissingDependencies(ctx),
periodicSyncCb: onPeriodicSync(ctx),
retrievalHintProvider: onRetrievalHint(ctx),
)

let retCode = handleRequest(
Expand All @@ -177,6 +194,13 @@ proc SdsSetEventCallback(
ctx[].eventCallback = cast[pointer](callback)
ctx[].eventUserData = userData

proc SdsSetRetrievalHintProvider(
ctx: ptr SdsContext, callback: SdsRetrievalHintProvider, userData: pointer
) {.dynlib, exportc.} =
initializeLibrary()
ctx[].retrievalHintProvider = cast[pointer](callback)
ctx[].retrievalHintUserData = userData

proc SdsCleanupReliabilityManager(
ctx: ptr SdsContext, callback: SdsCallBack, userData: pointer
): cint {.dynlib, exportc.} =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ proc createReliabilityManager(
rm.setCallbacks(
appCallbacks.messageReadyCb, appCallbacks.messageSentCb,
appCallbacks.missingDependenciesCb, appCallbacks.periodicSyncCb,
appCallbacks.retrievalHintProvider,
)

return ok(rm)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import std/[json, strutils, net, sequtils]
import std/[json, strutils, net, sequtils, base64]
import chronos, chronicles, results

import ../../../alloc
Expand All @@ -17,7 +17,7 @@ type SdsMessageRequest* = object

type SdsUnwrapResponse* = object
message*: seq[byte]
missingDeps*: seq[SdsMessageID]
missingDeps*: seq[HistoryEntry]

proc createShared*(
T: type SdsMessageRequest,
Expand Down Expand Up @@ -61,13 +61,23 @@ proc process*(
of UNWRAP_MESSAGE:
let messageBytes = self.message.toSeq()

let (unwrappedMessage, missingDeps, _) = unwrapReceivedMessage(rm[], messageBytes).valueOr:
let (unwrappedMessage, missingDeps, extractedChannelId) = unwrapReceivedMessage(rm[], messageBytes).valueOr:
error "UNWRAP_MESSAGE failed", error = error
return err("error processing UNWRAP_MESSAGE request: " & $error)

let res = SdsUnwrapResponse(message: unwrappedMessage, missingDeps: missingDeps)

# return the result as a json string
return ok($(%*(res)))
var node = newJObject()
node["message"] = %*res.message
node["channelId"] = %*extractedChannelId
var missingDepsNode = newJArray()
for dep in res.missingDeps:
var depNode = newJObject()
depNode["messageId"] = %*dep.messageId
depNode["retrievalHint"] = %*encode(dep.retrievalHint)
missingDepsNode.add(depNode)
node["missingDeps"] = missingDepsNode
return ok($node)

return ok("")
2 changes: 2 additions & 0 deletions library/sds_thread/sds_thread.nim
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type SdsContext* = object
userData*: pointer
eventCallback*: pointer
eventUserdata*: pointer
retrievalHintProvider*: pointer
retrievalHintUserData*: pointer
running: Atomic[bool] # To control when the thread is running

proc runSds(ctx: ptr SdsContext) {.async.} =
Expand Down
6 changes: 5 additions & 1 deletion src/message.nim
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ type
SdsMessageID* = string
SdsChannelID* = string

HistoryEntry* = object
messageId*: SdsMessageID
retrievalHint*: seq[byte] ## Optional hint for efficient retrieval (e.g., Waku message hash)

SdsMessage* = object
messageId*: SdsMessageID
lamportTimestamp*: int64
causalHistory*: seq[SdsMessageID]
causalHistory*: seq[HistoryEntry]
channelId*: SdsChannelID
content*: seq[byte]
bloomFilter*: seq[byte]
Expand Down
31 changes: 25 additions & 6 deletions src/protobuf.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ proc encode*(msg: SdsMessage): ProtoBuffer =
pb.write(1, msg.messageId)
pb.write(2, uint64(msg.lamportTimestamp))

for hist in msg.causalHistory:
pb.write(3, hist)
for entry in msg.causalHistory:
var entryPb = initProtoBuffer()
entryPb.write(1, entry.messageId)
if entry.retrievalHint.len > 0:
entryPb.write(2, entry.retrievalHint)
entryPb.finish()
pb.write(3, entryPb.buffer)

pb.write(4, msg.channelId)
pb.write(5, msg.content)
Expand All @@ -31,10 +36,24 @@ proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] =
return err(ProtobufError.missingRequiredField("lamportTimestamp"))
msg.lamportTimestamp = int64(timestamp)

var causalHistory: seq[SdsMessageID]
let histResult = pb.getRepeatedField(3, causalHistory)
if histResult.isOk:
msg.causalHistory = causalHistory
# Handle both old and new causal history formats
var historyBuffers: seq[seq[byte]]
if pb.getRepeatedField(3, historyBuffers).isOk:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nitpick one. Verbs with parenthesis; nouns without

Suggested change
if pb.getRepeatedField(3, historyBuffers).isOk:
if pb.getRepeatedField(3, historyBuffers).isOk():

# New format: repeated HistoryEntry
for histBuffer in historyBuffers:
let entryPb = initProtoBuffer(histBuffer)
var entry: HistoryEntry
if not ?entryPb.getField(1, entry.messageId):
return err(ProtobufError.missingRequiredField("HistoryEntry.messageId"))
# retrievalHint is optional
discard entryPb.getField(2, entry.retrievalHint)
msg.causalHistory.add(entry)
else:
# Try old format: repeated string
var causalHistory: seq[SdsMessageID]
let histResult = pb.getRepeatedField(3, causalHistory)
if histResult.isOk:
msg.causalHistory = toCausalHistory(causalHistory)

if not ?pb.getField(4, msg.channelId):
return err(ProtobufError.missingRequiredField("channelId"))
Expand Down
15 changes: 9 additions & 6 deletions src/reliability.nim
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ proc newReliabilityManager*(

proc isAcknowledged*(
msg: UnacknowledgedMessage,
causalHistory: seq[SdsMessageID],
causalHistory: seq[HistoryEntry],
rbf: Option[RollingBloomFilter],
): bool =
if msg.message.messageId in causalHistory:
if msg.message.messageId in causalHistory.getMessageIds():
return true

if rbf.isSome():
Expand Down Expand Up @@ -112,7 +112,7 @@ proc wrapOutgoingMessage*(
let msg = SdsMessage(
messageId: messageId,
lamportTimestamp: channel.lamportTimestamp,
causalHistory: rm.getRecentSdsMessageIDs(rm.config.maxCausalHistory, channelId),
causalHistory: rm.getRecentHistoryEntries(rm.config.maxCausalHistory, channelId),
channelId: channelId,
content: message,
bloomFilter: bfResult.get(),
Expand Down Expand Up @@ -176,7 +176,7 @@ proc processIncomingBuffer(rm: ReliabilityManager, channelId: SdsChannelID) {.gc
proc unwrapReceivedMessage*(
rm: ReliabilityManager, message: seq[byte]
): Result[
tuple[message: seq[byte], missingDeps: seq[SdsMessageID], channelId: SdsChannelID],
tuple[message: seq[byte], missingDeps: seq[HistoryEntry], channelId: SdsChannelID],
ReliabilityError,
] =
## Unwraps a received message and processes its reliability metadata.
Expand Down Expand Up @@ -209,7 +209,7 @@ proc unwrapReceivedMessage*(
if missingDeps.len == 0:
var depsInBuffer = false
for msgId, entry in channel.incomingBuffer.pairs():
if msgId in msg.causalHistory:
if msgId in msg.causalHistory.getMessageIds():
depsInBuffer = true
break
# Check if any dependencies are still in incoming buffer
Expand All @@ -224,7 +224,7 @@ proc unwrapReceivedMessage*(
rm.onMessageReady(msg.messageId, channelId)
else:
channel.incomingBuffer[msg.messageId] =
IncomingMessage(message: msg, missingDeps: missingDeps.toHashSet())
IncomingMessage(message: msg, missingDeps: missingDeps.getMessageIds().toHashSet())
if not rm.onMissingDependencies.isNil():
rm.onMissingDependencies(msg.messageId, missingDeps, channelId)

Expand Down Expand Up @@ -271,6 +271,7 @@ proc setCallbacks*(
onMessageSent: MessageSentCallback,
onMissingDependencies: MissingDependenciesCallback,
onPeriodicSync: PeriodicSyncCallback = nil,
onRetrievalHint: RetrievalHintProvider = nil
) =
## Sets the callback functions for various events in the ReliabilityManager.
##
Expand All @@ -279,11 +280,13 @@ proc setCallbacks*(
## - onMessageSent: Callback function called when a message is confirmed as sent.
## - onMissingDependencies: Callback function called when a message has missing dependencies.
## - onPeriodicSync: Callback function called to notify about periodic sync
## - onRetrievalHint: Callback function called to get a retrieval hint for a message ID.
withLock rm.lock:
rm.onMessageReady = onMessageReady
rm.onMessageSent = onMessageSent
rm.onMissingDependencies = onMissingDependencies
rm.onPeriodicSync = onPeriodicSync
rm.onRetrievalHint = onRetrievalHint

proc checkUnacknowledgedMessages(
rm: ReliabilityManager, channelId: SdsChannelID
Expand Down
Loading