-
Notifications
You must be signed in to change notification settings - Fork 894
Manual loss weights adaptation in TF2.0 #1656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
haison19952013
wants to merge
18
commits into
lululxvi:master
Choose a base branch
from
haison19952013:manual_loss_weight_adaptation_in_TFv2
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 16 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
4e858ac
Add loss_weights as the arguments of functions in class Model
haison19952013 8b8ce82
add PrintLossWeight and ManualDynamicLossWeight to track and modify t…
haison19952013 f7f4156
details of what to changes
haison19952013 d3abfd4
This one for easily refer to the functions and modules
haison19952013 6a69f42
example of using ManualDynamicLossWeight
haison19952013 9726e90
correct Loss_idx to loss_idx
haison19952013 da34d52
remove test files
haison19952013 c7ca762
Add an inverse example using the manual loss weights adaptation
haison19952013 d588ead
comments sys
haison19952013 8372190
remove unnecessary file
haison19952013 4859894
This works for my local machine
haison19952013 e8edf00
Revert "This works for my local machine"
haison19952013 2c4d160
Format code using black
haison19952013 137a646
remove unnecessary tf import
haison19952013 f3360d0
Format
haison19952013 1db2400
Close sentence with period
haison19952013 c598b80
remove space lines
haison19952013 5a438bb
format for codacity check
haison19952013 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -119,7 +119,9 @@ def compile( | |
| print("Compiling model...") | ||
| self.opt_name = optimizer | ||
| loss_fn = losses_module.get(loss) | ||
| self.loss_weights = loss_weights | ||
| self.loss_weights = tf.convert_to_tensor( | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| loss_weights, dtype=config.default_float() | ||
| ) | ||
| if external_trainable_variables is None: | ||
| self.external_trainable_variables = [] | ||
| else: | ||
|
|
@@ -202,7 +204,9 @@ def _compile_tensorflow(self, lr, loss_fn, decay): | |
| def outputs(training, inputs): | ||
| return self.net(inputs, training=training) | ||
|
|
||
| def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn): | ||
| def outputs_losses( | ||
| training, inputs, targets, auxiliary_vars, losses_fn, loss_weights | ||
| ): | ||
| self.net.auxiliary_vars = auxiliary_vars | ||
| # Don't call outputs() decorated by @tf.function above, otherwise the | ||
| # gradient of outputs wrt inputs will be lost here. | ||
|
|
@@ -218,29 +222,41 @@ def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn): | |
| losses += [tf.math.reduce_sum(self.net.losses)] | ||
| losses = tf.convert_to_tensor(losses) | ||
| # Weighted losses | ||
| if self.loss_weights is not None: | ||
| losses *= self.loss_weights | ||
| if loss_weights is not None: | ||
| losses *= loss_weights | ||
| return outputs_, losses | ||
|
|
||
| @tf.function(jit_compile=config.xla_jit) | ||
| def outputs_losses_train(inputs, targets, auxiliary_vars): | ||
| def outputs_losses_train(inputs, targets, auxiliary_vars, loss_weights): | ||
| return outputs_losses( | ||
| True, inputs, targets, auxiliary_vars, self.data.losses_train | ||
| True, | ||
| inputs, | ||
| targets, | ||
| auxiliary_vars, | ||
| self.data.losses_train, | ||
| loss_weights, | ||
| ) | ||
|
|
||
| @tf.function(jit_compile=config.xla_jit) | ||
| def outputs_losses_test(inputs, targets, auxiliary_vars): | ||
| def outputs_losses_test(inputs, targets, auxiliary_vars, loss_weights): | ||
| return outputs_losses( | ||
| False, inputs, targets, auxiliary_vars, self.data.losses_test | ||
| False, | ||
| inputs, | ||
| targets, | ||
| auxiliary_vars, | ||
| self.data.losses_test, | ||
| loss_weights, | ||
| ) | ||
|
|
||
| opt = optimizers.get(self.opt_name, learning_rate=lr, decay=decay) | ||
|
|
||
| @tf.function(jit_compile=config.xla_jit) | ||
| def train_step(inputs, targets, auxiliary_vars): | ||
| def train_step(inputs, targets, auxiliary_vars, loss_weights): | ||
| # inputs and targets are np.ndarray and automatically converted to Tensor. | ||
| with tf.GradientTape() as tape: | ||
| losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] | ||
| losses = outputs_losses_train( | ||
| inputs, targets, auxiliary_vars, loss_weights | ||
| )[1] | ||
| total_loss = tf.math.reduce_sum(losses) | ||
| trainable_variables = ( | ||
| self.net.trainable_variables + self.external_trainable_variables | ||
|
|
@@ -531,7 +547,7 @@ def _outputs(self, training, inputs): | |
| outs = self.outputs(self.net.params, training, inputs) | ||
| return utils.to_numpy(outs) | ||
|
|
||
| def _outputs_losses(self, training, inputs, targets, auxiliary_vars): | ||
| def _outputs_losses(self, training, inputs, targets, auxiliary_vars, loss_weights): | ||
| if training: | ||
| outputs_losses = self.outputs_losses_train | ||
| else: | ||
|
|
@@ -540,7 +556,7 @@ def _outputs_losses(self, training, inputs, targets, auxiliary_vars): | |
| feed_dict = self.net.feed_dict(training, inputs, targets, auxiliary_vars) | ||
| return self.sess.run(outputs_losses, feed_dict=feed_dict) | ||
| if backend_name == "tensorflow": | ||
| outs = outputs_losses(inputs, targets, auxiliary_vars) | ||
| outs = outputs_losses(inputs, targets, auxiliary_vars, loss_weights) | ||
| elif backend_name == "pytorch": | ||
| self.net.requires_grad_(requires_grad=False) | ||
| outs = outputs_losses(inputs, targets, auxiliary_vars) | ||
|
|
@@ -552,12 +568,12 @@ def _outputs_losses(self, training, inputs, targets, auxiliary_vars): | |
| outs = outputs_losses(inputs, targets, auxiliary_vars) | ||
| return utils.to_numpy(outs[0]), utils.to_numpy(outs[1]) | ||
|
|
||
| def _train_step(self, inputs, targets, auxiliary_vars): | ||
| def _train_step(self, inputs, targets, auxiliary_vars, loss_weights): | ||
| if backend_name == "tensorflow.compat.v1": | ||
| feed_dict = self.net.feed_dict(True, inputs, targets, auxiliary_vars) | ||
| self.sess.run(self.train_step, feed_dict=feed_dict) | ||
| elif backend_name in ["tensorflow", "paddle"]: | ||
| self.train_step(inputs, targets, auxiliary_vars) | ||
| self.train_step(inputs, targets, auxiliary_vars, loss_weights) | ||
| elif backend_name == "pytorch": | ||
| self.train_step(inputs, targets, auxiliary_vars) | ||
| elif backend_name == "jax": | ||
|
|
@@ -669,6 +685,7 @@ def _train_sgd(self, iterations, display_every): | |
| self.train_state.X_train, | ||
| self.train_state.y_train, | ||
| self.train_state.train_aux_vars, | ||
| self.loss_weights, | ||
| ) | ||
|
|
||
| self.train_state.epoch += 1 | ||
|
|
@@ -827,12 +844,14 @@ def _test(self): | |
| self.train_state.X_train, | ||
| self.train_state.y_train, | ||
| self.train_state.train_aux_vars, | ||
| self.loss_weights, | ||
| ) | ||
| self.train_state.y_pred_test, self.train_state.loss_test = self._outputs_losses( | ||
| False, | ||
| self.train_state.X_test, | ||
| self.train_state.y_test, | ||
| self.train_state.test_aux_vars, | ||
| self.loss_weights, | ||
| ) | ||
|
|
||
| if isinstance(self.train_state.y_test, (list, tuple)): | ||
|
|
||
81 changes: 81 additions & 0 deletions
81
examples/pinn_inverse/elliptic_inverse_field_manual_dynamic_loss_weights.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| """Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" | ||
|
|
||
| # import sys | ||
| import deepxde as dde | ||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| from deepxde.callbacks import PrintLossWeight, ManualDynamicLossWeight | ||
|
|
||
| dde.config.disable_xla_jit() | ||
| from deepxde.backend import set_default_backend | ||
|
|
||
| set_default_backend("tensorflow") | ||
|
|
||
|
|
||
| def gen_traindata(num): | ||
| # generate num equally-spaced points from -1 to 1 | ||
| xvals = np.linspace(-1, 1, num).reshape(num, 1) | ||
| uvals = np.sin(np.pi * xvals) | ||
| return xvals, uvals | ||
|
|
||
|
|
||
| def pde(x, y): | ||
| u, q = y[:, 0:1], y[:, 1:2] | ||
| du_xx = dde.grad.hessian(y, x, component=0, i=0, j=0) | ||
| return -du_xx + q | ||
|
|
||
|
|
||
| def sol(x): | ||
| # solution is u(x) = sin(pi*x), q(x) = -pi^2 * sin(pi*x) | ||
| return np.sin(np.pi * x) | ||
|
|
||
|
|
||
| geom = dde.geometry.Interval(-1, 1) | ||
| bc = dde.icbc.DirichletBC(geom, sol, lambda _, on_boundary: on_boundary, component=0) | ||
| ob_x, ob_u = gen_traindata(100) | ||
| observe_u = dde.icbc.PointSetBC(ob_x, ob_u, component=0) | ||
|
|
||
| data = dde.data.PDE( | ||
| geom, | ||
| pde, | ||
| [bc, observe_u], | ||
| num_domain=200, | ||
| num_boundary=2, | ||
| anchors=ob_x, | ||
| num_test=1000, | ||
| ) | ||
|
|
||
| net = dde.nn.FNN([1, 40, 40, 40, 2], "tanh", "Glorot uniform") | ||
| PrintLossWeight_cb = PrintLossWeight(period=1) | ||
| ManualDynamicLossWeight_cb = ManualDynamicLossWeight( | ||
| epoch2change=5000, value=1, loss_idx=0 | ||
| ) | ||
| model = dde.Model(data, net) | ||
| model.compile("adam", lr=0.0001, loss_weights=[0, 100, 1000]) | ||
| losshistory, train_state = model.train( | ||
| iterations=20000, | ||
| display_every=1, | ||
| callbacks=[PrintLossWeight_cb, ManualDynamicLossWeight_cb], | ||
| ) | ||
| # dde.saveplot(losshistory, train_state, issave=True, isplot=True) | ||
|
|
||
| # view results | ||
| x = geom.uniform_points(500) | ||
| yhat = model.predict(x) | ||
| uhat, qhat = yhat[:, 0:1], yhat[:, 1:2] | ||
|
|
||
| utrue = np.sin(np.pi * x) | ||
| print("l2 relative error for u: " + str(dde.metrics.l2_relative_error(utrue, uhat))) | ||
| plt.figure() | ||
| plt.plot(x, utrue, "-", label="u_true") | ||
| plt.plot(x, uhat, "--", label="u_NN") | ||
| plt.legend() | ||
|
|
||
| qtrue = -np.pi**2 * np.sin(np.pi * x) | ||
| print("l2 relative error for q: " + str(dde.metrics.l2_relative_error(qtrue, qhat))) | ||
| plt.figure() | ||
| plt.plot(x, qtrue, "-", label="q_true") | ||
| plt.plot(x, qhat, "--", label="q_NN") | ||
| plt.legend() | ||
|
|
||
| plt.show() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.