Skip to content

Commit 015fba0

Browse files
committed
Refactor mf nn
1 parent f671064 commit 015fba0

File tree

4 files changed

+69
-78
lines changed

4 files changed

+69
-78
lines changed

deepxde/data/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from .func import Func
55
from .func_constraint import FuncConstraint
66
from .ide import IDE
7-
from .mf_dataset import MfDataSet
8-
from .mf_func import MfFunc
7+
from .mf import MfDataSet
8+
from .mf import MfFunc
99
from .op_dataset import OpDataSet
1010
from .pde import PDE
1111
from .pde import TimePDE

deepxde/data/mf_dataset.py renamed to deepxde/data/mf.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,64 @@
77
from sklearn import preprocessing
88

99
from .data import Data
10-
from .. import losses
10+
from .. import losses as losses_module
11+
from ..utils import run_if_any_none
12+
13+
14+
class MfFunc(Data):
15+
"""Multifidelity function approximation.
16+
"""
17+
18+
def __init__(
19+
self, geom, func_lo, func_hi, num_lo, num_hi, num_test, dist_train="uniform"
20+
):
21+
self.geom = geom
22+
self.func_lo = func_lo
23+
self.func_hi = func_hi
24+
self.num_lo = num_lo
25+
self.num_hi = num_hi
26+
self.num_test = num_test
27+
self.dist_train = dist_train
28+
29+
self.X_train = None
30+
self.y_train = None
31+
self.X_test = None
32+
self.y_test = None
33+
34+
def losses(self, targets, outputs, loss_type, model):
35+
loss_f = losses_module.get(loss_type)
36+
loss_lo = loss_f(targets[0][: self.num_lo], outputs[0][: self.num_lo])
37+
loss_hi = loss_f(targets[1][self.num_lo :], outputs[1][self.num_lo :])
38+
return [loss_lo, loss_hi]
39+
40+
@run_if_any_none("X_train", "y_train")
41+
def train_next_batch(self, batch_size=None):
42+
if self.dist_train == "uniform":
43+
self.X_train = np.vstack(
44+
(
45+
self.geom.uniform_points(self.num_lo, True),
46+
self.geom.uniform_points(self.num_hi, True),
47+
)
48+
)
49+
else:
50+
self.X_train = np.vstack(
51+
(
52+
self.geom.random_points(self.num_lo, "sobol"),
53+
self.geom.random_points(self.num_hi, "sobol"),
54+
)
55+
)
56+
y_lo_train = self.func_lo(self.X_train)
57+
y_hi_train = self.func_hi(self.X_train)
58+
self.y_train = [y_lo_train, y_hi_train]
59+
return self.X_train, self.y_train
60+
61+
@run_if_any_none("X_test", "y_test")
62+
def test(self):
63+
self.X_test = self.geom.uniform_points(self.num_test, True)
64+
y_lo_test = self.func_lo(self.X_test)
65+
y_hi_test = self.func_hi(self.X_test)
66+
self.y_test = [y_lo_test, y_hi_test]
67+
return self.X_test, self.y_test
1168

1269

1370
class MfDataSet(Data):
@@ -54,27 +111,28 @@ def __init__(
54111
raise ValueError("No training data.")
55112

56113
self.X_train = None
114+
self.y_train = None
57115
self.scaler_x = None
58116
self._standardize()
59117

60-
def losses(self, targets, outputs, loss, model):
118+
def losses(self, targets, outputs, loss_type, model):
119+
loss_f = losses_module.get(loss_type)
61120
n = tf.cond(
62121
tf.equal(model.net.data_id, 0), lambda: len(self.X_lo_train), lambda: 0
63122
)
64-
loss_lo = losses.get(loss)(targets[0][:n], outputs[0][:n])
65-
loss_hi = losses.get(loss)(targets[1][n:], outputs[1][n:])
123+
loss_lo = loss_f(targets[0][:n], outputs[0][:n])
124+
loss_hi = loss_f(targets[1][n:], outputs[1][n:])
66125
return [loss_lo, loss_hi]
67126

127+
@run_if_any_none("X_train", "y_train")
68128
def train_next_batch(self, batch_size=None):
69-
if self.X_train is not None:
70-
return self.X_train, [self.y_lo_train, self.y_hi_train]
71-
72129
self.X_train = np.vstack((self.X_lo_train, self.X_hi_train))
73130
self.y_lo_train, self.y_hi_train = (
74131
np.vstack((self.y_lo_train, np.zeros_like(self.y_hi_train))),
75132
np.vstack((np.zeros_like(self.y_lo_train), self.y_hi_train)),
76133
)
77-
return self.X_train, [self.y_lo_train, self.y_hi_train]
134+
self.y_train = [self.y_lo_train, self.y_hi_train]
135+
return self.X_train, self.y_train
78136

79137
def test(self):
80138
return self.X_hi_test, [self.y_hi_test, self.y_hi_test]

deepxde/data/mf_func.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

examples/mf_func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def func_hi(x):
1717

1818
geom = dde.geometry.Interval(0, 1)
1919
num_test = 1000
20-
data = dde.data.MfFunc(geom, func_lo, func_hi, 51, 5, num_test)
20+
data = dde.data.MfFunc(geom, func_lo, func_hi, 100, 6, num_test)
2121

2222
activation = "tanh"
2323
initializer = "Glorot uniform"

0 commit comments

Comments
 (0)