Skip to content

Commit cd5a9d1

Browse files
committed
Use Map base class
1 parent 015fba0 commit cd5a9d1

File tree

5 files changed

+49
-43
lines changed

5 files changed

+49
-43
lines changed

deepxde/maps/fnn.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from . import activations
1010
from . import initializers
1111
from . import regularizers
12+
from .map import Map
1213
from .. import config
1314
from ..utils import timing
1415

1516

16-
class FNN(object):
17+
class FNN(Map):
1718
"""feed-forward neural networks
1819
"""
1920

@@ -33,10 +34,7 @@ def __init__(
3334
self.dropout_rate = dropout_rate
3435
self.batch_normalization = batch_normalization
3536

36-
self.training, self.dropout = None, None
37-
self.data_id = None # 0: train data, 1: test data
38-
self.x, self.y, self.y_ = None, None, None
39-
self.build()
37+
super(FNN, self).__init__()
4038

4139
@property
4240
def inputs(self):
@@ -53,9 +51,6 @@ def targets(self):
5351
@timing
5452
def build(self):
5553
print("\nBuilding feed-forward neural network...")
56-
self.training = tf.placeholder(tf.bool)
57-
self.dropout = tf.placeholder(tf.bool)
58-
self.data_id = tf.placeholder(tf.uint8)
5954
self.x = tf.placeholder(config.real(tf), [None, self.layer_size[0]])
6055

6156
y = self.x

deepxde/maps/map.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import tensorflow as tf
6+
7+
from ..utils import timing
8+
9+
10+
class Map(object):
11+
"""Map base class."""
12+
13+
def __init__(self):
14+
if not hasattr(self, "regularizer"):
15+
self.regularizer = None
16+
17+
self.training = tf.placeholder(tf.bool)
18+
self.dropout = tf.placeholder(tf.bool)
19+
self.data_id = tf.placeholder(tf.uint8) # 0: train data, 1: test data
20+
21+
self.build()
22+
23+
@property
24+
def inputs(self):
25+
"""Return the mapping inputs."""
26+
27+
@property
28+
def outputs(self):
29+
"""Return the mapping outputs."""
30+
31+
@property
32+
def targets(self):
33+
"""Return the targets of the mapping outputs."""
34+
35+
@timing
36+
def build(self):
37+
"""Construct the mapping."""

deepxde/maps/mfnn.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from . import activations
88
from . import initializers
99
from . import regularizers
10+
from .map import Map
1011
from .. import config
1112
from ..utils import timing
1213

1314

14-
class MfNN(object):
15+
class MfNN(Map):
1516
"""Multifidelity neural networks
1617
"""
1718

@@ -31,16 +32,7 @@ def __init__(
3132
self.regularizer = regularizers.get(regularization)
3233
self.residue = residue
3334

34-
self.training = None
35-
self.dropout = None
36-
self.data_id = None
37-
self.X = None
38-
self.y_lo = None
39-
self.y_hi = None
40-
self.target_lo = None
41-
self.target_hi = None
42-
43-
self.build()
35+
super(MfNN, self).__init__()
4436

4537
@property
4638
def inputs(self):
@@ -57,9 +49,6 @@ def targets(self):
5749
@timing
5850
def build(self):
5951
print("Building multifidelity neural network...")
60-
self.training = tf.placeholder(tf.bool)
61-
self.dropout = tf.placeholder(tf.bool)
62-
self.data_id = tf.placeholder(tf.uint8)
6352
self.X = tf.placeholder(config.real(tf), [None, self.layer_size_lo[0]])
6453

6554
# Low fidelity

deepxde/maps/opnn.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from . import activations
88
from . import initializers
99
from . import regularizers
10+
from .map import Map
1011
from .. import config
1112
from ..utils import timing
1213

1314

14-
class OpNN(object):
15+
class OpNN(Map):
1516
"""Operator neural networks
1617
"""
1718

@@ -37,15 +38,7 @@ def __init__(
3738
)
3839
self.regularizer = regularizers.get(regularization)
3940

40-
self.training = None
41-
self.dropout = None
42-
self.data_id = None
43-
self.X_func = None
44-
self.X_loc = None
45-
self.y = None
46-
self.target = None
47-
48-
self.build()
41+
super(OpNN, self).__init__()
4942

5043
@property
5144
def inputs(self):
@@ -62,9 +55,6 @@ def targets(self):
6255
@timing
6356
def build(self):
6457
print("Building operator neural network...")
65-
self.training = tf.placeholder(tf.bool)
66-
self.dropout = tf.placeholder(tf.bool)
67-
self.data_id = tf.placeholder(tf.uint8)
6858
self.X_func = tf.placeholder(config.real(tf), [None, self.layer_size_func[0]])
6959
self.X_loc = tf.placeholder(config.real(tf), [None, self.layer_size_loc[0]])
7060

deepxde/maps/resnet.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from . import activations
88
from . import initializers
99
from . import regularizers
10+
from .map import Map
1011
from .. import config
1112
from ..utils import timing
1213

1314

14-
class ResNet(object):
15+
class ResNet(Map):
1516
"""Residual neural network
1617
"""
1718

@@ -33,10 +34,7 @@ def __init__(
3334
self.kernel_initializer = initializers.get(kernel_initializer)
3435
self.regularizer = regularizers.get(regularization)
3536

36-
self.training, self.dropout = None, None
37-
self.data_id = None # 0: train data, 1: test data
38-
self.x, self.y, self.y_ = None, None, None
39-
self.build()
37+
super(ResNet, self).__init__()
4038

4139
@property
4240
def inputs(self):
@@ -53,9 +51,6 @@ def targets(self):
5351
@timing
5452
def build(self):
5553
print("Building residual neural network...")
56-
self.training = tf.placeholder(tf.bool)
57-
self.dropout = tf.placeholder(tf.bool)
58-
self.data_id = tf.placeholder(tf.uint8)
5954
self.x = tf.placeholder(config.real(tf), [None, self.input_size])
6055

6156
y = self.dense(self.x, self.num_neurons, activation=self.activation)

0 commit comments

Comments
 (0)