Skip to content

Commit 29e2b00

Browse files
author
Xuye (Chris) Qin
authored
[BACKPORT] Add mars.learn.cluster.KMeans support (#1426) (#1428)
1 parent 18981d2 commit 29e2b00

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+4515
-44
lines changed

.codecov.yml

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ ignore:
2828
- "mars/lib/uhashring"
2929
- "mars/serialize/protos"
3030
- "mars/learn/contrib/xgboost/tracker.py"
31+
- "mars/learn/cluster/_k_means_fast.*"
32+
- "mars/learn/cluster/_k_means_elkan.pyx"
33+
- "mars/learn/cluster/_k_means_lloyd.pyx"
34+
- "mars/learn/utils/_cython_blas.*"
3135
- "mars/tensor/einsum/einsumfunc.py"
3236
- "**/*.html"
3337
- "**/*.pxd"

.coveragerc

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ omit =
1919
mars/lib/uhashring/*
2020
mars/serialize/protos/*
2121
mars/learn/contrib/xgboost/tracker.py
22+
mars/learn/cluster/_k_means_fast.*
23+
mars/learn/cluster/_k_means_elkan.pyx
24+
mars/learn/cluster/_k_means_lloyd.pyx
25+
mars/learn/utils/_cython_blas.*
2226
mars/tensor/einsum/einsumfunc.py
2327
*.html
2428
*.pxd

.coveragerc-threaded

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ omit =
2222
mars/lib/uhashring/*
2323
mars/serialize/protos/*
2424
mars/learn/contrib/xgboost/tracker.py
25+
mars/learn/cluster/_k_means_fast.*
26+
mars/learn/cluster/_k_means_elkan.pyx
27+
mars/learn/cluster/_k_means_lloyd.pyx
28+
mars/learn/utils/_cython_blas.*
2529
mars/tensor/einsum/einsumfunc.py
2630
*.html
2731
*.pxd

.github/codecov-upstream.yml

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ ignore:
3030
- "mars/lib/uhashring"
3131
- "mars/serialize/protos"
3232
- "mars/learn/contrib/xgboost/tracker.py"
33+
- "mars/learn/cluster/_k_means_fast.*"
34+
- "mars/learn/cluster/_k_means_elkan.pyx"
35+
- "mars/learn/cluster/_k_means_lloyd.pyx"
36+
- "mars/learn/utils/_cython_blas.*"
3337
- "mars/tensor/einsum/einsumfunc.py"
3438
- "**/*.html"
3539
- "**/*.pxd"

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ jobs:
195195

196196
- name: Collect coverage data
197197
run: |
198-
pip install numpy cython gevent coverage
198+
pip install numpy scipy cython gevent coverage
199199
pip install codecov
200200
CYTHON_TRACE=1 python setup.py sdist
201201
export DEFAULT_VENV=$VIRTUAL_ENV

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ mars/optimizes/**/*.c*
8282
mars/scheduler/*.c*
8383
mars/serialize/*.c*
8484
mars/worker/*.c*
85+
mars/learn/cluster/*.c*
86+
mars/learn/utils/*.c*
8587

8688
# files built from protobuf files
8789
mars/opcodes.py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
mars.learn.cluster.KMeans
2+
=========================
3+
4+
.. currentmodule:: mars.learn.cluster
5+
6+
.. autoclass:: KMeans
7+
8+
9+
.. automethod:: __init__
10+
11+
12+
.. rubric:: Methods
13+
14+
.. autosummary::
15+
16+
~KMeans.__init__
17+
~KMeans.fit
18+
~KMeans.fit_predict
19+
~KMeans.fit_transform
20+
~KMeans.get_params
21+
~KMeans.predict
22+
~KMeans.score
23+
~KMeans.set_params
24+
~KMeans.transform
25+
26+
27+
28+
29+
30+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
mars.learn.cluster.k\_means
2+
===========================
3+
4+
.. currentmodule:: mars.learn.cluster
5+
6+
.. autofunction:: k_means

mars/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import subprocess
1616
import os
1717

18-
version_info = (0, 4, 3)
18+
version_info = (0, 4, 4)
1919
_num_index = max(idx if isinstance(v, int) else 0
2020
for idx, v in enumerate(version_info))
2121
__version__ = '.'.join(map(str, version_info[:_num_index + 1])) + \

mars/executor.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import numpy as np
2929

3030
from .operands import Fetch, ShuffleProxy
31-
from .graph import DAG, DirectedGraph
31+
from .graph import DAG
3232
from .config import options
3333
from .tiles import IterativeChunkGraphBuilder, ChunkGraphBuilder, get_tiled
3434
from .optimizes.runtime.core import RuntimeOptimizer
@@ -712,7 +712,7 @@ def _on_tile_success(before_tile_data, after_tile_data):
712712
chunk_result = self._chunk_result.copy()
713713
tileable_graph_builder = TileableGraphBuilder()
714714
tileable_graph = tileable_graph_builder.build([tileable])
715-
chunk_graph_builder = ChunkGraphBuilder(graph_cls=DirectedGraph, compose=compose,
715+
chunk_graph_builder = ChunkGraphBuilder(compose=compose,
716716
on_tile_success=_on_tile_success)
717717
chunk_graph = chunk_graph_builder.build([tileable], tileable_graph=tileable_graph)
718718
ret = self.execute_graph(chunk_graph, result_keys, n_parallel=n_parallel or n_thread,
@@ -774,6 +774,7 @@ def execute_tileables(self, tileables, fetch=True, n_parallel=None, n_thread=Non
774774
tileable_data_to_chunks = weakref.WeakKeyDictionary()
775775

776776
node_to_fetch = weakref.WeakKeyDictionary()
777+
skipped_tileables = set()
777778

778779
def _generate_fetch_tileable(node):
779780
# Attach chunks to fetch tileables to skip tile.
@@ -786,6 +787,22 @@ def _generate_fetch_tileable(node):
786787

787788
return node
788789

790+
def _skip_executed_tileables(inps):
791+
# skip the input that executed, and not gc collected
792+
new_inps = []
793+
for inp in inps:
794+
if inp.key in self.stored_tileables:
795+
try:
796+
get_tiled(inp)
797+
except KeyError:
798+
new_inps.append(inp)
799+
else:
800+
skipped_tileables.add(inp)
801+
continue
802+
else:
803+
new_inps.append(inp)
804+
return new_inps
805+
789806
def _generate_fetch_if_executed(nd):
790807
# node processor that if the node is executed
791808
# replace it with a fetch node
@@ -830,7 +847,8 @@ def _get_tileable_graph_builder(**kwargs):
830847
with self._gen_local_context(chunk_result):
831848
# build tileable graph
832849
tileable_graph_builder = _get_tileable_graph_builder(
833-
node_processor=_generate_fetch_tileable)
850+
node_processor=_generate_fetch_tileable,
851+
inputs_selector=_skip_executed_tileables)
834852
tileable_graph = tileable_graph_builder.build(tileables)
835853
chunk_graph_builder = IterativeChunkGraphBuilder(
836854
graph_cls=DAG, node_processor=_generate_fetch_if_executed,

mars/learn/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# see https://github.com/pytorch/pytorch/issues/2575
1818
from .contrib import pytorch, tensorflow, xgboost, lightgbm
1919
from .metrics import pairwise
20+
from . import cluster
2021
from . import preprocessing
2122
from . import neighbors
2223
from . import utils

mars/learn/cluster/__init__.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 1999-2020 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
try:
16+
from ._kmeans import KMeans, k_means
17+
18+
def _install():
19+
from ._k_means_common import KMeansInertia, KMeansRelocateEmptyClusters
20+
from ._k_means_elkan_iter import KMeansElkanInitBounds, \
21+
KMeansElkanUpdate, KMeansElkanPostprocess
22+
from ._k_means_init import KMeansPlusPlusInit
23+
from ._k_means_lloyd_iter import KMeansLloydUpdate, KMeansLloydPostprocess
24+
25+
del KMeansInertia, KMeansRelocateEmptyClusters, KMeansElkanInitBounds, \
26+
KMeansElkanUpdate, KMeansElkanPostprocess, KMeansPlusPlusInit, \
27+
KMeansLloydUpdate, KMeansLloydPostprocess
28+
29+
30+
_install()
31+
del _install
32+
except ImportError:
33+
KMeans = None
34+
k_means = None
35+

0 commit comments

Comments
 (0)