@@ -395,13 +395,15 @@ def plot_cumulative_events_versus_time(
395395 pyplot .show ()
396396 return ax
397397
398-
399398def plot_magnitude_histogram (
400- catalog_forecast : Union ["CatalogForecast" , List [ "CSEPCatalog" ] ],
401- observation : "CSEPCatalog" ,
399+ forecast : Union ["CatalogForecast" , "GriddedForecast" ],
400+ observation : Optional [ "CSEPCatalog" ] = None ,
402401 magnitude_bins : Optional [Union [List [float ], numpy .ndarray ]] = None ,
403402 percentile : int = 95 ,
404403 log_scale : bool = True ,
404+ normalize : bool = True ,
405+ cumulative : bool = False ,
406+ intervals : bool = True ,
405407 ax : Optional ["matplotlib.axes.Axes" ] = None ,
406408 show : bool = False ,
407409 ** kwargs : Any ,
@@ -416,7 +418,7 @@ def plot_magnitude_histogram(
416418 - :ref:`Catalog-based Forecast Plots<catalog-forecast-evaluation-exploratory>`
417419
418420 Args:
419- catalog_forecast (CatalogForecast or list of CSEPCatalog): A catalog-based forecast
421+ forecast (CatalogForecast or list of CSEPCatalog or GriddedForecast ): A forecast
420422 or a list of observed catalogs.
421423 observation (CSEPCatalog): The observed catalog for comparison.
422424 magnitude_bins (list of float or numpy.ndarray, optional): The bins for magnitude
@@ -426,6 +428,12 @@ def plot_magnitude_histogram(
426428 `95`.
427429 log_scale (bool, optional): Whether to plot the y-axis in logarithmic scale. Defaults to
428430 True.
431+ normalize (bool, optional): Whether to normalize the forecast for the total number in the
432+ observation catalog.
433+ cumulative (bool, optional): Whether to plot cumulative counts N(M >= m). Defaults to
434+ False.
435+ intervals (bool, optional): Whether to display forecast uncertainty intervals.
436+ Defaults to True.
429437 ax (matplotlib.axes.Axes, optional): The axes object to draw the plot on. If `None`, a
430438 new figure and axes are created. Defaults to `None`.
431439 show (bool, optional): Whether to display the plot immediately. Defaults to `False`.
@@ -453,11 +461,12 @@ def plot_magnitude_histogram(
453461 matplotlib.axes.Axes: The axes object containing the plot.
454462
455463 .. versionchanged:: 0.8.0
456- It now requires a `CatalogForecast` rather than a list of stochastic event sets. The
457- `plot_args` dictionary is only partially supported and will be removed in v1.0.0
458- .. versionadded:: 0.8.0
459- Added `magnitude_bins`, `percentile` and `log_scale` to fine-tune the plot.
460- Added parameters to customize coloring, formatting and sizing of the plot elements.
464+ It now accepts a `CatalogForecast` or a `GriddedForecast`. An obervation `CSEPCatalog`
465+ is now optional. Added `magnitude_bins`, `percentile` and `log_scale` to fine-tune the
466+ plot. Added parameters to customize coloring, formatting and sizing of the plot
467+ elements. The `plot_args` dictionary is only partially supported and will be removed in
468+ v1.0.0
469+
461470 """
462471 if "plot_args" in kwargs :
463472 _warning_plot_args ("plot_magnitude_histogram" )
@@ -467,80 +476,175 @@ def plot_magnitude_histogram(
467476 fig , ax = pyplot .subplots (figsize = plot_args ["figsize" ]) if ax is None else (ax .figure , ax )
468477
469478 # Get magnitudes from observations and (lazily) from forecast
470- forecast_mws = list (map (lambda x : x .get_magnitudes (), catalog_forecast ))
471- obs_mw = observation .get_magnitudes ()
472- n_obs = observation .get_number_of_events ()
473-
474- # Get magnitude bins from args, if not from region, or lastly from standard CSEP bins.
475- if magnitude_bins is None :
479+ if magnitude_bins is not None :
480+ forecast_bins = numpy .asarray (magnitude_bins )
481+ else :
476482 try :
477- magnitude_bins = observation . region . magnitudes
483+ forecast_bins = getattr ( forecast , " magnitudes" )
478484 except AttributeError :
479- magnitude_bins = CSEP_MW_BINS
485+ raise AttributeError ("Forecast must be defined on a 'region', having "
486+ "'magnitudes' attribute as left-edge magnitude bins." )
487+
488+ dm_forecast = numpy .median (numpy .diff (forecast_bins ))
489+ forecast_centers = forecast_bins + dm_forecast / 2.0
480490
481- def get_histogram_synthetic_cat (x , mags , normed = True ):
482- n_temp = len (x )
483- if normed and n_temp != 0 :
484- temp_scale = n_obs / n_temp
485- hist = numpy .histogram (x , bins = mags )[0 ] * temp_scale
491+ if observation is not None :
492+ if magnitude_bins is not None :
493+ obs_bins , obs_counts = observation .magnitude_counts (
494+ mag_bins = magnitude_bins , retbins = True
495+ )
496+ elif hasattr (observation , "region" ) and hasattr (observation .region , "magnitudes" ):
497+ obs_bins , obs_counts = observation .magnitude_counts (mag_bins = None , retbins = True )
498+ else :
499+ obs_bins , obs_counts = observation .magnitude_counts (
500+ mag_bins = CSEP_MW_BINS , retbins = True
501+ )
502+ nonzero_idx = numpy .nonzero (obs_counts )[0 ]
503+ if nonzero_idx .size > 0 :
504+ first = nonzero_idx [0 ]
505+ obs_bins = obs_bins [first :]
506+ obs_counts = obs_counts [first :]
507+
508+ if len (obs_bins ) > 1 :
509+ dm_obs = numpy .median (numpy .diff (obs_bins ))
486510 else :
487- hist = numpy .histogram (x , bins = mags )[0 ]
488- return hist
511+ dm_obs = dm_forecast
512+ obs_centers = obs_bins + dm_obs / 2.0
513+ idxs = numpy .where (obs_centers >= forecast_centers [0 ])[0 ]
514+ if idxs .size > 0 :
515+ obs_index = idxs [0 ]
516+ n_obs = numpy .sum (obs_counts [obs_index :])
517+ else :
518+ n_obs = 0
519+ else :
520+ obs_counts = None
521+ obs_centers = None
522+ n_obs = 0
523+
524+ if hasattr (forecast , "catalogs" ):
525+ forecast_mws = list (map (lambda x : x .get_magnitudes (), forecast ))
526+
527+ def get_histogram_synthetic_cat (x , mags , normed_hist = normalize ):
528+ n_syn_events = len (x )
529+ if normed_hist and n_syn_events != 0 and n_obs != 0 :
530+ temp_scale = n_obs / n_syn_events
531+ hist = numpy .histogram (x , bins = mags )[0 ] * temp_scale
532+ else :
533+ hist = numpy .histogram (x , bins = mags )[0 ]
534+ return hist
489535
490- # get histogram values
491- forecast_hist = numpy .array (
492- list (map (lambda x : get_histogram_synthetic_cat (x , magnitude_bins ), forecast_mws ))
493- )
494- obs_hist , bin_edges = numpy .histogram (obs_mw , bins = magnitude_bins )
495- bin_centers = (bin_edges [1 :] + bin_edges [:- 1 ]) / 2
536+ # get histogram values
537+ catalog_forecast_bins = numpy .append (forecast_bins , forecast_bins [- 1 ] + dm_forecast )
538+ forecast_hist = numpy .array (
539+ list (map (lambda x : get_histogram_synthetic_cat (x , catalog_forecast_bins ), forecast_mws ))
540+ )
496541
542+ if cumulative :
543+ hist_for_stats = numpy .cumsum (forecast_hist [:, ::- 1 ], axis = 1 )[:, ::- 1 ]
544+ else :
545+ hist_for_stats = forecast_hist
546+
547+ forecast_mean = hist_for_stats .mean (axis = 0 )
548+ if intervals :
549+ lower_p = (100.0 - percentile ) / 2.0
550+ upper_p = 100.0 - lower_p
551+ forecast_low = numpy .percentile (hist_for_stats , lower_p , axis = 0 )
552+ forecast_high = numpy .percentile (hist_for_stats , upper_p , axis = 0 )
553+ else :
554+ forecast_low = None
555+ forecast_high = None
556+ else :
557+ rates = numpy .asarray (forecast .magnitude_counts ())
558+ if len (rates ) != len (forecast_bins ):
559+ raise ValueError (
560+ "Length of forecast.magnitude_counts() must match number of forecast magnitude bins."
561+ )
562+ if normalize and n_obs != 0 :
563+ scale = n_obs / numpy .sum (rates )
564+ else :
565+ scale = 1.0
566+
567+ if cumulative :
568+ lam = numpy .cumsum (rates [::- 1 ])[::- 1 ]
569+ else :
570+ lam = rates
571+
572+ forecast_mean = lam * scale
573+ if intervals :
574+ alpha = (100.0 - percentile ) / 200.0
575+ low_counts = poisson .ppf (alpha , lam )
576+ high_counts = poisson .ppf (1.0 - alpha , lam )
577+ forecast_low = low_counts * scale
578+ forecast_high = high_counts * scale
579+ else :
580+ forecast_low = None
581+ forecast_high = None
582+
583+ # Compute statistics for the forecast histograms
497584 # Compute statistics for the forecast histograms
498- forecast_mean = numpy .mean (forecast_hist , axis = 0 )
499- forecast_median = numpy .median (forecast_hist , axis = 0 )
500- forecast_low = numpy .percentile (forecast_hist , (100 - percentile ) / 2.0 , axis = 0 )
501- forecast_high = numpy .percentile (forecast_hist , 100 - (100 - percentile ) / 2.0 , axis = 0 )
502- forecast_err_lower = forecast_median - forecast_low
503- forecast_err_upper = forecast_high - forecast_median
585+ if intervals :
586+ low = numpy .nan_to_num (forecast_low , nan = 0.0 )
587+ high = numpy .nan_to_num (forecast_high , nan = forecast_mean )
588+ low = numpy .minimum (low , forecast_mean )
589+ high = numpy .maximum (high , forecast_mean )
590+ forecast_err_lower = numpy .clip (forecast_mean - low , 0.0 , None )
591+ forecast_err_upper = numpy .clip (high - forecast_mean , 0.0 , None )
592+ else :
593+ forecast_err_lower = None
594+ forecast_err_upper = None
595+
596+ # cumulative transform for observation (after n_obs calculation)
597+ if cumulative and obs_counts is not None :
598+ obs_counts = numpy .cumsum (obs_counts [::- 1 ])[::- 1 ]
504599
505600 # Plot observed counts
506- ax .plot (
507- bin_centers ,
508- obs_hist ,
509- color = plot_args ["color" ],
510- marker = "o" ,
511- lw = 0 ,
512- markersize = plot_args ["markersize" ],
513- label = "Observation" ,
514- zorder = 3 ,
515- )
601+ if obs_counts is not None :
602+ ax .plot (
603+ obs_centers ,
604+ obs_counts ,
605+ color = plot_args ["color" ],
606+ marker = "o" ,
607+ lw = 0 ,
608+ markersize = plot_args ["markersize" ],
609+ label = "Observation" ,
610+ zorder = 3 ,
611+ )
516612 # Plot forecast histograms as bar plot with error bars
517613 ax .plot (
518- bin_centers ,
614+ forecast_centers ,
519615 forecast_mean ,
520616 "." ,
521617 markersize = plot_args ["markersize" ],
522618 color = "darkred" ,
523619 label = "Forecast Mean" ,
524620 )
525- ax .errorbar (
526- bin_centers ,
527- forecast_median ,
528- yerr = [forecast_err_lower , forecast_err_upper ],
529- fmt = "None" ,
530- color = "darkred" ,
531- markersize = plot_args ["markersize" ],
532- capsize = plot_args ["capsize" ],
533- linewidth = plot_args ["linewidth" ],
534- label = "Forecast (95% CI)" ,
535- )
621+
622+ if intervals :
623+ ax .errorbar (
624+ forecast_centers ,
625+ forecast_mean ,
626+ yerr = [forecast_err_lower , forecast_err_upper ],
627+ fmt = "None" ,
628+ color = "darkred" ,
629+ markersize = plot_args ["markersize" ],
630+ capsize = plot_args ["capsize" ],
631+ linewidth = plot_args ["linewidth" ],
632+ label = f"Forecast ({ percentile } % CI)" ,
633+ )
536634
537635 # Scale x-axis
538636 if plot_args ["xlim" ]:
539637 ax .set_xlim (plot_args ["xlim" ])
540638 else :
541- ax = _autoscale_histogram (
542- ax , magnitude_bins , numpy .hstack (forecast_mws ), obs_mw , mass = 100
543- )
639+ if observation is not None and hasattr (forecast , "catalogs" ):
640+ forecast_mws = [c .get_magnitudes () for c in forecast ]
641+ ax = _autoscale_histogram (
642+ ax ,
643+ forecast_bins ,
644+ numpy .hstack (forecast_mws ),
645+ observation .get_magnitudes (),
646+ mass = 100 ,
647+ )
544648 # Scale y-axis
545649 if log_scale :
546650 ax .set_yscale ('log' )
@@ -624,7 +728,7 @@ def plot_basemap(
624728 """
625729
626730 if "plot_args" in kwargs :
627- _warning_plot_args ("plot_magnitude_histogram " )
731+ _warning_plot_args ("plot_basemap " )
628732
629733 # Initialize plot
630734 plot_args = {** DEFAULT_PLOT_ARGS , ** kwargs }
@@ -767,7 +871,7 @@ def plot_catalog(
767871 the events sizing.
768872 """
769873 if "plot_args" in kwargs :
770- _warning_plot_args ("plot_magnitude_histogram " )
874+ _warning_plot_args ("plot_basemap " )
771875
772876 # Initialize plot
773877 plot_args = {** DEFAULT_PLOT_ARGS , ** kwargs .get ("plot_args" , {}), ** kwargs }
@@ -935,7 +1039,7 @@ def plot_gridded_dataset(
9351039 """
9361040
9371041 if "plot_args" in kwargs :
938- _warning_plot_args ("plot_magnitude_histogram " )
1042+ _warning_plot_args ("plot_gridded_dataset " )
9391043 # Initialize plot
9401044
9411045 plot_args = {** DEFAULT_PLOT_ARGS , ** kwargs .get ("plot_args" , {}), ** kwargs }
@@ -2691,6 +2795,3 @@ def _warning_plot_args(func_name: str):
26912795 DeprecationWarning ,
26922796 stacklevel = 2
26932797 )
2694-
2695-
2696-
0 commit comments