Skip to content

Commit c13a45a

Browse files
authored
Feature/counter (#732)
* add counter no tests * some refactoring * add tests + fix docu for counter * counter tests * rename test function * fix format * change aggregation default for counter * fix doxygen
1 parent 5366ba8 commit c13a45a

File tree

9 files changed

+769
-296
lines changed

9 files changed

+769
-296
lines changed

include/kamping/measurements/aggregated_tree_node.hpp

Lines changed: 128 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,18 @@
1919
#include "kamping/measurements/measurement_aggregation_definitions.hpp"
2020

2121
namespace kamping::measurements {
22-
/// @brief Class representing a node in the timer tree. Each node represents a time measurement (or multiple with
23-
/// the
24-
/// same key). A node can have multiple children which represent nested time measurements. The measurements
25-
/// associated with a node's children are executed while the node's measurement is still active.
22+
23+
/// @brief Class representing a node in an (globally) aggregated tree, i.e., a node of a timer (or counter) tree
24+
/// where the global aggregation operations has been performed and which can be printed.
2625
///
27-
/// @tparam Duration Type of a duration.
28-
template <typename Duration>
29-
class AggregatedTreeNode : public internal::TreeNode<AggregatedTreeNode<Duration>> {
26+
/// @tparam DataType Underlying data type.
27+
template <typename DataType>
28+
class AggregatedTreeNode : public internal::TreeNode<AggregatedTreeNode<DataType>> {
3029
public:
31-
using internal::TreeNode<AggregatedTreeNode<Duration>>::TreeNode;
30+
using internal::TreeNode<AggregatedTreeNode<DataType>>::TreeNode;
3231

3332
///@brief Type into which the aggregated data is stored together with the applied aggregation operation.
34-
using StorageType = std::unordered_map<GlobalAggregationMode, std::vector<ScalarOrContainer<Duration>>>;
33+
using StorageType = std::unordered_map<GlobalAggregationMode, std::vector<ScalarOrContainer<DataType>>>;
3534

3635
/// @brief Access to stored aggregated data.
3736
/// @return Reference to aggregated data.
@@ -41,9 +40,9 @@ class AggregatedTreeNode : public internal::TreeNode<AggregatedTreeNode<Duration
4140

4241
/// @brief Add scalar of type T to aggregated data storage together with the name of the applied aggregation
4342
/// operation.
44-
/// @param aggregation_mode Aggregation mode that has been applied to the duration data.
43+
/// @param aggregation_mode Aggregation mode that has been applied to the data.
4544
/// @param data Scalar resulted from applying the given aggregation operation.
46-
void add(GlobalAggregationMode aggregation_mode, std::optional<Duration> data) {
45+
void add(GlobalAggregationMode aggregation_mode, std::optional<DataType> data) {
4746
if (data) {
4847
_aggregated_data[aggregation_mode].emplace_back(data.value());
4948
}
@@ -53,11 +52,128 @@ class AggregatedTreeNode : public internal::TreeNode<AggregatedTreeNode<Duration
5352
/// operation.
5453
/// @param aggregation_mode Aggregation mode that has been applied to the duration data.
5554
/// @param data Vector of Scalars resulted from applying the given aggregation operation.
56-
void add(GlobalAggregationMode aggregation_mode, std::vector<Duration> const& data) {
55+
void add(GlobalAggregationMode aggregation_mode, std::vector<DataType> const& data) {
5756
_aggregated_data[aggregation_mode].emplace_back(data);
5857
}
5958

6059
public:
6160
StorageType _aggregated_data; ///< Storage of the aggregated data.
6261
};
62+
63+
/// @brief Class representing an aggregated measurement tree, i.e., a measurement tree for which the global aggregation
64+
/// has been performed.
65+
///
66+
/// @tparam DataType Type of interanlly stored data.
67+
template <typename DataType>
68+
class AggregatedTree {
69+
public:
70+
/// @brief Globally aggregates the measurement tree provided with \param measurement_root_node across all ranks in
71+
/// \param comm .
72+
///
73+
/// @tparam MeasurementNode Type of the measurement tree to aggregate.
74+
/// @tparam Communicator Communicator defining the scope for the global aggregation.
75+
template <typename MeasurementNode, typename Communicator>
76+
AggregatedTree(MeasurementNode const& measurement_root_node, Communicator const& comm) : _root{"root"} {
77+
aggregate(_root, measurement_root_node, comm);
78+
}
79+
80+
/// @brief Access to the root of the aggregated tree.
81+
/// @return Reference to root node of aggregated tree.
82+
auto& root() {
83+
return _root;
84+
}
85+
86+
/// @brief Access to the root of the aggregated tree.
87+
/// @return Reference to root node of aggregated tree.
88+
auto const& root() const {
89+
return _root;
90+
}
91+
92+
private:
93+
AggregatedTreeNode<DataType> _root; ///< Root node of aggregated tree.
94+
/// @brief Traverses and evaluates the given (Measurement)TreeNode and stores the result in the corresponding
95+
/// AggregatedTreeNode
96+
///
97+
/// param aggregation_tree_node Node where the aggregated data points are stored.
98+
/// param measurement_tree_node Node where the raw (not aggregated) data points are stored.
99+
template <typename MeasurementNode, typename Communciator>
100+
void aggregate(
101+
AggregatedTreeNode<DataType>& aggregation_tree_node,
102+
MeasurementNode& measurement_tree_node,
103+
Communciator const& comm
104+
) {
105+
KASSERT(
106+
internal::is_string_same_on_all_ranks(measurement_tree_node.name(), comm),
107+
"Currently processed MeasurementTreeNode has not the same name on all ranks -> measurement trees have "
108+
"diverged",
109+
assert::heavy_communication
110+
);
111+
KASSERT(
112+
comm.is_same_on_all_ranks(measurement_tree_node.measurements().size()),
113+
"Currently processed MeasurementTreeNode has not the same number of measurements on all ranks -> "
114+
"measurement trees have "
115+
"diverged",
116+
assert::light_communication
117+
);
118+
119+
// gather all durations at once as gathering all durations individually may deteriorate
120+
// the performance of the evaluation operation significantly.
121+
auto recv_buf = comm.gatherv(send_buf(measurement_tree_node.measurements()));
122+
auto const num_durations = measurement_tree_node.measurements().size();
123+
for (size_t duration_idx = 0; duration_idx < num_durations; ++duration_idx) {
124+
if (!comm.is_root()) {
125+
continue;
126+
}
127+
std::vector<DataType> cur_durations;
128+
cur_durations.reserve(comm.size());
129+
// gather the durations belonging to the same measurement
130+
for (size_t rank = 0; rank < comm.size(); ++rank) {
131+
cur_durations.push_back(recv_buf[duration_idx + rank * num_durations]);
132+
}
133+
134+
for (auto const& aggregation_mode: measurement_tree_node.measurements_aggregation_operations()) {
135+
aggregate_measurements_globally(aggregation_mode, cur_durations, aggregation_tree_node);
136+
}
137+
}
138+
for (auto& measurement_tree_child: measurement_tree_node.children()) {
139+
auto& aggregation_tree_child = aggregation_tree_node.find_or_insert(measurement_tree_child->name());
140+
aggregate(aggregation_tree_child, *measurement_tree_child.get(), comm);
141+
}
142+
}
143+
144+
/// @brief Computes the specified aggregation operation on an already gathered range of values.
145+
///
146+
/// @param mode Aggregation operation to perform.
147+
/// @param gathered_data Durations gathered from all participating ranks.
148+
/// @param evaluation_node Object where the aggregated and evaluated measurements are stored.
149+
void aggregate_measurements_globally(
150+
GlobalAggregationMode mode,
151+
std::vector<DataType> const& gathered_data,
152+
kamping::measurements::AggregatedTreeNode<DataType>& evaluation_node
153+
) {
154+
switch (mode) {
155+
case GlobalAggregationMode::max: {
156+
using Operation = internal::Max;
157+
evaluation_node.add(mode, Operation::compute(gathered_data));
158+
break;
159+
}
160+
case GlobalAggregationMode::min: {
161+
using Operation = internal::Min;
162+
evaluation_node.add(mode, Operation::compute(gathered_data));
163+
break;
164+
}
165+
case GlobalAggregationMode::sum: {
166+
using Operation = internal::Sum;
167+
evaluation_node.add(mode, Operation::compute(gathered_data));
168+
break;
169+
}
170+
case GlobalAggregationMode::gather: {
171+
using Operation = internal::Gather;
172+
evaluation_node.add(mode, Operation::compute(gathered_data));
173+
break;
174+
}
175+
}
176+
}
177+
};
178+
63179
} // namespace kamping::measurements
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// This file is part of KaMPIng.
2+
//
3+
// Copyright 2023 The KaMPIng Authors
4+
//
5+
// KaMPIng is free software : you can redistribute it and/or modify it under the
6+
// terms of the GNU Lesser General Public License as published by the Free
7+
// Software Foundation, either version 3 of the License, or (at your option) any
8+
// later version. KaMPIng is distributed in the hope that it will be useful, but
9+
// WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
10+
// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
11+
// for more details.
12+
//
13+
// You should have received a copy of the GNU Lesser General Public License
14+
// along with KaMPIng. If not, see <https://www.gnu.org/licenses/>.
15+
16+
/// @file
17+
/// This file contains a (distributed) counter class.
18+
19+
#pragma once
20+
21+
#include <kassert/kassert.hpp>
22+
#include <mpi.h>
23+
24+
#include "kamping/collectives/barrier.hpp"
25+
#include "kamping/collectives/gather.hpp"
26+
#include "kamping/communicator.hpp"
27+
#include "kamping/environment.hpp"
28+
#include "kamping/measurements/aggregated_tree_node.hpp"
29+
#include "kamping/measurements/internal/measurement_utils.hpp"
30+
31+
namespace kamping::measurements {
32+
33+
/// @brief Distributed counter object.
34+
/// @tparam CommunicatorType Communicator in which the measurements are
35+
/// executed.
36+
template <typename CommunicatorType = Communicator<>>
37+
class Counter {
38+
public:
39+
using DataType = std::int64_t; ///< Data type of the stored measurements.
40+
/// @brief Constructs a timer using the \c MPI_COMM_WORLD communicator.
41+
Counter() : _tree{}, _comm{comm_world()} {}
42+
43+
/// @brief Constructs a timer using a given communicator.
44+
///
45+
/// @param comm Communicator in which the time measurements are executed.
46+
Counter(CommunicatorType const& comm) : _tree{}, _comm{comm} {}
47+
48+
/// @brief Creates a measurement entry with name \param name and stores \param data therein. If such an entry
49+
/// already exists with associated data entry `data_prev`, \c data will be added to it, i.e. `data_prev +
50+
/// data`.
51+
/// @param global_aggregation_modi Specify how the measurement entry is aggregated over all participationg PEs when
52+
/// Counter::aggregate() is called.
53+
void
54+
add(std::string const& name,
55+
DataType const& data,
56+
std::vector<GlobalAggregationMode> const& global_aggregation_modi = std::vector<GlobalAggregationMode>{}) {
57+
add_measurement(name, data, LocalAggregationMode::accumulate, global_aggregation_modi);
58+
}
59+
60+
/// @brief Looks for a measurement entry with name \param name and appends \param data to the list of previously
61+
/// stored data. If no such entry exists, a new measurement entry with \c data as first entry will be created. entry
62+
/// `data_prev`, \c data will be added to it, i.e. `data_prev + data`.
63+
/// @param global_aggregation_modi Specify how the measurement entry is aggregated over all participationg PEs when
64+
/// Counter::aggregate() is called.
65+
void append(
66+
std::string const& name,
67+
DataType const& data,
68+
std::vector<GlobalAggregationMode> const& global_aggregation_modi = std::vector<GlobalAggregationMode>{}
69+
) {
70+
add_measurement(name, data, LocalAggregationMode::append, global_aggregation_modi);
71+
}
72+
73+
/// @brief Aggregate the measurement entries globally.
74+
/// @return AggregatedTree object which encapsulates the aggregated data in a tree structure representing the
75+
/// measurements.
76+
auto aggregate() {
77+
AggregatedTree<DataType> aggregated_tree(_tree.root, _comm);
78+
return aggregated_tree;
79+
}
80+
81+
/// @brief Clears all stored measurements.
82+
void clear() {
83+
_tree.reset();
84+
}
85+
86+
/// @brief Aggregates and outputs the the executed measurements. The output is
87+
/// done via the print() method of a given Printer object.
88+
///
89+
/// The print() method must accept an object of type AggregatedTreeNode and
90+
/// receives the root of the evaluated timer tree as parameter. The print()
91+
/// method is only called on the root rank of the communicator. See
92+
/// EvaluationTreeNode for the accessible data. The
93+
/// EvaluationTreeNode::children() member function can be used to navigate the
94+
/// nested measurement structure.
95+
///
96+
/// @tparam Printer Type of printer which is used to output the aggregated
97+
/// timing data. Printer must possess a member print() which accepts a
98+
/// EvaluationTreeNode as parameter.
99+
/// @param printer Printer object used to output the aggregated timing data.
100+
template <typename Printer>
101+
void aggregate_and_print(Printer&& printer) {
102+
auto const aggregated_tree = aggregate();
103+
if (_comm.is_root()) {
104+
printer.print(aggregated_tree.root());
105+
}
106+
}
107+
108+
private:
109+
internal::Tree<internal::CounterTreeNode<DataType>>
110+
_tree; ///< Tree structure in which the counted values are stored. Note that unlike for Timer, the tree is
111+
///< always a star as there is currently no functionality to allow for "nested" counting, e.g. by
112+
///< defining different phase within your algorithm.
113+
CommunicatorType const& _comm; ///< Communicator in which the time measurements take place.
114+
115+
/// @brief Adds a new measurement to the tree
116+
/// @param local_aggregation_mode Specifies how the measurement duration is
117+
/// locally aggregated when there are multiple measurements at the same level
118+
/// with identical key.
119+
/// @param global_aggregation_modi Specifies how the measurement data is
120+
/// aggregated over all participationg ranks when Timer::aggregate() is
121+
/// called.
122+
void add_measurement(
123+
std::string const& name,
124+
DataType const& data,
125+
LocalAggregationMode local_aggregation_mode,
126+
std::vector<GlobalAggregationMode> const& global_aggreation_modi
127+
) {
128+
auto& child = _tree.current_node->find_or_insert(name);
129+
child.aggregate_measurements_locally(data, local_aggregation_mode);
130+
if (!global_aggreation_modi.empty()) {
131+
child.measurements_aggregation_operations() = global_aggreation_modi;
132+
}
133+
}
134+
};
135+
136+
/// @brief A basic Counter that uses kamping::Communicator<> as underlying
137+
/// communicator type.
138+
using BasicCounter = Counter<Communicator<>>;
139+
140+
/// @brief Gets a reference to a kamping::measurements::BasicTimer.
141+
///
142+
/// @return A reference to a kamping::measurements::BasicCounter.
143+
inline Counter<Communicator<>>& counter() {
144+
static Counter<Communicator<>> counter;
145+
return counter;
146+
}
147+
} // namespace kamping::measurements

0 commit comments

Comments
 (0)