diff --git a/deepxde/utils/external.py b/deepxde/utils/external.py index bca1cc1d1..5695dc08f 100644 --- a/deepxde/utils/external.py +++ b/deepxde/utils/external.py @@ -131,7 +131,6 @@ def uniformly_continuous_delta(X, Y, eps): return delta - dx / 2 delta += dx - def saveplot( loss_history, train_state, @@ -141,25 +140,22 @@ def saveplot( train_fname="train.dat", test_fname="test.dat", output_dir=None, + save_format="png", + generate_statistics=False ): """Save/plot the loss history and best trained result. - This function is used to quickly check your results. To better investigate your - result, use ``save_loss_history()`` and ``save_best_state()``. - Args: - loss_history: ``LossHistory`` instance. The first variable returned from - ``Model.train()``. - train_state: ``TrainState`` instance. The second variable returned from - ``Model.train()``. - issave (bool): Set ``True`` (default) to save the loss, training points, - and testing points. - isplot (bool): Set ``True`` (default) to plot loss, metric, and the predicted - solution. + loss_history: LossHistory instance. The first variable returned from Model.train(). + train_state: TrainState instance. The second variable returned from Model.train(). + issave (bool): Set True (default) to save the loss, training points, and testing points. + isplot (bool): Set True (default) to plot loss, metric, and the predicted solution. loss_fname (string): Name of the file to save the loss in. train_fname (string): Name of the file to save the training points in. test_fname (string): Name of the file to save the testing points in. - output_dir (string): If ``None``, use the current working directory. + output_dir (string): If None, use the current working directory. + save_format (string): File format for saving the plot (default is "png"). + generate_statistics (bool): Set True to generate additional statistics (average, standard deviation, etc.). """ if output_dir is None: output_dir = os.getcwd() @@ -175,29 +171,61 @@ def saveplot( save_best_state(train_state, train_fname, test_fname) if isplot: - plot_loss_history(loss_history) - plot_best_state(train_state) - plt.show() - - -def plot_loss_history(loss_history, fname=None): - """Plot the training and testing loss history. - - Note: - You need to call ``plt.show()`` to show the figure. + plot_style = { + 'train_color': 'b', + 'test_color': 'r', + 'train_linestyle': '-', + 'test_linestyle': '--' + } + + plot_loss_history(loss_history, fname=os.path.join(output_dir, f"custom_style_loss.{save_format}"), plot_style=plot_style) + + if generate_statistics: + average_loss_train = np.mean(loss_history.loss_train) + std_loss_train = np.std(loss_history.loss_train) + average_loss_test = np.mean(loss_history.loss_test) + std_loss_test = np.std(loss_history.loss_test) + + print(f"Average Train Loss: {average_loss_train}") + print(f"Standard Deviation Train Loss: {std_loss_train}") + print(f"Average Test Loss: {average_loss_test}") + print(f"Standard Deviation Test Loss: {std_loss_test}") + + +def plot_loss_history(loss_history, fname=None, plot_style=None): + """Plot the training and testing loss history with custom style. Args: - loss_history: ``LossHistory`` instance. The first variable returned from - ``Model.train()``. - fname (string): If `fname` is a string (e.g., 'loss_history.png'), then save the - figure to the file of the file name `fname`. + loss_history: LossHistory instance. The first variable returned from Model.train(). + fname (string): If fname is a string (e.g., 'loss_history.png'), then save the + figure to the file of the file name fname. + plot_style (dict): A dictionary containing style information for the plot. It can + include keys like 'train_color', 'test_color', 'train_linestyle', 'test_linestyle', etc. """ loss_train = np.sum(loss_history.loss_train, axis=1) loss_test = np.sum(loss_history.loss_test, axis=1) plt.figure() - plt.semilogy(loss_history.steps, loss_train, label="Train loss") - plt.semilogy(loss_history.steps, loss_test, label="Test loss") + + # Default plot style settings + default_style = { + 'train_color': 'b', + 'test_color': 'r', + 'train_linestyle': '-', + 'test_linestyle': '--' + } + + # Merge user-defined style with default style + plot_style = {**default_style, **(plot_style or {})} + + train_color = plot_style['train_color'] + test_color = plot_style['test_color'] + train_linestyle = plot_style['train_linestyle'] + test_linestyle = plot_style['test_linestyle'] + + plt.semilogy(loss_history.steps, loss_train, label="Train loss", color=train_color, linestyle=train_linestyle) + plt.semilogy(loss_history.steps, loss_test, label="Test loss", color=test_color, linestyle=test_linestyle) + for i in range(len(loss_history.metrics_test[0])): plt.semilogy( loss_history.steps, @@ -211,6 +239,7 @@ def plot_loss_history(loss_history, fname=None): plt.savefig(fname) + def save_loss_history(loss_history, fname): """Save the training and testing loss history to a file.""" print("Saving loss history to {} ...".format(fname))