1
1
import random
2
2
import warnings
3
+ from contextlib import nullcontext
3
4
from typing import Optional
4
5
5
6
import numpy as np
6
7
7
8
from ..base import EstimatorTransformer
8
9
from ._cluster_model import ClusterModel
9
10
from . import metrics
11
+ from ..util .callbacks import ProgressCallback
10
12
11
13
from ..util .parallel import handle_n_jobs
12
14
@@ -173,6 +175,10 @@ class KMeans(EstimatorTransformer):
173
175
initial_centers: None or np.ndarray[k, dim], default=None
174
176
This is used to resume the kmeans iteration. Note, that if this is set, the init_strategy is ignored and
175
177
the centers are directly passed to the kmeans iteration algorithm.
178
+ progress : object
179
+ Progress bar object that `KMeans` will call to indicate progress to the user. Tested for a tqdm progress bar.
180
+ The interface is checked
181
+ via :meth:`supports_progress_interface <deeptime.util.callbacks.supports_progress_interface>`.
176
182
177
183
References
178
184
----------
@@ -186,7 +192,7 @@ class KMeans(EstimatorTransformer):
186
192
187
193
def __init__ (self , n_clusters : int , max_iter : int = 500 , metric = 'euclidean' ,
188
194
tolerance = 1e-5 , init_strategy : str = 'kmeans++' , fixed_seed = False ,
189
- n_jobs = None , initial_centers = None ):
195
+ n_jobs = None , initial_centers = None , progress = None ):
190
196
super (KMeans , self ).__init__ ()
191
197
192
198
self .n_clusters = n_clusters
@@ -198,6 +204,7 @@ def __init__(self, n_clusters: int, max_iter: int = 500, metric='euclidean',
198
204
self .random_state = np .random .RandomState (self .fixed_seed )
199
205
self .n_jobs = handle_n_jobs (n_jobs )
200
206
self .initial_centers = initial_centers
207
+ self .progress = progress
201
208
202
209
@property
203
210
def initial_centers (self ) -> Optional [np .ndarray ]:
@@ -421,14 +428,30 @@ def fit(self, data, initial_centers=None, callback_init_centers=None, callback_l
421
428
if initial_centers is not None :
422
429
self .initial_centers = initial_centers
423
430
if self .initial_centers is None :
424
- self .initial_centers = self ._pick_initial_centers (data , self .init_strategy , n_jobs , callback_init_centers )
431
+ if self .progress is not None :
432
+ callback = KMeansCallback (self .progress , "KMeans++ initialization" , self .n_clusters ,
433
+ callback_init_centers )
434
+ context = callback
435
+ else :
436
+ callback = callback_init_centers
437
+ context = nullcontext ()
438
+ with context :
439
+ self .initial_centers = self ._pick_initial_centers (data , self .init_strategy , n_jobs , callback )
425
440
426
441
# run k-means with all the data
427
442
converged = False
428
443
impl = metrics [self .metric ]
429
- cluster_centers , code , iterations , cost = impl .kmeans .cluster_loop (
430
- data , self .initial_centers .copy (), n_jobs , self .max_iter ,
431
- self .tolerance , callback_loop )
444
+
445
+ if self .progress is not None :
446
+ callback = KMeansCallback (self .progress , "KMeans iterations" , self .max_iter , callback_loop )
447
+ context = callback
448
+ else :
449
+ callback = callback_loop
450
+ context = nullcontext ()
451
+ with context :
452
+ cluster_centers , code , iterations , cost = impl .kmeans .cluster_loop (
453
+ data , self .initial_centers .copy (), n_jobs , self .max_iter ,
454
+ self .tolerance , callback )
432
455
if code == 0 :
433
456
converged = True
434
457
else :
@@ -526,3 +549,15 @@ def partial_fit(self, data, n_jobs=None):
526
549
self ._model ._converged = True
527
550
528
551
return self
552
+
553
+
554
+ class KMeansCallback (ProgressCallback ):
555
+
556
+ def __init__ (self , progress , description , total , parent_callback = None ):
557
+ super ().__init__ (progress , description = description , total = total )
558
+ self ._parent_callback = parent_callback
559
+
560
+ def __call__ (self , * args , ** kw ):
561
+ super ().__call__ (* args , ** kw )
562
+ if self ._parent_callback is not None :
563
+ self ._parent_callback (* args , ** kw )
0 commit comments