Skip to content

Commit beb65be

Browse files
authored
Merge pull request #281 from pabloitu/260-plot-refactoring
Improved plot_magnitude_histogram
2 parents 505c5c8 + 44d3bd7 commit beb65be

4 files changed

Lines changed: 443 additions & 89 deletions

File tree

csep/core/catalogs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,8 @@ def magnitude_counts(self, mag_bins=None, tol=None, retbins=False):
733733
else:
734734
return out
735735
idx = bin1d_vec(self.get_magnitudes(), mag_bins, tol=tol, right_continuous=True)
736-
numpy.add.at(out, idx, 1)
736+
valid = idx >= 0
737+
numpy.add.at(out, idx[valid], 1)
737738
if retbins:
738739
return (mag_bins, out)
739740
else:

csep/plots.py

Lines changed: 168 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -395,13 +395,15 @@ def plot_cumulative_events_versus_time(
395395
pyplot.show()
396396
return ax
397397

398-
399398
def plot_magnitude_histogram(
400-
catalog_forecast: Union["CatalogForecast", List["CSEPCatalog"]],
401-
observation: "CSEPCatalog",
399+
forecast: Union["CatalogForecast", "GriddedForecast"],
400+
observation: Optional["CSEPCatalog"] = None,
402401
magnitude_bins: Optional[Union[List[float], numpy.ndarray]] = None,
403402
percentile: int = 95,
404403
log_scale: bool = True,
404+
normalize: bool = True,
405+
cumulative: bool = False,
406+
intervals: bool = True,
405407
ax: Optional["matplotlib.axes.Axes"] = None,
406408
show: bool = False,
407409
**kwargs: Any,
@@ -416,7 +418,7 @@ def plot_magnitude_histogram(
416418
- :ref:`Catalog-based Forecast Plots<catalog-forecast-evaluation-exploratory>`
417419
418420
Args:
419-
catalog_forecast (CatalogForecast or list of CSEPCatalog): A catalog-based forecast
421+
forecast (CatalogForecast or list of CSEPCatalog or GriddedForecast): A forecast
420422
or a list of observed catalogs.
421423
observation (CSEPCatalog): The observed catalog for comparison.
422424
magnitude_bins (list of float or numpy.ndarray, optional): The bins for magnitude
@@ -426,6 +428,12 @@ def plot_magnitude_histogram(
426428
`95`.
427429
log_scale (bool, optional): Whether to plot the y-axis in logarithmic scale. Defaults to
428430
True.
431+
normalize (bool, optional): Whether to normalize the forecast for the total number in the
432+
observation catalog.
433+
cumulative (bool, optional): Whether to plot cumulative counts N(M >= m). Defaults to
434+
False.
435+
intervals (bool, optional): Whether to display forecast uncertainty intervals.
436+
Defaults to True.
429437
ax (matplotlib.axes.Axes, optional): The axes object to draw the plot on. If `None`, a
430438
new figure and axes are created. Defaults to `None`.
431439
show (bool, optional): Whether to display the plot immediately. Defaults to `False`.
@@ -453,11 +461,12 @@ def plot_magnitude_histogram(
453461
matplotlib.axes.Axes: The axes object containing the plot.
454462
455463
.. versionchanged:: 0.8.0
456-
It now requires a `CatalogForecast` rather than a list of stochastic event sets. The
457-
`plot_args` dictionary is only partially supported and will be removed in v1.0.0
458-
.. versionadded:: 0.8.0
459-
Added `magnitude_bins`, `percentile` and `log_scale` to fine-tune the plot.
460-
Added parameters to customize coloring, formatting and sizing of the plot elements.
464+
It now accepts a `CatalogForecast` or a `GriddedForecast`. An obervation `CSEPCatalog`
465+
is now optional. Added `magnitude_bins`, `percentile` and `log_scale` to fine-tune the
466+
plot. Added parameters to customize coloring, formatting and sizing of the plot
467+
elements. The `plot_args` dictionary is only partially supported and will be removed in
468+
v1.0.0
469+
461470
"""
462471
if "plot_args" in kwargs:
463472
_warning_plot_args("plot_magnitude_histogram")
@@ -467,80 +476,175 @@ def plot_magnitude_histogram(
467476
fig, ax = pyplot.subplots(figsize=plot_args["figsize"]) if ax is None else (ax.figure, ax)
468477

469478
# Get magnitudes from observations and (lazily) from forecast
470-
forecast_mws = list(map(lambda x: x.get_magnitudes(), catalog_forecast))
471-
obs_mw = observation.get_magnitudes()
472-
n_obs = observation.get_number_of_events()
473-
474-
# Get magnitude bins from args, if not from region, or lastly from standard CSEP bins.
475-
if magnitude_bins is None:
479+
if magnitude_bins is not None:
480+
forecast_bins = numpy.asarray(magnitude_bins)
481+
else:
476482
try:
477-
magnitude_bins = observation.region.magnitudes
483+
forecast_bins = getattr(forecast, "magnitudes")
478484
except AttributeError:
479-
magnitude_bins = CSEP_MW_BINS
485+
raise AttributeError("Forecast must be defined on a 'region', having "
486+
"'magnitudes' attribute as left-edge magnitude bins.")
487+
488+
dm_forecast = numpy.median(numpy.diff(forecast_bins))
489+
forecast_centers = forecast_bins + dm_forecast / 2.0
480490

481-
def get_histogram_synthetic_cat(x, mags, normed=True):
482-
n_temp = len(x)
483-
if normed and n_temp != 0:
484-
temp_scale = n_obs / n_temp
485-
hist = numpy.histogram(x, bins=mags)[0] * temp_scale
491+
if observation is not None:
492+
if magnitude_bins is not None:
493+
obs_bins, obs_counts = observation.magnitude_counts(
494+
mag_bins=magnitude_bins, retbins=True
495+
)
496+
elif hasattr(observation, "region") and hasattr(observation.region, "magnitudes"):
497+
obs_bins, obs_counts = observation.magnitude_counts(mag_bins=None, retbins=True)
498+
else:
499+
obs_bins, obs_counts = observation.magnitude_counts(
500+
mag_bins=CSEP_MW_BINS, retbins=True
501+
)
502+
nonzero_idx = numpy.nonzero(obs_counts)[0]
503+
if nonzero_idx.size > 0:
504+
first = nonzero_idx[0]
505+
obs_bins = obs_bins[first:]
506+
obs_counts = obs_counts[first:]
507+
508+
if len(obs_bins) > 1:
509+
dm_obs = numpy.median(numpy.diff(obs_bins))
486510
else:
487-
hist = numpy.histogram(x, bins=mags)[0]
488-
return hist
511+
dm_obs = dm_forecast
512+
obs_centers = obs_bins + dm_obs / 2.0
513+
idxs = numpy.where(obs_centers >= forecast_centers[0])[0]
514+
if idxs.size > 0:
515+
obs_index = idxs[0]
516+
n_obs = numpy.sum(obs_counts[obs_index:])
517+
else:
518+
n_obs = 0
519+
else:
520+
obs_counts = None
521+
obs_centers = None
522+
n_obs = 0
523+
524+
if hasattr(forecast, "catalogs"):
525+
forecast_mws = list(map(lambda x: x.get_magnitudes(), forecast))
526+
527+
def get_histogram_synthetic_cat(x, mags, normed_hist=normalize):
528+
n_syn_events = len(x)
529+
if normed_hist and n_syn_events != 0 and n_obs != 0:
530+
temp_scale = n_obs / n_syn_events
531+
hist = numpy.histogram(x, bins=mags)[0] * temp_scale
532+
else:
533+
hist = numpy.histogram(x, bins=mags)[0]
534+
return hist
489535

490-
# get histogram values
491-
forecast_hist = numpy.array(
492-
list(map(lambda x: get_histogram_synthetic_cat(x, magnitude_bins), forecast_mws))
493-
)
494-
obs_hist, bin_edges = numpy.histogram(obs_mw, bins=magnitude_bins)
495-
bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
536+
# get histogram values
537+
catalog_forecast_bins = numpy.append(forecast_bins, forecast_bins[-1] + dm_forecast)
538+
forecast_hist = numpy.array(
539+
list(map(lambda x: get_histogram_synthetic_cat(x, catalog_forecast_bins), forecast_mws))
540+
)
496541

542+
if cumulative:
543+
hist_for_stats = numpy.cumsum(forecast_hist[:, ::-1], axis=1)[:, ::-1]
544+
else:
545+
hist_for_stats = forecast_hist
546+
547+
forecast_mean = hist_for_stats.mean(axis=0)
548+
if intervals:
549+
lower_p = (100.0 - percentile) / 2.0
550+
upper_p = 100.0 - lower_p
551+
forecast_low = numpy.percentile(hist_for_stats, lower_p, axis=0)
552+
forecast_high = numpy.percentile(hist_for_stats, upper_p, axis=0)
553+
else:
554+
forecast_low = None
555+
forecast_high = None
556+
else:
557+
rates = numpy.asarray(forecast.magnitude_counts())
558+
if len(rates) != len(forecast_bins):
559+
raise ValueError(
560+
"Length of forecast.magnitude_counts() must match number of forecast magnitude bins."
561+
)
562+
if normalize and n_obs != 0:
563+
scale = n_obs / numpy.sum(rates)
564+
else:
565+
scale = 1.0
566+
567+
if cumulative:
568+
lam = numpy.cumsum(rates[::-1])[::-1]
569+
else:
570+
lam = rates
571+
572+
forecast_mean = lam * scale
573+
if intervals:
574+
alpha = (100.0 - percentile) / 200.0
575+
low_counts = poisson.ppf(alpha, lam)
576+
high_counts = poisson.ppf(1.0 - alpha, lam)
577+
forecast_low = low_counts * scale
578+
forecast_high = high_counts * scale
579+
else:
580+
forecast_low = None
581+
forecast_high = None
582+
583+
# Compute statistics for the forecast histograms
497584
# Compute statistics for the forecast histograms
498-
forecast_mean = numpy.mean(forecast_hist, axis=0)
499-
forecast_median = numpy.median(forecast_hist, axis=0)
500-
forecast_low = numpy.percentile(forecast_hist, (100 - percentile) / 2.0, axis=0)
501-
forecast_high = numpy.percentile(forecast_hist, 100 - (100 - percentile) / 2.0, axis=0)
502-
forecast_err_lower = forecast_median - forecast_low
503-
forecast_err_upper = forecast_high - forecast_median
585+
if intervals:
586+
low = numpy.nan_to_num(forecast_low, nan=0.0)
587+
high = numpy.nan_to_num(forecast_high, nan=forecast_mean)
588+
low = numpy.minimum(low, forecast_mean)
589+
high = numpy.maximum(high, forecast_mean)
590+
forecast_err_lower = numpy.clip(forecast_mean - low, 0.0, None)
591+
forecast_err_upper = numpy.clip(high - forecast_mean, 0.0, None)
592+
else:
593+
forecast_err_lower = None
594+
forecast_err_upper = None
595+
596+
# cumulative transform for observation (after n_obs calculation)
597+
if cumulative and obs_counts is not None:
598+
obs_counts = numpy.cumsum(obs_counts[::-1])[::-1]
504599

505600
# Plot observed counts
506-
ax.plot(
507-
bin_centers,
508-
obs_hist,
509-
color=plot_args["color"],
510-
marker="o",
511-
lw=0,
512-
markersize=plot_args["markersize"],
513-
label="Observation",
514-
zorder=3,
515-
)
601+
if obs_counts is not None:
602+
ax.plot(
603+
obs_centers,
604+
obs_counts,
605+
color=plot_args["color"],
606+
marker="o",
607+
lw=0,
608+
markersize=plot_args["markersize"],
609+
label="Observation",
610+
zorder=3,
611+
)
516612
# Plot forecast histograms as bar plot with error bars
517613
ax.plot(
518-
bin_centers,
614+
forecast_centers,
519615
forecast_mean,
520616
".",
521617
markersize=plot_args["markersize"],
522618
color="darkred",
523619
label="Forecast Mean",
524620
)
525-
ax.errorbar(
526-
bin_centers,
527-
forecast_median,
528-
yerr=[forecast_err_lower, forecast_err_upper],
529-
fmt="None",
530-
color="darkred",
531-
markersize=plot_args["markersize"],
532-
capsize=plot_args["capsize"],
533-
linewidth=plot_args["linewidth"],
534-
label="Forecast (95% CI)",
535-
)
621+
622+
if intervals:
623+
ax.errorbar(
624+
forecast_centers,
625+
forecast_mean,
626+
yerr=[forecast_err_lower, forecast_err_upper],
627+
fmt="None",
628+
color="darkred",
629+
markersize=plot_args["markersize"],
630+
capsize=plot_args["capsize"],
631+
linewidth=plot_args["linewidth"],
632+
label=f"Forecast ({percentile}% CI)",
633+
)
536634

537635
# Scale x-axis
538636
if plot_args["xlim"]:
539637
ax.set_xlim(plot_args["xlim"])
540638
else:
541-
ax = _autoscale_histogram(
542-
ax, magnitude_bins, numpy.hstack(forecast_mws), obs_mw, mass=100
543-
)
639+
if observation is not None and hasattr(forecast, "catalogs"):
640+
forecast_mws = [c.get_magnitudes() for c in forecast]
641+
ax = _autoscale_histogram(
642+
ax,
643+
forecast_bins,
644+
numpy.hstack(forecast_mws),
645+
observation.get_magnitudes(),
646+
mass=100,
647+
)
544648
# Scale y-axis
545649
if log_scale:
546650
ax.set_yscale('log')
@@ -624,7 +728,7 @@ def plot_basemap(
624728
"""
625729

626730
if "plot_args" in kwargs:
627-
_warning_plot_args("plot_magnitude_histogram")
731+
_warning_plot_args("plot_basemap")
628732

629733
# Initialize plot
630734
plot_args = {**DEFAULT_PLOT_ARGS, **kwargs}
@@ -767,7 +871,7 @@ def plot_catalog(
767871
the events sizing.
768872
"""
769873
if "plot_args" in kwargs:
770-
_warning_plot_args("plot_magnitude_histogram")
874+
_warning_plot_args("plot_basemap")
771875

772876
# Initialize plot
773877
plot_args = {**DEFAULT_PLOT_ARGS, **kwargs.get("plot_args", {}), **kwargs}
@@ -935,7 +1039,7 @@ def plot_gridded_dataset(
9351039
"""
9361040

9371041
if "plot_args" in kwargs:
938-
_warning_plot_args("plot_magnitude_histogram")
1042+
_warning_plot_args("plot_gridded_dataset")
9391043
# Initialize plot
9401044

9411045
plot_args = {**DEFAULT_PLOT_ARGS, **kwargs.get("plot_args", {}), **kwargs}
@@ -2691,6 +2795,3 @@ def _warning_plot_args(func_name: str):
26912795
DeprecationWarning,
26922796
stacklevel=2
26932797
)
2694-
2695-
2696-

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103
"pandas": ("http://pandas.pydata.org/pandas-docs/stable/", None),
104104
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
105105
"matplotlib": ("https://matplotlib.org/stable", None),
106-
"cartopy": ('https://scitools.org.uk/cartopy/docs/latest/', None)
106+
"cartopy": ('https://cartopy.readthedocs.io/stable/', None)
107107
}
108108

109109
html_theme_options = {}

0 commit comments

Comments
 (0)