Skip to content

Commit ef9e228

Browse files
author
Xuye (Chris) Qin
authored
Implements mars.learn.wrappers.ParallelPostFit (#2425)
1 parent 0e11714 commit ef9e228

File tree

10 files changed

+556
-3
lines changed

10 files changed

+556
-3
lines changed

docs/source/reference/learn/reference.rst

+16
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,22 @@ Utilities
260260
utils.validation.check_is_fitted
261261
utils.validation.column_or_1d
262262

263+
.. _learn_misc_ref:
264+
265+
Misc
266+
====
267+
268+
.. automodule:: mars.learn.wrappers
269+
:no-members:
270+
:no-inherited-members:
271+
272+
.. currentmodule:: mars.learn
273+
274+
.. autosummary::
275+
:toctree: generated/
276+
277+
wrappers.ParallelPostFit
278+
263279
.. _lightgbm_ref:
264280

265281
LightGBM Integration

mars/core/entity/utils.py

-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def recursive_tile(tileable: TileableType, *tileables: TileableType) -> \
6464
tileable = raw[0]
6565
tileables = raw[1:]
6666

67-
inputs_set = set(tileable.op.inputs)
6867
to_tile = [tileable] + list(tileables)
6968
q = [t for t in to_tile if t.is_coarse()]
7069
while q:
@@ -79,7 +78,6 @@ def recursive_tile(tileable: TileableType, *tileables: TileableType) -> \
7978
for inp in t.op.inputs:
8079
if has_unknown_shape(inp):
8180
to_update_inputs.append(inp)
82-
if inp not in inputs_set:
8381
chunks.extend(inp.chunks)
8482
if obj is None:
8583
yield chunks + to_update_inputs

mars/learn/metrics/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
from ._classification import accuracy_score, log_loss
1717
from ._ranking import roc_curve, auc
1818
from ._regresssion import r2_score
19+
from ._scorer import get_scorer

mars/learn/metrics/_scorer.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Callable, Union
16+
17+
from sklearn.metrics import make_scorer
18+
19+
from . import accuracy_score, log_loss, r2_score
20+
21+
22+
accuracy_score = make_scorer(accuracy_score)
23+
r2_score = make_scorer(r2_score)
24+
neg_log_loss_scorer = make_scorer(log_loss, greater_is_better=False,
25+
needs_proba=True)
26+
27+
28+
SCORERS = dict(
29+
r2=r2_score,
30+
accuracy=accuracy_score,
31+
neg_log_loss=neg_log_loss_scorer,
32+
)
33+
34+
35+
def get_scorer(score_func: Union[str, Callable], **kwargs) -> Callable:
36+
"""
37+
Get a scorer from string
38+
39+
Parameters
40+
----------
41+
score_func : str | callable
42+
scoring method as string. If callable it is returned as is.
43+
44+
Returns
45+
-------
46+
scorer : callable
47+
The scorer.
48+
"""
49+
if isinstance(score_func, str):
50+
try:
51+
scorer = SCORERS[score_func]
52+
except KeyError:
53+
raise ValueError(
54+
"{} is not a valid scoring value. "
55+
"Valid options are {}".format(score_func, sorted(SCORERS))
56+
)
57+
return scorer
58+
else:
59+
return make_scorer(score_func, **kwargs)
+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright 1999-2020 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from sklearn.metrics import r2_score
17+
18+
from .. import get_scorer
19+
20+
21+
def test_get_scorer():
22+
with pytest.raises(ValueError):
23+
get_scorer('unknown')
24+
25+
assert get_scorer('r2') is not None
26+
assert get_scorer(r2_score) is not None

mars/learn/tests/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

mars/learn/tests/test_wrappers.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pytest
17+
from sklearn.datasets import make_classification
18+
from sklearn.decomposition import PCA
19+
from sklearn.ensemble import GradientBoostingClassifier
20+
from sklearn.linear_model import LinearRegression, LogisticRegression
21+
22+
from ... import tensor as mt
23+
from ..wrappers import ParallelPostFit
24+
25+
26+
raw_x, raw_y = make_classification(n_samples=1000)
27+
X, y = mt.tensor(raw_x, chunk_size=100), mt.tensor(raw_y, chunk_size=100)
28+
29+
30+
def test_parallel_post_fit_basic(setup):
31+
clf = ParallelPostFit(GradientBoostingClassifier())
32+
clf.fit(X, y)
33+
34+
assert isinstance(clf.predict(X), mt.Tensor)
35+
assert isinstance(clf.predict_proba(X), mt.Tensor)
36+
37+
result = clf.score(X, y)
38+
expected = clf.estimator.score(X, y)
39+
assert result.fetch() == expected
40+
41+
clf = ParallelPostFit(LinearRegression())
42+
clf.fit(X, y)
43+
with pytest.raises(AttributeError,
44+
match="The wrapped estimator (.|\n)* 'predict_proba' method."):
45+
clf.predict_proba(X)
46+
47+
48+
def test_parallel_post_fit_predict(setup):
49+
base = LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs")
50+
wrap = ParallelPostFit(LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs"))
51+
52+
base.fit(X, y)
53+
wrap.fit(X, y)
54+
55+
result = wrap.predict(X)
56+
expected = base.predict(X)
57+
np.testing.assert_allclose(result, expected)
58+
59+
result = wrap.predict_proba(X)
60+
expected = base.predict_proba(X)
61+
np.testing.assert_allclose(result, expected)
62+
63+
result = wrap.predict_log_proba(X)
64+
expected = base.predict_log_proba(X)
65+
np.testing.assert_allclose(result, expected)
66+
67+
68+
def test_parallel_post_fit_transform(setup):
69+
base = PCA(random_state=0)
70+
wrap = ParallelPostFit(PCA(random_state=0))
71+
72+
base.fit(raw_x, raw_y)
73+
wrap.fit(X, y)
74+
75+
result = base.transform(X)
76+
expected = wrap.transform(X)
77+
np.testing.assert_allclose(result, expected, atol=.1)
78+
79+
80+
def test_parallel_post_fit_multiclass(setup):
81+
raw_x, raw_y = make_classification(n_classes=3, n_informative=4)
82+
X, y = mt.tensor(raw_x, chunk_size=50), mt.tensor(raw_y, chunk_size=50)
83+
84+
clf = ParallelPostFit(
85+
LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs", multi_class="auto")
86+
)
87+
88+
clf.fit(X, y)
89+
result = clf.predict(X)
90+
expected = clf.estimator.predict(X)
91+
92+
np.testing.assert_allclose(result, expected)
93+
94+
result = clf.predict_proba(X)
95+
expected = clf.estimator.predict_proba(X)
96+
97+
np.testing.assert_allclose(result, expected)
98+
99+
result = clf.predict_log_proba(X)
100+
expected = clf.estimator.predict_log_proba(X)
101+
102+
np.testing.assert_allclose(result, expected)

mars/learn/utils/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from .collect_ports import collect_ports
1919
from .core import convert_to_tensor_or_dataframe, \
20-
concat_chunks
20+
concat_chunks, copy_learned_attributes
2121
from .validation import check_array, assert_all_finite, \
2222
check_consistent_length, column_or_1d, check_X_y
2323
from .shuffle import shuffle

mars/learn/utils/core.py

+8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import pandas as pd
16+
from sklearn.base import BaseEstimator
1617

1718
from ...tensor import tensor as astensor
1819
from ...dataframe import DataFrame, Series
@@ -32,3 +33,10 @@ def convert_to_tensor_or_dataframe(item):
3233
def concat_chunks(chunks):
3334
tileable = chunks[0].op.create_tileable_from_chunks(chunks)
3435
return tileable.op.concat_tileable_chunks(tileable).chunks[0]
36+
37+
38+
def copy_learned_attributes(from_estimator: BaseEstimator,
39+
to_estimator: BaseEstimator):
40+
attrs = {k: v for k, v in vars(from_estimator).items() if k.endswith('_')}
41+
for k, v in attrs.items():
42+
setattr(to_estimator, k, v)

0 commit comments

Comments
 (0)