Skip to content

Commit ed1f880

Browse files
committed
#2281: Cleanup StateHolder (move implementation to .cc impl.h files)
1 parent 33c5e98 commit ed1f880

File tree

3 files changed

+272
-167
lines changed

3 files changed

+272
-167
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
//@HEADER
3+
// *****************************************************************************
4+
//
5+
// state_holder.cc
6+
// DARMA/vt => Virtual Transport
7+
//
8+
// Copyright 2019-2024 National Technology & Engineering Solutions of Sandia, LLC
9+
// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
10+
// Government retains certain rights in this software.
11+
//
12+
// Redistribution and use in source and binary forms, with or without
13+
// modification, are permitted provided that the following conditions are met:
14+
//
15+
// * Redistributions of source code must retain the above copyright notice,
16+
// this list of conditions and the following disclaimer.
17+
//
18+
// * Redistributions in binary form must reproduce the above copyright notice,
19+
// this list of conditions and the following disclaimer in the documentation
20+
// and/or other materials provided with the distribution.
21+
//
22+
// * Neither the name of the copyright holder nor the names of its
23+
// contributors may be used to endorse or promote products derived from this
24+
// software without specific prior written permission.
25+
//
26+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
27+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
29+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
30+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
31+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
32+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
33+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
34+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
35+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
36+
// POSSIBILITY OF SUCH DAMAGE.
37+
//
38+
// Questions? Contact [email protected]
39+
//
40+
// *****************************************************************************
41+
//@HEADER
42+
*/
43+
44+
#include "vt/config.h"
45+
#include "state_holder.h"
46+
#include "vt/configs/error/hard_error.h"
47+
#include "vt/configs/error/config_assert.h"
48+
49+
namespace vt::collective::reduce::allreduce {
50+
51+
size_t
52+
getNextIdImpl(StateHolder::StatesVec& states, size_t idx) {
53+
size_t id = u64empty;
54+
55+
for (; idx < states.size(); ++idx) {
56+
auto& state = states.at(idx);
57+
if (not state or not state->active_) {
58+
id = idx;
59+
break;
60+
}
61+
}
62+
63+
if (id == u64empty) {
64+
id = states.size();
65+
}
66+
67+
return id;
68+
}
69+
70+
size_t StateHolder::getNextID(detail::StrongVrtProxy proxy) {
71+
auto& states = active_coll_states_[proxy.get()];
72+
73+
collection_idx_ = getNextIdImpl(states, collection_idx_);
74+
return collection_idx_;
75+
}
76+
77+
size_t StateHolder::getNextID(detail::StrongObjGroup proxy) {
78+
auto& states = active_obj_states_[proxy.get()];
79+
80+
objgroup_idx_ = getNextIdImpl(states, objgroup_idx_);
81+
return objgroup_idx_;
82+
}
83+
84+
size_t StateHolder::getNextID(detail::StrongGroup group) {
85+
auto& states = active_grp_states_[group.get()];
86+
87+
group_idx_ = getNextIdImpl(states, group_idx_);
88+
return group_idx_;
89+
}
90+
91+
static inline void
92+
clearSingleImpl(StateHolder::StatesVec& states, size_t idx) {
93+
auto const num_states = states.size();
94+
vtAssert(
95+
num_states > idx,
96+
fmt::format(
97+
"Attempting to access state {} with total numer of states {}!", idx,
98+
num_states));
99+
100+
states.at(idx).reset();
101+
}
102+
103+
void StateHolder::clearSingle(detail::StrongVrtProxy proxy, size_t idx) {
104+
auto& states = active_coll_states_[proxy.get()];
105+
106+
clearSingleImpl(states, idx);
107+
}
108+
109+
void StateHolder::clearSingle(detail::StrongObjGroup proxy, size_t idx) {
110+
auto& states = active_obj_states_[proxy.get()];
111+
112+
clearSingleImpl(states, idx);
113+
}
114+
115+
void StateHolder::clearSingle(detail::StrongGroup group, size_t idx) {
116+
auto& states = active_grp_states_[group.get()];
117+
118+
clearSingleImpl(states, idx);
119+
}
120+
121+
void StateHolder::clearAll(detail::StrongVrtProxy proxy) {
122+
active_coll_states_.erase(proxy.get());
123+
}
124+
125+
void StateHolder::clearAll(detail::StrongObjGroup proxy) {
126+
active_obj_states_.erase(proxy.get());
127+
}
128+
129+
void StateHolder::clearAll(detail::StrongGroup group) {
130+
active_grp_states_.erase(group.get());
131+
}
132+
133+
} // namespace vt::collective::reduce::allreduce

src/vt/collective/reduce/allreduce/state_holder.h

Lines changed: 23 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -45,205 +45,59 @@
4545
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_H
4646

4747
#include "vt/collective/reduce/allreduce/data_handler.h"
48-
#include "vt/collective/reduce/allreduce/helpers.h"
4948
#include "vt/collective/reduce/allreduce/type.h"
5049
#include "vt/collective/reduce/scoping/strong_types.h"
5150
#include "vt/configs/types/types_type.h"
52-
#include "vt/configs/debug/debug_print.h"
5351
#include "vt/collective/reduce/allreduce/state.h"
5452

5553
#include <memory>
56-
#include <type_traits>
5754
#include <unordered_map>
5855

5956
namespace vt::collective::reduce::allreduce {
6057

6158
struct StateHolder {
59+
using StatesVec = std::vector<std::unique_ptr<StateBase>>;
60+
6261
template <
6362
typename ReducerT, typename DataT,
6463
typename Scalar = typename DataHandler<DataT>::Scalar>
65-
static auto& getState(detail::StrongVrtProxy proxy, size_t idx) {
66-
return getStateImpl<ReducerT, DataT>(proxy, active_coll_states_, idx);
67-
}
64+
static decltype(auto) getState(detail::StrongVrtProxy proxy, size_t idx);
6865

6966
template <
7067
typename ReducerT, typename DataT,
7168
typename Scalar = typename DataHandler<DataT>::Scalar>
72-
static auto& getState(detail::StrongObjGroup proxy, size_t idx) {
73-
return getStateImpl<ReducerT, DataT>(proxy, active_obj_states_, idx);
74-
}
69+
static decltype(auto) getState(detail::StrongObjGroup proxy, size_t idx);
7570

7671
template <
7772
typename ReducerT, typename DataT,
7873
typename Scalar = typename DataHandler<DataT>::Scalar>
79-
static auto& getState(detail::StrongGroup proxy, size_t idx) {
80-
return getStateImpl<ReducerT, DataT>(proxy, active_grp_states_, idx);
81-
}
82-
83-
template <typename ReducerT>
84-
static size_t getNextID(detail::StrongVrtProxy proxy) {
85-
size_t id = 0;
86-
auto& allreducers = active_coll_states_[proxy.get()];
87-
88-
if (not allreducers.empty()) {
89-
// Last element is invalidated (allreduce completed) or not completed
90-
// Generate new ID
91-
if (not allreducers.back() or allreducers.back()->active_) {
92-
id = allreducers.size();
93-
}
94-
// Most recent state is not active, don't generate new ID
95-
else if (not allreducers.back()->active_) {
96-
id = allreducers.size() - 1;
97-
}
98-
}
99-
100-
return id;
101-
}
102-
103-
template <typename ReducerT>
104-
static size_t getNextID(detail::StrongObjGroup proxy) {
105-
size_t id = 0;
106-
auto& allreducers = active_obj_states_[proxy.get()];
107-
108-
if (not allreducers.empty()) {
109-
// Last element is invalidated (allreduce completed) or not completed
110-
// Generate new ID
111-
if (not allreducers.back() or allreducers.back()->active_) {
112-
id = allreducers.size();
113-
}
114-
// Most recent state is not active, don't generate new ID
115-
else if (not allreducers.back()->active_) {
116-
id = allreducers.size() - 1;
117-
}
118-
}
119-
120-
return id;
121-
}
122-
123-
static size_t getNextID(detail::StrongGroup group) {
124-
size_t id = 0;
125-
auto& allreducers = active_grp_states_[group.get()];
126-
127-
if (not allreducers.empty()) {
128-
// Last element is invalidated (allreduce completed) or not completed
129-
// Generate new ID
130-
if (not allreducers.back() or allreducers.back()->active_) {
131-
id = allreducers.size();
132-
}
133-
// Most recent state is not active, don't generate new ID
134-
else if (not allreducers.back()->active_) {
135-
id = allreducers.size() - 1;
136-
}
137-
}
138-
139-
return id;
140-
}
141-
142-
static void clearSingle(detail::StrongVrtProxy proxy, size_t idx) {
143-
clearSingleImpl(proxy, active_coll_states_, idx);
144-
}
145-
146-
static void clearSingle(detail::StrongObjGroup proxy, size_t idx) {
147-
clearSingleImpl(proxy, active_obj_states_, idx);
148-
}
74+
static decltype(auto) getState(detail::StrongGroup proxy, size_t idx);
14975

150-
static void clearSingle(detail::StrongGroup group, size_t idx) {
151-
clearSingleImpl(group, active_grp_states_, idx);
152-
}
76+
static size_t getNextID(detail::StrongVrtProxy proxy);
77+
static size_t getNextID(detail::StrongObjGroup proxy);
78+
static size_t getNextID(detail::StrongGroup group);
15379

154-
static void clearAll(detail::StrongVrtProxy proxy) {
155-
// fmt::print("Clearing all states for VrtProxy={:x}\n", proxy.get());
156-
clearAllImpl(proxy, active_coll_states_);
157-
}
158-
159-
static void clearAll(detail::StrongObjGroup proxy) {
160-
// fmt::print("Clearing all states for Objgroup={:x}\n", proxy.get());
161-
clearAllImpl(proxy, active_obj_states_);
162-
}
80+
static void clearSingle(detail::StrongVrtProxy proxy, size_t idx);
81+
static void clearSingle(detail::StrongObjGroup proxy, size_t idx);
82+
static void clearSingle(detail::StrongGroup group, size_t idx);
16383

164-
static void clearAll(detail::StrongGroup group) {
165-
// fmt::print("Clearing all states for group={:x}\n", group.get());
166-
clearAllImpl(group, active_grp_states_);
167-
}
84+
static void clearAll(detail::StrongVrtProxy proxy);
85+
static void clearAll(detail::StrongObjGroup proxy);
86+
static void clearAll(detail::StrongGroup group);
16887

16988
private:
170-
template <typename ProxyT, typename MapT>
171-
static void clearSingleImpl(ProxyT proxy, MapT& states_map, size_t idx) {
172-
auto& states = states_map[proxy.get()];
173-
174-
auto const num_states = states.size();
175-
vtAssert(
176-
num_states > idx,
177-
fmt::format(
178-
"Attempting to access state {} with total numer of states {}!", idx,
179-
num_states));
89+
static inline size_t collection_idx_ = 0;
90+
static inline size_t objgroup_idx_ = 0;
91+
static inline size_t group_idx_ = 0;
18092

181-
states.at(idx).reset();
182-
}
183-
184-
template <typename ProxyT, typename MapT>
185-
static void clearAllImpl(ProxyT proxy, MapT& states_map) {
186-
states_map.erase(proxy.get());
187-
}
188-
189-
template <
190-
typename ReduceT, typename DataT,
191-
typename Scalar = typename DataHandler<DataT>::Scalar, typename ProxyT,
192-
typename MapT>
193-
static auto& getStateImpl(ProxyT proxy, MapT& states_map, size_t idx) {
194-
auto& states = states_map[proxy.get()];
195-
auto const num_states = states.size();
196-
197-
vtAssert(
198-
num_states >= idx,
199-
fmt::format(
200-
"Attempting to access state {} with total number of states {}!", idx,
201-
num_states));
202-
203-
if (idx >= num_states || num_states == 0) {
204-
if constexpr (std::is_same_v<ReduceT, RabenseifnerT>) {
205-
states.push_back(std::make_unique<RabenseifnerState<Scalar, DataT>>());
206-
} else {
207-
states.push_back(std::make_unique<RecursiveDoublingState<DataT>>());
208-
}
209-
}
210-
211-
vtAssert(
212-
states.at(idx),
213-
fmt::format("Attempting to access invalidated state at idx={}!", idx));
214-
215-
if constexpr (std::is_same_v<ReduceT, RabenseifnerT>) {
216-
auto* ptr =
217-
dynamic_cast<RabenseifnerState<Scalar, DataT>*>(states.at(idx).get());
218-
vtAssert(
219-
ptr,
220-
fmt::format(
221-
"Invalid Rabenseifner cast at idx={} with size={}!", idx,
222-
states.size()));
223-
return *ptr;
224-
} else {
225-
auto* ptr =
226-
dynamic_cast<RecursiveDoublingState<DataT>*>(states.at(idx).get());
227-
vtAssert(
228-
ptr,
229-
fmt::format(
230-
"Invalid RecursiveDoubling cast at idx={} with size={}!", idx,
231-
states.size()));
232-
return *ptr;
233-
}
234-
}
235-
236-
static inline std::unordered_map<
237-
VirtualProxyType, std::vector<std::unique_ptr<StateBase>>>
93+
static inline std::unordered_map<VirtualProxyType, StatesVec>
23894
active_coll_states_ = {};
23995

240-
static inline std::unordered_map<
241-
ObjGroupProxyType, std::vector<std::unique_ptr<StateBase>>>
96+
static inline std::unordered_map<ObjGroupProxyType, StatesVec>
24297
active_obj_states_ = {};
24398

244-
static inline std::unordered_map<
245-
GroupType, std::vector<std::unique_ptr<StateBase>>>
246-
active_grp_states_ = {};
99+
static inline std::unordered_map<GroupType, StatesVec> active_grp_states_ =
100+
{};
247101
};
248102

249103
template <typename ReducerT, typename DataT>
@@ -272,4 +126,6 @@ static inline void cleanupState(ComponentInfo info, size_t id) {
272126

273127
} // namespace vt::collective::reduce::allreduce
274128

129+
#include "state_holder.impl.h"
130+
275131
#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_H*/

0 commit comments

Comments
 (0)