Skip to content

Commit f8d4a8e

Browse files
authored
Merge pull request #388 from joker-star-l/dev
Add three ML operators which map to SparkML
2 parents e353317 + 9e02abe commit f8d4a8e

23 files changed

+1219
-60
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.basic.model;
20+
21+
public interface DecisionTreeClassificationModel extends Model<double[], Integer> {
22+
23+
int getDepth();
24+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.basic.model;
20+
21+
public interface KMeansModel extends Model<double[], Integer> {
22+
23+
public int getK();
24+
25+
public double[][] getClusterCenters();
26+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.basic.model;
20+
21+
public interface LinearRegressionModel extends Model<double[], Double> {
22+
23+
public double[] getCoefficients();
24+
25+
public double getIntercept();
26+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.basic.model;
20+
21+
import java.io.Serializable;
22+
23+
/**
24+
* A Type that represents a ML model
25+
*/
26+
public interface Model<X, Y> extends Serializable {
27+
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.basic.operators;
20+
21+
import org.apache.wayang.basic.data.Tuple2;
22+
import org.apache.wayang.basic.model.DecisionTreeClassificationModel;
23+
import org.apache.wayang.core.plan.wayangplan.UnaryToUnaryOperator;
24+
import org.apache.wayang.core.types.DataSetType;
25+
26+
public class DecisionTreeClassificationOperator extends UnaryToUnaryOperator<Tuple2<double[], Integer>, DecisionTreeClassificationModel> {
27+
28+
public DecisionTreeClassificationOperator() {
29+
super(DataSetType.createDefaultUnchecked(Tuple2.class),
30+
DataSetType.createDefaultUnchecked(DecisionTreeClassificationModel.class),
31+
false);
32+
}
33+
34+
public DecisionTreeClassificationOperator(DecisionTreeClassificationOperator that) {
35+
super(that);
36+
}
37+
}

wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/KMeansOperator.java

+4-5
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,22 @@
1818

1919
package org.apache.wayang.basic.operators;
2020

21-
import org.apache.wayang.basic.data.Tuple2;
21+
import org.apache.wayang.basic.model.KMeansModel;
2222
import org.apache.wayang.core.api.Configuration;
2323
import org.apache.wayang.core.optimizer.cardinality.CardinalityEstimator;
2424
import org.apache.wayang.core.plan.wayangplan.UnaryToUnaryOperator;
2525
import org.apache.wayang.core.types.DataSetType;
2626

2727
import java.util.Optional;
2828

29-
public class KMeansOperator extends UnaryToUnaryOperator<double[], Tuple2<double[], Integer>> {
29+
public class KMeansOperator extends UnaryToUnaryOperator<double[], KMeansModel> {
30+
3031
// TODO other parameters
3132
protected int k;
3233

3334
public KMeansOperator(int k) {
3435
super(DataSetType.createDefaultUnchecked(double[].class),
35-
DataSetType.createDefaultUnchecked(Tuple2.class),
36+
DataSetType.createDefaultUnchecked(KMeansModel.class),
3637
false);
3738
this.k = k;
3839
}
@@ -46,8 +47,6 @@ public int getK() {
4647
return k;
4748
}
4849

49-
// TODO support fit and transform
50-
5150
@Override
5251
public Optional<CardinalityEstimator> createCardinalityEstimator(int outputIndex, Configuration configuration) {
5352
// TODO
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.basic.operators;
20+
21+
import org.apache.wayang.basic.data.Tuple2;
22+
import org.apache.wayang.basic.model.LinearRegressionModel;
23+
import org.apache.wayang.core.api.Configuration;
24+
import org.apache.wayang.core.optimizer.cardinality.CardinalityEstimator;
25+
import org.apache.wayang.core.plan.wayangplan.UnaryToUnaryOperator;
26+
import org.apache.wayang.core.types.DataSetType;
27+
28+
import java.util.Optional;
29+
30+
public class LinearRegressionOperator extends UnaryToUnaryOperator<Tuple2<double[], Double>, LinearRegressionModel> {
31+
32+
// TODO other parameters
33+
protected boolean fitIntercept;
34+
35+
public LinearRegressionOperator(boolean fitIntercept) {
36+
super(DataSetType.createDefaultUnchecked(Tuple2.class),
37+
DataSetType.createDefaultUnchecked(LinearRegressionModel.class),
38+
false);
39+
this.fitIntercept = fitIntercept;
40+
}
41+
42+
public LinearRegressionOperator(LinearRegressionOperator that) {
43+
super(that);
44+
this.fitIntercept = that.fitIntercept;
45+
}
46+
47+
public boolean getFitIntercept() {
48+
return fitIntercept;
49+
}
50+
51+
@Override
52+
public Optional<CardinalityEstimator> createCardinalityEstimator(int outputIndex, Configuration configuration) {
53+
// TODO
54+
return super.createCardinalityEstimator(outputIndex, configuration);
55+
}
56+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.basic.operators;
20+
21+
import com.fasterxml.jackson.core.type.TypeReference;
22+
import org.apache.wayang.basic.data.Tuple2;
23+
import org.apache.wayang.basic.model.Model;
24+
import org.apache.wayang.core.api.Configuration;
25+
import org.apache.wayang.core.optimizer.cardinality.CardinalityEstimator;
26+
import org.apache.wayang.core.plan.wayangplan.BinaryToUnaryOperator;
27+
import org.apache.wayang.core.types.DataSetType;
28+
import org.apache.wayang.core.util.TypeConverter;
29+
30+
import java.util.Optional;
31+
32+
public class ModelTransformOperator<X, Y> extends BinaryToUnaryOperator<Model<X, Y>, X, Tuple2<X, Y>> {
33+
34+
public static ModelTransformOperator<double[], Integer> kMeans() {
35+
// The type of TypeReference cannot be omitted, to avoid the following error.
36+
// error: cannot infer type arguments for TypeReference<T>, reason: cannot use '<>' with anonymous inner classes
37+
return new ModelTransformOperator<>(new TypeReference<double[]>() {}, new TypeReference<Tuple2<double[], Integer>>() {});
38+
}
39+
40+
public static ModelTransformOperator<double[], Double> linearRegression() {
41+
return new ModelTransformOperator<>(new TypeReference<double[]>() {}, new TypeReference<Tuple2<double[], Double>>() {});
42+
}
43+
44+
public static ModelTransformOperator<double[], Integer> decisionTreeClassification() {
45+
return new ModelTransformOperator<>(new TypeReference<double[]>() {}, new TypeReference<Tuple2<double[], Integer>>() {});
46+
}
47+
48+
public ModelTransformOperator(DataSetType<X> inType, DataSetType<Tuple2<X, Y>> outType) {
49+
// TODO createDefaultUnchecked or createDefault?
50+
super(DataSetType.createDefaultUnchecked(Model.class), inType, outType, false);
51+
}
52+
53+
public ModelTransformOperator(Class<X> inType, Class<Tuple2<X, Y>> outType) {
54+
this(DataSetType.createDefault(inType), DataSetType.createDefault(outType));
55+
}
56+
57+
public ModelTransformOperator(TypeReference<X> inType, TypeReference<Tuple2<X, Y>> outType) {
58+
this(TypeConverter.convert(inType), TypeConverter.convert(outType));
59+
}
60+
61+
public ModelTransformOperator(ModelTransformOperator<X, Y> that) {
62+
super(that);
63+
}
64+
65+
@Override
66+
public Optional<CardinalityEstimator> createCardinalityEstimator(int outputIndex, Configuration configuration) {
67+
// TODO
68+
return super.createCardinalityEstimator(outputIndex, configuration);
69+
}
70+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.core.util;
20+
21+
import com.fasterxml.jackson.core.type.TypeReference;
22+
23+
import java.lang.reflect.ParameterizedType;
24+
import java.lang.reflect.Type;
25+
26+
public class TypeConverter {
27+
public static <T> Class<T> convert(TypeReference<T> ref) {
28+
Type type = ref.getType();
29+
if (type instanceof ParameterizedType) {
30+
return (Class<T>)((ParameterizedType) type).getRawType();
31+
}
32+
return (Class<T>) type;
33+
}
34+
}

wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/mapping/Mappings.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import org.apache.wayang.core.mapping.Mapping;
2222
import org.apache.wayang.spark.mapping.graph.PageRankMapping;
23-
import org.apache.wayang.spark.mapping.ml.KMeansMapping;
23+
import org.apache.wayang.spark.mapping.ml.*;
2424

2525
import java.util.Arrays;
2626
import java.util.Collection;
@@ -65,7 +65,10 @@ public class Mappings {
6565
);
6666

6767
public static Collection<Mapping> ML_MAPPINGS = Arrays.asList(
68-
new KMeansMapping()
68+
new KMeansMapping(),
69+
new LinearRegressionMapping(),
70+
new DecisionTreeClassificationMapping(),
71+
new ModelTransformMapping()
6972
);
7073

7174
}

0 commit comments

Comments
 (0)