Skip to content

Commit 9d9fbea

Browse files
wjsiXuye (Chris) Qin
andauthored
[BACKPORT] Fix accuracy_score for distributed execution (#1945) (#1948)
Co-authored-by: wenjun.swj <[email protected]> Co-authored-by: Xuye (Chris) Qin <[email protected]>
1 parent 2eca131 commit 9d9fbea

File tree

6 files changed

+109
-20
lines changed

6 files changed

+109
-20
lines changed

mars/dataframe/core.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,14 @@ def __str__(self):
512512
def __repr__(self):
513513
return self._to_str(representation=True)
514514

515+
def _to_mars_tensor(self, dtype=None, order='K', extract_multi_index=False):
516+
tensor = self.to_tensor(extract_multi_index=extract_multi_index)
517+
dtype = dtype if dtype is not None else tensor.dtype
518+
return tensor.astype(dtype=dtype, order=order, copy=False)
519+
520+
def __mars_tensor__(self, dtype=None, order='K'):
521+
return self._to_mars_tensor(dtype=dtype, order=order)
522+
515523
@property
516524
def dtype(self):
517525
return getattr(self, '_dtype', None) or self.op.dtype
@@ -553,13 +561,8 @@ def __new__(cls, data: Union[pd.Index, IndexData], **_):
553561
def __len__(self):
554562
return len(self._data)
555563

556-
def _to_mars_tensor(self, dtype=None, order='K', extract_multi_index=False):
557-
tensor = self._data.to_tensor(extract_multi_index=extract_multi_index)
558-
dtype = dtype if dtype is not None else tensor.dtype
559-
return tensor.astype(dtype=dtype, order=order, copy=False)
560-
561564
def __mars_tensor__(self, dtype=None, order='K'):
562-
return self._to_mars_tensor(dtype=dtype, order=order)
565+
return self._data.__mars_tensor__(dtype=dtype, order=order)
563566

564567
def _get_df_or_series(self):
565568
obj = getattr(self, '_df_or_series', None)
@@ -597,6 +600,10 @@ def names(self, value):
597600
else:
598601
self.rename(value, inplace=True)
599602

603+
@property
604+
def values(self):
605+
return self.to_tensor()
606+
600607
def to_frame(self, index: bool = True, name=None):
601608
"""
602609
Create a DataFrame with a column containing the Index.
@@ -663,7 +670,7 @@ def to_frame(self, index: bool = True, name=None):
663670
else:
664671
columns = [name or self.name or 0]
665672
index_ = self if index else None
666-
return dataframe_from_tensor(self._to_mars_tensor(self, extract_multi_index=True),
673+
return dataframe_from_tensor(self._data._to_mars_tensor(self, extract_multi_index=True),
667674
index=index_, columns=columns)
668675

669676
def to_series(self, index=None, name=None):
@@ -889,6 +896,11 @@ def from_tensor(in_tensor, index=None, name=None):
889896
class SeriesData(_BatchedFetcher, BaseSeriesData):
890897
_type_name = 'Series'
891898

899+
def __mars_tensor__(self, dtype=None, order='K'):
900+
tensor = self.to_tensor()
901+
dtype = dtype if dtype is not None else tensor.dtype
902+
return tensor.astype(dtype=dtype, order=order, copy=False)
903+
892904
@classmethod
893905
def cls(cls, provider):
894906
if provider.type == ProviderType.protobuf:
@@ -1002,9 +1014,22 @@ def __len__(self):
10021014
return len(self._data)
10031015

10041016
def __mars_tensor__(self, dtype=None, order='K'):
1005-
tensor = self._data.to_tensor()
1006-
dtype = dtype if dtype is not None else tensor.dtype
1007-
return tensor.astype(dtype=dtype, order=order, copy=False)
1017+
return self._data.__mars_tensor__(dtype=dtype, order=order)
1018+
1019+
def keys(self):
1020+
"""
1021+
Return alias for index.
1022+
1023+
Returns
1024+
-------
1025+
Index
1026+
Index of the Series.
1027+
"""
1028+
return self.index
1029+
1030+
@property
1031+
def values(self):
1032+
return self.to_tensor()
10081033

10091034
def iteritems(self, batch_size=10000, session=None):
10101035
"""
@@ -1288,6 +1313,9 @@ def __str__(self):
12881313
def __repr__(self):
12891314
return self._to_str(representation=True)
12901315

1316+
def __mars_tensor__(self, dtype=None, order='K'):
1317+
return self.to_tensor().astype(dtype=dtype, order=order, copy=False)
1318+
12911319
def _repr_html_(self):
12921320
if len(self._executed_sessions) == 0:
12931321
# not executed before, fall back to normal repr
@@ -1350,7 +1378,7 @@ def from_records(self, records, **kw):
13501378
return self._data.from_records(records, **kw)
13511379

13521380
def __mars_tensor__(self, dtype=None, order='K'):
1353-
return self._data.to_tensor().astype(dtype=dtype, order=order, copy=False)
1381+
return self._data.__mars_tensor__(dtype=dtype, order=order)
13541382

13551383
def __getattr__(self, key):
13561384
try:
@@ -1409,6 +1437,23 @@ def columns(self, new_columns):
14091437
new_df = op(self)
14101438
self.data = new_df.data
14111439

1440+
def keys(self):
1441+
"""
1442+
Get the 'info axis' (see Indexing for more).
1443+
1444+
This is index for Series, columns for DataFrame.
1445+
1446+
Returns
1447+
-------
1448+
Index
1449+
Info axis.
1450+
"""
1451+
return self.columns
1452+
1453+
@property
1454+
def values(self):
1455+
return self.to_tensor()
1456+
14121457
@property
14131458
def dtypes(self):
14141459
"""

mars/dataframe/tests/test_core.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,28 @@ def testToFrameOrSeries(self):
9494
r = index.to_series(name='new_name')
9595
result = self.executor.execute_dataframe(r, concat=True)[0]
9696
pd.testing.assert_series_equal(raw.to_series(name='new_name'), result)
97+
98+
def testKeyValue(self):
99+
raw = pd.DataFrame(np.random.rand(4, 3), columns=list('ABC'))
100+
df = DataFrame(raw)
101+
102+
result = self.executor.execute_dataframe(df.values, concat=True)[0]
103+
np.testing.assert_array_equal(result, raw.values)
104+
105+
result = self.executor.execute_dataframe(df.keys(), concat=True)[0]
106+
pd.testing.assert_index_equal(result, raw.keys())
107+
108+
raw = pd.Series(np.random.rand(10))
109+
s = Series(raw)
110+
111+
result = self.executor.execute_dataframe(s.values, concat=True)[0]
112+
np.testing.assert_array_equal(result, raw.values)
113+
114+
result = self.executor.execute_dataframe(s.keys(), concat=True)[0]
115+
pd.testing.assert_index_equal(result, raw.keys())
116+
117+
raw = pd.Index(np.random.rand(10))
118+
idx = Index(raw)
119+
120+
result = self.executor.execute_dataframe(idx.values, concat=True)[0]
121+
np.testing.assert_array_equal(result, raw.values)

mars/learn/metrics/_check_targets.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from ... import opcodes as OperandDef
1818
from ... import tensor as mt
19-
from ...core import Base, Entity, ExecutableTuple, get_output_types
19+
from ...core import Base, Entity, ExecutableTuple
2020
from ...context import get_context
2121
from ...serialize import AnyField, KeyField
2222
from ...tiles import TilesError
@@ -39,9 +39,7 @@ def __init__(self, y_true=None, y_pred=None, type_true=None, type_pred=None, **k
3939
super().__init__(_y_true=y_true, _y_pred=y_pred,
4040
_type_true=type_true, _type_pred=type_pred, **kw)
4141
# scalar(y_type), y_true, y_pred
42-
self.output_types = \
43-
[OutputType.tensor] + get_output_types(*[y_true, y_pred],
44-
unknown_as=OutputType.tensor)
42+
self.output_types = [OutputType.tensor] * 3
4543

4644
@property
4745
def output_limit(self):
@@ -111,7 +109,7 @@ def tile(cls, op):
111109
type_true, type_pred = ctx.get_chunk_results(
112110
[op.type_true.chunks[0].key,
113111
op.type_pred.chunks[0].key])
114-
except KeyError:
112+
except (KeyError, AttributeError):
115113
raise TilesError('type_true and type_pred '
116114
'needs to be executed first')
117115

mars/learn/metrics/_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def tile(cls, op):
9090
ctx = get_context()
9191
try:
9292
type_true = ctx.get_chunk_results([op.type_true.chunks[0].key])[0]
93-
except KeyError:
93+
except (KeyError, AttributeError):
9494
raise TilesError('type_true needed to be executed first')
9595

9696
y_true, y_pred = op.y_true, op.y_pred

mars/learn/metrics/tests/integrated/test_ranking.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@
2020
import pandas as pd
2121
try:
2222
import sklearn
23-
from sklearn.metrics import roc_curve as sklearn_roc_curve, auc as sklearn_auc
23+
from sklearn.metrics import roc_curve as sklearn_roc_curve, auc as sklearn_auc, \
24+
accuracy_score as sklearn_accuracy_score
2425
except ImportError:
2526
sklearn = None
2627

2728
from mars import dataframe as md
28-
from mars.learn.metrics import roc_curve, auc
29+
from mars.learn.metrics import roc_curve, auc, accuracy_score
2930
from mars.tests.integrated.base import IntegrationTestBase
3031
from mars.session import new_session
3132

@@ -55,3 +56,22 @@ def testRocCurveAuc(self):
5556
pos_label=2)
5657
expect_m = sklearn_auc(sk_fpr, sk_tpr)
5758
self.assertAlmostEqual(m.fetch(session=sess), expect_m)
59+
60+
def testAccuracyScore(self):
61+
service_ep = 'http://127.0.0.1:' + self.web_port
62+
timeout = 120 if 'CI' in os.environ else -1
63+
with new_session(service_ep) as sess:
64+
run_kwargs = {'timeout': timeout}
65+
66+
rs = np.random.RandomState(0)
67+
raw = pd.DataFrame({'a': rs.randint(0, 10, (10,)),
68+
'b': rs.randint(0, 10, (10,))})
69+
70+
df = md.DataFrame(raw)
71+
y = df['a'].to_tensor().astype('int')
72+
pred = df['b'].astype('int')
73+
74+
score = accuracy_score(y, pred, session=sess, run_kwargs=run_kwargs)
75+
expect = sklearn_accuracy_score(raw['a'].to_numpy().astype('int'),
76+
raw['b'].to_numpy().astype('int'))
77+
self.assertAlmostEqual(score.fetch(session=sess), expect)

mars/learn/utils/multiclass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,8 @@ def type_of_target(y):
398398
'multilabel-indicator'
399399
"""
400400
valid_types = (Sequence, spmatrix) if spmatrix is not None else (Sequence,)
401-
valid = ((isinstance(y, valid_types) or hasattr(y, '__array__'))
401+
valid = ((isinstance(y, valid_types) or
402+
hasattr(y, '__array__') or hasattr(y, '__mars_tensor__'))
402403
and not isinstance(y, str))
403404

404405
if not valid:

0 commit comments

Comments
 (0)