Skip to content

Commit 23c57d7

Browse files
authored
Backend TensorFlow supports auxiliary variables (#348)
1 parent 07773fd commit 23c57d7

File tree

4 files changed

+554
-529
lines changed

4 files changed

+554
-529
lines changed

deepxde/data/pde.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ def train_next_batch(self, batch_size=None):
171171
self.train_x = np.vstack((self.train_x, self.train_x_all))
172172
self.train_y = self.soln(self.train_x) if self.soln else None
173173
if self.auxiliary_var_fn is not None:
174-
self.train_aux_vars = self.auxiliary_var_fn(self.train_x)
174+
self.train_aux_vars = self.auxiliary_var_fn(self.train_x).astype(
175+
config.real(np)
176+
)
175177
return self.train_x, self.train_y, self.train_aux_vars
176178

177179
@run_if_all_none("test_x", "test_y", "test_aux_vars")
@@ -182,7 +184,9 @@ def test(self):
182184
self.test_x = self.test_points()
183185
self.test_y = self.soln(self.test_x) if self.soln else None
184186
if self.auxiliary_var_fn is not None:
185-
self.test_aux_vars = self.auxiliary_var_fn(self.test_x)
187+
self.test_aux_vars = self.auxiliary_var_fn(self.test_x).astype(
188+
config.real(np)
189+
)
186190
return self.test_x, self.test_y, self.test_aux_vars
187191

188192
def resample_train_points(self):
@@ -203,7 +207,9 @@ def add_anchors(self, anchors):
203207
self.train_x = np.vstack((self.train_x, self.train_x_all))
204208
self.train_y = self.soln(self.train_x) if self.soln else None
205209
if self.auxiliary_var_fn is not None:
206-
self.train_aux_vars = self.auxiliary_var_fn(self.train_x)
210+
self.train_aux_vars = self.auxiliary_var_fn(self.train_x).astype(
211+
config.real(np)
212+
)
207213

208214
def train_points(self):
209215
X = np.empty((0, self.geom.dim), dtype=config.real(np))

deepxde/maps/tensorflow/nn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def auxiliary_vars(self):
4343
"""Tensors: Any additional variables needed."""
4444
return self._auxiliary_vars
4545

46+
@auxiliary_vars.setter
47+
def auxiliary_vars(self, value):
48+
self._auxiliary_vars = value
49+
4650
def apply_feature_transform(self, transform):
4751
"""Compute the features by appling a transform to the network inputs, i.e.,
4852
features = transform(inputs). Then, outputs = network(features).

deepxde/model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,20 +142,21 @@ def compute_losses(targets, outputs):
142142

143143
# TODO: Avoid creating multiple graphs by using tf.TensorSpec.
144144
@tf.function
145-
def outputs_losses(data_id, inputs, targets):
145+
def outputs_losses(data_id, inputs, targets, auxiliary_vars=None):
146146
self.net.data_id = data_id
147147
self.net.inputs = inputs
148148
self.net.targets = targets
149+
self.net.auxiliary_vars = auxiliary_vars
149150
outputs = self.net(inputs)
150151
losses = compute_losses(targets, outputs)
151152
return outputs, losses
152153

153154
opt = optimizers.get(self.opt_name, learning_rate=lr, decay=decay)
154155

155156
@tf.function
156-
def train_step(data_id, inputs, targets):
157+
def train_step(data_id, inputs, targets, auxiliary_vars=None):
157158
with tf.GradientTape() as tape:
158-
_, losses = outputs_losses(data_id, inputs, targets)
159+
_, losses = outputs_losses(data_id, inputs, targets, auxiliary_vars)
159160
total_loss = tf.math.reduce_sum(losses)
160161
trainable_variables = (
161162
self.net.trainable_variables + self.external_trainable_variables
@@ -504,8 +505,8 @@ def _run(
504505
)
505506
return self.sess.run(fetches, feed_dict=feed_dict)
506507
if backend_name == "tensorflow":
507-
# TODO: Support training, dropout, auxiliary_vars
508-
outs = fetches(data_id, inputs, targets)
508+
# TODO: Support training, dropout
509+
outs = fetches(data_id, inputs, targets, auxiliary_vars)
509510
return None if outs is None else [out.numpy() for out in outs]
510511
if backend_name == "pytorch":
511512
# TODO: Use torch.no_grad() in _test() and predict()

0 commit comments

Comments
 (0)