Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions deepxde/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def on_epoch_end(self):
if self.verbose > 0:
print(
"Epoch {}: {} improved from {:.2e} to {:.2e}, saving model to {} ...\n".format(
self.model.train_state.epoch,
self.model.train_state.iteration,
self.monitor,
self.best,
current,
Expand Down Expand Up @@ -224,7 +224,7 @@ def on_train_begin(self):
self.best = np.inf if self.monitor_op == np.less else -np.inf

def on_epoch_end(self):
if self.model.train_state.epoch < self.start_from_epoch:
if self.model.train_state.iteration < self.start_from_epoch:
return
current = self.get_monitor_value()
if self.monitor_op(current - self.min_delta, self.best):
Expand All @@ -233,7 +233,7 @@ def on_epoch_end(self):
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = self.model.train_state.epoch
self.stopped_epoch = self.model.train_state.iteration
self.model.stop_training = True

def on_train_end(self):
Expand Down Expand Up @@ -274,7 +274,7 @@ def on_epoch_end(self):
self.model.stop_training = True
print(
"\nStop training as time used up. time used: {:.1f} mins, epoch trained: {}".format(
(time.time() - self.t_start) / 60, self.model.train_state.epoch
(time.time() - self.t_start) / 60, self.model.train_state.iteration
)
)

Expand Down Expand Up @@ -347,7 +347,7 @@ def on_train_begin(self):
self.value = [var.value for var in self.var_list]

print(
self.model.train_state.epoch,
self.model.train_state.iteration,
utils.list_to_str(self.value, precision=self.precision),
file=self.file,
)
Expand Down Expand Up @@ -420,7 +420,7 @@ def op(inputs, params):
def on_train_begin(self):
self.on_predict_end()
print(
self.model.train_state.epoch,
self.model.train_state.iteration,
utils.list_to_str(self.value.flatten().tolist(), precision=self.precision),
file=self.file,
)
Expand Down
28 changes: 19 additions & 9 deletions deepxde/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = ["LossHistory", "Model", "TrainState"]

import pickle
import warnings
from collections import OrderedDict

import numpy as np
Expand Down Expand Up @@ -715,7 +716,7 @@ def _train_sgd(self, iterations, display_every, verbose=1):
self.train_state.train_aux_vars,
)

self.train_state.epoch += 1
self.train_state.iteration += 1
self.train_state.step += 1
if self.train_state.step % display_every == 0 or i + 1 == iterations:
self._test(verbose=verbose)
Expand All @@ -728,7 +729,7 @@ def _train_sgd(self, iterations, display_every, verbose=1):

def _train_tensorflow_compat_v1_scipy(self, display_every, verbose=1):
def loss_callback(loss_train, loss_test, *args):
self.train_state.epoch += 1
self.train_state.iteration += 1
self.train_state.step += 1
if self.train_state.step % display_every == 0:
self.train_state.loss_train = loss_train
Expand All @@ -749,7 +750,7 @@ def loss_callback(loss_train, loss_test, *args):
cb.epochs_since_last = 0

print(
cb.model.train_state.epoch,
cb.model.train_state.iteration,
list_to_str(
[float(arg) for arg in args],
precision=cb.precision,
Expand Down Expand Up @@ -792,7 +793,7 @@ def _train_tensorflow_tfp(self, verbose=1):
self.train_state.train_aux_vars,
)
n_iter += results.num_iterations.numpy()
self.train_state.epoch += results.num_iterations.numpy()
self.train_state.iteration += results.num_iterations.numpy()
self.train_state.step += results.num_iterations.numpy()
self._test(verbose=verbose)

Expand All @@ -819,7 +820,7 @@ def _train_pytorch_lbfgs(self, verbose=1):
# Converged
break

self.train_state.epoch += n_iter - prev_n_iter
self.train_state.iteration += n_iter - prev_n_iter
self.train_state.step += n_iter - prev_n_iter
prev_n_iter = n_iter
self._test(verbose=verbose)
Expand Down Expand Up @@ -851,7 +852,7 @@ def _train_paddle_lbfgs(self, verbose=1):
# Converged
break

self.train_state.epoch += n_iter - prev_n_iter
self.train_state.iteration += n_iter - prev_n_iter
self.train_state.step += n_iter - prev_n_iter
prev_n_iter = n_iter
self._test(verbose=verbose)
Expand Down Expand Up @@ -1071,7 +1072,7 @@ def save(self, save_path, protocol="backend", verbose=0):
Returns:
string: Path where model is saved.
"""
save_path = f"{save_path}-{self.train_state.epoch}"
save_path = f"{save_path}-{self.train_state.iteration}"
if protocol == "pickle":
save_path += ".pkl"
with open(save_path, "wb") as f:
Expand Down Expand Up @@ -1104,7 +1105,7 @@ def save(self, save_path, protocol="backend", verbose=0):
if verbose > 0:
print(
"Epoch {}: saving model to {} ...\n".format(
self.train_state.epoch, save_path
self.train_state.iteration, save_path
)
)
return save_path
Expand Down Expand Up @@ -1159,7 +1160,7 @@ def print_model(self):

class TrainState:
def __init__(self):
self.epoch = 0
self.iteration = 0
self.step = 0

# Current data
Expand Down Expand Up @@ -1188,6 +1189,15 @@ def __init__(self):
self.best_ystd = None
self.best_metrics = None

@property
def epoch(self):
warnings.warn(
"TrainState.epoch is deprecated and will be removed in a future version. Use TrainState.iteration instead.",
DeprecationWarning,
stacklevel=2,
)
return self.iteration

def set_data_train(self, X_train, y_train, train_aux_vars=None):
self.X_train = X_train
self.y_train = y_train
Expand Down
2 changes: 1 addition & 1 deletion docs/demos/pinn_forward/elasticity.plate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ We then train the model for 5000 iterations:

.. code-block:: python

losshistory, train_state = model.train(epochs=5000)
losshistory, train_state = model.train(iterations=5000)

Complete code
--------------
Expand Down
2 changes: 1 addition & 1 deletion docs/demos/pinn_forward/helmholtz.2d.neumann.hole.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ First, the DeepXDE, Numpy and Matplotlib modules are imported:
import matplotlib.pyplot as plt
import numpy as np

We begin by defining the general parameters for the problem. We use a collocation points density of 15 (resp. 30) points per wavelength for the training (resp. testing) data along each direction. The PINN will be trained over 5000 epochs. We define the learning rate, the number of dense layers and nodes, and the activation function.
We begin by defining the general parameters for the problem. We use a collocation points density of 15 (resp. 30) points per wavelength for the training (resp. testing) data along each direction. The PINN will be trained over 5000 iterations. We define the learning rate, the number of dense layers and nodes, and the activation function.

.. code-block:: python

Expand Down