|
1 | 1 | #!/usr/bin/env python |
2 | 2 | import warnings |
3 | | -from sklearn.cluster import KMeans, MiniBatchKMeans, AgglomerativeClustering, Birch, FeatureAgglomeration, SpectralClustering |
4 | | -import numpy as np |
5 | 3 | import six |
6 | | -from hdbscan import HDBSCAN |
| 4 | +import numpy as np |
| 5 | +from sklearn.cluster import KMeans, MiniBatchKMeans, AgglomerativeClustering, Birch, FeatureAgglomeration, SpectralClustering |
7 | 6 | from .._shared.helpers import * |
8 | 7 | from .format_data import format_data as formatter |
9 | 8 |
|
| 9 | +# dictionary of models |
| 10 | +models = { |
| 11 | + 'KMeans': KMeans, |
| 12 | + 'MiniBatchKMeans': MiniBatchKMeans, |
| 13 | + 'AgglomerativeClustering': AgglomerativeClustering, |
| 14 | + 'FeatureAgglomeration': FeatureAgglomeration, |
| 15 | + 'Birch': Birch, |
| 16 | + 'SpectralClustering': SpectralClustering, |
| 17 | +} |
| 18 | + |
| 19 | +try: |
| 20 | + from hdbscan import HDBSCAN |
| 21 | + _has_hdbscan = True |
| 22 | + models.update({'HDBSCAN': HDBSCAN}) |
| 23 | +except ImportError: |
| 24 | + _has_hdbscan = False |
| 25 | + |
10 | 26 |
|
11 | 27 | @memoize |
12 | 28 | def cluster(x, cluster='KMeans', n_clusters=3, ndims=None, format_data=True): |
@@ -46,48 +62,39 @@ def cluster(x, cluster='KMeans', n_clusters=3, ndims=None, format_data=True): |
46 | 62 |
|
47 | 63 | """ |
48 | 64 |
|
49 | | - # if cluster is None, just return data |
50 | | - if cluster is None: |
| 65 | + if cluster == None: |
51 | 66 | return x |
52 | | - else: |
53 | | - |
54 | | - if ndims is not None: |
55 | | - warnings.warn('The ndims argument is now deprecated. Ignoring dimensionality reduction step.') |
56 | | - |
57 | | - if format_data: |
58 | | - x = formatter(x, ppca=True) |
59 | | - |
60 | | - # dictionary of models |
61 | | - models = { |
62 | | - 'KMeans' : KMeans, |
63 | | - 'MiniBatchKMeans' : MiniBatchKMeans, |
64 | | - 'AgglomerativeClustering' : AgglomerativeClustering, |
65 | | - 'FeatureAgglomeration' : FeatureAgglomeration, |
66 | | - 'Birch' : Birch, |
67 | | - 'SpectralClustering' : SpectralClustering, |
68 | | - 'HDBSCAN' : HDBSCAN |
69 | | - } |
70 | | - |
71 | | - # if reduce is a string, find the corresponding model |
72 | | - if isinstance(cluster, six.string_types): |
73 | | - model = models[cluster] |
74 | | - if cluster != 'HDBSCAN': |
75 | | - model_params = { |
76 | | - 'n_clusters' : n_clusters |
77 | | - } |
78 | | - else: |
79 | | - model_params = {} |
80 | | - # if its a dict, use custom params |
81 | | - elif type(cluster) is dict: |
82 | | - if isinstance(cluster['model'], six.string_types): |
83 | | - model = models[cluster['model']] |
84 | | - model_params = cluster['params'] |
85 | | - |
86 | | - # initialize model |
87 | | - model = model(**model_params) |
88 | | - |
89 | | - # fit the model |
90 | | - model.fit(np.vstack(x)) |
91 | | - |
92 | | - # return the labels |
93 | | - return list(model.labels_) |
| 67 | + elif (isinstance(cluster, six.string_types) and cluster=='HDBSCAN') or \ |
| 68 | + (isinstance(cluster, dict) and cluster['model']=='HDBSCAN'): |
| 69 | + if not _has_hdbscan: |
| 70 | + raise ImportError('HDBSCAN is not installed. Please install hdbscan>=0.8.11') |
| 71 | + |
| 72 | + if ndims != None: |
| 73 | + warnings.warn('The ndims argument is now deprecated. Ignoring dimensionality reduction step.') |
| 74 | + |
| 75 | + if format_data: |
| 76 | + x = formatter(x, ppca=True) |
| 77 | + |
| 78 | + # if reduce is a string, find the corresponding model |
| 79 | + if isinstance(cluster, six.string_types): |
| 80 | + model = models[cluster] |
| 81 | + if cluster != 'HDBSCAN': |
| 82 | + model_params = { |
| 83 | + 'n_clusters' : n_clusters |
| 84 | + } |
| 85 | + else: |
| 86 | + model_params = {} |
| 87 | + # if its a dict, use custom params |
| 88 | + elif type(cluster) is dict: |
| 89 | + if isinstance(cluster['model'], six.string_types): |
| 90 | + model = models[cluster['model']] |
| 91 | + model_params = cluster['params'] |
| 92 | + |
| 93 | + # initialize model |
| 94 | + model = model(**model_params) |
| 95 | + |
| 96 | + # fit the model |
| 97 | + model.fit(np.vstack(x)) |
| 98 | + |
| 99 | + # return the labels |
| 100 | + return list(model.labels_) |
0 commit comments