-
Notifications
You must be signed in to change notification settings - Fork 894
Add new useful plot features #1544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -131,7 +131,6 @@ def uniformly_continuous_delta(X, Y, eps): | |
| return delta - dx / 2 | ||
| delta += dx | ||
|
|
||
|
|
||
| def saveplot( | ||
| loss_history, | ||
| train_state, | ||
|
|
@@ -141,25 +140,22 @@ def saveplot( | |
| train_fname="train.dat", | ||
| test_fname="test.dat", | ||
| output_dir=None, | ||
| save_format="png", | ||
| generate_statistics=False | ||
| ): | ||
| """Save/plot the loss history and best trained result. | ||
|
|
||
| This function is used to quickly check your results. To better investigate your | ||
| result, use ``save_loss_history()`` and ``save_best_state()``. | ||
|
|
||
| Args: | ||
| loss_history: ``LossHistory`` instance. The first variable returned from | ||
| ``Model.train()``. | ||
| train_state: ``TrainState`` instance. The second variable returned from | ||
| ``Model.train()``. | ||
| issave (bool): Set ``True`` (default) to save the loss, training points, | ||
| and testing points. | ||
| isplot (bool): Set ``True`` (default) to plot loss, metric, and the predicted | ||
| solution. | ||
| loss_history: LossHistory instance. The first variable returned from Model.train(). | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why modifying this? |
||
| train_state: TrainState instance. The second variable returned from Model.train(). | ||
| issave (bool): Set True (default) to save the loss, training points, and testing points. | ||
| isplot (bool): Set True (default) to plot loss, metric, and the predicted solution. | ||
| loss_fname (string): Name of the file to save the loss in. | ||
| train_fname (string): Name of the file to save the training points in. | ||
| test_fname (string): Name of the file to save the testing points in. | ||
| output_dir (string): If ``None``, use the current working directory. | ||
| output_dir (string): If None, use the current working directory. | ||
| save_format (string): File format for saving the plot (default is "png"). | ||
| generate_statistics (bool): Set True to generate additional statistics (average, standard deviation, etc.). | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. max 88 characters per line |
||
| """ | ||
| if output_dir is None: | ||
| output_dir = os.getcwd() | ||
|
|
@@ -175,29 +171,61 @@ def saveplot( | |
| save_best_state(train_state, train_fname, test_fname) | ||
|
|
||
| if isplot: | ||
| plot_loss_history(loss_history) | ||
| plot_best_state(train_state) | ||
| plt.show() | ||
|
|
||
|
|
||
| def plot_loss_history(loss_history, fname=None): | ||
| """Plot the training and testing loss history. | ||
|
|
||
| Note: | ||
| You need to call ``plt.show()`` to show the figure. | ||
| plot_style = { | ||
| 'train_color': 'b', | ||
| 'test_color': 'r', | ||
| 'train_linestyle': '-', | ||
| 'test_linestyle': '--' | ||
| } | ||
|
|
||
| plot_loss_history(loss_history, fname=os.path.join(output_dir, f"custom_style_loss.{save_format}"), plot_style=plot_style) | ||
|
|
||
| if generate_statistics: | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are these statistics useful? |
||
| average_loss_train = np.mean(loss_history.loss_train) | ||
| std_loss_train = np.std(loss_history.loss_train) | ||
| average_loss_test = np.mean(loss_history.loss_test) | ||
| std_loss_test = np.std(loss_history.loss_test) | ||
|
|
||
| print(f"Average Train Loss: {average_loss_train}") | ||
| print(f"Standard Deviation Train Loss: {std_loss_train}") | ||
| print(f"Average Test Loss: {average_loss_test}") | ||
| print(f"Standard Deviation Test Loss: {std_loss_test}") | ||
|
|
||
|
|
||
| def plot_loss_history(loss_history, fname=None, plot_style=None): | ||
| """Plot the training and testing loss history with custom style. | ||
|
|
||
| Args: | ||
| loss_history: ``LossHistory`` instance. The first variable returned from | ||
| ``Model.train()``. | ||
| fname (string): If `fname` is a string (e.g., 'loss_history.png'), then save the | ||
| figure to the file of the file name `fname`. | ||
| loss_history: LossHistory instance. The first variable returned from Model.train(). | ||
| fname (string): If fname is a string (e.g., 'loss_history.png'), then save the | ||
| figure to the file of the file name fname. | ||
| plot_style (dict): A dictionary containing style information for the plot. It can | ||
| include keys like 'train_color', 'test_color', 'train_linestyle', 'test_linestyle', etc. | ||
| """ | ||
| loss_train = np.sum(loss_history.loss_train, axis=1) | ||
| loss_test = np.sum(loss_history.loss_test, axis=1) | ||
|
|
||
| plt.figure() | ||
| plt.semilogy(loss_history.steps, loss_train, label="Train loss") | ||
| plt.semilogy(loss_history.steps, loss_test, label="Test loss") | ||
|
|
||
| # Default plot style settings | ||
| default_style = { | ||
| 'train_color': 'b', | ||
| 'test_color': 'r', | ||
| 'train_linestyle': '-', | ||
| 'test_linestyle': '--' | ||
| } | ||
|
|
||
| # Merge user-defined style with default style | ||
| plot_style = {**default_style, **(plot_style or {})} | ||
|
|
||
| train_color = plot_style['train_color'] | ||
| test_color = plot_style['test_color'] | ||
| train_linestyle = plot_style['train_linestyle'] | ||
| test_linestyle = plot_style['test_linestyle'] | ||
|
|
||
| plt.semilogy(loss_history.steps, loss_train, label="Train loss", color=train_color, linestyle=train_linestyle) | ||
| plt.semilogy(loss_history.steps, loss_test, label="Test loss", color=test_color, linestyle=test_linestyle) | ||
|
|
||
| for i in range(len(loss_history.metrics_test[0])): | ||
| plt.semilogy( | ||
| loss_history.steps, | ||
|
|
@@ -211,6 +239,7 @@ def plot_loss_history(loss_history, fname=None): | |
| plt.savefig(fname) | ||
|
|
||
|
|
||
|
|
||
| def save_loss_history(loss_history, fname): | ||
| """Save the training and testing loss history to a file.""" | ||
| print("Saving loss history to {} ...".format(fname)) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.