|
45 | 45 | #define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_H |
46 | 46 |
|
47 | 47 | #include "vt/collective/reduce/allreduce/data_handler.h" |
48 | | -#include "vt/collective/reduce/allreduce/helpers.h" |
49 | 48 | #include "vt/collective/reduce/allreduce/type.h" |
50 | 49 | #include "vt/collective/reduce/scoping/strong_types.h" |
51 | 50 | #include "vt/configs/types/types_type.h" |
52 | | -#include "vt/configs/debug/debug_print.h" |
53 | 51 | #include "vt/collective/reduce/allreduce/state.h" |
54 | 52 |
|
55 | 53 | #include <memory> |
56 | | -#include <type_traits> |
57 | 54 | #include <unordered_map> |
58 | 55 |
|
59 | 56 | namespace vt::collective::reduce::allreduce { |
60 | 57 |
|
61 | 58 | struct StateHolder { |
| 59 | + using StatesVec = std::vector<std::unique_ptr<StateBase>>; |
| 60 | + |
62 | 61 | template < |
63 | 62 | typename ReducerT, typename DataT, |
64 | 63 | typename Scalar = typename DataHandler<DataT>::Scalar> |
65 | | - static auto& getState(detail::StrongVrtProxy proxy, size_t idx) { |
66 | | - return getStateImpl<ReducerT, DataT>(proxy, active_coll_states_, idx); |
67 | | - } |
| 64 | + static decltype(auto) getState(detail::StrongVrtProxy proxy, size_t idx); |
68 | 65 |
|
69 | 66 | template < |
70 | 67 | typename ReducerT, typename DataT, |
71 | 68 | typename Scalar = typename DataHandler<DataT>::Scalar> |
72 | | - static auto& getState(detail::StrongObjGroup proxy, size_t idx) { |
73 | | - return getStateImpl<ReducerT, DataT>(proxy, active_obj_states_, idx); |
74 | | - } |
| 69 | + static decltype(auto) getState(detail::StrongObjGroup proxy, size_t idx); |
75 | 70 |
|
76 | 71 | template < |
77 | 72 | typename ReducerT, typename DataT, |
78 | 73 | typename Scalar = typename DataHandler<DataT>::Scalar> |
79 | | - static auto& getState(detail::StrongGroup proxy, size_t idx) { |
80 | | - return getStateImpl<ReducerT, DataT>(proxy, active_grp_states_, idx); |
81 | | - } |
82 | | - |
83 | | - template <typename ReducerT> |
84 | | - static size_t getNextID(detail::StrongVrtProxy proxy) { |
85 | | - size_t id = 0; |
86 | | - auto& allreducers = active_coll_states_[proxy.get()]; |
87 | | - |
88 | | - if (not allreducers.empty()) { |
89 | | - // Last element is invalidated (allreduce completed) or not completed |
90 | | - // Generate new ID |
91 | | - if (not allreducers.back() or allreducers.back()->active_) { |
92 | | - id = allreducers.size(); |
93 | | - } |
94 | | - // Most recent state is not active, don't generate new ID |
95 | | - else if (not allreducers.back()->active_) { |
96 | | - id = allreducers.size() - 1; |
97 | | - } |
98 | | - } |
99 | | - |
100 | | - return id; |
101 | | - } |
102 | | - |
103 | | - template <typename ReducerT> |
104 | | - static size_t getNextID(detail::StrongObjGroup proxy) { |
105 | | - size_t id = 0; |
106 | | - auto& allreducers = active_obj_states_[proxy.get()]; |
107 | | - |
108 | | - if (not allreducers.empty()) { |
109 | | - // Last element is invalidated (allreduce completed) or not completed |
110 | | - // Generate new ID |
111 | | - if (not allreducers.back() or allreducers.back()->active_) { |
112 | | - id = allreducers.size(); |
113 | | - } |
114 | | - // Most recent state is not active, don't generate new ID |
115 | | - else if (not allreducers.back()->active_) { |
116 | | - id = allreducers.size() - 1; |
117 | | - } |
118 | | - } |
119 | | - |
120 | | - return id; |
121 | | - } |
122 | | - |
123 | | - static size_t getNextID(detail::StrongGroup group) { |
124 | | - size_t id = 0; |
125 | | - auto& allreducers = active_grp_states_[group.get()]; |
126 | | - |
127 | | - if (not allreducers.empty()) { |
128 | | - // Last element is invalidated (allreduce completed) or not completed |
129 | | - // Generate new ID |
130 | | - if (not allreducers.back() or allreducers.back()->active_) { |
131 | | - id = allreducers.size(); |
132 | | - } |
133 | | - // Most recent state is not active, don't generate new ID |
134 | | - else if (not allreducers.back()->active_) { |
135 | | - id = allreducers.size() - 1; |
136 | | - } |
137 | | - } |
138 | | - |
139 | | - return id; |
140 | | - } |
141 | | - |
142 | | - static void clearSingle(detail::StrongVrtProxy proxy, size_t idx) { |
143 | | - clearSingleImpl(proxy, active_coll_states_, idx); |
144 | | - } |
145 | | - |
146 | | - static void clearSingle(detail::StrongObjGroup proxy, size_t idx) { |
147 | | - clearSingleImpl(proxy, active_obj_states_, idx); |
148 | | - } |
| 74 | + static decltype(auto) getState(detail::StrongGroup proxy, size_t idx); |
149 | 75 |
|
150 | | - static void clearSingle(detail::StrongGroup group, size_t idx) { |
151 | | - clearSingleImpl(group, active_grp_states_, idx); |
152 | | - } |
| 76 | + static size_t getNextID(detail::StrongVrtProxy proxy); |
| 77 | + static size_t getNextID(detail::StrongObjGroup proxy); |
| 78 | + static size_t getNextID(detail::StrongGroup group); |
153 | 79 |
|
154 | | - static void clearAll(detail::StrongVrtProxy proxy) { |
155 | | - // fmt::print("Clearing all states for VrtProxy={:x}\n", proxy.get()); |
156 | | - clearAllImpl(proxy, active_coll_states_); |
157 | | - } |
158 | | - |
159 | | - static void clearAll(detail::StrongObjGroup proxy) { |
160 | | - // fmt::print("Clearing all states for Objgroup={:x}\n", proxy.get()); |
161 | | - clearAllImpl(proxy, active_obj_states_); |
162 | | - } |
| 80 | + static void clearSingle(detail::StrongVrtProxy proxy, size_t idx); |
| 81 | + static void clearSingle(detail::StrongObjGroup proxy, size_t idx); |
| 82 | + static void clearSingle(detail::StrongGroup group, size_t idx); |
163 | 83 |
|
164 | | - static void clearAll(detail::StrongGroup group) { |
165 | | - // fmt::print("Clearing all states for group={:x}\n", group.get()); |
166 | | - clearAllImpl(group, active_grp_states_); |
167 | | - } |
| 84 | + static void clearAll(detail::StrongVrtProxy proxy); |
| 85 | + static void clearAll(detail::StrongObjGroup proxy); |
| 86 | + static void clearAll(detail::StrongGroup group); |
168 | 87 |
|
169 | 88 | private: |
170 | | - template <typename ProxyT, typename MapT> |
171 | | - static void clearSingleImpl(ProxyT proxy, MapT& states_map, size_t idx) { |
172 | | - auto& states = states_map[proxy.get()]; |
173 | | - |
174 | | - auto const num_states = states.size(); |
175 | | - vtAssert( |
176 | | - num_states > idx, |
177 | | - fmt::format( |
178 | | - "Attempting to access state {} with total numer of states {}!", idx, |
179 | | - num_states)); |
| 89 | + static inline size_t collection_idx_ = 0; |
| 90 | + static inline size_t objgroup_idx_ = 0; |
| 91 | + static inline size_t group_idx_ = 0; |
180 | 92 |
|
181 | | - states.at(idx).reset(); |
182 | | - } |
183 | | - |
184 | | - template <typename ProxyT, typename MapT> |
185 | | - static void clearAllImpl(ProxyT proxy, MapT& states_map) { |
186 | | - states_map.erase(proxy.get()); |
187 | | - } |
188 | | - |
189 | | - template < |
190 | | - typename ReduceT, typename DataT, |
191 | | - typename Scalar = typename DataHandler<DataT>::Scalar, typename ProxyT, |
192 | | - typename MapT> |
193 | | - static auto& getStateImpl(ProxyT proxy, MapT& states_map, size_t idx) { |
194 | | - auto& states = states_map[proxy.get()]; |
195 | | - auto const num_states = states.size(); |
196 | | - |
197 | | - vtAssert( |
198 | | - num_states >= idx, |
199 | | - fmt::format( |
200 | | - "Attempting to access state {} with total number of states {}!", idx, |
201 | | - num_states)); |
202 | | - |
203 | | - if (idx >= num_states || num_states == 0) { |
204 | | - if constexpr (std::is_same_v<ReduceT, RabenseifnerT>) { |
205 | | - states.push_back(std::make_unique<RabenseifnerState<Scalar, DataT>>()); |
206 | | - } else { |
207 | | - states.push_back(std::make_unique<RecursiveDoublingState<DataT>>()); |
208 | | - } |
209 | | - } |
210 | | - |
211 | | - vtAssert( |
212 | | - states.at(idx), |
213 | | - fmt::format("Attempting to access invalidated state at idx={}!", idx)); |
214 | | - |
215 | | - if constexpr (std::is_same_v<ReduceT, RabenseifnerT>) { |
216 | | - auto* ptr = |
217 | | - dynamic_cast<RabenseifnerState<Scalar, DataT>*>(states.at(idx).get()); |
218 | | - vtAssert( |
219 | | - ptr, |
220 | | - fmt::format( |
221 | | - "Invalid Rabenseifner cast at idx={} with size={}!", idx, |
222 | | - states.size())); |
223 | | - return *ptr; |
224 | | - } else { |
225 | | - auto* ptr = |
226 | | - dynamic_cast<RecursiveDoublingState<DataT>*>(states.at(idx).get()); |
227 | | - vtAssert( |
228 | | - ptr, |
229 | | - fmt::format( |
230 | | - "Invalid RecursiveDoubling cast at idx={} with size={}!", idx, |
231 | | - states.size())); |
232 | | - return *ptr; |
233 | | - } |
234 | | - } |
235 | | - |
236 | | - static inline std::unordered_map< |
237 | | - VirtualProxyType, std::vector<std::unique_ptr<StateBase>>> |
| 93 | + static inline std::unordered_map<VirtualProxyType, StatesVec> |
238 | 94 | active_coll_states_ = {}; |
239 | 95 |
|
240 | | - static inline std::unordered_map< |
241 | | - ObjGroupProxyType, std::vector<std::unique_ptr<StateBase>>> |
| 96 | + static inline std::unordered_map<ObjGroupProxyType, StatesVec> |
242 | 97 | active_obj_states_ = {}; |
243 | 98 |
|
244 | | - static inline std::unordered_map< |
245 | | - GroupType, std::vector<std::unique_ptr<StateBase>>> |
246 | | - active_grp_states_ = {}; |
| 99 | + static inline std::unordered_map<GroupType, StatesVec> active_grp_states_ = |
| 100 | + {}; |
247 | 101 | }; |
248 | 102 |
|
249 | 103 | template <typename ReducerT, typename DataT> |
@@ -272,4 +126,6 @@ static inline void cleanupState(ComponentInfo info, size_t id) { |
272 | 126 |
|
273 | 127 | } // namespace vt::collective::reduce::allreduce |
274 | 128 |
|
| 129 | +#include "state_holder.impl.h" |
| 130 | + |
275 | 131 | #endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_H*/ |
0 commit comments