diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..273cb6e Binary files /dev/null and b/.DS_Store differ diff --git a/sds/sds.go b/sds/sds.go index cd8168e..754682e 100644 --- a/sds/sds.go +++ b/sds/sds.go @@ -10,6 +10,8 @@ package sds extern void sdsGlobalEventCallback(int ret, char* msg, size_t len, void* userData); + extern void sdsGlobalRetrievalHintProvider(char* messageId, char** hint, size_t* hintLen, void* userData); + typedef struct { int ret; char* msg; @@ -78,6 +80,10 @@ package sds SdsSetEventCallback(rmCtx, (SdsCallBack) sdsGlobalEventCallback, rmCtx); } + static void cGoSdsSetRetrievalHintProvider(void* rmCtx) { + SdsSetRetrievalHintProvider(rmCtx, (SdsRetrievalHintProvider) sdsGlobalRetrievalHintProvider, rmCtx); + } + static void cGoSdsCleanupReliabilityManager(void* rmCtx, void* resp) { SdsCleanupReliabilityManager(rmCtx, (SdsCallBack) SdsGoCallback, resp); } @@ -158,8 +164,9 @@ func SdsGoCallback(ret C.int, msg *C.char, len C.size_t, resp unsafe.Pointer) { type EventCallbacks struct { OnMessageReady func(messageId MessageID, channelId string) OnMessageSent func(messageId MessageID, channelId string) - OnMissingDependencies func(messageId MessageID, missingDeps []MessageID, channelId string) + OnMissingDependencies func(messageId MessageID, missingDeps []HistoryEntry, channelId string) OnPeriodicSync func() + RetrievalHintProvider func(messageId MessageID) []byte } // ReliabilityManager represents an instance of a nim-sds ReliabilityManager @@ -189,6 +196,7 @@ func NewReliabilityManager() (*ReliabilityManager, error) { C.cGoSdsSetEventCallback(rm.rmCtx) registerReliabilityManager(rm) + C.cGoSdsSetRetrievalHintProvider(rm.rmCtx) Debug("Successfully created Reliability Manager") return rm, nil @@ -246,14 +254,33 @@ type msgEvent struct { type missingDepsEvent struct { MessageId MessageID `json:"messageId"` - MissingDeps []MessageID `json:"missingDeps"` - ChannelId string `json:"channelId"` + MissingDeps []HistoryEntry `json:"missingDeps"` + ChannelId string `json:"channelId"` } func (rm *ReliabilityManager) RegisterCallbacks(callbacks EventCallbacks) { rm.callbacks = callbacks } +//export sdsGlobalRetrievalHintProvider +func sdsGlobalRetrievalHintProvider(messageId *C.char, hint **C.char, hintLen *C.size_t, userData unsafe.Pointer) { + msgId := C.GoString(messageId) + Debug("sdsGlobalRetrievalHintProvider called for messageId: %s", msgId) + rm, ok := rmRegistry[userData] + if ok && rm.callbacks.RetrievalHintProvider != nil { + Debug("Found RM and callback, calling provider") + hintBytes := rm.callbacks.RetrievalHintProvider(MessageID(msgId)) + Debug("Provider returned hint of length: %d", len(hintBytes)) + if len(hintBytes) > 0 { + *hint = (*C.char)(C.CBytes(hintBytes)) + *hintLen = C.size_t(len(hintBytes)) + Debug("Set hint in C memory: %s", string(hintBytes)) + } + } else { + Debug("No RM found or no callback registered") + } +} + func (rm *ReliabilityManager) OnEvent(eventStr string) { jsonEvent := jsonEvent{} @@ -467,10 +494,11 @@ func (rm *ReliabilityManager) UnwrapReceivedMessage(message []byte) (*UnwrappedM } Debug("Successfully unwrapped message") - unwrappedMessage := UnwrappedMessage{} + Debug("Unwrapped message JSON: %s", resStr) + var unwrappedMessage UnwrappedMessage err := json.Unmarshal([]byte(resStr), &unwrappedMessage) if err != nil { - Error("Failed to unmarshal unwrapped message") + Error("Failed to unmarshal unwrapped message: %v", err) return nil, err } diff --git a/sds/sds_test.go b/sds/sds_test.go index 5e80cf7..963fd59 100644 --- a/sds/sds_test.go +++ b/sds/sds_test.go @@ -80,7 +80,7 @@ func TestDependencies(t *testing.T) { foundDep1 := false for _, dep := range *unwrappedMessage2.MissingDeps { - if dep == msgID1 { + if dep.MessageID == msgID1 { foundDep1 = true break } @@ -239,12 +239,15 @@ func TestCallback_OnMissingDependencies(t *testing.T) { var cbMutex sync.Mutex callbacks := EventCallbacks{ - OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) { + OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) { require.Equal(t, channelID, chId) cbMutex.Lock() missingCalled = true missingMsgID = messageId - missingDepsList = missingDeps // Copy slice + missingDepsList = make([]MessageID, len(missingDeps)) + for i, dep := range missingDeps { + missingDepsList[i] = dep.MessageID + } cbMutex.Unlock() wg.Done() }, @@ -383,7 +386,7 @@ func TestCallbacks_Combined(t *testing.T) { // are typically relevant to the Sender. We don't expect this. t.Errorf("Unexpected OnMessageSent call on Receiver for %s", messageId) }, - OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) { + OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) { // This callback is registered on Receiver, used for receiverRm2 below }, } @@ -404,7 +407,7 @@ func TestCallbacks_Combined(t *testing.T) { } } }, - OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) { + OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) { // Not expected on sender }, } @@ -447,11 +450,16 @@ func TestCallbacks_Combined(t *testing.T) { defer receiverRm2.Cleanup() callbacksReceiver2 := EventCallbacks{ - OnMissingDependencies: func(messageId MessageID, missingDeps []MessageID, chId string) { + OnMissingDependencies: func(messageId MessageID, missingDeps []HistoryEntry, chId string) { require.Equal(t, channelID, chId) if messageId == msgID3 { + // Convert []HistoryEntry to []MessageID for the channel + deps := make([]MessageID, len(missingDeps)) + for i, d := range missingDeps { + deps[i] = d.MessageID + } select { - case missingChan <- missingDeps: + case missingChan <- deps: default: } } @@ -677,3 +685,61 @@ func TestMultiChannelCallbacks(t *testing.T) { require.Equal(t, channel2, readyMessages[ackID2_ch2], "OnMessageReady for ack2 has incorrect channel") require.Len(t, readyMessages, 2, "Expected exactly 2 ready messages") } + +func TestRetrievalHints(t *testing.T) { + rm, err := NewReliabilityManager() + require.NoError(t, err) + defer rm.Cleanup() + + channelID := "test-retrieval-hints" + + // Set a retrieval hint provider + rm.RegisterCallbacks(EventCallbacks{ + RetrievalHintProvider: func(messageId MessageID) []byte { + return []byte("hint-for-" + messageId) + }, + }) + + // 1. Send a message to populate the history + payload1 := []byte("message one") + msgID1 := MessageID("msg-hint-1") + wrappedMsg1, err := rm.WrapOutgoingMessage(payload1, msgID1, channelID) + require.NoError(t, err) + + // 2. Receive the message to add it to history + _, err = rm.UnwrapReceivedMessage(wrappedMsg1) + require.NoError(t, err) + + // 3. Send a second message, which will include the first in its causal history + payload2 := []byte("message two") + msgID2 := MessageID("msg-hint-2") + wrappedMsg2, err := rm.WrapOutgoingMessage(payload2, msgID2, channelID) + require.NoError(t, err) + + // 4. Unwrap the second message to inspect its causal history + // We need a new RM to avoid acknowledging the message + rm2, err := NewReliabilityManager() + require.NoError(t, err) + defer rm2.Cleanup() + + rm2.RegisterCallbacks(EventCallbacks{ + RetrievalHintProvider: func(messageId MessageID) []byte { + return []byte("hint-for-" + messageId) + }, + }) + + unwrappedMsg2, err := rm2.UnwrapReceivedMessage(wrappedMsg2) + require.NoError(t, err) + + // 5. Check that the causal history contains the retrieval hint + require.Greater(t, len(*unwrappedMsg2.MissingDeps), 0, "Expected missing dependencies") + foundDep := false + for _, dep := range *unwrappedMsg2.MissingDeps { + if dep.MessageID == msgID1 { + foundDep = true + require.Equal(t, []byte("hint-for-"+msgID1), dep.RetrievalHint, "Retrieval hint does not match") + break + } + } + require.True(t, foundDep, "Expected to find dependency %s", msgID1) +} diff --git a/sds/types.go b/sds/types.go index 8d788ca..6606fd4 100644 --- a/sds/types.go +++ b/sds/types.go @@ -2,8 +2,13 @@ package sds type MessageID string +type HistoryEntry struct { + MessageID MessageID `json:"messageId"` + RetrievalHint []byte `json:"retrievalHint"` +} + type UnwrappedMessage struct { - Message *[]byte `json:"message"` - MissingDeps *[]MessageID `json:"missingDeps"` - ChannelId *string `json:"channelId"` + Message *[]byte `json:"message"` + MissingDeps *[]HistoryEntry `json:"missingDeps"` + ChannelId *string `json:"channelId"` }