Skip to content

Commit 5c33e20

Browse files
committed
#2281: Resolve issue with incorrect index generated by StateHolder::getNextID
1 parent ed1f880 commit 5c33e20

File tree

11 files changed

+33
-50
lines changed

11 files changed

+33
-50
lines changed

src/vt/collective/reduce/allreduce/allreduce_holder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
#include "vt/configs/types/types_type.h"
4848
#include "vt/collective/reduce/allreduce/type.h"
4949
#include "vt/collective/reduce/scoping/strong_types.h"
50-
#include "vt/configs/types/types_sentinels.h"
5150
#include "vt/objgroup/proxy/proxy_objgroup.h"
5251

5352
#include <unordered_map>
@@ -56,6 +55,7 @@ namespace vt::collective::reduce::allreduce {
5655

5756
struct Rabenseifner;
5857
struct RecursiveDoubling;
58+
5959
struct AllreduceHolder {
6060
using RabenseifnerProxy = ObjGroupProxyType;
6161
using RecursiveDoublingProxy = ObjGroupProxyType;

src/vt/collective/reduce/allreduce/rabenseifner.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,4 @@ void Rabenseifner::initializeVrtNode() {
150150
}
151151
}
152152

153-
Rabenseifner::~Rabenseifner() {
154-
if (info_.first == ComponentT::ObjGroup) {
155-
StateHolder::clearAll(detail::StrongObjGroup{info_.second});
156-
AllreduceHolder::remove(detail::StrongObjGroup{info_.second});
157-
} else if(info_.first == ComponentT::Group){
158-
StateHolder::clearAll(detail::StrongGroup{info_.second});
159-
}
160-
}
161-
162153
} // namespace vt::collective::reduce::allreduce

src/vt/collective/reduce/allreduce/rabenseifner.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@ struct Rabenseifner {
8787

8888
void initializeVrtNode();
8989

90-
~Rabenseifner();
91-
9290
/**
9391
* \brief Set final handler that will be executed with allreduce result
9492
*

src/vt/collective/reduce/allreduce/recursive_doubling.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,4 @@ void RecursiveDoubling::initializeVrtNode() {
117117
}
118118
}
119119

120-
RecursiveDoubling::~RecursiveDoubling() {
121-
if (info_.first == ComponentT::ObjGroup) {
122-
StateHolder::clearAll(detail::StrongObjGroup{info_.second});
123-
AllreduceHolder::remove(detail::StrongObjGroup{info_.second});
124-
} else if(info_.first == ComponentT::Group){
125-
StateHolder::clearAll(detail::StrongGroup{info_.second});
126-
}
127-
}
128-
129120
} // namespace vt::collective::reduce::allreduce

src/vt/collective/reduce/allreduce/recursive_doubling.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ struct RecursiveDoubling {
8989
*/
9090
void initializeVrtNode();
9191

92-
~RecursiveDoubling();
93-
9492
/**
9593
* \brief Execute the final handler callback with the reduced result.
9694
*

src/vt/collective/reduce/allreduce/recursive_doubling.impl.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ template <typename DataT, template <typename Arg> class Op>
177177
template <typename DataT, template <typename Arg> class Op>
178178
void RecursiveDoubling::adjustForPowerOfTwoHan(
179179
RecursiveDoublingMsg<DataT>* msg) {
180-
using DataType = DataHandler<DataT>;
181180
auto& state = getState<RecursiveDoublingT, DataT>(info_, msg->id_);
182181
if (not state.value_assigned_) {
183182
if (not state.initialized_) {
@@ -311,7 +310,6 @@ RecursiveDoubling::reduceIterHandler(RecursiveDoublingMsg<DataT>* msg) {
311310

312311
template <typename DataT, template <typename Arg> class Op>
313312
void RecursiveDoubling::reduceIterHan(RecursiveDoublingMsg<DataT>* msg) {
314-
using DataType = DataHandler<DataT>;
315313
auto& state = getState<RecursiveDoublingT, DataT>(info_, msg->id_);
316314

317315
if (not state.value_assigned_) {

src/vt/collective/reduce/allreduce/state_holder.cc

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ size_t
5252
getNextIdImpl(StateHolder::StatesVec& states, size_t idx) {
5353
size_t id = u64empty;
5454

55+
vt_debug_print(
56+
terse, allreduce, "getNextIdImpl idx={} size={} \n", idx, states.size());
57+
5558
for (; idx < states.size(); ++idx) {
5659
auto& state = states.at(idx);
5760
if (not state or not state->active_) {
@@ -64,28 +67,35 @@ getNextIdImpl(StateHolder::StatesVec& states, size_t idx) {
6467
id = states.size();
6568
}
6669

70+
6771
return id;
6872
}
6973

7074
size_t StateHolder::getNextID(detail::StrongVrtProxy proxy) {
71-
auto& states = active_coll_states_[proxy.get()];
75+
auto& [idx, states] = active_coll_states_[proxy.get()];
76+
77+
auto current_idx = getNextIdImpl(states, idx);
78+
idx = current_idx + 1;
7279

73-
collection_idx_ = getNextIdImpl(states, collection_idx_);
74-
return collection_idx_;
80+
return current_idx;
7581
}
7682

7783
size_t StateHolder::getNextID(detail::StrongObjGroup proxy) {
78-
auto& states = active_obj_states_[proxy.get()];
84+
auto& [idx, states] = active_obj_states_[proxy.get()];
7985

80-
objgroup_idx_ = getNextIdImpl(states, objgroup_idx_);
81-
return objgroup_idx_;
86+
auto current_idx = getNextIdImpl(states, idx);
87+
idx = current_idx + 1;
88+
89+
return current_idx;
8290
}
8391

8492
size_t StateHolder::getNextID(detail::StrongGroup group) {
85-
auto& states = active_grp_states_[group.get()];
93+
auto& [idx, states] = active_grp_states_[group.get()];
94+
95+
auto current_idx = getNextIdImpl(states, idx);
8696

87-
group_idx_ = getNextIdImpl(states, group_idx_);
88-
return group_idx_;
97+
idx = current_idx + 1;
98+
return current_idx;
8999
}
90100

91101
static inline void
@@ -101,19 +111,19 @@ clearSingleImpl(StateHolder::StatesVec& states, size_t idx) {
101111
}
102112

103113
void StateHolder::clearSingle(detail::StrongVrtProxy proxy, size_t idx) {
104-
auto& states = active_coll_states_[proxy.get()];
114+
auto& [_, states] = active_coll_states_[proxy.get()];
105115

106116
clearSingleImpl(states, idx);
107117
}
108118

109119
void StateHolder::clearSingle(detail::StrongObjGroup proxy, size_t idx) {
110-
auto& states = active_obj_states_[proxy.get()];
120+
auto& [_, states] = active_obj_states_[proxy.get()];
111121

112122
clearSingleImpl(states, idx);
113123
}
114124

115125
void StateHolder::clearSingle(detail::StrongGroup group, size_t idx) {
116-
auto& states = active_grp_states_[group.get()];
126+
auto& [_, states] = active_grp_states_[group.get()];
117127

118128
clearSingleImpl(states, idx);
119129
}

src/vt/collective/reduce/allreduce/state_holder.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ namespace vt::collective::reduce::allreduce {
5757

5858
struct StateHolder {
5959
using StatesVec = std::vector<std::unique_ptr<StateBase>>;
60+
using StatesInfo = std::pair<size_t, StatesVec>;
6061

6162
template <
6263
typename ReducerT, typename DataT,
@@ -86,17 +87,13 @@ struct StateHolder {
8687
static void clearAll(detail::StrongGroup group);
8788

8889
private:
89-
static inline size_t collection_idx_ = 0;
90-
static inline size_t objgroup_idx_ = 0;
91-
static inline size_t group_idx_ = 0;
92-
93-
static inline std::unordered_map<VirtualProxyType, StatesVec>
90+
static inline std::unordered_map<VirtualProxyType, StatesInfo>
9491
active_coll_states_ = {};
9592

96-
static inline std::unordered_map<ObjGroupProxyType, StatesVec>
93+
static inline std::unordered_map<ObjGroupProxyType, StatesInfo>
9794
active_obj_states_ = {};
9895

99-
static inline std::unordered_map<GroupType, StatesVec> active_grp_states_ =
96+
static inline std::unordered_map<GroupType, StatesInfo> active_grp_states_ =
10097
{};
10198
};
10299

src/vt/collective/reduce/allreduce/state_holder.impl.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
//@HEADER
4242
*/
4343

44-
#include "vt/collective/reduce/allreduce/state.h"
4544
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_IMPL_H
4645
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_IMPL_H
4746

@@ -80,7 +79,7 @@ template <
8079
typename Scalar = typename DataHandler<DataT>::Scalar, typename ProxyT,
8180
typename MapT>
8281
static auto& getStateImpl(ProxyT proxy, MapT& states_map, size_t idx) {
83-
auto& states = states_map[proxy.get()];
82+
auto& [_, states] = states_map[proxy.get()];
8483
auto const num_states = states.size();
8584

8685
if (idx >= num_states || num_states == 0) {

src/vt/objgroup/manager.impl.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
#include "vt/collective/reduce/allreduce/type.h"
6464
#include "vt/collective/reduce/allreduce/helpers.h"
6565
#include "vt/collective/reduce/scoping/strong_types.h"
66-
#include "vt/collective/reduce/allreduce/state_holder.h"
66+
#include "vt/collective/reduce/allreduce/allreduce_holder.h"
6767
#include "vt/pipe/pipe_manager.h"
6868

6969
#include <utility>
@@ -147,6 +147,10 @@ void ObjGroupManager::destroyCollective(ProxyType<ObjT> proxy) {
147147
if (label_iter != labels_.end()) {
148148
labels_.erase(label_iter);
149149
}
150+
151+
vt::collective::reduce::allreduce::AllreduceHolder::remove(
152+
vt::collective::reduce::detail::StrongObjGroup{proxy.getProxy()}
153+
);
150154
}
151155

152156
template <typename ObjT>

0 commit comments

Comments
 (0)