Skip to content

Commit 4aba392

Browse files
authored
Update custom metrics handling (#160)
* docs * custom metric example
1 parent 5969b00 commit 4aba392

23 files changed

+334
-269
lines changed

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ else()
1616
add_compile_options(-Wall -Wextra -pedantic -Werror)
1717
endif()
1818

19-
include("${PYBIND_CMAKE_DIR}/pybind11Config.cmake")
19+
find_package(pybind11 REQUIRED)
2020

2121
set(common_includes "${CMAKE_CURRENT_LIST_DIR}/deeptime/src/include")
2222

deeptime/clustering/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,11 @@
4747
:toctree: generated/
4848
:template: class_nomodule.rst
4949
50-
Metric
5150
metrics
5251
MetricRegistry
5352
"""
5453

5554
from ._metric import metrics, MetricRegistry
56-
from ._clustering_bindings import Metric
5755
from ._kmeans import KMeans, MiniBatchKMeans, KMeansModel
5856
from ._regspace import RegularSpace
5957
from ._box import BoxDiscretization, BoxDiscretizationModel

deeptime/clustering/_cluster_model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from deeptime.base import Model, Transformer
33

4-
from . import _clustering_bindings as _bd, metrics
4+
from . import metrics
55
from ..util.parallel import handle_n_jobs
66

77

@@ -117,5 +117,6 @@ def transform(self, data, n_jobs=None) -> np.ndarray:
117117
n_jobs = handle_n_jobs(n_jobs)
118118
if data.ndim == 1:
119119
data = data[..., None]
120-
dtraj = _bd.assign(data, self.cluster_centers, n_jobs, metrics[self.metric]())
120+
impl = metrics[self.metric]
121+
dtraj = impl.assign(data, self.cluster_centers, n_jobs)
121122
return dtraj

deeptime/clustering/_kmeans.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ..base import EstimatorTransformer
88
from ._cluster_model import ClusterModel
9-
from . import _clustering_bindings as _bd, metrics
9+
from . import metrics
1010

1111
from ..util.parallel import handle_n_jobs
1212

@@ -41,9 +41,8 @@ def kmeans_plusplus(data, n_clusters: int, metric: str = 'euclidean', callback=N
4141
.. footbibliography::
4242
"""
4343
n_jobs = handle_n_jobs(n_jobs)
44-
metric = metrics[metric]()
45-
return _bd.kmeans.init_centers_kmpp(data, k=n_clusters, random_seed=seed, n_threads=n_jobs,
46-
callback=callback, metric=metric)
44+
impl = metrics[metric]
45+
return impl.kmeans.init_centers_kmpp(data, k=n_clusters, random_seed=seed, n_threads=n_jobs, callback=callback)
4746

4847

4948
class KMeansModel(ClusterModel):
@@ -132,7 +131,8 @@ def score(self, data: np.ndarray, n_jobs: Optional[int] = None) -> float:
132131
the inertia
133132
"""
134133
n_jobs = handle_n_jobs(n_jobs)
135-
return _bd.kmeans.cost_function(data, self.cluster_centers, n_jobs, metrics[self.metric]())
134+
impl = metrics[self.metric]
135+
return impl.kmeans.cost_function(data, self.cluster_centers, n_jobs)
136136

137137

138138
class KMeans(EstimatorTransformer):
@@ -425,9 +425,10 @@ def fit(self, data, initial_centers=None, callback_init_centers=None, callback_l
425425

426426
# run k-means with all the data
427427
converged = False
428-
cluster_centers, code, iterations, cost = _bd.kmeans.cluster_loop(
428+
impl = metrics[self.metric]
429+
cluster_centers, code, iterations, cost = impl.kmeans.cluster_loop(
429430
data, self.initial_centers.copy(), n_jobs, self.max_iter,
430-
self.tolerance, callback_loop, metrics[self.metric]())
431+
self.tolerance, callback_loop)
431432
if code == 0:
432433
converged = True
433434
else:
@@ -514,9 +515,9 @@ def partial_fit(self, data, n_jobs=None):
514515
tolerance=self.tolerance, inertias=np.array([float('inf')]))
515516
if data.ndim == 1:
516517
data = data[:, np.newaxis]
517-
metric_instance = metrics[self.metric]()
518-
self._model._cluster_centers = _bd.kmeans.cluster(data, self._model.cluster_centers, n_jobs, metric_instance)[0]
519-
cost = _bd.kmeans.cost_function(data, self._model.cluster_centers, n_jobs, metric_instance)
518+
impl = metrics[self.metric]
519+
self._model._cluster_centers = impl.kmeans.cluster(data, self._model.cluster_centers, n_jobs)[0]
520+
cost = impl.kmeans.cost_function(data, self._model.cluster_centers, n_jobs)
520521

521522
rel_change = np.abs(cost - self._model.inertia) / cost if cost != 0.0 else 0.0
522523
self._model._inertias = np.append(self._model._inertias, cost)

deeptime/clustering/_metric.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,53 @@ class MetricRegistry:
88
If a custom metric is implemented, it can be registered through a call to
99
:meth:`register <deeptime.clustering.MetricRegistry.register>`.
1010
11-
Note that the registry should not be instantiated directly but rather be accessed
12-
through :data:`metrics <deeptime.clustering.metrics>`.
11+
.. note::
12+
13+
The registry should not be instantiated directly but rather be accessed
14+
through the :data:`metrics <deeptime.clustering.metrics>` singleton.
15+
16+
17+
.. rubric:: Adding a new metric
18+
19+
A new metric may be added by linking against the deeptime clustering c++ library (directory is provided by
20+
`deeptime.capi_includes(inc_clustering=True)`) and subsequently exposing the clustering algorithms with your custom
21+
metric like
22+
23+
.. code-block:: cpp
24+
25+
#include "register_clustering.h"
26+
27+
PYBIND11_MODULE(_clustering_bindings, m) {
28+
m.doc() = "module containing clustering algorithms.";
29+
auto customModule = m.def_submodule("custom");
30+
deeptime::clustering::registerClusteringImplementation<Custom>(customModule);
31+
}
32+
33+
and registering it with the deeptime library through
34+
35+
.. code-block:: python
36+
37+
import deeptime
38+
import bindings # this is your compiled extension, rename as appropriate
39+
40+
deeptime.clustering.metrics.register("custom", bindings.custom)
1341
"""
1442

1543
def __init__(self):
1644
self._registered = None
17-
self.register("euclidean", _bd.EuclideanMetric)
45+
self.register("euclidean", _bd.euclidean)
1846

19-
def register(self, name: str, clazz):
47+
def register(self, name: str, impl):
2048
r""" Adds a new metric to the registry.
2149
2250
Parameters
2351
----------
2452
name : str
2553
The name of the metric.
26-
clazz : class
27-
Reference to the class of the metric.
54+
impl : module
55+
Reference to the implementation module.
2856
"""
29-
self._mapping[name] = clazz
57+
self._mapping[name] = impl
3058

3159
@property
3260
def available(self) -> Tuple[str]:

deeptime/clustering/_regspace.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44

55
from . import metrics
6-
from ._clustering_bindings import regspace as _regspace_ext
76
from ._cluster_model import ClusterModel
87
from ..base import Estimator
98

@@ -142,14 +141,14 @@ def fetch_model(self) -> ClusterModel:
142141

143142
def partial_fit(self, data, n_jobs=None):
144143
r""" Fits data to an existing model. See :meth:`fit`. """
144+
impl = metrics[self.metric]
145145
n_jobs = self.n_jobs if n_jobs is None else handle_n_jobs(n_jobs)
146146
if data.ndim == 1:
147147
data = data[:, np.newaxis]
148148
try:
149-
metric = metrics[self.metric]()
150-
_regspace_ext.cluster(data, self._clustercenters, self.dmin, self.max_centers, n_jobs, metric)
149+
impl.regspace.cluster(data, self._clustercenters, self.dmin, self.max_centers, n_jobs)
151150
self._converged = True
152-
except _regspace_ext.MaxCentersReachedException:
151+
except impl.regspace.MaxCentersReachedException:
153152
warnings.warn('Maximum number of cluster centers reached.'
154153
' Consider increasing max_centers or choose'
155154
' a larger minimum distance, dmin.')

deeptime/clustering/include/bits/kmeans_bits.h

+15-26
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,9 @@ namespace deeptime {
1111
namespace clustering {
1212
namespace kmeans {
1313

14-
template<typename T>
14+
template<typename Metric, typename T>
1515
inline std::tuple<np_array<T>, np_array<int>> cluster(const np_array_nfc<T> &np_chunk,
16-
const np_array_nfc<T> &np_centers, int n_threads,
17-
const Metric *metric) {
18-
if (metric == nullptr) {
19-
metric = default_metric();
20-
}
21-
16+
const np_array_nfc<T> &np_centers, int n_threads) {
2217
if (np_chunk.ndim() != 2) {
2318
throw std::runtime_error(R"(Number of dimensions of "chunk" ain't 2.)");
2419
}
@@ -51,9 +46,9 @@ inline std::tuple<np_array<T>, np_array<int>> cluster(const np_array_nfc<T> &np_
5146
for (pybind11::ssize_t i = 0; i < n_frames; ++i) {
5247
int argMinDist = 0;
5348
{
54-
T minDist = metric->compute(&chunk(i, 0), &centers(0, 0), dim);
49+
T minDist = Metric::template compute(&chunk(i, 0), &centers(0, 0), dim);
5550
for (std::size_t j = 1; j < n_centers; ++j) {
56-
auto dist = metric->compute(&chunk(i, 0), &centers(j, 0), dim);
51+
auto dist = Metric::template compute(&chunk(i, 0), &centers(j, 0), dim);
5752
if (dist < minDist) {
5853
minDist = dist;
5954
argMinDist = j;
@@ -77,7 +72,7 @@ inline std::tuple<np_array<T>, np_array<int>> cluster(const np_array_nfc<T> &np_
7772
for (pybind11::ssize_t i = 0; i < n_frames; ++i) {
7873
std::vector<T> dists(n_centers);
7974
for (std::size_t j = 0; j < n_centers; ++j) {
80-
dists[j] = metric->compute(&chunk(i, 0), &centers(j, 0), dim);
75+
dists[j] = Metric::template compute(&chunk(i, 0), &centers(j, 0), dim);
8176
}
8277
#pragma omp flush(dists)
8378

@@ -106,9 +101,9 @@ inline std::tuple<np_array<T>, np_array<int>> cluster(const np_array_nfc<T> &np_
106101
for (auto i = begin; i < end; ++i) {
107102
std::size_t argMinDist = 0;
108103
{
109-
T minDist = metric->compute(&chunk(i, 0), &centers(0, 0), dim);
104+
T minDist = Metric::template compute(&chunk(i, 0), &centers(0, 0), dim);
110105
for (std::size_t j = 1; j < n_centers; ++j) {
111-
auto dist = metric->compute(&chunk(i, 0), &centers(j, 0), dim);
106+
auto dist = Metric::template compute(&chunk(i, 0), &centers(j, 0), dim);
112107
if(dist < minDist) {
113108
minDist = dist;
114109
argMinDist = j;
@@ -151,13 +146,10 @@ inline std::tuple<np_array<T>, np_array<int>> cluster(const np_array_nfc<T> &np_
151146
return std::make_tuple(newCenters, std::move(assignments));
152147
}
153148

154-
template<typename T>
149+
template<typename Metric, typename T>
155150
inline std::tuple<np_array_nfc<T>, int, int, np_array<T>> cluster_loop(
156151
const np_array_nfc<T> &np_chunk, const np_array_nfc<T> &np_centers,
157-
int n_threads, int max_iter, T tolerance, py::object &callback, const Metric *metric) {
158-
if (metric == nullptr) {
159-
metric = default_metric();
160-
}
152+
int n_threads, int max_iter, T tolerance, py::object &callback) {
161153
int it = 0;
162154
bool converged = false;
163155
T rel_change;
@@ -168,10 +160,10 @@ inline std::tuple<np_array_nfc<T>, int, int, np_array<T>> cluster_loop(
168160
inertias.reserve(max_iter);
169161

170162
do {
171-
auto clusterResult = cluster<T>(np_chunk, currentCenters, n_threads, metric);
163+
auto clusterResult = cluster<Metric>(np_chunk, currentCenters, n_threads);
172164
currentCenters = std::get<0>(clusterResult);
173165
const auto &assignments = std::get<1>(clusterResult);
174-
auto cost = costFunction(np_chunk, currentCenters, assignments, n_threads, metric);
166+
auto cost = costFunction<Metric>(np_chunk, currentCenters, assignments, n_threads);
175167
inertias.push_back(cost);
176168
rel_change = (cost != 0.0) ? std::abs(cost - prev_cost) / cost : 0;
177169
prev_cost = cost;
@@ -193,12 +185,9 @@ inline std::tuple<np_array_nfc<T>, int, int, np_array<T>> cluster_loop(
193185
return std::make_tuple(currentCenters, res, it, npInertias);
194186
}
195187

196-
template<typename T>
188+
template<typename Metric, typename T>
197189
inline T costFunction(const np_array_nfc<T> &np_data, const np_array_nfc<T> &np_centers,
198-
const np_array<int> &assignments, int n_threads, const Metric *metric) {
199-
if(metric == nullptr) {
200-
metric = default_metric();
201-
}
190+
const np_array<int> &assignments, int n_threads) {
202191
auto data = np_data.template unchecked<2>();
203192
auto centers = np_centers.template unchecked<2>();
204193

@@ -210,9 +199,9 @@ inline T costFunction(const np_array_nfc<T> &np_data, const np_array_nfc<T> &np_
210199
omp_set_num_threads(n_threads);
211200
#endif
212201

213-
#pragma omp parallel for reduction(+:value) default(none) firstprivate(n_frames, metric, data, centers, assignmentsPtr, dim)
202+
#pragma omp parallel for reduction(+:value) default(none) firstprivate(n_frames, data, centers, assignmentsPtr, dim)
214203
for (std::size_t i = 0; i < n_frames; i++) {
215-
auto l = metric->compute(&data(i, 0), &centers(assignmentsPtr[i], 0), dim);
204+
auto l = Metric::template compute(&data(i, 0), &centers(assignmentsPtr[i], 0), dim);
216205
{
217206
value += l * l;
218207
}

deeptime/clustering/include/bits/metric_base_bits.h

+4-18
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,10 @@
1010
#include <omp.h>
1111
#endif
1212

13-
template<>
14-
inline float Metric::compute_squared<float>(const float* xs, const float* ys, std::size_t dim) const {
15-
return compute_squared_f(xs, ys, dim);
16-
}
17-
18-
template<>
19-
inline double Metric::compute_squared<double>(const double* xs, const double* ys, std::size_t dim) const {
20-
return compute_squared_d(xs, ys, dim);
21-
}
22-
23-
template<typename T>
13+
template<typename Metric, typename T>
2414
inline py::array_t<int> assign_chunk_to_centers(const np_array_nfc<T>& chunk,
2515
const np_array_nfc<T>& centers,
26-
int n_threads,
27-
const Metric* metric) {
28-
if (metric == nullptr) {
29-
metric = default_metric();
30-
}
16+
int n_threads) {
3117
if (chunk.ndim() != 2) {
3218
throw std::invalid_argument("provided chunk does not have two dimensions.");
3319
}
@@ -57,12 +43,12 @@ inline py::array_t<int> assign_chunk_to_centers(const np_array_nfc<T>& chunk,
5743
omp_set_num_threads(n_threads);
5844
#endif
5945

60-
#pragma omp parallel default(none) firstprivate(N_frames, N_centers, centers_buff, input_dim, metric, chunk_buff, dtraj_buff, dists)
46+
#pragma omp parallel default(none) firstprivate(N_frames, N_centers, centers_buff, input_dim, chunk_buff, dtraj_buff, dists)
6147
{
6248
#pragma omp for
6349
for(size_t i = 0; i < N_frames; ++i) {
6450
for(size_t j = 0; j < N_centers; ++j) {
65-
dists[j] = metric->compute(&chunk_buff(i, 0), &centers_buff(j, 0), input_dim);
51+
dists[j] = Metric::template compute<T>(&chunk_buff(i, 0), &centers_buff(j, 0), input_dim);
6652
}
6753

6854
{

0 commit comments

Comments
 (0)