Skip to content
7 changes: 7 additions & 0 deletions python/src/pywy/basic/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,10 @@ def __init__(self, out: Op):

def get_out(self):
return self.out

class LogisticRegression(Op):
def __init__(self, name=None):
super().__init__(Op.DType.FLOAT32, name)

def inputs_required(self):
Comment thread
zkaoudi marked this conversation as resolved.
Outdated
return 2
7 changes: 0 additions & 7 deletions python/src/pywy/basic/model/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,6 @@ def get_dType(self):
def inputs_required(self):
return 1

class LogisticRegression(Op):
def __init__(self, name=None):
super().__init__(Op.DType.FLOAT32, name)

def inputs_required(self):
return 2 # features + labels


class Softmax(Op):
def __init__(self, name=None):
Expand Down
6 changes: 2 additions & 4 deletions python/src/pywy/dataquanta.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from pywy.operators import *
from pywy.basic.data.record import Record
from pywy.basic.model.option import Option
from pywy.basic.model.models import Model
from pywy.basic.model.ops import LogisticRegression
from pywy.basic.model.models import (Model, LogisticRegression)



class Configuration:
Expand Down Expand Up @@ -200,8 +200,6 @@ def train_logistic_regression(
labels: "DataQuanta[In]",
fit_intercept: bool = True
) -> "DataQuanta[Out]":
from pywy.basic.model.ops import LogisticRegression

op = LogisticRegression()
self._connect(op, 0)
labels._connect(op, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I
val operator = new LogisticRegressionOperator(fitIntercept)
this.connectTo(operator, 0)
labels.connectTo(operator, 1)
new DataQuanta[LogisticRegressionModel](operator)
operator
}


Expand Down
Loading