Skip to content
Merged
6 changes: 6 additions & 0 deletions python/src/pywy/basic/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ def __init__(self, out: Op):

def get_out(self):
return self.out

class LogisticRegression(Op):
def __init__(self, name=None):
super().__init__(Op.DType.FLOAT32, name)


15 changes: 14 additions & 1 deletion python/src/pywy/dataquanta.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from pywy.operators import *
from pywy.basic.data.record import Record
from pywy.basic.model.option import Option
from pywy.basic.model.models import Model
from pywy.basic.model.models import (Model, LogisticRegression)



class Configuration:
Expand Down Expand Up @@ -193,6 +194,18 @@ def predict(
that._connect(op, 1)
)


def train_logistic_regression(
self: "DataQuanta[In]",
labels: "DataQuanta[In]",
fit_intercept: bool = True
) -> "DataQuanta[Out]":
op = LogisticRegression()
self._connect(op, 0)
labels._connect(op, 1)
return DataQuanta(self.context, op)


def store_textfile(self: "DataQuanta[In]", path: str, input_type: GenericTco = None) -> None:
last: List[SinkOperator] = [
cast(
Expand Down
42 changes: 42 additions & 0 deletions python/src/pywy/tests/train_logistic_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest
from pywy.dataquanta import WayangContext
from pywy.platforms.java import JavaPlugin
from pywy.platforms.spark import SparkPlugin

class TestTrainLogisticRegression(unittest.TestCase):

def test_train_and_predict(self):
ctx = WayangContext().register({JavaPlugin, SparkPlugin})

features = ctx.load_collection([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]])
labels = ctx.load_collection([1.0, 1.0, 0.0, 0.0])

model = features.train_logistic_regression(labels)
predictions = model.predict(features)

result = predictions.collect()
print("Predictions:", result)

self.assertEqual(len(result), 4)
for pred in result:
self.assertIn(pred, [0.0, 1.0])

if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.wayang.core.plan.wayangplan._
import org.apache.wayang.core.platform.Platform
import org.apache.wayang.core.util.{Tuple => WayangTuple}
import org.apache.wayang.basic.data.{Tuple2 => WayangTuple2}
import org.apache.wayang.basic.model.DLModel;
import org.apache.wayang.basic.model.{DLModel, LogisticRegressionModel};
import org.apache.wayang.commons.util.profiledb.model.Experiment
import com.google.protobuf.ByteString;
import org.apache.wayang.api.python.function._
Expand Down Expand Up @@ -105,6 +105,17 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I
udfLoad: LoadProfileEstimator = null): DataQuanta[NewOut] =
mapPartitionsJava(toSerializablePartitionFunction(udf), selectivity, udfLoad)


def trainLogisticRegression(labels: DataQuanta[java.lang.Double], fitIntercept: Boolean): DataQuanta[LogisticRegressionModel] = {
val operator = new LogisticRegressionOperator(fitIntercept)
this.connectTo(operator, 0)
labels.connectTo(operator, 1)
operator
}




/**
* Feed this instance into a [[MapPartitionsOperator]].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import java.util.{Collection => JavaCollection}
import org.apache.wayang.api.graph.{Edge, EdgeDataQuantaBuilder, EdgeDataQuantaBuilderDecorator}
import org.apache.wayang.api.util.{DataQuantaBuilderCache, TypeTrap}
import org.apache.wayang.basic.data.{Record, Tuple2 => RT2}
import org.apache.wayang.basic.model.{DLModel, Model}
import org.apache.wayang.basic.operators.{DLTrainingOperator, GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator}
import org.apache.wayang.basic.model.{DLModel, Model, LogisticRegressionModel}
import org.apache.wayang.basic.operators.{DLTrainingOperator, GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator, LogisticRegressionOperator}
import org.apache.wayang.commons.util.profiledb.model.Experiment
import org.apache.wayang.core.function.FunctionDescriptor.{SerializableBiFunction, SerializableBinaryOperator, SerializableFunction, SerializableIntUnaryOperator, SerializablePredicate}
import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval
Expand All @@ -38,6 +38,9 @@ import org.apache.wayang.core.plan.wayangplan.{Operator, OutputSlot, UnarySource
import org.apache.wayang.core.platform.Platform
import org.apache.wayang.core.types.DataSetType
import org.apache.wayang.core.util.{Logging, ReflectionUtils, WayangCollections, Tuple => WayangTuple}
import org.apache.wayang.core.plan.wayangplan.OutputSlot



import scala.collection.mutable.ListBuffer
import scala.reflect.ClassTag
Expand Down Expand Up @@ -288,6 +291,12 @@ trait DataQuantaBuilder[+This <: DataQuantaBuilder[_, Out], Out] extends Logging
option: DLTrainingOperator.Option) =
new DLTrainingDataQuantaBuilder(this, that, model, option)

def trainLogisticRegression(that: DataQuantaBuilder[_, java.lang.Double], fitIntercept: Boolean = true): LogisticRegressionDataQuantaBuilder =
new LogisticRegressionDataQuantaBuilder(this.asInstanceOf[DataQuantaBuilder[_, Array[Double]]], that, fitIntercept)




/**
* Feed the built [[DataQuanta]] of this and the given instance into a
* [[org.apache.wayang.basic.operators.PredictOperator]].
Expand All @@ -298,6 +307,8 @@ trait DataQuantaBuilder[+This <: DataQuantaBuilder[_, Out], Out] extends Logging
def predict[ThatOut, Result](that: DataQuantaBuilder[_, ThatOut], resultType: Class[Result]) =
new PredictDataQuantaBuilder(this.asInstanceOf[DataQuantaBuilder[_, Model]], that, resultType)



/**
* Feed the built [[DataQuanta]] of this and the given instance into a
* [[org.apache.wayang.basic.operators.CoGroupOperator]].
Expand Down Expand Up @@ -1765,6 +1776,33 @@ class FakeDataQuantaBuilder[T](_dataQuanta: DataQuanta[T])(implicit javaPlanBuil
override protected def build: DataQuanta[T] = _dataQuanta
}

/**
* [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.LogisticRegressionOperator]]s.
*
* @param inputDataQuanta0 [[DataQuantaBuilder]] για τα χαρακτηριστικά (features)
* @param inputDataQuanta1 [[DataQuantaBuilder]] για τις ετικέτες (labels)
*/
class LogisticRegressionDataQuantaBuilder(inputDataQuanta0: DataQuantaBuilder[_, Array[Double]],
inputDataQuanta1: DataQuantaBuilder[_, java.lang.Double],
fitIntercept: Boolean = true)
(implicit javaPlanBuilder: JavaPlanBuilder)
extends BasicDataQuantaBuilder[LogisticRegressionDataQuantaBuilder, LogisticRegressionModel] {

locally {
this.outputTypeTrap.dataSetType = dataSetType[LogisticRegressionModel]
}

override protected def build: DataQuanta[LogisticRegressionModel] =
inputDataQuanta0
.dataQuanta()
.trainLogisticRegression(inputDataQuanta1.dataQuanta(), fitIntercept)


}




/**
* This is not an actual [[DataQuantaBuilder]] but rather decorates such a [[DataQuantaBuilder]] with a key.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import org.apache.wayang.basic.WayangBasics;
import org.apache.wayang.basic.data.Tuple2;
import org.apache.wayang.basic.operators.*;
import org.apache.wayang.basic.model.LogisticRegressionModel;
import org.apache.wayang.api.DataQuanta;
import org.apache.wayang.api.JavaPlanBuilder;
import org.apache.wayang.core.api.Configuration;
import org.apache.wayang.core.api.Job;
import org.apache.wayang.core.api.WayangContext;
Expand All @@ -36,13 +39,18 @@
import org.junit.Assert;
import org.junit.Test;





import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.*;

/**
* Test the Spark integration with Wayang.
Expand Down Expand Up @@ -492,6 +500,51 @@ public void testLogisticRegressionOperator() {
}
}

@Test
public void testLogisticRegressionWithAPI() {
WayangContext context = new WayangContext()
.with(Spark.basicPlugin())
.with(Spark.mlPlugin());

JavaPlanBuilder planBuilder = new JavaPlanBuilder(context)
.withJobName("Logistic Regression Test")
.withUdfJarOf(this.getClass());

// Sample training data
List<double[]> features = Arrays.asList(
new double[]{0.0, 1.0},
new double[]{1.0, 0.0},
new double[]{1.0, 1.0},
new double[]{0.0, 0.0}
);
List<Double> labels = Arrays.asList(1.0, 1.0, 0.0, 0.0);

// Build the pipeline using DataQuantaBuilder
LogisticRegressionModel model = planBuilder
.loadCollection(features).withName("Load Features")
.trainLogisticRegression(
planBuilder.loadCollection(labels).withName("Load Labels"),
true
)
.collect()
.iterator()
.next();

// Predict using the model
Collection<Double> predictions = planBuilder
.loadCollection(Collections.singletonList(model))
.predict(planBuilder.loadCollection(features), Double.class)
.collect();


assertEquals(4, predictions.size());
for (double prediction : predictions) {
assertTrue(prediction == 0.0 || prediction == 1.0);
}
}




@Test
public void testKMeans() {
Expand Down
Loading