diff --git a/src/collectives.cpp b/src/collectives.cpp index 94f4b85..b91b62e 100644 --- a/src/collectives.cpp +++ b/src/collectives.cpp @@ -239,6 +239,7 @@ void CmiReductionsInit(void) { // node reduction must be initialized with a valid lock nodered.lock = CmiCreateLock(); // in non-smp this would just be a nullptr + nodered.red = NULL; } CsvAccess(_node_reduction_info) = noderedinfo; @@ -298,9 +299,15 @@ static void CmiClearReduction(CmiReductionID id) { auto &reduction_ref = CpvAccess(_reduction_info)[CmiGetReductionIndex(id)]; CmiReduction *red = reduction_ref; if (red != NULL) { - free(red->remotebuffer); - // we assume the user is freeing the actual messages that the buffer is - // holding or is that something we do? + // Free all the messages stored in remotebuffer + if (red->remotebuffer != NULL) { + for (int i = 0; i < red->messagesReceived; i++) { + if (red->remotebuffer[i] != NULL) { + CmiFree(red->remotebuffer[i]); + } + } + free(red->remotebuffer); + } free(red); } reduction_ref = NULL; @@ -318,10 +325,26 @@ static CmiReduction *CmiGetCreateReduction(CmiReductionID id) { // should handle the 2 cases: // 1. a reduction message arrives from a child for ex), but the parent hasn't // gotten the chance to create a reduction struct yet - // 2. a reduction structure already exists adn teh parent has initiaited its + // 2. a reduction structure already exists and the parent has initiated its // own contribution auto &reduction_ref = CpvAccess(_reduction_info)[CmiGetReductionIndex(id)]; CmiReduction *red = reduction_ref; + + // Force cleanup of stale reduction with different ID (due to wraparound) + if (red != NULL && red->ReductionID != id) { + if (red->remotebuffer != NULL) { + for (int i = 0; i < red->messagesReceived; i++) { + if (red->remotebuffer[i] != NULL) { + CmiFree(red->remotebuffer[i]); + } + } + free(red->remotebuffer); + } + free(red); + red = NULL; + reduction_ref = NULL; + } + if (reduction_ref == NULL) { CmiReduction *newred = (CmiReduction *)malloc(sizeof(CmiReduction)); newred->ReductionID = id; @@ -350,6 +373,7 @@ static CmiReduction *CmiGetCreateReduction(CmiReductionID id) { // gets called by every PE pariticapting in the reduction void CmiReduce(void *msg, int size, CmiReduceMergeFn mergeFn) { const CmiReductionID id = CmiGetNextReductionID(); + CmiSetRedID(msg, id); CmiReduction *red = CmiGetCreateReduction(id); CmiInternalReduce(msg, size, mergeFn, red); } @@ -406,6 +430,13 @@ void CmiSendReduce(CmiReduction *red) { void CmiReduceHandler(void *msg) { CmiReduction *reduction = CmiGetCreateReduction(CmiGetRedID(msg)); + // Add bounds checking to prevent buffer overflow + if (reduction->messagesReceived >= reduction->numChildren) { + CmiAbort("CmiReduceHandler: received more messages than expected (%d >= %d)", + reduction->messagesReceived, reduction->numChildren); + return; + } + // how are we ensuring the messages arrive in order again? reduction->remotebuffer[reduction->messagesReceived] = (char *)msg; reduction->messagesReceived++; @@ -418,9 +449,15 @@ static void CmiClearNodeReduction(CmiReductionID id) { CsvAccess(_node_reduction_info)[CmiGetReductionIndex(id)].red; CmiReduction *red = reduction_ref; if (red != NULL) { - free(red->remotebuffer); - // we assume the user is freeing the actual messages that the buffer is - // holding or is that something we do? + // Free all the messages stored in remotebuffer + if (red->remotebuffer != NULL) { + for (int i = 0; i < red->messagesReceived; i++) { + if (red->remotebuffer[i] != NULL) { + CmiFree(red->remotebuffer[i]); + } + } + free(red->remotebuffer); + } free(red); } reduction_ref = NULL; @@ -429,11 +466,13 @@ static void CmiClearNodeReduction(CmiReductionID id) { // lock and unlock are used to support SMP void CmiNodeReduce(void *msg, int size, CmiReduceMergeFn mergeFn) { + const CmiReductionID id = CmiGetNextNodeReductionID(); + CmiSetRedID(msg, id); + CmiNodeReduction nodeRed = - CsvAccess(_node_reduction_info)[CmiGetReductionIndex(CmiGetRedID(msg))]; + CsvAccess(_node_reduction_info)[CmiGetReductionIndex(id)]; CmiLock(nodeRed.lock); - const CmiReductionID id = CmiGetNextNodeReductionID(); CmiReduction *red = CmiGetCreateNodeReduction(id); CmiInternalNodeReduce(msg, size, mergeFn, red); @@ -453,11 +492,27 @@ static CmiReduction *CmiGetCreateNodeReduction(CmiReductionID id) { // should handle the 2 cases: // 1. a reduction message arrives from a child for ex), but the parent hasn't // gotten the chance to create a reduction struct yet - // 2. a reduction structure already exists adn teh parent has initiaited its + // 2. a reduction structure already exists and the parent has initiated its // own contribution auto &reduction_ref = CsvAccess(_node_reduction_info)[CmiGetReductionIndex(id)].red; CmiReduction *red = reduction_ref; + + // Force cleanup of stale reduction with different ID (due to wraparound) + if (red != NULL && red->ReductionID != id) { + if (red->remotebuffer != NULL) { + for (int i = 0; i < red->messagesReceived; i++) { + if (red->remotebuffer[i] != NULL) { + CmiFree(red->remotebuffer[i]); + } + } + free(red->remotebuffer); + } + free(red); + red = NULL; + reduction_ref = NULL; + } + if (reduction_ref == NULL) { CmiReduction *newred = (CmiReduction *)malloc(sizeof(CmiReduction)); newred->ReductionID = id; @@ -539,6 +594,14 @@ void CmiNodeReduceHandler(void *msg) { CmiReduction *reduction = CmiGetCreateNodeReduction(CmiGetRedID(msg)); + // Add bounds checking to prevent buffer overflow + if (reduction->messagesReceived >= reduction->numChildren) { + CmiAbort("CmiNodeReduceHandler: received more messages than expected (%d >= %d)", + reduction->messagesReceived, reduction->numChildren); + CmiUnlock(nodeRed.lock); + return; + } + // how are we ensuring the messages arrive in order again? reduction->remotebuffer[reduction->messagesReceived] = (char *)msg; reduction->messagesReceived++;