Skip to content

Commit 763c9c7

Browse files
committed
Add standardize=False in dde.data.MfDataSet
1 parent 6de8a9f commit 763c9c7

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

deepxde/data/mf.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import numpy as np
2-
from sklearn import preprocessing
32

43
from .data import Data
54
from ..backend import tf
6-
from ..utils import run_if_any_none
5+
from ..utils import run_if_any_none, standardize
76

87

98
class MfFunc(Data):
@@ -81,6 +80,7 @@ def __init__(
8180
fname_hi_test=None,
8281
col_x=None,
8382
col_y=None,
83+
standardize=False,
8484
):
8585
if X_lo_train is not None:
8686
self.X_lo_train = X_lo_train
@@ -104,8 +104,10 @@ def __init__(
104104

105105
self.X_train = None
106106
self.y_train = None
107+
107108
self.scaler_x = None
108-
self._standardize()
109+
if standardize:
110+
self._standardize()
109111

110112
def losses(self, targets, outputs, loss, model):
111113
n = tf.cond(model.net.training, lambda: len(self.X_lo_train), lambda: 0)
@@ -129,7 +131,7 @@ def test(self):
129131
return self.X_hi_test, [self.y_hi_test, self.y_hi_test]
130132

131133
def _standardize(self):
132-
self.scaler_x = preprocessing.StandardScaler(with_mean=True, with_std=True)
133-
self.X_lo_train = self.scaler_x.fit_transform(self.X_lo_train)
134-
self.X_hi_train = self.scaler_x.transform(self.X_hi_train)
134+
self.scaler_x, self.X_lo_train, self.X_hi_train = standardize(
135+
self.X_lo_train, self.X_hi_train
136+
)
135137
self.X_hi_test = self.scaler_x.transform(self.X_hi_test)

examples/function/mf_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
fname_hi_test=fname_hi_test,
1313
col_x=(0,),
1414
col_y=(1,),
15+
standardize=True,
1516
)
1617

1718
activation = "tanh"

0 commit comments

Comments
 (0)