Skip to content

Commit 268ceca

Browse files
authored
Backend JAX: Avoid re-initializing the existing neural network (#1635)
1 parent 19f94e1 commit 268ceca

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

deepxde/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,10 @@ def _compile_jax(self, lr, loss_fn, decay):
374374
if self.loss_weights is not None:
375375
raise NotImplementedError("Loss weights are not supported for backend jax.")
376376
# Initialize the network's parameters
377-
key = jax.random.PRNGKey(config.jax_random_seed)
378-
self.net.params = self.net.init(key, self.data.test()[0])
379-
self.params = [self.net.params, self.external_trainable_variables]
377+
if self.params is None:
378+
key = jax.random.PRNGKey(config.jax_random_seed)
379+
self.net.params = self.net.init(key, self.data.test()[0])
380+
self.params = [self.net.params, self.external_trainable_variables]
380381
# TODO: learning rate decay
381382
self.opt = optimizers.get(self.opt_name, learning_rate=lr)
382383
self.opt_state = self.opt.init(self.params)

0 commit comments

Comments
 (0)