Skip to content

Commit 4a602b5

Browse files
Xuye (Chris) Qinwjsi
Xuye (Chris) Qin
authored andcommitted
Add pairwise distances support for learn (#926)
1 parent 8f75cb2 commit 4a602b5

34 files changed

+2186
-16
lines changed

mars/dataframe/reduction/core.py

-7
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ class DataFrameReductionOperand(DataFrameOperand):
3333
_numeric_only = BoolField('numeric_only')
3434
_min_count = Int32Field('min_count')
3535

36-
_stage = Int32Field('stage', on_serialize=lambda x: getattr(x, 'value', None),
37-
on_deserialize=OperandStage)
38-
3936
_dtype = DataTypeField('dtype')
4037
_combine_size = Int32Field('combine_size')
4138

@@ -65,10 +62,6 @@ def numeric_only(self):
6562
def min_count(self):
6663
return self._min_count
6764

68-
@property
69-
def stage(self):
70-
return self._stage
71-
7265
@property
7366
def dtype(self):
7467
return self._dtype

mars/learn/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
# register operands
1616
from .utils.shuffle import shuffle
1717
from .contrib import xgboost, tensorflow, pytorch
18+
from .metrics import pairwise
19+
from . import preprocessing
1820

1921
for _mod in [xgboost, tensorflow, pytorch]:
2022
_mod.register_op()
2123

22-
del _mod, shuffle
24+
del _mod, shuffle, pairwise, preprocessing

mars/learn/metrics/__init__.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 1999-2020 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .pairwise import euclidean_distances, pairwise_distances
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 1999-2020 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .euclidean import euclidean_distances
16+
from .haversine import haversine_distances
17+
from .manhattan import manhattan_distances
18+
from .cosine import cosine_distances, cosine_similarity
19+
from .pairwise import pairwise_distances

mars/learn/metrics/pairwise/core.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 1999-2020 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import itertools
16+
17+
import numpy as np
18+
19+
from ....tensor.operands import TensorOperand, TensorOperandMixin
20+
from ....tensor import tensor as astensor
21+
from ....tiles import TilesError
22+
from ....utils import check_chunks_unknown_shape
23+
from ...utils import check_array
24+
25+
26+
class PairwiseDistances(TensorOperand, TensorOperandMixin):
27+
_op_module_ = 'learn'
28+
29+
@staticmethod
30+
def _return_float_dtype(X, Y):
31+
"""
32+
1. If dtype of X and Y is float32, then dtype float32 is returned.
33+
2. Else dtype float is returned.
34+
"""
35+
36+
X = astensor(X)
37+
38+
if Y is None:
39+
Y_dtype = X.dtype
40+
else:
41+
Y = astensor(Y)
42+
Y_dtype = Y.dtype
43+
44+
if X.dtype == Y_dtype == np.float32:
45+
dtype = np.float32
46+
else:
47+
dtype = np.float
48+
49+
return X, Y, dtype
50+
51+
@staticmethod
52+
def check_pairwise_arrays(X, Y, precomputed=False, dtype=None):
53+
X, Y, dtype_float = PairwiseDistances._return_float_dtype(X, Y)
54+
55+
estimator = 'check_pairwise_arrays'
56+
if dtype is None:
57+
dtype = dtype_float
58+
59+
if Y is X or Y is None:
60+
X = Y = check_array(X, accept_sparse=True, dtype=dtype,
61+
estimator=estimator)
62+
else:
63+
X = check_array(X, accept_sparse=True, dtype=dtype,
64+
estimator=estimator)
65+
Y = check_array(Y, accept_sparse=True, dtype=dtype,
66+
estimator=estimator)
67+
68+
if precomputed:
69+
if X.shape[1] != Y.shape[0]:
70+
raise ValueError("Precomputed metric requires shape "
71+
"(n_queries, n_indexed). Got (%d, %d) "
72+
"for %d indexed." %
73+
(X.shape[0], X.shape[1], Y.shape[0]))
74+
elif X.shape[1] != Y.shape[1]:
75+
raise ValueError("Incompatible dimension for X and Y matrices: "
76+
"X.shape[1] == %d while Y.shape[1] == %d" % (
77+
X.shape[1], Y.shape[1]))
78+
79+
return X, Y
80+
81+
@classmethod
82+
def _tile_one_chunk(cls, op):
83+
out = op.outputs[0]
84+
chunk_op = op.copy().reset_key()
85+
chunk = chunk_op.new_chunk([op.x.chunks[0], op.y.chunks[0]],
86+
shape=out.shape, order=out.order,
87+
index=(0, 0))
88+
new_op = op.copy()
89+
return new_op.new_tensors(op.inputs, shape=out.shape,
90+
order=out.order, chunks=[chunk],
91+
nsplits=tuple((s,) for s in out.shape))
92+
93+
@classmethod
94+
def _tile_chunks(cls, op, x, y):
95+
out = op.outputs[0]
96+
out_chunks = []
97+
for idx in itertools.product(range(x.chunk_shape[0]),
98+
range(y.chunk_shape[0])):
99+
xi, yi = idx
100+
101+
chunk_op = op.copy().reset_key()
102+
chunk_inputs = [x.cix[xi, 0], y.cix[yi, 0]]
103+
out_chunk = chunk_op.new_chunk(
104+
chunk_inputs, shape=(chunk_inputs[0].shape[0],
105+
chunk_inputs[1].shape[0],),
106+
order=out.order, index=idx)
107+
out_chunks.append(out_chunk)
108+
109+
new_op = op.copy()
110+
return new_op.new_tensors(op.inputs, shape=out.shape,
111+
order=out.order, chunks=out_chunks,
112+
nsplits=(x.nsplits[0], y.nsplits[0]))
113+
114+
@classmethod
115+
def _rechunk_cols_into_one(cls, x, y):
116+
y_is_x = y is x
117+
if x.chunk_shape[1] != 1 or y.chunk_shape[1] != 1:
118+
check_chunks_unknown_shape([x, y], TilesError)
119+
120+
x = x.rechunk({1: x.shape[1]})._inplace_tile()
121+
if y_is_x:
122+
y = x
123+
else:
124+
y = y.rechunk({1: y.shape[1]})._inplace_tile()
125+
126+
return x, y

mars/learn/metrics/pairwise/cosine.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 1999-2020 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
17+
from .... import opcodes as OperandDef
18+
from ....serialize import KeyField
19+
from .... import tensor as mt
20+
from ....tensor.core import TensorOrder
21+
from ....tensor.utils import recursive_tile
22+
from ...preprocessing import normalize
23+
from .core import PairwiseDistances
24+
25+
26+
class CosineDistances(PairwiseDistances):
27+
_op_type_ = OperandDef.PAIRWISE_COSINE_DISTANCES
28+
29+
_x = KeyField('x')
30+
_y = KeyField('y')
31+
32+
def __init__(self, x=None, y=None, dtype=None, gpu=None, **kw):
33+
super().__init__(_x=x, _y=y, _dtype=dtype, _gpu=gpu, **kw)
34+
35+
@property
36+
def x(self):
37+
return self._x
38+
39+
@property
40+
def y(self):
41+
return self._y
42+
43+
def _set_inputs(self, inputs):
44+
super()._set_inputs(inputs)
45+
self._x = self._inputs[0]
46+
self._y = self._inputs[1]
47+
48+
def __call__(self, x, y=None):
49+
x, y = self.check_pairwise_arrays(x, y)
50+
return self.new_tensor([x, y], shape=(x.shape[0], y.shape[0]),
51+
order=TensorOrder.C_ORDER)
52+
53+
@classmethod
54+
def tile(cls, op):
55+
x, y = op.x, op.y
56+
if x is y:
57+
S = cosine_similarity(x)
58+
else:
59+
S = cosine_similarity(x, y)
60+
S = (S * -1) + 1
61+
S = mt.clip(S, 0, 2)
62+
if x is y:
63+
mt.fill_diagonal(S, 0.0)
64+
return [recursive_tile(S)]
65+
66+
67+
def cosine_similarity(X, Y=None, dense_output=True):
68+
"""Compute cosine similarity between samples in X and Y.
69+
70+
Cosine similarity, or the cosine kernel, computes similarity as the
71+
normalized dot product of X and Y:
72+
73+
K(X, Y) = <X, Y> / (||X||*||Y||)
74+
75+
On L2-normalized data, this function is equivalent to linear_kernel.
76+
77+
Read more in the :ref:`User Guide <cosine_similarity>`.
78+
79+
Parameters
80+
----------
81+
X : Tensor or sparse tensor, shape: (n_samples_X, n_features)
82+
Input data.
83+
84+
Y : Tensor or sparse tensor, shape: (n_samples_Y, n_features)
85+
Input data. If ``None``, the output will be the pairwise
86+
similarities between all samples in ``X``.
87+
88+
dense_output : boolean (optional), default True
89+
Whether to return dense output even when the input is sparse. If
90+
``False``, the output is sparse if both input tensors are sparse.
91+
92+
Returns
93+
-------
94+
kernel matrix : Tensor
95+
A tensor with shape (n_samples_X, n_samples_Y).
96+
"""
97+
X, Y = PairwiseDistances.check_pairwise_arrays(X, Y)
98+
99+
X_normalized = normalize(X, copy=True)
100+
if X is Y:
101+
Y_normalized = X_normalized
102+
else:
103+
Y_normalized = normalize(Y, copy=True)
104+
105+
K = X_normalized.dot(Y_normalized.T)
106+
if dense_output:
107+
K = K.todense()
108+
return K
109+
110+
111+
def cosine_distances(X, Y=None):
112+
"""Compute cosine distance between samples in X and Y.
113+
114+
Cosine distance is defined as 1.0 minus the cosine similarity.
115+
116+
Read more in the :ref:`User Guide <metrics>`.
117+
118+
Parameters
119+
----------
120+
X : array_like, sparse matrix
121+
with shape (n_samples_X, n_features).
122+
123+
Y : array_like, sparse matrix (optional)
124+
with shape (n_samples_Y, n_features).
125+
126+
Returns
127+
-------
128+
distance matrix : Tensor
129+
A tensor with shape (n_samples_X, n_samples_Y).
130+
131+
See also
132+
--------
133+
mars.learn.metrics.pairwise.cosine_similarity
134+
mars.tensor.spatial.distance.cosine : dense matrices only
135+
"""
136+
op = CosineDistances(x=X, y=Y, dtype=np.dtype(np.float64))
137+
return op(X, y=Y)

0 commit comments

Comments
 (0)