Skip to content

Commit 02f3256

Browse files
authored
Merge pull request #69 from cseptesting/68-improve-report-module
Improved report module
2 parents 3a93e6f + b59eebf commit 02f3256

File tree

23 files changed

+472
-151
lines changed

23 files changed

+472
-151
lines changed

floatcsep/experiment.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,6 @@ def set_tasks(self) -> None:
394394
tstring=time_i,
395395
models=self.models,
396396
)
397-
print("000")
398397
task_graph.add(task=task_j)
399398

400399
# Set up the Forecasts creation

floatcsep/infrastructure/repositories.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -205,29 +205,26 @@ def set_test_cats(self, tstring: str, fmt: str = "json") -> None:
205205
"""
206206

207207
test_cat_name = self.registry.get_test_catalog_key(tstring)
208-
if not exists(test_cat_name):
209-
log.debug(
210-
f"[Catalogs] Filtering testing catalog and saving to "
211-
f"{self.registry.rel(test_cat_name)}"
212-
)
213-
start, end = str2timewindow(tstring)
214-
sub_cat = self.catalog.filter(
215-
[
216-
f"origin_time < {end.timestamp() * 1000}",
217-
f"origin_time >= {start.timestamp() * 1000}",
218-
f"magnitude >= {self.mag_min}",
219-
f"magnitude < {self.mag_max}",
220-
],
221-
in_place=False,
222-
)
223-
if self.region:
224-
sub_cat.filter_spatial(region=self.region, in_place=True)
225208

226-
writer = getattr(CatalogSerializer, fmt)
227-
writer(catalog=sub_cat, filename=test_cat_name)
209+
log.debug(
210+
f"[Catalogs] Filtering testing catalog and saving to "
211+
f"{self.registry.rel(test_cat_name)}"
212+
)
213+
start, end = str2timewindow(tstring)
214+
sub_cat = self.catalog.filter(
215+
[
216+
f"origin_time < {end.timestamp() * 1000}",
217+
f"origin_time >= {start.timestamp() * 1000}",
218+
f"magnitude >= {self.mag_min}",
219+
f"magnitude < {self.mag_max}",
220+
],
221+
in_place=False,
222+
)
223+
if self.region:
224+
sub_cat.filter_spatial(region=self.region, in_place=True)
228225

229-
else:
230-
log.debug(f"[Catalogs] Using test catalog from {self.registry.rel(test_cat_name)}")
226+
writer = getattr(CatalogSerializer, fmt)
227+
writer(catalog=sub_cat, filename=test_cat_name)
231228

232229
def filter_catalog(
233230
self,
@@ -344,26 +341,28 @@ def __init__(self, registry: ModelFileRegistry, **kwargs):
344341
def load_forecast(
345342
self,
346343
tstring: Union[str, list],
344+
name=None,
347345
region=None,
348346
n_sims=None,
349347
) -> Union[CatalogForecast, list[CatalogForecast]]:
350348
"""
351349
Returns a forecast object or a sequence of them for a set of time window strings.
352350
353351
Args:
354-
tstring (str, list): String representing the time-window
352+
tstring (str, list): String representing the time-window.
353+
name (str): Name of the forecast model.
355354
region (optional): A region, in case the forecast requires to be filtered lazily.
356-
n_sims (optional: The number of simulations/synthetic catalogs of the forecast
355+
n_sims (optional: The number of simulations/synthetic catalogs of the forecast.
357356
358357
Returns:
359358
The CSEP CatalogForecast object or a list of them.
360359
"""
361360
if isinstance(tstring, str):
362-
return self._load_single_forecast(tstring, region=region, n_sims=n_sims)
361+
return self._load_single_forecast(tstring, name=name, region=region, n_sims=n_sims)
363362
else:
364363
return [self._load_single_forecast(t, region) for t in tstring]
365364

366-
def _load_single_forecast(self, tstring: str, region=None, n_sims=None):
365+
def _load_single_forecast(self, tstring: str, name=None, region=None, n_sims=None):
367366
start_date, end_date = str2timewindow(tstring)
368367

369368
fc_path = self.registry.get_forecast_key(tstring)
@@ -372,6 +371,7 @@ def _load_single_forecast(self, tstring: str, region=None, n_sims=None):
372371

373372
forecast_ = f_parser(
374373
fc_path,
374+
name=name,
375375
start_time=start_date,
376376
end_time=end_date,
377377
n_cat=n_sims,

floatcsep/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def get_forecast(
364364
If None, will return a forecast for all regions.
365365
366366
"""
367-
return self.repository.load_forecast(tstring, region=region)
367+
return self.repository.load_forecast(tstring, name=self.name, region=region)
368368

369369
def create_forecast(self, tstring: str, **kwargs) -> None:
370370
"""

floatcsep/postprocess/plot_handler.py

Lines changed: 72 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -85,29 +85,42 @@ def plot_forecasts(experiment: "Experiment") -> None:
8585
plot_forecast_config["projection"]: ccrs.Projection = parse_projection(
8686
plot_forecast_config.get("projection")
8787
)
88+
plot_forecast_config["title"] = None
8889

8990
for model in experiment.models:
9091
for window in time_windows:
9192
forecast = model.get_forecast(window, region=experiment.region)
9293
ax = forecast.plot(plot_args=plot_forecast_config)
9394

94-
# If catalog option is passed, catalog is plotted on top of the forecast
9595
if plot_forecast_config.get("catalog"):
9696
cat_args = plot_forecast_config.get("catalog", {})
9797
if cat_args is True:
9898
cat_args = {}
99-
experiment.catalog_repo.get_test_cat(window).plot(
99+
overlay_args = {
100+
**cat_args,
101+
"basemap": plot_forecast_config.get("basemap", None),
102+
}
103+
ax = experiment.catalog_repo.get_test_cat(window).plot(
100104
ax=ax,
101105
extent=ax.get_extent(),
102-
plot_args=cat_args.update(
103-
{
104-
"basemap": plot_forecast_config.get("basemap", None),
105-
"title": ax.get_title(),
106-
}
107-
),
106+
plot_args=overlay_args,
108107
)
109-
fig_path = experiment.registry.get_figure_key(window, "forecasts", model.name)
110-
pyplot.savefig(fig_path, dpi=plot_forecast_config.get("dpi", 300))
108+
109+
fig = ax.get_figure()
110+
fig.canvas.draw()
111+
112+
dpi = plot_forecast_config.get("dpi", 300)
113+
png_path = (
114+
experiment.registry.get_figure_key(window, "forecasts", model.name) + ".png"
115+
)
116+
fig.savefig(
117+
png_path,
118+
dpi=dpi,
119+
bbox_inches="tight",
120+
pad_inches=0.02,
121+
facecolor="white",
122+
)
123+
pyplot.close(fig)
111124

112125

113126
def plot_catalogs(experiment: "Experiment") -> None:
@@ -171,33 +184,73 @@ def plot_catalogs(experiment: "Experiment") -> None:
171184
if test_catalog.get_number_of_events() == 0:
172185
log.debug(f"Catalog has zero events in {experiment_timewindow}")
173186
return
187+
dpi = plot_catalog_config.get("dpi", 300)
174188

175189
# Plot catalog map
176190
ax = test_catalog.plot(plot_args=plot_catalog_config)
191+
fig = ax.get_figure()
192+
fig.canvas.draw()
177193
cat_map_path = experiment.registry.get_figure_key("main_catalog_map")
178-
ax.get_figure().savefig(cat_map_path, dpi=plot_catalog_config.get("dpi", 300))
194+
fig.savefig(
195+
cat_map_path,
196+
dpi=dpi,
197+
bbox_inches="tight", # <— trim outer margins
198+
pad_inches=0.02, # <— tiny padding to avoid clipping
199+
facecolor="white",
200+
)
201+
pyplot.close(fig)
179202

180203
# Plot catalog time series vs. magnitude
181204
ax = magnitude_vs_time(test_catalog)
205+
fig = ax.get_figure()
206+
fig.canvas.draw()
182207
cat_time_path = experiment.registry.get_figure_key("main_catalog_time")
183-
ax.get_figure().savefig(cat_time_path, dpi=plot_catalog_config.get("dpi", 300))
208+
fig.savefig(
209+
cat_time_path,
210+
dpi=dpi,
211+
bbox_inches="tight",
212+
pad_inches=0.02,
213+
facecolor="white",
214+
)
215+
pyplot.close(fig)
184216

185217
# If selected, plot the test catalogs for each of the time windows
186218
if plot_catalog_config.get("all_time_windows"):
187219
for tw in experiment.time_windows:
188-
test_catalog = experiment.catalog_repo.get_test_cat(timewindow2str(tw))
220+
tw_str = timewindow2str(tw)
221+
test_catalog = experiment.catalog_repo.get_test_cat(tw_str)
189222

190-
if test_catalog.get_number_of_events() != 0:
191-
log.debug(f"Catalog has zero events in {tw}. Skip plotting")
223+
if test_catalog.get_number_of_events() == 0:
224+
log.debug(f"Catalog has zero events in {tw_str}. Skip plotting")
192225
continue
193226

227+
# Map
194228
ax = test_catalog.plot(plot_args=plot_catalog_config)
195-
cat_map_path = experiment.registry.get_figure_key(tw, "catalog_map")
196-
ax.get_figure().savefig(cat_map_path, dpi=plot_catalog_config.get("dpi", 300))
229+
fig = ax.get_figure()
230+
fig.canvas.draw()
231+
cat_map_path = experiment.registry.get_figure_key(tw_str, "catalog_map") + ".png"
232+
fig.savefig(
233+
cat_map_path,
234+
dpi=dpi,
235+
bbox_inches="tight",
236+
pad_inches=0.02,
237+
facecolor="white",
238+
)
239+
pyplot.close(fig)
197240

241+
# Time series
198242
ax = magnitude_vs_time(test_catalog)
199-
cat_time_path = experiment.registry.get_figure_key(tw, "catalog_time")
200-
ax.get_figure().savefig(cat_time_path, dpi=plot_catalog_config.get("dpi", 300))
243+
fig = ax.get_figure()
244+
fig.canvas.draw()
245+
cat_time_path = experiment.registry.get_figure_key(tw_str, "catalog_time") + ".png"
246+
fig.savefig(
247+
cat_time_path,
248+
dpi=dpi,
249+
bbox_inches="tight",
250+
pad_inches=0.02,
251+
facecolor="white",
252+
)
253+
pyplot.close(fig)
201254

202255

203256
def plot_custom(experiment: "Experiment"):

0 commit comments

Comments
 (0)