Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
38 changes: 33 additions & 5 deletions sds/sds.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ package sds

extern void sdsGlobalEventCallback(int ret, char* msg, size_t len, void* userData);

extern void sdsGlobalRetrievalHintProvider(char* messageId, char** hint, size_t* hintLen, void* userData);

typedef struct {
int ret;
char* msg;
Expand Down Expand Up @@ -78,6 +80,10 @@ package sds
SdsSetEventCallback(rmCtx, (SdsCallBack) sdsGlobalEventCallback, rmCtx);
}

static void cGoSdsSetRetrievalHintProvider(void* rmCtx) {
SdsSetRetrievalHintProvider(rmCtx, (SdsRetrievalHintProvider) sdsGlobalRetrievalHintProvider, rmCtx);
}

static void cGoSdsCleanupReliabilityManager(void* rmCtx, void* resp) {
SdsCleanupReliabilityManager(rmCtx, (SdsCallBack) SdsGoCallback, resp);
}
Expand Down Expand Up @@ -158,8 +164,9 @@ func SdsGoCallback(ret C.int, msg *C.char, len C.size_t, resp unsafe.Pointer) {
type EventCallbacks struct {
OnMessageReady func(messageId MessageID, channelId string)
OnMessageSent func(messageId MessageID, channelId string)
OnMissingDependencies func(messageId MessageID, missingDeps []MessageID, channelId string)
OnMissingDependencies func(messageId MessageID, missingDeps []HistoryEntry, channelId string)
OnPeriodicSync func()
RetrievalHintProvider func(messageId MessageID) []byte
}

// ReliabilityManager represents an instance of a nim-sds ReliabilityManager
Expand Down Expand Up @@ -189,6 +196,7 @@ func NewReliabilityManager() (*ReliabilityManager, error) {

C.cGoSdsSetEventCallback(rm.rmCtx)
registerReliabilityManager(rm)
C.cGoSdsSetRetrievalHintProvider(rm.rmCtx)

Debug("Successfully created Reliability Manager")
return rm, nil
Expand Down Expand Up @@ -246,14 +254,33 @@ type msgEvent struct {

type missingDepsEvent struct {
MessageId MessageID `json:"messageId"`
MissingDeps []MessageID `json:"missingDeps"`
ChannelId string `json:"channelId"`
MissingDeps []HistoryEntry `json:"missingDeps"`
ChannelId string `json:"channelId"`
}

func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) {
rm.callbacks = callbacks
}

//export sdsGlobalRetrievalHintProvider
func sdsGlobalRetrievalHintProvider(messageId *C.char, hint **C.char, hintLen *C.size_t, userData unsafe.Pointer) {
msgId := C.GoString(messageId)
Debug("sdsGlobalRetrievalHintProvider called for messageId: %s", msgId)
rm, ok := rmRegistry[userData]
if ok && rm.callbacks.RetrievalHintProvider != nil {
Debug("Found RM and callback, calling provider")
hintBytes := rm.callbacks.RetrievalHintProvider(MessageID(msgId))
Debug("Provider returned hint of length: %d", len(hintBytes))
if len(hintBytes) > 0 {
*hint = (*C.char)(C.CBytes(hintBytes))
*hintLen = C.size_t(len(hintBytes))
Debug("Set hint in C memory: %s", string(hintBytes))
}
} else {
Debug("No RM found or no callback registered")
}
}

func (rm *ReliabilityManager) OnEvent(eventStr string) {

jsonEvent := jsonEvent{}
Expand Down Expand Up @@ -467,10 +494,11 @@ func (rm *ReliabilityManager) UnwrapReceivedMessage(message []byte) (*UnwrappedM
}
Debug("Successfully unwrapped message")

unwrappedMessage := UnwrappedMessage{}
Debug("Unwrapped message JSON: %s", resStr)
var unwrappedMessage UnwrappedMessage
err := json.Unmarshal([]byte(resStr), &unwrappedMessage)
if err != nil {
Error("Failed to unmarshal unwrapped message")
Error("Failed to unmarshal unwrapped message: %v", err)
return nil, err
}

Expand Down
80 changes: 73 additions & 7 deletions sds/sds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestDependencies(t *testing.T) {

foundDep1 := false
for _, dep := range *unwrappedMessage2.MissingDeps {
if dep == msgID1 {
if dep.MessageID == msgID1 {
foundDep1 = true
break
}
Expand Down Expand Up @@ -239,12 +239,15 @@ func TestCallback_OnMissingDependencies(t *testing.T) {
var cbMutex sync.Mutex

callbacks := EventCallbacks{
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
require.Equal(t, channelID, chId)
cbMutex.Lock()
missingCalled = true
missingMsgID = messageId
missingDepsList = missingDeps // Copy slice
missingDepsList = make([]MessageID, len(missingDeps))
for i, dep := range missingDeps {
missingDepsList[i] = dep.MessageID
}
cbMutex.Unlock()
wg.Done()
},
Expand Down Expand Up @@ -383,7 +386,7 @@ func TestCallbacks_Combined(t *testing.T) {
// are typically relevant to the Sender. We don't expect this.
t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId)
},
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
// This callback is registered on Receiver, used for receiverRm2 below
},
}
Expand All @@ -404,7 +407,7 @@ func TestCallbacks_Combined(t *testing.T) {
}
}
},
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
// Not expected on sender
},
}
Expand Down Expand Up @@ -447,11 +450,16 @@ func TestCallbacks_Combined(t *testing.T) {
defer receiverRm2.Cleanup()

callbacksReceiver2 := EventCallbacks{
OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) {
OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) {
require.Equal(t, channelID, chId)
if messageId == msgID3 {
// Convert []HistoryEntry to []MessageID for the channel
deps := make([]MessageID, len(missingDeps))
for i, d := range missingDeps {
deps[i] = d.MessageID
}
select {
case missingChan <- missingDeps:
case missingChan <- deps:
default:
}
}
Expand Down Expand Up @@ -677,3 +685,61 @@ func TestMultiChannelCallbacks(t *testing.T) {
require.Equal(t, channel2, readyMessages[ackID2_ch2], "OnMessageReady for ack2 has incorrect channel")
require.Len(t, readyMessages, 2, "Expected exactly 2 ready messages")
}

func TestRetrievalHints(t *testing.T) {
rm, err := NewReliabilityManager()
require.NoError(t, err)
defer rm.Cleanup()

channelID := "test-retrieval-hints"

// Set a retrieval hint provider
rm.RegisterCallbacks(EventCallbacks{
RetrievalHintProvider: func(messageId MessageID) []byte {
return []byte("hint-for-" + messageId)
},
})

// 1. Send a message to populate the history
payload1 := []byte("message one")
msgID1 := MessageID("msg-hint-1")
wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1, channelID)
require.NoError(t, err)

// 2. Receive the message to add it to history
_, err = rm.UnwrapReceivedMessage(wrappedMsg1)
require.NoError(t, err)

// 3. Send a second message, which will include the first in its causal history
payload2 := []byte("message two")
msgID2 := MessageID("msg-hint-2")
wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2, channelID)
require.NoError(t, err)

// 4. Unwrap the second message to inspect its causal history
// We need a new RM to avoid acknowledging the message
rm2, err := NewReliabilityManager()
require.NoError(t, err)
defer rm2.Cleanup()

rm2.RegisterCallbacks(EventCallbacks{
RetrievalHintProvider: func(messageId MessageID) []byte {
return []byte("hint-for-" + messageId)
},
})

unwrappedMsg2, err := rm2.UnwrapReceivedMessage(wrappedMsg2)
require.NoError(t, err)

// 5. Check that the causal history contains the retrieval hint
require.Greater(t, len(*unwrappedMsg2.MissingDeps), 0, "Expected missing dependencies")
foundDep := false
for _, dep := range *unwrappedMsg2.MissingDeps {
if dep.MessageID == msgID1 {
foundDep = true
require.Equal(t, []byte("hint-for-"+msgID1), dep.RetrievalHint, "Retrieval hint does not match")
break
}
}
require.True(t, foundDep, "Expected to find dependency %s", msgID1)
}
11 changes: 8 additions & 3 deletions sds/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ package sds

type MessageID string

type HistoryEntry struct {
MessageID MessageID `json:"messageId"`
RetrievalHint []byte `json:"retrievalHint"`
}

type UnwrappedMessage struct {
Message *[]byte `json:"message"`
MissingDeps *[]MessageID `json:"missingDeps"`
ChannelId *string `json:"channelId"`
Message *[]byte `json:"message"`
MissingDeps *[]HistoryEntry `json:"missingDeps"`
ChannelId *string `json:"channelId"`
}