@@ -11,14 +11,9 @@ namespace deeptime {
11
11
namespace clustering {
12
12
namespace kmeans {
13
13
14
- template <typename T>
14
+ template <typename Metric, typename T>
15
15
inline std::tuple<np_array<T>, np_array<int >> cluster (const np_array_nfc<T> &np_chunk,
16
- const np_array_nfc<T> &np_centers, int n_threads,
17
- const Metric *metric) {
18
- if (metric == nullptr ) {
19
- metric = default_metric ();
20
- }
21
-
16
+ const np_array_nfc<T> &np_centers, int n_threads) {
22
17
if (np_chunk.ndim () != 2 ) {
23
18
throw std::runtime_error (R"( Number of dimensions of "chunk" ain't 2.)" );
24
19
}
@@ -51,9 +46,9 @@ inline std::tuple<np_array<T>, np_array<int>> cluster(const np_array_nfc<T> &np_
51
46
for (pybind11::ssize_t i = 0 ; i < n_frames; ++i) {
52
47
int argMinDist = 0 ;
53
48
{
54
- T minDist = metric-> compute (&chunk (i, 0 ), ¢ers (0 , 0 ), dim);
49
+ T minDist = Metric:: template compute (&chunk (i, 0 ), ¢ers (0 , 0 ), dim);
55
50
for (std::size_t j = 1 ; j < n_centers; ++j) {
56
- auto dist = metric-> compute (&chunk (i, 0 ), ¢ers (j, 0 ), dim);
51
+ auto dist = Metric:: template compute (&chunk (i, 0 ), ¢ers (j, 0 ), dim);
57
52
if (dist < minDist) {
58
53
minDist = dist;
59
54
argMinDist = j;
@@ -77,7 +72,7 @@ inline std::tuple<np_array<T>, np_array<int>> cluster(const np_array_nfc<T> &np_
77
72
for (pybind11::ssize_t i = 0 ; i < n_frames; ++i) {
78
73
std::vector<T> dists (n_centers);
79
74
for (std::size_t j = 0 ; j < n_centers; ++j) {
80
- dists[j] = metric-> compute (&chunk (i, 0 ), ¢ers (j, 0 ), dim);
75
+ dists[j] = Metric:: template compute (&chunk (i, 0 ), ¢ers (j, 0 ), dim);
81
76
}
82
77
#pragma omp flush(dists)
83
78
@@ -106,9 +101,9 @@ inline std::tuple<np_array<T>, np_array<int>> cluster(const np_array_nfc<T> &np_
106
101
for (auto i = begin; i < end; ++i) {
107
102
std::size_t argMinDist = 0 ;
108
103
{
109
- T minDist = metric-> compute (&chunk (i, 0 ), ¢ers (0 , 0 ), dim);
104
+ T minDist = Metric:: template compute (&chunk (i, 0 ), ¢ers (0 , 0 ), dim);
110
105
for (std::size_t j = 1 ; j < n_centers; ++j) {
111
- auto dist = metric-> compute (&chunk (i, 0 ), ¢ers (j, 0 ), dim);
106
+ auto dist = Metric:: template compute (&chunk (i, 0 ), ¢ers (j, 0 ), dim);
112
107
if (dist < minDist) {
113
108
minDist = dist;
114
109
argMinDist = j;
@@ -151,13 +146,10 @@ inline std::tuple<np_array<T>, np_array<int>> cluster(const np_array_nfc<T> &np_
151
146
return std::make_tuple (newCenters, std::move (assignments));
152
147
}
153
148
154
- template <typename T>
149
+ template <typename Metric, typename T>
155
150
inline std::tuple<np_array_nfc<T>, int , int , np_array<T>> cluster_loop (
156
151
const np_array_nfc<T> &np_chunk, const np_array_nfc<T> &np_centers,
157
- int n_threads, int max_iter, T tolerance, py::object &callback, const Metric *metric) {
158
- if (metric == nullptr ) {
159
- metric = default_metric ();
160
- }
152
+ int n_threads, int max_iter, T tolerance, py::object &callback) {
161
153
int it = 0 ;
162
154
bool converged = false ;
163
155
T rel_change;
@@ -168,10 +160,10 @@ inline std::tuple<np_array_nfc<T>, int, int, np_array<T>> cluster_loop(
168
160
inertias.reserve (max_iter);
169
161
170
162
do {
171
- auto clusterResult = cluster<T >(np_chunk, currentCenters, n_threads, metric );
163
+ auto clusterResult = cluster<Metric >(np_chunk, currentCenters, n_threads);
172
164
currentCenters = std::get<0 >(clusterResult);
173
165
const auto &assignments = std::get<1 >(clusterResult);
174
- auto cost = costFunction (np_chunk, currentCenters, assignments, n_threads, metric );
166
+ auto cost = costFunction<Metric> (np_chunk, currentCenters, assignments, n_threads);
175
167
inertias.push_back (cost);
176
168
rel_change = (cost != 0.0 ) ? std::abs (cost - prev_cost) / cost : 0 ;
177
169
prev_cost = cost;
@@ -193,12 +185,9 @@ inline std::tuple<np_array_nfc<T>, int, int, np_array<T>> cluster_loop(
193
185
return std::make_tuple (currentCenters, res, it, npInertias);
194
186
}
195
187
196
- template <typename T>
188
+ template <typename Metric, typename T>
197
189
inline T costFunction (const np_array_nfc<T> &np_data, const np_array_nfc<T> &np_centers,
198
- const np_array<int > &assignments, int n_threads, const Metric *metric) {
199
- if (metric == nullptr ) {
200
- metric = default_metric ();
201
- }
190
+ const np_array<int > &assignments, int n_threads) {
202
191
auto data = np_data.template unchecked <2 >();
203
192
auto centers = np_centers.template unchecked <2 >();
204
193
@@ -210,9 +199,9 @@ inline T costFunction(const np_array_nfc<T> &np_data, const np_array_nfc<T> &np_
210
199
omp_set_num_threads (n_threads);
211
200
#endif
212
201
213
- #pragma omp parallel for reduction(+:value) default(none) firstprivate(n_frames, metric, data, centers, assignmentsPtr, dim)
202
+ #pragma omp parallel for reduction(+:value) default(none) firstprivate(n_frames, data, centers, assignmentsPtr, dim)
214
203
for (std::size_t i = 0 ; i < n_frames; i++) {
215
- auto l = metric-> compute (&data (i, 0 ), ¢ers (assignmentsPtr[i], 0 ), dim);
204
+ auto l = Metric:: template compute (&data (i, 0 ), ¢ers (assignmentsPtr[i], 0 ), dim);
216
205
{
217
206
value += l * l;
218
207
}
0 commit comments