1313from csep .utils .time_utils import decimal_year , datetime_to_utc_epoch
1414from csep .core .catalogs import AbstractBaseCatalog
1515from csep .utils .constants import SECONDS_PER_ASTRONOMICAL_YEAR
16- from csep .utils . plots import plot_spatial_dataset
16+ from csep .plots import plot_gridded_dataset
1717
1818
1919# idea: should this be a SpatialDataSet and the class below SpaceMagnitudeDataSet, bc of functions like
@@ -432,17 +432,27 @@ def load_ascii(cls, ascii_fname, start_date=None, end_date=None, name=None, swap
432432 gds = cls (start_date , end_date , magnitudes = mws , name = name , region = region , data = rates )
433433 return gds
434434
435- def plot (self , ax = None , show = False , log = True , extent = None , set_global = False , plot_args = None ):
436- """ Plot gridded forecast according to plate-carree projection
435+ def plot (self , ax = None , show = False , log = True , extent = None , set_global = False , plot_args = None ,
436+ ** kwargs ):
437+ """ Plot the spatial rate of the forecast
438+
439+ See :func:`csep.utils.plots.plot_gridded_dataset` for a detailed description of the
440+ keyword arguments.
437441
438442 Args:
439- show (bool): if true, show the figure. this call is blocking.
440- plot_args (optional/dict): dictionary containing plotting arguments for making figures
443+ ax (`matplotlib.pyplot.axes`): Previous axes onto which catalog can be drawn
444+ show (bool): If True, shows the figure.
445+ log (bool): If True, plots the base-10 logarithm of the spatial rates
446+ extent (list): Force an extent [lon_min, lon_max, lat_min, lat_max]
447+ set_global (bool): Whether to plot using a global projection
448+ **kwargs (dict): Keyword arguments passed to
449+ :func:`csep.utils.plots.plot_gridded_dataset`
441450
442451 Returns:
443452 axes: matplotlib.Axes.axes
444453 """
445- # no mutable function arguments
454+
455+
446456 if self .start_time is None or self .end_time is None :
447457 time = 'forecast period'
448458 else :
@@ -451,19 +461,24 @@ def plot(self, ax=None, show=False, log=True, extent=None, set_global=False, plo
451461 time = f'{ round (end - start ,3 )} years'
452462
453463 plot_args = plot_args or {}
454- plot_args .setdefault ('figsize' , (10 , 10 ))
455- plot_args .setdefault ('title' , self .name )
456-
464+ plot_args .update ({
465+ 'basemap' : kwargs .pop ('basemap' , 'ESRI_terrain' ) if ax is None else None ,
466+ 'title' : kwargs .pop ('title' , None ) or self .name ,
467+ 'figsize' : kwargs .pop ('figsize' , None ) or (8 , 8 ),
468+ 'plot_region' : True
469+ })
470+ plot_args .update (** kwargs )
457471 # this call requires internet connection and basemap
458472 if log :
459473 plot_args .setdefault ('clabel' , f'log10 M{ self .min_magnitude } + rate per cell per { time } ' )
460474 with numpy .errstate (divide = 'ignore' ):
461- ax = plot_spatial_dataset (numpy .log10 (self .spatial_counts (cartesian = True )), self .region , ax = ax ,
462- show = show , extent = extent , set_global = set_global , plot_args = plot_args )
475+ ax = plot_gridded_dataset (numpy .log10 (self .spatial_counts (cartesian = True )), self .region , ax = ax ,
476+ show = show , extent = extent , set_global = set_global ,
477+ ** plot_args )
463478 else :
464479 plot_args .setdefault ('clabel' , f'M{ self .min_magnitude } + rate per cell per { time } ' )
465- ax = plot_spatial_dataset (self .spatial_counts (cartesian = True ), self .region , ax = ax ,show = show , extent = extent ,
466- set_global = set_global , plot_args = plot_args )
480+ ax = plot_gridded_dataset (self .spatial_counts (cartesian = True ), self .region , ax = ax , show = show , extent = extent ,
481+ set_global = set_global , ** plot_args )
467482 return ax
468483
469484
@@ -654,7 +669,7 @@ def magnitude_counts(self):
654669 self .get_expected_rates ()
655670 return self .expected_rates .magnitude_counts ()
656671
657- def get_event_counts (self , verbose = True ):
672+ def get_event_counts (self , verbose = False ):
658673 """ Returns a numpy array containing the number of event counts for each catalog.
659674
660675 Note: This function can take a while to compute if called without already iterating through a forecast that
@@ -715,7 +730,7 @@ def get_expected_rates(self, verbose=False):
715730 magnitudes = self .magnitudes , name = self .name )
716731 return self .expected_rates
717732
718- def plot (self , plot_args = None , verbose = True , ** kwargs ):
733+ def plot (self , plot_args = None , verbose = False , ** kwargs ):
719734 plot_args = plot_args or {}
720735 if self .expected_rates is None :
721736 self .get_expected_rates (verbose = verbose )
0 commit comments