2222 RBFInterpolator ,
2323)
2424
25+ from ..plots .plot_helpers import show_or_save_plot
26+
2527# Numpy 1.x compatibility,
2628# TODO: remove these lines when all dependencies support numpy>=2.0.0
2729if np .lib .NumpyVersion (np .__version__ ) >= "2.0.0b1" :
@@ -1378,7 +1380,7 @@ def remove_outliers_iqr(self, threshold=1.5):
13781380 )
13791381
13801382 # Define all presentation methods
1381- def __call__ (self , * args ):
1383+ def __call__ (self , * args , filename = None ):
13821384 """Plot the Function if no argument is given. If an
13831385 argument is given, return the value of the function at the desired
13841386 point.
@@ -1392,13 +1394,18 @@ def __call__(self, *args):
13921394 evaluated at all points in the list and a list of floats will be
13931395 returned. If the function is N-D, N arguments must be given, each
13941396 one being an scalar or list.
1397+ filename : str | None, optional
1398+ The path the plot should be saved to. By default None, in which case
1399+ the plot will be shown instead of saved. Supported file endings are:
1400+ eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff
1401+ and webp (these are the formats supported by matplotlib).
13951402
13961403 Returns
13971404 -------
13981405 ans : None, scalar, list
13991406 """
14001407 if len (args ) == 0 :
1401- return self .plot ()
1408+ return self .plot (filename = filename )
14021409 else :
14031410 return self .get_value (* args )
14041411
@@ -1459,8 +1466,11 @@ def plot(self, *args, **kwargs):
14591466 Function.plot_2d if Function is 2-Dimensional and forward arguments
14601467 and key-word arguments."""
14611468 if isinstance (self , list ):
1469+ # Extract filename from kwargs
1470+ filename = kwargs .get ("filename" , None )
1471+
14621472 # Compare multiple plots
1463- Function .compare_plots (self )
1473+ Function .compare_plots (self , filename )
14641474 else :
14651475 if self .__dom_dim__ == 1 :
14661476 self .plot_1d (* args , ** kwargs )
@@ -1488,6 +1498,7 @@ def plot_1d( # pylint: disable=too-many-statements
14881498 force_points = False ,
14891499 return_object = False ,
14901500 equal_axis = False ,
1501+ filename = None ,
14911502 ):
14921503 """Plot 1-Dimensional Function, from a lower limit to an upper limit,
14931504 by sampling the Function several times in the interval. The title of
@@ -1518,6 +1529,11 @@ def plot_1d( # pylint: disable=too-many-statements
15181529 Setting force_points to True will plot all points, as a scatter, in
15191530 which the Function was evaluated in the dataset. Default value is
15201531 False.
1532+ filename : str | None, optional
1533+ The path the plot should be saved to. By default None, in which case
1534+ the plot will be shown instead of saved. Supported file endings are:
1535+ eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff
1536+ and webp (these are the formats supported by matplotlib).
15211537
15221538 Returns
15231539 -------
@@ -1558,7 +1574,7 @@ def plot_1d( # pylint: disable=too-many-statements
15581574 plt .title (self .title )
15591575 plt .xlabel (self .__inputs__ [0 ].title ())
15601576 plt .ylabel (self .__outputs__ [0 ].title ())
1561- plt . show ( )
1577+ show_or_save_plot ( filename )
15621578 if return_object :
15631579 return fig , ax
15641580
@@ -1581,6 +1597,7 @@ def plot_2d( # pylint: disable=too-many-statements
15811597 disp_type = "surface" ,
15821598 alpha = 0.6 ,
15831599 cmap = "viridis" ,
1600+ filename = None ,
15841601 ):
15851602 """Plot 2-Dimensional Function, from a lower limit to an upper limit,
15861603 by sampling the Function several times in the interval. The title of
@@ -1620,6 +1637,11 @@ def plot_2d( # pylint: disable=too-many-statements
16201637 cmap : string, optional
16211638 Colormap of plotted graph, which can be any of the color maps
16221639 available in matplotlib. Default value is viridis.
1640+ filename : str | None, optional
1641+ The path the plot should be saved to. By default None, in which case
1642+ the plot will be shown instead of saved. Supported file endings are:
1643+ eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff
1644+ and webp (these are the formats supported by matplotlib).
16231645
16241646 Returns
16251647 -------
@@ -1692,7 +1714,7 @@ def plot_2d( # pylint: disable=too-many-statements
16921714 axes .set_xlabel (self .__inputs__ [0 ].title ())
16931715 axes .set_ylabel (self .__inputs__ [1 ].title ())
16941716 axes .set_zlabel (self .__outputs__ [0 ].title ())
1695- plt . show ( )
1717+ show_or_save_plot ( filename )
16961718
16971719 @staticmethod
16981720 def compare_plots ( # pylint: disable=too-many-statements
@@ -1707,6 +1729,7 @@ def compare_plots( # pylint: disable=too-many-statements
17071729 force_points = False ,
17081730 return_object = False ,
17091731 show = True ,
1732+ filename = None ,
17101733 ):
17111734 """Plots N 1-Dimensional Functions in the same plot, from a lower
17121735 limit to an upper limit, by sampling the Functions several times in
@@ -1751,6 +1774,11 @@ def compare_plots( # pylint: disable=too-many-statements
17511774 False.
17521775 show : bool, optional
17531776 If True, shows the plot. Default value is True.
1777+ filename : str | None, optional
1778+ The path the plot should be saved to. By default None, in which case
1779+ the plot will be shown instead of saved. Supported file endings are:
1780+ eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff
1781+ and webp (these are the formats supported by matplotlib).
17541782
17551783 Returns
17561784 -------
@@ -1826,7 +1854,7 @@ def compare_plots( # pylint: disable=too-many-statements
18261854 plt .ylabel (ylabel )
18271855
18281856 if show :
1829- plt . show ( )
1857+ show_or_save_plot ( filename )
18301858
18311859 if return_object :
18321860 return fig , ax
0 commit comments